Unverified Commit 1b0a1555 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Perf] Using `__nv_fp8_e4m3` instead of `c10::e4m3` for `per_token_group_quant` (#21867)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 44bc46da
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include "../per_token_group_quant_8bit.h"
#include <cmath>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <torch/all.h>
......@@ -199,7 +197,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
if (dst_type == at::ScalarType::Float8_e4m3fn) {
LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn);
LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3);
} else if (dst_type == at::ScalarType::Char) {
LAUNCH_KERNEL(scalar_t, int8_t);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment