Unverified Commit 90bb2be2 authored by Rex's avatar Rex Committed by GitHub
Browse files

Minor improvement to per_tensor_quant_fp8 (#4197)

parent b93ef5e5
...@@ -57,13 +57,9 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output ...@@ -57,13 +57,9 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output
template <typename T> template <typename T>
__global__ void per_tensor_quant_fp8_kernel( __global__ void per_tensor_quant_fp8_kernel(
const T* __restrict__ input, const T* __restrict__ input, FP8_TYPE* __restrict__ output, const float scale_val, const int64_t num_elements) {
FP8_TYPE* __restrict__ output,
const float* __restrict__ scale,
const int64_t num_elements) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x; const int gid = blockIdx.x * blockDim.x + threadIdx.x;
const int grid_size = blockDim.x * gridDim.x; const int grid_size = blockDim.x * gridDim.x;
const float scale_val = 1.0f / (*scale);
constexpr uint32_t vec_size = 16 / sizeof(T); constexpr uint32_t vec_size = 16 / sizeof(T);
using vec_t = flashinfer::vec_t<T, vec_size>; using vec_t = flashinfer::vec_t<T, vec_size>;
...@@ -125,12 +121,9 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch ...@@ -125,12 +121,9 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch
per_tensor_absmax_kernel<scalar_t><<<grid, block, 0, stream>>>( per_tensor_absmax_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()), static_cast<float*>(output_s.data_ptr()), num_elements); static_cast<scalar_t*>(input.data_ptr()), static_cast<float*>(output_s.data_ptr()), num_elements);
} }
float scale_val = 1.0f / (*static_cast<float*>(output_s.data_ptr()));
per_tensor_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>( per_tensor_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()), static_cast<scalar_t*>(input.data_ptr()), static_cast<FP8_TYPE*>(output_q.data_ptr()), scale_val, num_elements);
static_cast<FP8_TYPE*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
num_elements);
return true; return true;
}); });
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment