Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
d86ee4c8
Commit
d86ee4c8
authored
Aug 28, 2025
by
yuguo
Browse files
[DCU] fix quantize bug
parent
546bb548
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
transformer_engine/common/util/vectorized_pointwise.h
transformer_engine/common/util/vectorized_pointwise.h
+4
-4
No files found.
transformer_engine/common/util/vectorized_pointwise.h
View file @
d86ee4c8
...
...
@@ -201,7 +201,7 @@ __launch_bounds__(unary_kernel_threads) __global__
__builtin_assume
(
max
>=
0
);
max
=
fmaxf
(
fabsf
(
temp
),
max
);
}
if
constexpr
(
is_fp8
<
OutputType
>::
value
)
{
if
constexpr
(
is_fp8
<
OutputType
>::
value
||
is_int8
<
OutputType
>::
value
)
{
temp
=
temp
*
s
;
}
if
constexpr
(
is_int8
<
OutputType
>::
value
)
{
...
...
@@ -222,7 +222,7 @@ __launch_bounds__(unary_kernel_threads) __global__
}
}
if
constexpr
(
is_fp8
<
OutputType
>::
value
)
{
if
constexpr
(
is_fp8
<
OutputType
>::
value
||
is_int8
<
OutputType
>::
value
)
{
// Update scale-inverse
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
&&
scale_inv
!=
nullptr
)
{
reciprocal
<
ComputeType
>
(
scale_inv
,
s
);
...
...
@@ -262,7 +262,7 @@ __launch_bounds__(unary_kernel_threads) __global__
__builtin_assume
(
max
>=
0
);
max
=
fmaxf
(
fabsf
(
temp
),
max
);
}
if
constexpr
(
is_fp8
<
OutputType
>::
value
)
{
if
constexpr
(
is_fp8
<
OutputType
>::
value
||
is_int8
<
OutputType
>::
value
)
{
temp
=
temp
*
s
;
}
if
constexpr
(
is_int8
<
OutputType
>::
value
)
{
...
...
@@ -283,7 +283,7 @@ __launch_bounds__(unary_kernel_threads) __global__
}
}
if
constexpr
(
is_fp8
<
OutputType
>::
value
)
{
if
constexpr
(
is_fp8
<
OutputType
>::
value
||
is_int8
<
OutputType
>::
value
)
{
// Update scale-inverse
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
&&
scale_inv
!=
nullptr
)
{
reciprocal
<
ComputeType
>
(
scale_inv
,
s
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment