"vscode:/vscode.git/clone" did not exist on "4aa68291a9671491521733da647cb7dd2cabb236"
Unverified Commit fb4ce17d authored by strgrb's avatar strgrb Committed by GitHub
Browse files

Fix per_token_group_quant_8bit when hidden_dim // group_size is not divided by 4. (#8449)


Co-authored-by: default avatarZhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
parent 25f73c6c
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <cuda_fp8.h>
#include <cmath>
#include <flashinfer/vec_dtypes.cuh>
......@@ -32,7 +33,7 @@ __global__ void per_token_group_quant_8bit_kernel(
const float eps,
const float min_8bit,
const float max_8bit,
const int scale_num_rows = 0,
const int num_groups_per_row = 0,
const int scale_stride = 0) {
const int threads_per_group = 16;
const int64_t local_group_id = threadIdx.x / threads_per_group;
......@@ -53,11 +54,10 @@ __global__ void per_token_group_quant_8bit_kernel(
if constexpr (IS_COLUMN_MAJOR) {
const int num_elems_per_pack = static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
const int scale_num_rows_element = scale_num_rows * num_elems_per_pack;
const int row_idx = global_group_id / scale_num_rows_element;
const int col_idx_raw = global_group_id % scale_num_rows_element;
const int col_idx = col_idx_raw / num_elems_per_pack;
const int pack_idx = col_idx_raw % num_elems_per_pack;
const int row_idx = global_group_id / num_groups_per_row;
const int col_idx_unpacked = global_group_id % num_groups_per_row;
const int col_idx = col_idx_unpacked / num_elems_per_pack;
const int pack_idx = col_idx_unpacked % num_elems_per_pack;
scale_output = reinterpret_cast<scale_element_t*>(output_s) +
(col_idx * scale_stride * num_elems_per_pack + row_idx * num_elems_per_pack + pack_idx);
} else {
......@@ -86,7 +86,7 @@ __global__ void per_token_group_quant_8bit_kernel(
float y_s = local_absmax / max_8bit;
if constexpr (SCALE_UE8M0) {
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
y_s = exp2f(ceilf(log2f(fmaxf(y_s, 1e-10f))));
}
// TODO can optimize
......@@ -152,7 +152,8 @@ void sgl_per_token_group_quant_8bit(
const int num_threads = groups_per_block * THREADS_PER_GROUP;
const bool is_column_major = output_s.stride(0) < output_s.stride(1);
const int scale_num_rows = output_s.size(1);
const int hidden_dim = input.size(input.dim() - 1);
const int num_groups_per_row = hidden_dim / group_size;
const int scale_stride = output_s.stride(1);
#define LAUNCH_KERNEL(T, DST_DTYPE) \
......@@ -171,7 +172,7 @@ void sgl_per_token_group_quant_8bit(
(float)eps, \
(float)min_8bit, \
(float)max_8bit, \
scale_num_rows, \
num_groups_per_row, \
scale_stride); \
} else { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false><<<grid, block, 0, stream>>>( \
......@@ -184,7 +185,7 @@ void sgl_per_token_group_quant_8bit(
(float)eps, \
(float)min_8bit, \
(float)max_8bit, \
scale_num_rows, \
num_groups_per_row, \
scale_stride); \
} \
} else { \
......@@ -207,7 +208,7 @@ void sgl_per_token_group_quant_8bit(
LAUNCH_KERNEL(scalar_t, int8_t);
return true;
} else if (dst_type == at::ScalarType::Float8_e4m3fn) {
LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn);
LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3);
return true;
}
return false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment