Commit 44ec8bed authored by lishen's avatar lishen
Browse files

支持更复杂的量化,包括fp8/int8/ue8m0,且支持per-group/per-channel

parent 81e56124
...@@ -136,7 +136,7 @@ struct LowLatencyLayout { ...@@ -136,7 +136,7 @@ struct LowLatencyLayout {
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) { int num_ranks, int num_experts) {
const int num_scales = hidden / FP8_QUANTIZATION_NUM_PER_CHANNEL; const int num_scales = hidden / QUANTIZATION_GROUPSIZE;
// Dispatch and combine layout: // Dispatch and combine layout:
// - 2 symmetric odd/even send buffer // - 2 symmetric odd/even send buffer
......
...@@ -1293,7 +1293,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int ...@@ -1293,7 +1293,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8, int quant_type, int quant_group_size, bool fp8_round_scale,
bool async, bool return_recv_hook) { bool async, bool return_recv_hook) {
EP_HOST_ASSERT(low_latency_mode); EP_HOST_ASSERT(low_latency_mode);
...@@ -1327,8 +1327,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1327,8 +1327,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
stream_wait(launch_stream, compute_stream); stream_wait(launch_stream, compute_stream);
// Allocate packed tensors // Allocate packed tensors
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, auto packed_recv_x_dtype = torch::kBFloat16;
x.options().dtype(use_int8 ? torch::kInt8 : use_fp8 ? torch::kFloat8_e4m3fnuz: torch::kBFloat16)); switch (quant_type) {
case 1: packed_recv_x_dtype = torch::kInt8; break;
case 2: packed_recv_x_dtype = torch::kFloat8_e4m3fnuz; break;
case 3: packed_recv_x_dtype = torch::kFloat8_e4m3fnuz; break;
case 4: packed_recv_x_dtype = torch::kFloat8_e5m2fnuz; break;
}
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(packed_recv_x_dtype));
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
...@@ -1336,21 +1343,28 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1336,21 +1343,28 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Allocate column-majored scales // Allocate column-majored scales
auto packed_recv_x_scales = std::optional<torch::Tensor>(); auto packed_recv_x_scales = std::optional<torch::Tensor>();
void* packed_recv_x_scales_ptr = nullptr; void* packed_recv_x_scales_ptr = nullptr;
if (use_fp8) { if (quant_type > 0) {
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
// TODO: support unaligned cases // TODO: support unaligned cases
EP_HOST_ASSERT(hidden % (FP8_QUANTIZATION_NUM_PER_CHANNEL * 4) == 0); EP_HOST_ASSERT(hidden % (QUANTIZATION_GROUPSIZE * 4) == 0);
EP_HOST_ASSERT(!(use_ue8m0 && use_int8));
// 计算scale_col的大小
int scales_col_size = 1; // 默认为per-channel
if (quant_group_size > 0) {
if (quant_type == 3) { // FP8_UE8M0比较特殊
scales_col_size = hidden / (QUANTIZATION_GROUPSIZE * 4);
} else {
scales_col_size = hidden / QUANTIZATION_GROUPSIZE;
}
}
if (use_ue8m0) { // 设置packed_recv_x_scales
EP_HOST_ASSERT(round_scale); if (quant_type == 3) { // FP8_UE8M0比较特殊,需要单独处理
packed_recv_x_scales = torch::empty({num_local_experts, hidden / (FP8_QUANTIZATION_NUM_PER_CHANNEL * 4), num_ranks * num_max_dispatch_tokens_per_rank}, EP_HOST_ASSERT(fp8_round_scale && quant_group_size == 128);
packed_recv_x_scales = torch::empty({num_local_experts, scales_col_size, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kInt).device(torch::kCUDA)); torch::dtype(torch::kInt).device(torch::kCUDA));
} else if (use_int8) {
packed_recv_x_scales = torch::empty({num_local_experts, 1, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
} else { } else {
packed_recv_x_scales = torch::empty({num_local_experts, hidden / FP8_QUANTIZATION_NUM_PER_CHANNEL, num_ranks * num_max_dispatch_tokens_per_rank}, packed_recv_x_scales = torch::empty({num_local_experts, scales_col_size, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA)); torch::dtype(torch::kFloat32).device(torch::kCUDA));
} }
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
...@@ -1370,7 +1384,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1370,7 +1384,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
next_clean_meta.first, next_clean_meta.second, next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank, num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks, num_topk, num_experts, rank, num_ranks,
use_fp8, round_scale, use_ue8m0, use_int8, quant_type, quant_group_size, fp8_round_scale,
workspace, num_device_sms, launch_stream, phases); workspace, num_device_sms, launch_stream, phases);
}; };
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
......
...@@ -177,7 +177,7 @@ public: ...@@ -177,7 +177,7 @@ public:
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8, int quant_type, int quant_group_size, bool fp8_round_scale,
bool async, bool return_recv_hook); bool async, bool return_recv_hook);
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
......
...@@ -147,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -147,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8, int quant_type, int group_size, bool fp8_round_scale,
void* workspace, int num_device_sms, hipStream_t stream, int phases); void* workspace, int num_device_sms, hipStream_t stream, int phases);
void combine(void* combined_x, void combine(void* combined_x,
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#define LOW_LATENCY_SEND_PHASE 1 #define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2 #define LOW_LATENCY_RECV_PHASE 2
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3 #define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define FP8_QUANTIZATION_NUM_PER_CHANNEL 128 #define QUANTIZATION_GROUPSIZE 128
#define DEFAULT_NUM_CU 20 #define DEFAULT_NUM_CU 20
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6 #define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6
......
This diff is collapsed.
...@@ -341,10 +341,14 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) { ...@@ -341,10 +341,14 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return *reinterpret_cast<dtype_t *>(recv_int_values); return *reinterpret_cast<dtype_t *>(recv_int_values);
} }
constexpr float kFP8Margin = 1e-4; // 设置不同的量化方式的最大值与相反数
constexpr float kFP8Margin = 0.0;
constexpr float kFinfoAmaxE4M3 = 240.0f; constexpr float kFinfoAmaxE4M3 = 240.0f;
constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3; constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3;
constexpr float kInt8Amax = 127.0f; constexpr float kFinfoAmaxE5M2 = 57344.0f;
constexpr float kFinfoAmaxInvE5M2 = 1.0f / kFinfoAmaxE5M2;
constexpr float kFinfoAmaxInt8 = 127.0f;
constexpr float kFinfoAmaxInvInt8 = 1.0f / 127.0f;
__forceinline__ __device__ float fast_pow2(int x) { __forceinline__ __device__ float fast_pow2(int x) {
// We can ensure `-126 <= x and x <= 127` // We can ensure `-126 <= x and x <= 127`
...@@ -359,22 +363,33 @@ __forceinline__ __device__ int fast_log2_ceil(float x) { ...@@ -359,22 +363,33 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
return exp_x - 127 + (man_bits != 0); return exp_x - 127 + (man_bits != 0);
} }
__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv, bool round_scale) { template <int kQuantType>
if (round_scale) { __forceinline__ __device__ void calculate_quant8bit_scales(float amax, float& scale, float& scale_inv, bool round_scale=0) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3); amax = fmaxf(amax, 1e-6f);
scale = fast_pow2(-exp_scale_inv); if constexpr(kQuantType == 1) { // 使用 INT8 对称量化
scale_inv = fast_pow2(exp_scale_inv); scale_inv = kFinfoAmaxInvInt8 * amax;
} else { scale = kFinfoAmaxInt8 / amax;
scale_inv = amax * kFinfoAmaxInvE4M3; } else if constexpr(kQuantType == 2 || kQuantType == 3) { // 使用 FP8_E4M3 或 FP8_UE8M0 非对称量化
scale = kFinfoAmaxE4M3 / amax; if (round_scale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3);
scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv);
} else {
scale_inv = amax * kFinfoAmaxInvE4M3;
scale = kFinfoAmaxE4M3 / amax;
}
} else if constexpr(kQuantType == 4) { // 使用 FP8_E5M2 对称量化
if (round_scale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE5M2);
scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv);
} else {
scale_inv = amax * kFinfoAmaxInvE5M2;
scale = kFinfoAmaxE5M2 / amax;
}
} }
} }
__forceinline__ __device__ void calculate_int8_scales(float amax, float& scale, float& scale_inv) {
scale = kInt8Amax / amax;
scale_inv = amax / kInt8Amax;
}
template <bool kIsUE8M0, typename out_dtype_t = std::conditional_t<kIsUE8M0, uint8_t, float>> template <bool kIsUE8M0, typename out_dtype_t = std::conditional_t<kIsUE8M0, uint8_t, float>>
__forceinline__ __device__ out_dtype_t extract_required_scale_format(float value) { __forceinline__ __device__ out_dtype_t extract_required_scale_format(float value) {
if constexpr (kIsUE8M0) { if constexpr (kIsUE8M0) {
......
...@@ -841,7 +841,7 @@ class Buffer: ...@@ -841,7 +841,7 @@ class Buffer:
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor, def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int, num_experts: int, num_max_dispatch_tokens_per_rank: int, num_experts: int,
use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False, use_int8: bool = False, quant_type: int = 1, quant_group_size: int = 0, fp8_round_scale: bool = False,
async_finish: bool = False, return_recv_hook: bool = False) -> \ async_finish: bool = False, return_recv_hook: bool = False) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]: Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
""" """
...@@ -858,10 +858,20 @@ class Buffer: ...@@ -858,10 +858,20 @@ class Buffer:
only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported. only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value. num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts. num_experts: the number of all experts.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors. 量化配置
round_scale: whether round the scaling factors into power of 2. quant_type: int 量化类型枚举
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`). 0 -> None 不量化,保持原始精度
use_int8: whether to enable INT8 casting. 1 -> Int8 使用 INT8 对称量化
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3_FNUZ)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2_FNUZ)
quant_group_size: int 量化分组大小
0 -> 逐token量化 (per-channel)
128-> 每 128 元素一组 (per-group) 量化
fp8_round_scale: bool 是否将 FP8 缩放因子取整为 2 的幂
true -> 缩放因子 = 2^k,硬件零开销
false -> 缩放因子 = 任意浮点,精度更高
异步配置
async_finish: the current stream will not wait for the communication kernels to be finished if set. async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
...@@ -869,15 +879,25 @@ class Buffer: ...@@ -869,15 +879,25 @@ class Buffer:
Returns: Returns:
recv_x: a tensor or tuple with received tokens for each expert. recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as - packed_recv_x:
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`. 存储接收到的 Token 数据,形状为
The second tensor is the corresponding scales for the first element with shape `[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]`。
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`, 数据类型取决于 quant_type:
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as quant_type == 1 -> torch.int8
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`. quant_type == 2 -> torch.float8_e4m3fnuz
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility. quant_type == 3 -> torch.float8_e4m3fnuz (UE8M0 使用 E4M3 格式存储)
With `use_fp8=False`, the result would be a tensor shaped as quant_type == 4 -> torch.float8_e5m2fnuz
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`. 其他 (非量化) -> torch.bfloat16
- packed_recv_x_scales (可选):
仅在 quant_type > 0 时存在,存储量化的 Scale 值。
形状为 `[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, scales_col_size]`。
- 当 quant_type == 3 (UE8M0) 时:
scales_col_size = hidden // 512
数据类型为 torch.int (内部打包存储 4-bit scale)。
*注意:此模式强制要求 fp8_round_scale=True 且 group_size=128。
- 当 quant_type == 1, 2, 4 时:
scales_col_size = hidden // 128 (若使用 group_size) 或 1 (per-channel)。
数据类型为 torch.float32。
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are, Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced). as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
...@@ -889,14 +909,15 @@ class Buffer: ...@@ -889,14 +909,15 @@ class Buffer:
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \ packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
self.runtime.low_latency_dispatch(x, topk_idx, self.runtime.low_latency_dispatch(x, topk_idx,
num_max_dispatch_tokens_per_rank, num_experts, num_max_dispatch_tokens_per_rank, num_experts,
use_fp8, round_scale, use_ue8m0, use_int8, quant_type, quant_group_size, fp8_round_scale,
async_finish, return_recv_hook) async_finish, return_recv_hook)
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts) handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
tensors_to_record = (x, topk_idx, tensors_to_record = (x, topk_idx,
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_x, packed_recv_x_scales, packed_recv_count,
packed_recv_src_info, packed_recv_layout_range) packed_recv_src_info, packed_recv_layout_range)
return (packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x, packed_recv_count, handle, \
EventOverlap(event, tensors_to_record if async_finish else None), hook recv_x = (packed_recv_x, packed_recv_x_scales) if (quant_type > 0) else packed_recv_x
return recv_x, packed_recv_count, handle, EventOverlap(event, tensors_to_record if async_finish else None), hook
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
......
...@@ -58,7 +58,8 @@ def test_main(num_tokens: int, ...@@ -58,7 +58,8 @@ def test_main(num_tokens: int,
x_list = [x] x_list = [x]
# # NOTES: the last one is for performance testing # # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels # # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1) # x_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1
# x_list = [x_rand]
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
...@@ -80,14 +81,19 @@ def test_main(num_tokens: int, ...@@ -80,14 +81,19 @@ def test_main(num_tokens: int,
hash_value, num_times = 0, 0 hash_value, num_times = 0, 0
for current_x in x_list: for current_x in x_list:
for return_recv_hook in (False, True): for return_recv_hook in (False, True):
for dispatch_use_fp8 in (False, True): for quant_type in (0, 2, 3, ): # 0: 不量化, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True)
for round_scale in (False, True) if dispatch_use_fp8 else (False,): dispatch_use_fp8 = quant_type > 0
for use_ue8m0 in (False, True) if round_scale else (False,): for fp8_round_scale in (False, True) if dispatch_use_fp8 else (False, ):
for quant_group_size in (128, ):
# 跳过不支持的情况
if quant_type == 3 and fp8_round_scale == False:
continue
num_times += 1 num_times += 1
for _ in range((num_times % 2) + 1): for _ in range((num_times % 2) + 1):
packed_recv_x, packed_recv_count, handle, event, hook = \ packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0, quant_type=quant_type, fp8_round_scale=fp8_round_scale, quant_group_size=quant_group_size,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait() hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
...@@ -115,13 +121,13 @@ def test_main(num_tokens: int, ...@@ -115,13 +121,13 @@ def test_main(num_tokens: int,
recv_x_amin = recv_x[:, :-128].amin(dim=-1) recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens] recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1)) assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
if round_scale: if fp8_round_scale:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007 assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
else: else:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0 assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
for j in range(num_ranks): for j in range(num_ranks):
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item() begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
if not round_scale: if not fp8_round_scale:
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item() assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0 assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_fp8: if dispatch_use_fp8:
...@@ -147,7 +153,7 @@ def test_main(num_tokens: int, ...@@ -147,7 +153,7 @@ def test_main(num_tokens: int,
if do_check: if do_check:
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0 assert torch.isnan(combined_x).sum().item() == 0
# if not round_scale: # if not fp8_round_scale:
assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: diff={diff}, dispatch_use_fp8={dispatch_use_fp8}, zero_copy={zero_copy}' assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: diff={diff}, dispatch_use_fp8={dispatch_use_fp8}, zero_copy={zero_copy}'
hash_value ^= hash_tensor(combined_x) hash_value ^= hash_tensor(combined_x)
...@@ -162,7 +168,8 @@ def test_main(num_tokens: int, ...@@ -162,7 +168,8 @@ def test_main(num_tokens: int,
def test_func(return_recv_hook: bool): def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \ recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook) quant_type=2, quant_group_size=128,
async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None large_gemm_with_hook(hook) if return_recv_hook else None
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx, topk_idx,
......
...@@ -54,16 +54,16 @@ def test_main(num_tokens: int, ...@@ -54,16 +54,16 @@ def test_main(num_tokens: int,
do_check = True do_check = True
hash_value, num_times = 0, 0 hash_value, num_times = 0, 0
for current_x in x_list: for current_x in x_list:
for return_recv_hook in (False, ): for return_recv_hook in (False, True):
for dispatch_use_fp8 in (True, ): for quant_type in (1, ):
for round_scale in (False, ): for fp8_round_scale in (False, ):
for use_ue8m0 in (False, ): for quant_group_size in (0, ):
dispatch_use_fp8 = quant_type > 0
num_times += 1 num_times += 1
use_int8 = True
for _ in range(1): for _ in range(1):
packed_recv_x, packed_recv_count, handle, event, hook = \ packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0, use_int8=use_int8, quant_type=quant_type, quant_group_size=quant_group_size,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait() hook() if return_recv_hook else event.current_stream_wait()
...@@ -97,9 +97,7 @@ def test_main(num_tokens: int, ...@@ -97,9 +97,7 @@ def test_main(num_tokens: int,
assert torch.equal(recv_x_amin, recv_x_amax) assert torch.equal(recv_x_amin, recv_x_amax)
if round_scale: if quant_type == 1:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
elif use_int8:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.01 assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.01
else: else:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0 assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
...@@ -131,7 +129,7 @@ def test_main(num_tokens: int, ...@@ -131,7 +129,7 @@ def test_main(num_tokens: int,
def test_func(return_recv_hook: bool): def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \ recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=True, round_scale=False, use_ue8m0=False, use_int8=True, quant_type=1, quant_group_size=0,
async_finish=False, return_recv_hook=return_recv_hook) async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None large_gemm_with_hook(hook) if return_recv_hook else None
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment