Unverified Commit fb4ce17d authored by strgrb's avatar strgrb Committed by GitHub
Browse files

Fix per_token_group_quant_8bit when hidden_dim // group_size is not divided by 4. (#8449)


Co-authored-by: default avatarZhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
parent 25f73c6c
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h> #include <c10/util/Float8_e4m3fn.h>
#include <cuda_fp8.h>
#include <cmath> #include <cmath>
#include <flashinfer/vec_dtypes.cuh> #include <flashinfer/vec_dtypes.cuh>
...@@ -32,7 +33,7 @@ __global__ void per_token_group_quant_8bit_kernel( ...@@ -32,7 +33,7 @@ __global__ void per_token_group_quant_8bit_kernel(
const float eps, const float eps,
const float min_8bit, const float min_8bit,
const float max_8bit, const float max_8bit,
const int scale_num_rows = 0, const int num_groups_per_row = 0,
const int scale_stride = 0) { const int scale_stride = 0) {
const int threads_per_group = 16; const int threads_per_group = 16;
const int64_t local_group_id = threadIdx.x / threads_per_group; const int64_t local_group_id = threadIdx.x / threads_per_group;
...@@ -53,11 +54,10 @@ __global__ void per_token_group_quant_8bit_kernel( ...@@ -53,11 +54,10 @@ __global__ void per_token_group_quant_8bit_kernel(
if constexpr (IS_COLUMN_MAJOR) { if constexpr (IS_COLUMN_MAJOR) {
const int num_elems_per_pack = static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t)); const int num_elems_per_pack = static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
const int scale_num_rows_element = scale_num_rows * num_elems_per_pack; const int row_idx = global_group_id / num_groups_per_row;
const int row_idx = global_group_id / scale_num_rows_element; const int col_idx_unpacked = global_group_id % num_groups_per_row;
const int col_idx_raw = global_group_id % scale_num_rows_element; const int col_idx = col_idx_unpacked / num_elems_per_pack;
const int col_idx = col_idx_raw / num_elems_per_pack; const int pack_idx = col_idx_unpacked % num_elems_per_pack;
const int pack_idx = col_idx_raw % num_elems_per_pack;
scale_output = reinterpret_cast<scale_element_t*>(output_s) + scale_output = reinterpret_cast<scale_element_t*>(output_s) +
(col_idx * scale_stride * num_elems_per_pack + row_idx * num_elems_per_pack + pack_idx); (col_idx * scale_stride * num_elems_per_pack + row_idx * num_elems_per_pack + pack_idx);
} else { } else {
...@@ -86,7 +86,7 @@ __global__ void per_token_group_quant_8bit_kernel( ...@@ -86,7 +86,7 @@ __global__ void per_token_group_quant_8bit_kernel(
float y_s = local_absmax / max_8bit; float y_s = local_absmax / max_8bit;
if constexpr (SCALE_UE8M0) { if constexpr (SCALE_UE8M0) {
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f)))); y_s = exp2f(ceilf(log2f(fmaxf(y_s, 1e-10f))));
} }
// TODO can optimize // TODO can optimize
...@@ -152,7 +152,8 @@ void sgl_per_token_group_quant_8bit( ...@@ -152,7 +152,8 @@ void sgl_per_token_group_quant_8bit(
const int num_threads = groups_per_block * THREADS_PER_GROUP; const int num_threads = groups_per_block * THREADS_PER_GROUP;
const bool is_column_major = output_s.stride(0) < output_s.stride(1); const bool is_column_major = output_s.stride(0) < output_s.stride(1);
const int scale_num_rows = output_s.size(1); const int hidden_dim = input.size(input.dim() - 1);
const int num_groups_per_row = hidden_dim / group_size;
const int scale_stride = output_s.stride(1); const int scale_stride = output_s.stride(1);
#define LAUNCH_KERNEL(T, DST_DTYPE) \ #define LAUNCH_KERNEL(T, DST_DTYPE) \
...@@ -171,7 +172,7 @@ void sgl_per_token_group_quant_8bit( ...@@ -171,7 +172,7 @@ void sgl_per_token_group_quant_8bit(
(float)eps, \ (float)eps, \
(float)min_8bit, \ (float)min_8bit, \
(float)max_8bit, \ (float)max_8bit, \
scale_num_rows, \ num_groups_per_row, \
scale_stride); \ scale_stride); \
} else { \ } else { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false><<<grid, block, 0, stream>>>( \ per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false><<<grid, block, 0, stream>>>( \
...@@ -184,7 +185,7 @@ void sgl_per_token_group_quant_8bit( ...@@ -184,7 +185,7 @@ void sgl_per_token_group_quant_8bit(
(float)eps, \ (float)eps, \
(float)min_8bit, \ (float)min_8bit, \
(float)max_8bit, \ (float)max_8bit, \
scale_num_rows, \ num_groups_per_row, \
scale_stride); \ scale_stride); \
} \ } \
} else { \ } else { \
...@@ -207,7 +208,7 @@ void sgl_per_token_group_quant_8bit( ...@@ -207,7 +208,7 @@ void sgl_per_token_group_quant_8bit(
LAUNCH_KERNEL(scalar_t, int8_t); LAUNCH_KERNEL(scalar_t, int8_t);
return true; return true;
} else if (dst_type == at::ScalarType::Float8_e4m3fn) { } else if (dst_type == at::ScalarType::Float8_e4m3fn) {
LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn); LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3);
return true; return true;
} }
return false; return false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment