Unverified Commit 7a4309cc authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

[sgl-kernel performace] fix fp8 quant kernels dispatch __nv_fp8_e4m3 bug to...


[sgl-kernel performace] fix fp8 quant kernels dispatch __nv_fp8_e4m3 bug to improve performance 10%-20% (#8499)
Co-authored-by: default avatarKe Bao <ispobaoke@gmail.com>
parent 81367066
...@@ -44,10 +44,10 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output ...@@ -44,10 +44,10 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output
} }
} }
template <typename T> template <typename T, typename DST_DTYPE>
__global__ void per_tensor_quant_fp8_kernel( __global__ void per_tensor_quant_fp8_kernel(
const T* __restrict__ input, const T* __restrict__ input,
FP8_TYPE* __restrict__ output, DST_DTYPE* __restrict__ output,
const float* __restrict__ scale, const float* __restrict__ scale,
const int64_t num_elements) { const int64_t num_elements) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x; const int gid = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -65,12 +65,12 @@ __global__ void per_tensor_quant_fp8_kernel( ...@@ -65,12 +65,12 @@ __global__ void per_tensor_quant_fp8_kernel(
vec_t input_vec; vec_t input_vec;
input_vec.cast_load(input + i * VEC_SIZE); input_vec.cast_load(input + i * VEC_SIZE);
FP8_TYPE output_arr[VEC_SIZE]; DST_DTYPE output_arr[VEC_SIZE];
#pragma unroll #pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) { for (uint32_t j = 0; j < VEC_SIZE; ++j) {
float val = fmax(fmin(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX); float val = fmax(fmin(static_cast<float>(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
#ifndef USE_ROCM #ifndef USE_ROCM
output_arr[j] = static_cast<FP8_TYPE>(val); output_arr[j] = static_cast<DST_DTYPE>(val);
#else #else
output_arr[j] = c10::Float8_e4m3fnuz( output_arr[j] = c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), __hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
...@@ -84,7 +84,7 @@ __global__ void per_tensor_quant_fp8_kernel( ...@@ -84,7 +84,7 @@ __global__ void per_tensor_quant_fp8_kernel(
for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) { for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) {
float val = fmax(-FP8_E4M3_MAX, fmin(static_cast<float>(input[idx]) * scale_val, FP8_E4M3_MAX)); float val = fmax(-FP8_E4M3_MAX, fmin(static_cast<float>(input[idx]) * scale_val, FP8_E4M3_MAX));
#ifndef USE_ROCM #ifndef USE_ROCM
output[idx] = static_cast<FP8_TYPE>(val); output[idx] = static_cast<DST_DTYPE>(val);
#else #else
output[idx] = c10::Float8_e4m3fnuz( output[idx] = c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), __hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
...@@ -113,9 +113,9 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch ...@@ -113,9 +113,9 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch
static_cast<scalar_t*>(input.data_ptr()), static_cast<float*>(output_s.data_ptr()), num_elements); static_cast<scalar_t*>(input.data_ptr()), static_cast<float*>(output_s.data_ptr()), num_elements);
} }
per_tensor_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>( per_tensor_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()), static_cast<scalar_t*>(input.data_ptr()),
static_cast<FP8_TYPE*>(output_q.data_ptr()), static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()), static_cast<float*>(output_s.data_ptr()),
num_elements); num_elements);
return true; return true;
......
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <cmath> #include <cmath>
......
...@@ -12,10 +12,10 @@ static constexpr int kWarpSize = 32; ...@@ -12,10 +12,10 @@ static constexpr int kWarpSize = 32;
// • One warp handles one token. // • One warp handles one token.
// • Eight tokens per 256‑thread CTA. // • Eight tokens per 256‑thread CTA.
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
template <typename T, int kTokensPerCTA = 8, int kVecSize = 16> template <typename T, typename DST_DTYPE, int kTokensPerCTA = 8, int kVecSize = 16>
__global__ void per_token_quant_fp8_kernel( __global__ void per_token_quant_fp8_kernel(
const T* __restrict__ input, const T* __restrict__ input,
FP8_TYPE* __restrict__ output_q, DST_DTYPE* __restrict__ output_q,
float* __restrict__ output_s, float* __restrict__ output_s,
const int64_t hidden_dim, const int64_t hidden_dim,
const int64_t num_tokens) { const int64_t num_tokens) {
...@@ -26,7 +26,7 @@ __global__ void per_token_quant_fp8_kernel( ...@@ -26,7 +26,7 @@ __global__ void per_token_quant_fp8_kernel(
// Global tensors for this token // Global tensors for this token
const T* token_input = input + token_id * hidden_dim; const T* token_input = input + token_id * hidden_dim;
FP8_TYPE* token_output = output_q + token_id * hidden_dim; DST_DTYPE* token_output = output_q + token_id * hidden_dim;
float* token_scale = output_s + token_id; float* token_scale = output_s + token_id;
// //
...@@ -62,14 +62,13 @@ __global__ void per_token_quant_fp8_kernel( ...@@ -62,14 +62,13 @@ __global__ void per_token_quant_fp8_kernel(
for (int i = lane_id; i < num_vec_elems; i += kWarpSize) { for (int i = lane_id; i < num_vec_elems; i += kWarpSize) {
vec_t input_vec; vec_t input_vec;
input_vec.cast_load(token_input + i * kVecSize); input_vec.cast_load(token_input + i * kVecSize);
FP8_TYPE output_arr[kVecSize]; DST_DTYPE output_arr[kVecSize];
#pragma unroll #pragma unroll
for (uint32_t j = 0; j < kVecSize; ++j) { for (uint32_t j = 0; j < kVecSize; ++j) {
float val = static_cast<float>(input_vec[j]) * scale_inv; float val = static_cast<float>(input_vec[j]) * scale_inv;
val = fmaxf(fminf(val, FP8_E4M3_MAX), -FP8_E4M3_MAX); val = fmaxf(fminf(val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
#ifndef USE_ROCM #ifndef USE_ROCM
output_arr[j] = static_cast<FP8_TYPE>(val); output_arr[j] = static_cast<DST_DTYPE>(val);
#else #else
output_arr[j] = c10::Float8_e4m3fnuz( output_arr[j] = c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), __hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
...@@ -83,10 +82,10 @@ __global__ void per_token_quant_fp8_kernel( ...@@ -83,10 +82,10 @@ __global__ void per_token_quant_fp8_kernel(
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// 2. Baseline kernel (1 token / CTA, CUB block reduce) // 2. Baseline kernel (1 token / CTA, CUB block reduce)
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
template <typename T> template <typename T, typename DST_DTYPE>
__global__ void per_token_quant_fp8_small_batch_kernel( __global__ void per_token_quant_fp8_small_batch_kernel(
const T* __restrict__ input, const T* __restrict__ input,
FP8_TYPE* __restrict__ output_q, DST_DTYPE* __restrict__ output_q,
float* __restrict__ output_s, float* __restrict__ output_s,
const int64_t hidden_dim, const int64_t hidden_dim,
const int64_t num_tokens) { const int64_t num_tokens) {
...@@ -97,7 +96,7 @@ __global__ void per_token_quant_fp8_small_batch_kernel( ...@@ -97,7 +96,7 @@ __global__ void per_token_quant_fp8_small_batch_kernel(
const int block_dim = blockDim.x; const int block_dim = blockDim.x;
const T* token_input = input + token_idx * hidden_dim; const T* token_input = input + token_idx * hidden_dim;
FP8_TYPE* token_output = output_q + token_idx * hidden_dim; DST_DTYPE* token_output = output_q + token_idx * hidden_dim;
float max_value = 0.0f; float max_value = 0.0f;
...@@ -135,12 +134,12 @@ __global__ void per_token_quant_fp8_small_batch_kernel( ...@@ -135,12 +134,12 @@ __global__ void per_token_quant_fp8_small_batch_kernel(
vec_t input_vec; vec_t input_vec;
input_vec.cast_load(token_input + i * VEC_SIZE); input_vec.cast_load(token_input + i * VEC_SIZE);
FP8_TYPE output_arr[VEC_SIZE]; DST_DTYPE output_arr[VEC_SIZE];
#pragma unroll #pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) { for (uint32_t j = 0; j < VEC_SIZE; ++j) {
float val = fmaxf(fminf(static_cast<float>(input_vec[j]) * scale_inv, FP8_E4M3_MAX), -FP8_E4M3_MAX); float val = fmaxf(fminf(static_cast<float>(input_vec[j]) * scale_inv, FP8_E4M3_MAX), -FP8_E4M3_MAX);
#ifndef USE_ROCM #ifndef USE_ROCM
output_arr[j] = static_cast<FP8_TYPE>(val); output_arr[j] = static_cast<DST_DTYPE>(val);
#else #else
output_arr[j] = c10::Float8_e4m3fnuz( output_arr[j] = c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), __hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
...@@ -173,9 +172,9 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch: ...@@ -173,9 +172,9 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
constexpr int THREADS = TOKENS_PER_CTA * kWarpSize; // 256 constexpr int THREADS = TOKENS_PER_CTA * kWarpSize; // 256
dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA); dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA);
dim3 block(THREADS); dim3 block(THREADS);
per_token_quant_fp8_kernel<scalar_t, TOKENS_PER_CTA, 16><<<grid, block, 0, stream>>>( per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16><<<grid, block, 0, stream>>>(
static_cast<const scalar_t*>(input.data_ptr()), static_cast<const scalar_t*>(input.data_ptr()),
static_cast<FP8_TYPE*>(output_q.data_ptr()), static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()), static_cast<float*>(output_s.data_ptr()),
hidden_dim, hidden_dim,
num_tokens); num_tokens);
...@@ -184,9 +183,9 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch: ...@@ -184,9 +183,9 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
constexpr int THREADS = 256; constexpr int THREADS = 256;
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(THREADS); dim3 block(THREADS);
per_token_quant_fp8_small_batch_kernel<scalar_t><<<grid, block, 0, stream>>>( per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3><<<grid, block, 0, stream>>>(
static_cast<const scalar_t*>(input.data_ptr()), static_cast<const scalar_t*>(input.data_ptr()),
static_cast<FP8_TYPE*>(output_q.data_ptr()), static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()), static_cast<float*>(output_s.data_ptr()),
hidden_dim, hidden_dim,
num_tokens); num_tokens);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment