Commit 1446ae62 authored by yuguo's avatar yuguo
Browse files

Merge branch 'develop_v2.7' of...

Merge branch 'develop_v2.7' of http://10.16.6.30/dcutoolkit/deeplearing/TransformerEngine into release_v2.7
parents 65e6a5e0 d86ee4c8
......@@ -201,7 +201,7 @@ __launch_bounds__(unary_kernel_threads) __global__
__builtin_assume(max >= 0);
max = fmaxf(fabsf(temp), max);
}
if constexpr (is_fp8<OutputType>::value) {
if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
temp = temp * s;
}
if constexpr (is_int8<OutputType>::value) {
......@@ -222,7 +222,7 @@ __launch_bounds__(unary_kernel_threads) __global__
}
}
if constexpr (is_fp8<OutputType>::value) {
if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s);
......@@ -262,7 +262,7 @@ __launch_bounds__(unary_kernel_threads) __global__
__builtin_assume(max >= 0);
max = fmaxf(fabsf(temp), max);
}
if constexpr (is_fp8<OutputType>::value) {
if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
temp = temp * s;
}
if constexpr (is_int8<OutputType>::value) {
......@@ -283,7 +283,7 @@ __launch_bounds__(unary_kernel_threads) __global__
}
}
if constexpr (is_fp8<OutputType>::value) {
if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment