Commit d86ee4c8 authored by yuguo's avatar yuguo
Browse files

[DCU] fix quantize bug

parent 546bb548
......@@ -201,7 +201,7 @@ __launch_bounds__(unary_kernel_threads) __global__
__builtin_assume(max >= 0);
max = fmaxf(fabsf(temp), max);
}
if constexpr (is_fp8<OutputType>::value) {
if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
temp = temp * s;
}
if constexpr (is_int8<OutputType>::value) {
......@@ -222,7 +222,7 @@ __launch_bounds__(unary_kernel_threads) __global__
}
}
if constexpr (is_fp8<OutputType>::value) {
if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s);
......@@ -262,7 +262,7 @@ __launch_bounds__(unary_kernel_threads) __global__
__builtin_assume(max >= 0);
max = fmaxf(fabsf(temp), max);
}
if constexpr (is_fp8<OutputType>::value) {
if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
temp = temp * s;
}
if constexpr (is_int8<OutputType>::value) {
......@@ -283,7 +283,7 @@ __launch_bounds__(unary_kernel_threads) __global__
}
}
if constexpr (is_fp8<OutputType>::value) {
if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment