Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
8a03ff34
Commit
8a03ff34
authored
Jun 18, 2025
by
wenjh
Browse files
Fix vector blockwise acc problem
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
d1bf39cf
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
5 deletions
+20
-5
tests/pytorch/references/blockwise_quantizer_reference.py
tests/pytorch/references/blockwise_quantizer_reference.py
+5
-1
tests/pytorch/references/quantize_scale_calc.py
tests/pytorch/references/quantize_scale_calc.py
+6
-2
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
...e/common/transpose/quantize_transpose_square_blockwise.cu
+9
-2
No files found.
tests/pytorch/references/blockwise_quantizer_reference.py
View file @
8a03ff34
...
@@ -171,7 +171,11 @@ class BlockwiseQuantizerReference:
...
@@ -171,7 +171,11 @@ class BlockwiseQuantizerReference:
qx
=
x_tiled
*
scale
.
reshape
(
M
,
K
//
tile_len
,
1
)
qx
=
x_tiled
*
scale
.
reshape
(
M
,
K
//
tile_len
,
1
)
qx
=
torch
.
clamp
(
qx
,
min
=-
dtype_max
,
max
=
dtype_max
)
qx
=
torch
.
clamp
(
qx
,
min
=-
dtype_max
,
max
=
dtype_max
)
if
quant_dtype
==
torch
.
int8
:
if
quant_dtype
==
torch
.
int8
:
qx
=
torch
.
round
(
qx
)
positive_mask
=
qx
>=
0
negative_mask
=
~
positive_mask
pos_part
=
torch
.
where
(
positive_mask
,
torch
.
floor
(
qx
+
0.5
),
0
)
neg_part
=
torch
.
where
(
negative_mask
,
torch
.
ceil
(
qx
-
0.5
),
0
)
qx
=
pos_part
+
neg_part
qx
=
qx
.
to
(
dtype
=
quant_dtype
)
qx
=
qx
.
to
(
dtype
=
quant_dtype
)
qx
=
qx
.
reshape
(
M
,
K
)
qx
=
qx
.
reshape
(
M
,
K
)
return
qx
,
scale_inv
return
qx
,
scale_inv
...
...
tests/pytorch/references/quantize_scale_calc.py
View file @
8a03ff34
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
from
typing
import
Tuple
from
typing
import
Tuple
import
torch
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
def
scale_from_amax_tensor
(
def
scale_from_amax_tensor
(
x_dtype
:
torch
.
dtype
,
x_dtype
:
torch
.
dtype
,
...
@@ -48,7 +48,11 @@ def scale_from_amax_tensor(
...
@@ -48,7 +48,11 @@ def scale_from_amax_tensor(
# No subnormals and zero.
# No subnormals and zero.
assert
(
exp
>
-
127
).
all
()
assert
(
exp
>
-
127
).
all
()
unity
=
torch
.
tensor
([
1.0
],
device
=
exp
.
device
)
unity
=
torch
.
tensor
([
1.0
],
device
=
exp
.
device
)
torch
.
ldexp
(
unity
,
exp
,
out
=
scale
)
if
IS_HIP_EXTENSION
:
host_scale
=
torch
.
ldexp
(
unity
.
cpu
(),
exp
.
cpu
())
scale
=
host_scale
.
to
(
exp
.
device
)
else
:
torch
.
ldexp
(
unity
,
exp
,
out
=
scale
)
# Case where amax is inf. The frexp, ldexp logic changes 0.0 scales
# Case where amax is inf. The frexp, ldexp logic changes 0.0 scales
# Return 0.0 for 0.0 scale for consistency with non-pow2 scale
# Return 0.0 for 0.0 scale for consistency with non-pow2 scale
# calculation.
# calculation.
...
...
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
View file @
8a03ff34
...
@@ -187,8 +187,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
...
@@ -187,8 +187,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
// Step 3: Store cast output
// Step 3: Store cast output
CType
scale_data
=
block_tile_scale
;
CType
scale_data
=
block_tile_scale
;
OType
scaled_elt
=
OType
scaled_elt
=
0
;
static_cast
<
OType
>
(
static_cast
<
CType
>
(
thrd_tile_input
[
i
].
data
.
elt
[
j
])
*
scale_data
);
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
scaled_elt
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
thrd_tile_input
[
i
].
data
.
elt
[
j
])
*
scale_data
))));
}
else
{
scaled_elt
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
thrd_tile_input
[
i
].
data
.
elt
[
j
])
*
scale_data
);
}
tmp_output_c
.
data
.
elt
[
j
]
=
scaled_elt
;
tmp_output_c
.
data
.
elt
[
j
]
=
scaled_elt
;
// Step 4: do transpose within thread tile
// Step 4: do transpose within thread tile
if
constexpr
(
kReturnTranspose
)
{
if
constexpr
(
kReturnTranspose
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment