Unverified Commit 90bb2be2 authored by Rex's avatar Rex Committed by GitHub
Browse files

Minor improvement to per_tensor_quant_fp8 (#4197)

parent b93ef5e5
......@@ -57,13 +57,9 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output
template <typename T>
__global__ void per_tensor_quant_fp8_kernel(
const T* __restrict__ input,
FP8_TYPE* __restrict__ output,
const float* __restrict__ scale,
const int64_t num_elements) {
const T* __restrict__ input, FP8_TYPE* __restrict__ output, const float scale_val, const int64_t num_elements) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
const int grid_size = blockDim.x * gridDim.x;
const float scale_val = 1.0f / (*scale);
constexpr uint32_t vec_size = 16 / sizeof(T);
using vec_t = flashinfer::vec_t<T, vec_size>;
......@@ -125,12 +121,9 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch
per_tensor_absmax_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()), static_cast<float*>(output_s.data_ptr()), num_elements);
}
float scale_val = 1.0f / (*static_cast<float*>(output_s.data_ptr()));
per_tensor_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()),
static_cast<FP8_TYPE*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
num_elements);
static_cast<scalar_t*>(input.data_ptr()), static_cast<FP8_TYPE*>(output_q.data_ptr()), scale_val, num_elements);
return true;
});
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment