Commit d0fcf024 authored by lishen's avatar lishen
Browse files

Merge branch 'quant_master' into 'main'

And quant.

See merge request dcutoolkit/deeplearing/DeepEP!19
parents 81e56124 ace6e18e
...@@ -136,7 +136,7 @@ struct LowLatencyLayout { ...@@ -136,7 +136,7 @@ struct LowLatencyLayout {
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) { int num_ranks, int num_experts) {
const int num_scales = hidden / FP8_QUANTIZATION_NUM_PER_CHANNEL; const int num_scales = hidden / QUANTIZATION_GROUPSIZE;
// Dispatch and combine layout: // Dispatch and combine layout:
// - 2 symmetric odd/even send buffer // - 2 symmetric odd/even send buffer
......
...@@ -1293,7 +1293,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int ...@@ -1293,7 +1293,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8, int quant_type, int quant_group_size, bool fp8_round_scale,
bool async, bool return_recv_hook) { bool async, bool return_recv_hook) {
EP_HOST_ASSERT(low_latency_mode); EP_HOST_ASSERT(low_latency_mode);
...@@ -1327,8 +1327,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1327,8 +1327,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
stream_wait(launch_stream, compute_stream); stream_wait(launch_stream, compute_stream);
// Allocate packed tensors // Allocate packed tensors
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, auto packed_recv_x_dtype = torch::kBFloat16;
x.options().dtype(use_int8 ? torch::kInt8 : use_fp8 ? torch::kFloat8_e4m3fnuz: torch::kBFloat16)); switch (quant_type) {
case 1: packed_recv_x_dtype = torch::kInt8; break;
case 2: packed_recv_x_dtype = torch::kFloat8_e4m3fn; break;
case 3: packed_recv_x_dtype = torch::kFloat8_e4m3fn; break;
case 4: packed_recv_x_dtype = torch::kFloat8_e5m2; break;
}
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(packed_recv_x_dtype));
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
...@@ -1336,21 +1343,28 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1336,21 +1343,28 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Allocate column-majored scales // Allocate column-majored scales
auto packed_recv_x_scales = std::optional<torch::Tensor>(); auto packed_recv_x_scales = std::optional<torch::Tensor>();
void* packed_recv_x_scales_ptr = nullptr; void* packed_recv_x_scales_ptr = nullptr;
if (use_fp8) { if (quant_type > 0) {
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
// TODO: support unaligned cases // TODO: support unaligned cases
EP_HOST_ASSERT(hidden % (FP8_QUANTIZATION_NUM_PER_CHANNEL * 4) == 0); EP_HOST_ASSERT(hidden % (QUANTIZATION_GROUPSIZE * 4) == 0);
EP_HOST_ASSERT(!(use_ue8m0 && use_int8));
// 计算scale_col的大小
int scales_col_size = 1; // 默认为per-channel
if (quant_group_size > 0) {
if (quant_type == 3) { // FP8_UE8M0比较特殊
scales_col_size = hidden / (QUANTIZATION_GROUPSIZE * 4);
} else {
scales_col_size = hidden / QUANTIZATION_GROUPSIZE;
}
}
if (use_ue8m0) { // 设置packed_recv_x_scales
EP_HOST_ASSERT(round_scale); if (quant_type == 3) { // FP8_UE8M0比较特殊,需要单独处理
packed_recv_x_scales = torch::empty({num_local_experts, hidden / (FP8_QUANTIZATION_NUM_PER_CHANNEL * 4), num_ranks * num_max_dispatch_tokens_per_rank}, EP_HOST_ASSERT(fp8_round_scale && quant_group_size == 128);
packed_recv_x_scales = torch::empty({num_local_experts, scales_col_size, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kInt).device(torch::kCUDA)); torch::dtype(torch::kInt).device(torch::kCUDA));
} else if (use_int8) {
packed_recv_x_scales = torch::empty({num_local_experts, 1, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
} else { } else {
packed_recv_x_scales = torch::empty({num_local_experts, hidden / FP8_QUANTIZATION_NUM_PER_CHANNEL, num_ranks * num_max_dispatch_tokens_per_rank}, packed_recv_x_scales = torch::empty({num_local_experts, scales_col_size, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA)); torch::dtype(torch::kFloat32).device(torch::kCUDA));
} }
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
...@@ -1370,7 +1384,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1370,7 +1384,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
next_clean_meta.first, next_clean_meta.second, next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank, num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks, num_topk, num_experts, rank, num_ranks,
use_fp8, round_scale, use_ue8m0, use_int8, quant_type, quant_group_size, fp8_round_scale,
workspace, num_device_sms, launch_stream, phases); workspace, num_device_sms, launch_stream, phases);
}; };
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
......
...@@ -177,7 +177,7 @@ public: ...@@ -177,7 +177,7 @@ public:
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8, int quant_type, int quant_group_size, bool fp8_round_scale,
bool async, bool return_recv_hook); bool async, bool return_recv_hook);
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
......
...@@ -147,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -147,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8, int quant_type, int group_size, bool fp8_round_scale,
void* workspace, int num_device_sms, hipStream_t stream, int phases); void* workspace, int num_device_sms, hipStream_t stream, int phases);
void combine(void* combined_x, void combine(void* combined_x,
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#define LOW_LATENCY_SEND_PHASE 1 #define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2 #define LOW_LATENCY_RECV_PHASE 2
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3 #define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define FP8_QUANTIZATION_NUM_PER_CHANNEL 128 #define QUANTIZATION_GROUPSIZE 128
#define DEFAULT_NUM_CU 20 #define DEFAULT_NUM_CU 20
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6 #define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6
......
This diff is collapsed.
...@@ -62,7 +62,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a ...@@ -62,7 +62,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case 8: \ case 8: \
case_macro(8); \ case_macro(8); \
default: \ default: \
EP_HOST_ASSERT(false and "Unsupported ranks"); \ EP_HOST_ASSERT(false and "Unsupported ranks"); \
} \ } \
while (false) while (false)
...@@ -83,7 +83,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a ...@@ -83,7 +83,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case 20: \ case 20: \
case_macro(20); \ case_macro(20); \
default: \ default: \
EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \ EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \
} \ } \
while (false) while (false)
...@@ -96,7 +96,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a ...@@ -96,7 +96,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case 8: \ case 8: \
case_macro(dtype, 8); \ case_macro(dtype, 8); \
default: \ default: \
EP_HOST_ASSERT(false and "Unsupported ranks"); \ EP_HOST_ASSERT(false and "Unsupported ranks"); \
} \ } \
while (false) while (false)
...@@ -107,7 +107,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a ...@@ -107,7 +107,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case HIP_R_32F: \ case HIP_R_32F: \
case_macro(float); \ case_macro(float); \
default: \ default: \
EP_HOST_ASSERT(false and "Unsupported type"); \ EP_HOST_ASSERT(false and "Unsupported type"); \
} \ } \
while (false) while (false)
...@@ -121,7 +121,9 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a ...@@ -121,7 +121,9 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case_macro(4096); \ case_macro(4096); \
case 7168: \ case 7168: \
case_macro(7168); \ case_macro(7168); \
case 8192: \
case_macro(8192); \
default: \ default: \
EP_HOST_ASSERT(false and "Unsupported hidden"); \ EP_HOST_ASSERT(false and "Unsupported hidden"); \
} \ } \
while (false) while (false)
...@@ -341,10 +341,13 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) { ...@@ -341,10 +341,13 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return *reinterpret_cast<dtype_t *>(recv_int_values); return *reinterpret_cast<dtype_t *>(recv_int_values);
} }
constexpr float kFP8Margin = 1e-4; // 设置不同的量化方式的最大值与相反数
constexpr float kFinfoAmaxE4M3 = 240.0f; constexpr float kFinfoAmaxE4M3 = 448.0f;
constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3; constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3;
constexpr float kInt8Amax = 127.0f; constexpr float kFinfoAmaxE5M2 = 57344.0f;
constexpr float kFinfoAmaxInvE5M2 = 1.0f / kFinfoAmaxE5M2;
constexpr float kFinfoAmaxInt8 = 127.0f;
constexpr float kFinfoAmaxInvInt8 = 1.0f / 127.0f;
__forceinline__ __device__ float fast_pow2(int x) { __forceinline__ __device__ float fast_pow2(int x) {
// We can ensure `-126 <= x and x <= 127` // We can ensure `-126 <= x and x <= 127`
...@@ -359,22 +362,33 @@ __forceinline__ __device__ int fast_log2_ceil(float x) { ...@@ -359,22 +362,33 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
return exp_x - 127 + (man_bits != 0); return exp_x - 127 + (man_bits != 0);
} }
__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv, bool round_scale) { template <int kQuantType>
if (round_scale) { __forceinline__ __device__ void calculate_quant8bit_scales(float amax, float& scale, float& scale_inv, bool round_scale=0) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3); amax = fmaxf(amax, 1e-6f);
scale = fast_pow2(-exp_scale_inv); if constexpr(kQuantType == 1) { // 使用 INT8 对称量化
scale_inv = fast_pow2(exp_scale_inv); scale_inv = kFinfoAmaxInvInt8 * amax;
} else { scale = kFinfoAmaxInt8 / amax;
scale_inv = amax * kFinfoAmaxInvE4M3; } else if constexpr(kQuantType == 2 || kQuantType == 3) { // 使用 FP8_E4M3 或 FP8_UE8M0 非对称量化
scale = kFinfoAmaxE4M3 / amax; if (round_scale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3);
scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv);
} else {
scale_inv = amax * kFinfoAmaxInvE4M3;
scale = kFinfoAmaxE4M3 / amax;
}
} else if constexpr(kQuantType == 4) { // 使用 FP8_E5M2 对称量化
if (round_scale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE5M2);
scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv);
} else {
scale_inv = amax * kFinfoAmaxInvE5M2;
scale = kFinfoAmaxE5M2 / amax;
}
} }
} }
__forceinline__ __device__ void calculate_int8_scales(float amax, float& scale, float& scale_inv) {
scale = kInt8Amax / amax;
scale_inv = amax / kInt8Amax;
}
template <bool kIsUE8M0, typename out_dtype_t = std::conditional_t<kIsUE8M0, uint8_t, float>> template <bool kIsUE8M0, typename out_dtype_t = std::conditional_t<kIsUE8M0, uint8_t, float>>
__forceinline__ __device__ out_dtype_t extract_required_scale_format(float value) { __forceinline__ __device__ out_dtype_t extract_required_scale_format(float value) {
if constexpr (kIsUE8M0) { if constexpr (kIsUE8M0) {
......
...@@ -841,7 +841,7 @@ class Buffer: ...@@ -841,7 +841,7 @@ class Buffer:
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor, def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int, num_experts: int, num_max_dispatch_tokens_per_rank: int, num_experts: int,
use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False, use_int8: bool = False, quant_type: int = 1, quant_group_size: int = 0, fp8_round_scale: bool = False,
async_finish: bool = False, return_recv_hook: bool = False) -> \ async_finish: bool = False, return_recv_hook: bool = False) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]: Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
""" """
...@@ -858,10 +858,20 @@ class Buffer: ...@@ -858,10 +858,20 @@ class Buffer:
only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported. only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value. num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts. num_experts: the number of all experts.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors. 量化配置
round_scale: whether round the scaling factors into power of 2. quant_type: int 量化类型枚举
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`). 0 -> None 不量化,保持原始精度
use_int8: whether to enable INT8 casting. 1 -> Int8 使用 INT8 对称量化
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3_FNUZ)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2_FNUZ)
quant_group_size: int 量化分组大小
0 -> 逐token量化 (per-channel)
128-> 每 128 元素一组 (per-group) 量化
fp8_round_scale: bool 是否将 FP8 缩放因子取整为 2 的幂
true -> 缩放因子 = 2^k,硬件零开销
false -> 缩放因子 = 任意浮点,精度更高
异步配置
async_finish: the current stream will not wait for the communication kernels to be finished if set. async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
...@@ -869,15 +879,25 @@ class Buffer: ...@@ -869,15 +879,25 @@ class Buffer:
Returns: Returns:
recv_x: a tensor or tuple with received tokens for each expert. recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as - packed_recv_x:
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`. 存储接收到的 Token 数据,形状为
The second tensor is the corresponding scales for the first element with shape `[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]`。
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`, 数据类型取决于 quant_type:
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as quant_type == 1 -> torch.int8
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`. quant_type == 2 -> torch.float8_e4m3fnuz
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility. quant_type == 3 -> torch.float8_e4m3fnuz (UE8M0 使用 E4M3 格式存储)
With `use_fp8=False`, the result would be a tensor shaped as quant_type == 4 -> torch.float8_e5m2fnuz
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`. 其他 (非量化) -> torch.bfloat16
- packed_recv_x_scales (可选):
仅在 quant_type > 0 时存在,存储量化的 Scale 值。
形状为 `[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, scales_col_size]`。
- 当 quant_type == 3 (UE8M0) 时:
scales_col_size = hidden // 512
数据类型为 torch.int (内部打包存储 4-bit scale)。
*注意:此模式强制要求 fp8_round_scale=True 且 group_size=128。
- 当 quant_type == 1, 2, 4 时:
scales_col_size = hidden // 128 (若使用 group_size) 或 1 (per-channel)。
数据类型为 torch.float32。
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are, Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced). as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
...@@ -889,14 +909,15 @@ class Buffer: ...@@ -889,14 +909,15 @@ class Buffer:
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \ packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
self.runtime.low_latency_dispatch(x, topk_idx, self.runtime.low_latency_dispatch(x, topk_idx,
num_max_dispatch_tokens_per_rank, num_experts, num_max_dispatch_tokens_per_rank, num_experts,
use_fp8, round_scale, use_ue8m0, use_int8, quant_type, quant_group_size, fp8_round_scale,
async_finish, return_recv_hook) async_finish, return_recv_hook)
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts) handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
tensors_to_record = (x, topk_idx, tensors_to_record = (x, topk_idx,
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_x, packed_recv_x_scales, packed_recv_count,
packed_recv_src_info, packed_recv_layout_range) packed_recv_src_info, packed_recv_layout_range)
return (packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x, packed_recv_count, handle, \
EventOverlap(event, tensors_to_record if async_finish else None), hook recv_x = (packed_recv_x, packed_recv_x_scales) if (quant_type > 0) else packed_recv_x
return recv_x, packed_recv_count, handle, EventOverlap(event, tensors_to_record if async_finish else None), hook
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
......
...@@ -6,7 +6,7 @@ from functools import partial ...@@ -6,7 +6,7 @@ from functools import partial
from typing import Literal, Set from typing import Literal, Set
import deep_ep import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_pg_back, per_token_cast_pc_back
def simulate_failure_and_skip(rank: int, api: Literal["dispatch", "combine", "clean"], expected_masked_ranks: Set[int]): def simulate_failure_and_skip(rank: int, api: Literal["dispatch", "combine", "clean"], expected_masked_ranks: Set[int]):
...@@ -58,7 +58,8 @@ def test_main(num_tokens: int, ...@@ -58,7 +58,8 @@ def test_main(num_tokens: int,
x_list = [x] x_list = [x]
# # NOTES: the last one is for performance testing # # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels # # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1) # x_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1
# x_list = [x_rand]
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
...@@ -80,22 +81,35 @@ def test_main(num_tokens: int, ...@@ -80,22 +81,35 @@ def test_main(num_tokens: int,
hash_value, num_times = 0, 0 hash_value, num_times = 0, 0
for current_x in x_list: for current_x in x_list:
for return_recv_hook in (False, True): for return_recv_hook in (False, True):
for dispatch_use_fp8 in (False, True): for quant_type in (0, 1, 2, 3, ): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
for round_scale in (False, True) if dispatch_use_fp8 else (False,): dispatch_use_quant = quant_type > 0
for use_ue8m0 in (False, True) if round_scale else (False,): for fp8_round_scale in (False, True) if quant_type != 3 else (True, ):
for quant_group_size in (0, 128,) if quant_type >= 2 else (0, ):
if quant_type == 3 and (fp8_round_scale == False or quant_group_size == 0):
continue
num_times += 1 num_times += 1
for _ in range((num_times % 2) + 1): for _ in range((num_times % 2) + 1):
packed_recv_x, packed_recv_count, handle, event, hook = \ packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0, quant_type=quant_type, fp8_round_scale=fp8_round_scale, quant_group_size=quant_group_size,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait() hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_quant else packed_recv_x
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \ if not dispatch_use_quant:
if dispatch_use_fp8 else packed_recv_x.clone() simulated_gemm_x = packed_recv_x.clone()
elif quant_group_size == 0:
simulated_gemm_x = per_token_cast_pc_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].reshape(-1)).view(packed_recv_x[0].shape)
elif quant_group_size == 128:
simulated_gemm_x = per_token_cast_pg_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape)
for i in range(num_local_experts if do_check else 0): for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i expert_id = rank * num_local_experts + i
recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i] if not dispatch_use_quant:
recv_x = packed_recv_x[i]
elif quant_group_size == 0:
recv_x = per_token_cast_pc_back(packed_recv_x[0][i], packed_recv_x[1][i])
elif quant_group_size == 128:
recv_x = per_token_cast_pg_back(packed_recv_x[0][i], packed_recv_x[1][i])
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i] recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
# Check expert indices # Check expert indices
...@@ -113,18 +127,25 @@ def test_main(num_tokens: int, ...@@ -113,18 +127,25 @@ def test_main(num_tokens: int,
if current_x is x: if current_x is x:
recv_x = recv_x[:num_valid_tokens] recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1) recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_x_amax = recv_x[:, :-128].amax(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens] recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1)) assert torch.equal(recv_x_amin, recv_x_amax)
if round_scale:
if dispatch_use_quant:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007 assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
else:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0 assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
for j in range(num_ranks): if quant_group_size != 0:
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item() if fp8_round_scale:
if not round_scale: assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item() else:
assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0 assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
if dispatch_use_fp8: for j in range(num_ranks):
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
if not fp8_round_scale:
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_quant:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens]) hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens]) hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
else: else:
...@@ -147,8 +168,8 @@ def test_main(num_tokens: int, ...@@ -147,8 +168,8 @@ def test_main(num_tokens: int,
if do_check: if do_check:
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0 assert torch.isnan(combined_x).sum().item() == 0
# if not round_scale: # if not fp8_round_scale:
assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: diff={diff}, dispatch_use_fp8={dispatch_use_fp8}, zero_copy={zero_copy}' assert diff < (9e-4 if dispatch_use_quant else 1e-5), f'Error: diff={diff}, dispatch_use_quant={dispatch_use_quant}, zero_copy={zero_copy}'
hash_value ^= hash_tensor(combined_x) hash_value ^= hash_tensor(combined_x)
# noinspection PyShadowingNames # noinspection PyShadowingNames
...@@ -162,7 +183,8 @@ def test_main(num_tokens: int, ...@@ -162,7 +183,8 @@ def test_main(num_tokens: int,
def test_func(return_recv_hook: bool): def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \ recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook) quant_type=2, quant_group_size=0,
async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None large_gemm_with_hook(hook) if return_recv_hook else None
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx, topk_idx,
......
import argparse
import random
import os
import torch
import torch.distributed as dist
from functools import partial
from typing import Literal, Set
import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back_int8
def test_main(num_tokens: int,
hidden: int,
num_experts: int,
num_topk: int,
rank: int,
num_ranks: int,
group: dist.ProcessGroup,
buffer: deep_ep.Buffer,
seed: int = 0):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks
# NOTES: the integers greater than 256 exceed the BF16 precision limit
rank_offset = 128
assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
x_list = [x]
# # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
# Randomly mask some positions
for _ in range(10):
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
# For failure simulation and shrink testing
mask_status = torch.zeros((num_ranks,), dtype=torch.int, device='cuda')
# Check dispatch correctness
do_check = True
hash_value, num_times = 0, 0
for current_x in x_list:
for return_recv_hook in (False, ):
for dispatch_use_fp8 in (True, ):
for round_scale in (False, ):
for use_ue8m0 in (False, ):
num_times += 1
use_int8 = True
for _ in range(1):
packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0, use_int8=use_int8,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
simulated_gemm_x = per_token_cast_back_int8(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, 1)).view(packed_recv_x[0].shape)
for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i
recv_x = per_token_cast_back_int8(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
# Check expert indices
int_mask = (2 ** 32) - 1
num_valid_tokens = recv_count.item()
assert num_valid_tokens == (
recv_layout_range
& int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
assert num_valid_tokens == (all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item(
), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item()}'
if num_valid_tokens == 0:
continue
# Check received data
if current_x is x:
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_x_amax = recv_x[:, :-128].amax(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x_amax)
if round_scale:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
elif use_int8:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.01
else:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
# for j in range(num_ranks):
# if (not round_scale):
# check_tmp1 = (recv_x_amin == j - rank_offset).sum().item()
# check_tmp2 = (all_topk_idx[j] == expert_id).sum().item()
# print(f'rank: {rank}, j: {j}, check_tmp1: {check_tmp1}, check_tmp2: {check_tmp2}, diff: {abs(check_tmp1 - check_tmp2)}')
# assert abs(check_tmp1 - check_tmp2) < 3
# assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_fp8:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
else:
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
print("dispatch int 8 pass")
# noinspection PyShadowingNames
def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float)
mat_1 = torch.randn((8192, 8192), dtype=torch.float)
mat_0 @ mat_1
hook()
# noinspection PyShadowingNames
def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=True, round_scale=False, use_ue8m0=False, use_int8=True,
async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
# Calculate bandwidth
scale_size = 1 # hidden / 128
num_fp8_bytes, num_bf16_bytes = (hidden + scale_size * 4 + 16), hidden * 2
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
for i in range(num_tokens):
num_selections = (topk_idx[i] != -1).sum().item()
num_dispatch_comm_bytes += num_fp8_bytes * num_selections
num_combine_comm_bytes += num_bf16_bytes * num_selections
# Separate profiling
for return_recv_hook in (True, False):
group.barrier()
dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
kernel_names=('dispatch', 'combine'),
barrier_comm_profiling=True,
suppress_kineto_output=True,
num_kernels_per_period=2 if return_recv_hook else 1)
if not return_recv_hook:
print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us',
flush=True)
else:
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us',
flush=True)
return hash_value
# noinspection PyUnboundLocalVariable,PyShadowingNames
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
num_tokens, hidden = args.num_tokens, args.hidden
num_topk, num_experts = args.num_topk, args.num_experts
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
if local_rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink,
explicitly_destroy=True,
allow_mnnvl=args.allow_mnnvl)
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1)
# Destroy the buffer runtime and communication group
buffer.destroy()
dist.barrier()
dist.destroy_process_group()
if __name__ == '__main__':
# TODO: you may modify NUMA binding for less CPU overhead
# TODO: buggy with `num_tokens=512`
parser = argparse.ArgumentParser(description='Test low-latency EP kernels')
parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)')
parser.add_argument('--num-tokens', type=int, default=128, help='Number of tokens (default: 128)')
parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)')
parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=256, help='Number of experts (default: 288)')
parser.add_argument('--allow-mnnvl', action="store_true", help='Allow MNNVL for communication')
parser.add_argument('--disable-nvlink', action='store_true', help='Whether to disable NVLink for testing')
parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test')
parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode')
parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine')
args = parser.parse_args()
num_processes = args.num_processes
torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)
...@@ -57,28 +57,22 @@ def per_token_cast_to_fp8(x: torch.Tensor): ...@@ -57,28 +57,22 @@ def per_token_cast_to_fp8(x: torch.Tensor):
return (x_padded_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, aligned_n)[:, :n].contiguous(), (x_amax / 448.0).view(m, -1) return (x_padded_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, aligned_n)[:, :n].contiguous(), (x_amax / 448.0).view(m, -1)
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor): def per_token_cast_pg_back(x: torch.Tensor, x_scales: torch.Tensor):
if x_fp8.numel() == 0: if x.numel() == 0:
return x_fp8.to(torch.bfloat16) return x.to(torch.bfloat16)
assert x_fp8.dim() == 2 assert x.dim() == 2
m, n = x_fp8.shape m, n = x.shape
aligned_n = align_up(n, 128) aligned_n = align_up(n, 128)
x_fp8_padded = torch.nn.functional.pad(x_fp8, (0, aligned_n - n), mode='constant', value=0) x_padded = torch.nn.functional.pad(x, (0, aligned_n - n), mode='constant', value=0)
if x_scales.dtype == torch.int: if x_scales.dtype == torch.int:
x_scales = x_scales.view(dtype=torch.uint8).to(torch.int) << 23 x_scales = x_scales.view(dtype=torch.uint8).to(torch.int) << 23
x_scales = x_scales.view(dtype=torch.float) x_scales = x_scales.view(dtype=torch.float)
x_fp32_padded = x_fp8_padded.to(torch.float32).view(x_fp8.size(0), -1, 128) x_fp32_padded = x_padded.to(torch.float32).view(x.size(0), -1, 128)
x_scales = x_scales.view(x_fp8.size(0), -1, 1) x_scales = x_scales.view(x.size(0), -1, 1)
return (x_fp32_padded * x_scales).view(x_fp8_padded.shape).to(torch.bfloat16)[:,:n].contiguous() return (x_fp32_padded * x_scales).view(x_padded.shape).to(torch.bfloat16)[:,:n].contiguous()
def per_token_cast_back_int8(x_int8: torch.Tensor, x_scales: torch.Tensor): def per_token_cast_pc_back(x_int8: torch.Tensor, x_scales: torch.Tensor):
"""
x_int8: [m, n] int8 tensor
x_scales: [m, n] 或 [m, 1] 或 [m, n/128] 量化 scale float
return: [m, n] bf16 tensor
"""
if x_int8.numel() == 0: if x_int8.numel() == 0:
return x_int8.to(torch.bfloat16) return x_int8.to(torch.bfloat16)
...@@ -86,12 +80,9 @@ def per_token_cast_back_int8(x_int8: torch.Tensor, x_scales: torch.Tensor): ...@@ -86,12 +80,9 @@ def per_token_cast_back_int8(x_int8: torch.Tensor, x_scales: torch.Tensor):
m, n = x_int8.shape m, n = x_int8.shape
aligned_n = align_up(n, 128) aligned_n = align_up(n, 128)
x_int8_padded = torch.nn.functional.pad( x_int8_padded = torch.nn.functional.pad(x_int8, (0, aligned_n - n), mode='constant', value=0)
x_int8, (0, aligned_n - n), mode='constant', value=0
)
x_fp32_padded = x_int8_padded.to(torch.float32).view(m, -1, 1) x_fp32_padded = x_int8_padded.to(torch.float32).view(m, -1, 1)
x_scales = x_scales.view(m, -1, 1).to(torch.float32) x_scales = x_scales.view(m, -1, 1).to(torch.float32)
# print(f'x_int8.shape: {x_int8.shape}, x_fp32_padded: {x_fp32_padded.shape}, x_scales: {x_scales.shape}')
x_deq = (x_fp32_padded * x_scales).view(m, aligned_n) x_deq = (x_fp32_padded * x_scales).view(m, aligned_n)
return x_deq[:, :n].to(torch.bfloat16).contiguous() return x_deq[:, :n].to(torch.bfloat16).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment