Unverified Commit 5c66c442 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support new DeepGEMM format in per token group quant (#7146)

parent aa46ed34
...@@ -116,7 +116,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -116,7 +116,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def( m.def(
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size," "sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float fp8_min, float fp8_max) -> ()"); " float eps, float fp8_min, float fp8_max, bool scale_ue8m0) -> ()");
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8); m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
m.def( m.def(
......
...@@ -16,11 +16,16 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) { ...@@ -16,11 +16,16 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
return val; return val;
} }
template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false> template <
typename T,
typename DST_DTYPE,
bool IS_COLUMN_MAJOR = false,
bool SCALE_UE8M0 = false,
typename scale_packed_t = std::conditional_t<SCALE_UE8M0, uint32_t, float>>
__global__ void per_token_group_quant_8bit_kernel( __global__ void per_token_group_quant_8bit_kernel(
const T* __restrict__ input, const T* __restrict__ input,
void* __restrict__ output_q, void* __restrict__ output_q,
float* __restrict__ output_s, scale_packed_t* __restrict__ output_s,
const int group_size, const int group_size,
const int num_groups, const int num_groups,
const int groups_per_block, const int groups_per_block,
...@@ -39,15 +44,24 @@ __global__ void per_token_group_quant_8bit_kernel( ...@@ -39,15 +44,24 @@ __global__ void per_token_group_quant_8bit_kernel(
float local_absmax = eps; float local_absmax = eps;
using scale_element_t = std::conditional_t<SCALE_UE8M0, uint8_t, float>;
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
const T* group_input = input + block_group_offset; const T* group_input = input + block_group_offset;
DST_DTYPE* group_output = static_cast<DST_DTYPE*>(output_q) + block_group_offset; DST_DTYPE* group_output = static_cast<DST_DTYPE*>(output_q) + block_group_offset;
float* scale_output; scale_element_t* scale_output;
if constexpr (IS_COLUMN_MAJOR) { if constexpr (IS_COLUMN_MAJOR) {
const int row_idx = global_group_id / scale_num_rows; const int num_elems_per_pack = static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
const int col_idx = global_group_id % scale_num_rows; const int scale_num_rows_element = scale_num_rows * num_elems_per_pack;
scale_output = output_s + (col_idx * scale_stride + row_idx); const int row_idx = global_group_id / scale_num_rows_element;
const int col_idx_raw = global_group_id % scale_num_rows_element;
const int col_idx = col_idx_raw / num_elems_per_pack;
const int pack_idx = col_idx_raw % num_elems_per_pack;
scale_output = reinterpret_cast<scale_element_t*>(output_s) +
(col_idx * scale_stride * num_elems_per_pack + row_idx * num_elems_per_pack + pack_idx);
} else { } else {
static_assert(!SCALE_UE8M0);
scale_output = output_s + global_group_id; scale_output = output_s + global_group_id;
} }
...@@ -70,10 +84,21 @@ __global__ void per_token_group_quant_8bit_kernel( ...@@ -70,10 +84,21 @@ __global__ void per_token_group_quant_8bit_kernel(
local_absmax = GroupReduceMax(local_absmax, lane_id); local_absmax = GroupReduceMax(local_absmax, lane_id);
const float y_s = local_absmax / max_8bit; float y_s = local_absmax / max_8bit;
if constexpr (SCALE_UE8M0) {
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
}
// TODO can optimize
scale_element_t y_s_quant;
if constexpr (SCALE_UE8M0) {
y_s_quant = (uint8_t)(((int)log2f(y_s)) + 127);
} else {
y_s_quant = y_s;
}
if (lane_id == 0) { if (lane_id == 0) {
*scale_output = y_s; *scale_output = y_s_quant;
} }
for (int32_t i = lane_id; i < num_vec_elems; i += 16) { for (int32_t i = lane_id; i < num_vec_elems; i += 16) {
...@@ -96,7 +121,8 @@ void sgl_per_token_group_quant_8bit( ...@@ -96,7 +121,8 @@ void sgl_per_token_group_quant_8bit(
int64_t group_size, int64_t group_size,
double eps, double eps,
double min_8bit, double min_8bit,
double max_8bit) { double max_8bit,
bool scale_ue8m0 = false) {
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(output_q); CHECK_INPUT(output_q);
...@@ -129,35 +155,51 @@ void sgl_per_token_group_quant_8bit( ...@@ -129,35 +155,51 @@ void sgl_per_token_group_quant_8bit(
const int scale_num_rows = output_s.size(1); const int scale_num_rows = output_s.size(1);
const int scale_stride = output_s.stride(1); const int scale_stride = output_s.stride(1);
#define LAUNCH_KERNEL(T, DST_DTYPE) \ #define LAUNCH_KERNEL(T, DST_DTYPE) \
do { \ do { \
dim3 grid(num_blocks); \ dim3 grid(num_blocks); \
dim3 block(num_threads); \ dim3 block(num_threads); \
if (is_column_major) { \ if (is_column_major) { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true><<<grid, block, 0, stream>>>( \ if (scale_ue8m0) { \
static_cast<T*>(input.data_ptr()), \ per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true><<<grid, block, 0, stream>>>( \
output_q.data_ptr(), \ static_cast<T*>(input.data_ptr()), \
static_cast<float*>(output_s.data_ptr()), \ output_q.data_ptr(), \
group_size, \ static_cast<uint32_t*>(output_s.data_ptr()), \
num_groups, \ group_size, \
groups_per_block, \ num_groups, \
(float)eps, \ groups_per_block, \
(float)min_8bit, \ (float)eps, \
(float)max_8bit, \ (float)min_8bit, \
scale_num_rows, \ (float)max_8bit, \
scale_stride); \ scale_num_rows, \
} else { \ scale_stride); \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false><<<grid, block, 0, stream>>>( \ } else { \
static_cast<T*>(input.data_ptr()), \ per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false><<<grid, block, 0, stream>>>( \
output_q.data_ptr(), \ static_cast<T*>(input.data_ptr()), \
static_cast<float*>(output_s.data_ptr()), \ output_q.data_ptr(), \
group_size, \ static_cast<float*>(output_s.data_ptr()), \
num_groups, \ group_size, \
groups_per_block, \ num_groups, \
(float)eps, \ groups_per_block, \
(float)min_8bit, \ (float)eps, \
(float)max_8bit); \ (float)min_8bit, \
} \ (float)max_8bit, \
scale_num_rows, \
scale_stride); \
} \
} else { \
assert(!scale_ue8m0); \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit); \
} \
} while (0) } while (0)
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
...@@ -192,6 +234,7 @@ void sgl_per_token_group_quant_fp8( ...@@ -192,6 +234,7 @@ void sgl_per_token_group_quant_fp8(
int64_t group_size, int64_t group_size,
double eps, double eps,
double fp8_min, double fp8_min,
double fp8_max) { double fp8_max,
sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max); bool scale_ue8m0) {
sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0);
} }
...@@ -175,7 +175,8 @@ void sgl_per_token_group_quant_fp8( ...@@ -175,7 +175,8 @@ void sgl_per_token_group_quant_fp8(
int64_t group_size, int64_t group_size,
double eps, double eps,
double fp8_min, double fp8_min,
double fp8_max); double fp8_max,
bool scale_ue8m0);
void sgl_per_token_group_quant_int8( void sgl_per_token_group_quant_int8(
at::Tensor input, at::Tensor input,
at::Tensor output_q, at::Tensor output_q,
......
...@@ -90,9 +90,10 @@ def sgl_per_token_group_quant_fp8( ...@@ -90,9 +90,10 @@ def sgl_per_token_group_quant_fp8(
eps: float, eps: float,
fp8_min: float, fp8_min: float,
fp8_max: float, fp8_max: float,
scale_ue8m0: bool,
) -> None: ) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default( torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
) )
......
...@@ -255,7 +255,10 @@ def sglang_per_token_group_quant_8bit( ...@@ -255,7 +255,10 @@ def sglang_per_token_group_quant_8bit(
f8_info = torch.finfo(dtype) f8_info = torch.finfo(dtype)
fp8_max = f8_info.max fp8_max = f8_info.max
fp8_min = f8_info.min fp8_min = f8_info.min
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) scale_ue8m0 = False # TODO also test true
sgl_per_token_group_quant_fp8(
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
)
return x_q, x_s return x_q, x_s
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment