Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
a9e430c0
Commit
a9e430c0
authored
Jul 31, 2025
by
yuguo
Browse files
Merge branch 'develop_v2.5' of
http://10.16.6.30/dcutoolkit/deeplearing/TransformerEngine
parents
9ab6cd98
a397dcb7
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
15 deletions
+23
-15
tests/pytorch/test_float8_current_scaling_exact.py
tests/pytorch/test_float8_current_scaling_exact.py
+1
-1
tests/pytorch/test_int8_channelwise_gemm_exact.py
tests/pytorch/test_int8_channelwise_gemm_exact.py
+22
-14
No files found.
tests/pytorch/test_float8_current_scaling_exact.py
View file @
a9e430c0
...
...
@@ -801,7 +801,7 @@ class TestFP8CurrentScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBas
use_bias
,
seed
=
torch
.
initial_seed
(),
dtype
=
dtype
,
y_error
=
0.5
,
y_error
=
0.98
if
int8_simulation_fp8
else
0.5
,
ln_out_error
=
0.5
,
dgrad_error
=
1
,
wgrad_error
=
1
,
...
...
tests/pytorch/test_int8_channelwise_gemm_exact.py
View file @
a9e430c0
...
...
@@ -190,13 +190,20 @@ use_split_accumulator = False
# fp8 to int8
quantizer
=
Float8CurrentScalingQuantizer
(
quantizer
_e5m2
=
Float8CurrentScalingQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
,
device
=
"cuda"
,
force_pow_2_scales
=
False
,
amax_epsilon
=
0.0
,
)
quantizer_e4m3
=
Float8CurrentScalingQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
,
force_pow_2_scales
=
False
,
amax_epsilon
=
0.0
,
)
# current scaling
def
to_float8_CS
(
tensor
:
torch
.
Tensor
,
...
...
@@ -206,6 +213,7 @@ def to_float8_CS(
amax_epsilon
:
float
=
0.0
,
)
->
Float8Tensor
:
"""Cast tensor to FP8"""
quantizer
=
quantizer_e5m2
if
fp8_dtype
==
tex
.
DType
.
kFloat8E5M2
else
quantizer_e4m3
if
return_transpose
:
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
else
:
...
...
@@ -235,13 +243,13 @@ end = time.time()
# print("w_int8: ", w_int8)
# Cast to FP8 and back
x_fp8
=
to_float8_CS
(
x_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E
5M2
)
w_fp8
=
to_float8_CS
(
w_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E
5M2
)
# print("x_fp8: ", x_fp8._data.view(dtype=torch.float8_e
5m2
))
# print("w_fp8: ", w_fp8._data.view(dtype=torch.float8_e
5m2
))
x_fp8
=
to_float8_CS
(
x_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E
4M3
)
w_fp8
=
to_float8_CS
(
w_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E
4M3
)
# print("x_fp8: ", x_fp8._data.view(dtype=torch.float8_e
4m3fn
))
# print("w_fp8: ", w_fp8._data.view(dtype=torch.float8_e
4m3fn
))
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
x_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e
5m2
),
x_fp8
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
w_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e
5m2
),
w_fp8
.
_scale_inv
,
False
)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
x_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e
4m3fn
),
x_fp8
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
w_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e
4m3fn
),
w_fp8
.
_scale_inv
,
False
)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
...
...
@@ -279,10 +287,10 @@ print("output: ", output)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
i
in
range
(
20
):
x_fp8
=
to_float8_CS
(
x_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E
5M2
)
w_fp8
=
to_float8_CS
(
w_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E
5M2
)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
x_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e
5m2
),
x_fp8
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
w_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e
5m2
),
w_fp8
.
_scale_inv
,
False
)
x_fp8
=
to_float8_CS
(
x_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E
4M3
)
#
w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E
4M3
)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
x_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e
4m3fn
),
x_fp8
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
w_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e
4m3fn
),
w_fp8
.
_scale_inv
,
False
)
y_int32
=
tex
.
generic_gemm
(
w_int8
,
transa
,
...
...
@@ -376,7 +384,7 @@ torch.cuda.synchronize()
start
=
time
.
time
()
for
i
in
range
(
20
):
dy_fp8
=
to_float8_CS
(
dy_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
w_fp8
=
to_float8_CS
(
w_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
#
w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8
(
dy_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
dy_fp8
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8_opt
(
w_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
w_fp8
.
_scale_inv
,
False
)
...
...
@@ -473,8 +481,8 @@ start = time.time()
for
i
in
range
(
20
):
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_fp8
=
to_float8_CS
(
dy_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
x_fp8
=
to_float8_CS
(
x_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
#
dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
#
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment