Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
e4f5325e
Commit
e4f5325e
authored
Feb 24, 2026
by
wenjh
Browse files
Merge branch 'develop_v2.10' into release_v2.10
parents
eebc98fc
a68e5f87
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
225 additions
and
854 deletions
+225
-854
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+5
-3
tests/pytorch/distributed/test_numerics.py
tests/pytorch/distributed/test_numerics.py
+16
-1
tests/pytorch/test_float8_blockwise_gemm_exact.py
tests/pytorch/test_float8_blockwise_gemm_exact.py
+3
-3
tests/pytorch/test_float8_blockwise_scaling_exact.py
tests/pytorch/test_float8_blockwise_scaling_exact.py
+116
-1
tests/pytorch/test_int8_channelwise_gemm_exact.py
tests/pytorch/test_int8_channelwise_gemm_exact.py
+0
-796
transformer_engine/pytorch/quantization.py
transformer_engine/pytorch/quantization.py
+21
-24
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+64
-26
No files found.
qa/L0_pytorch_unittest/test.sh
View file @
e4f5325e
...
@@ -36,10 +36,12 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PA
...
@@ -36,10 +36,12 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PA
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8tensor.xml
$TE_PATH
/tests/pytorch/test_float8tensor.py
||
test_fail
"test_float8tensor.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8tensor.xml
$TE_PATH
/tests/pytorch/test_float8tensor.py
||
test_fail
"test_float8tensor.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8blockwisetensor.xml
$TE_PATH
/tests/pytorch/test_float8blockwisetensor.py
||
test_fail
"test_float8blockwisetensor.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8blockwisetensor.xml
$TE_PATH
/tests/pytorch/test_float8blockwisetensor.py
||
test_fail
"test_float8blockwisetensor.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_scaling_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_scaling_exact.py
||
test_fail
"test_float8_blockwise_scaling_exact.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_scaling_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_scaling_exact.py
||
test_fail
"test_float8_blockwise_scaling_exact.py"
NVTE_INT8_SIM_FP8
=
1 python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_gemm_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_gemm_exact.py
||
test_fail
"test_float8_blockwise_gemm_exact.py"
NVTE_INT8_SIM_FP8
=
1 python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_gemm_exact_int8.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_gemm_exact.py
||
test_fail
"test_float8_blockwise_gemm_exact.py_int8"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_gemm_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_gemm_exact.py
||
test_fail
"test_float8_blockwise_gemm_exact.py"
# channelwise int8 test
# channelwise int8 test
NVTE_INT8_SIM_FP8
=
1 python3
-m
pytest
-v
-s
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_current_scaling_exact.xml
$TE_PATH
/tests/pytorch/test_float8_current_scaling_exact.py
python3
-m
pytest
-v
-s
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_current_scaling_exact.xml
$TE_PATH
/tests/pytorch/test_float8_current_scaling_exact.py
||
test_fail
"test_float8_current_scaling_exact.py"
NVTE_INT8_SIM_FP8
=
1
NVTE_INT8_SIM_FP8_TENSORWISE
=
1 python3
-m
pytest
-v
-s
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_current_scaling_exact.xml
$TE_PATH
/tests/pytorch/test_float8_current_scaling_exact.py
NVTE_INT8_SIM_FP8
=
1 python3
-m
pytest
-v
-s
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_current_scaling_exact_int8.xml
$TE_PATH
/tests/pytorch/test_float8_current_scaling_exact.py
||
test_fail
"test_float8_current_scaling_exact.py_int8"
NVTE_INT8_SIM_FP8
=
1
NVTE_INT8_SIM_FP8_TENSORWISE
=
1 python3
-m
pytest
-v
-s
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_current_scaling_exact_int8_tensorwise.xml
$TE_PATH
/tests/pytorch/test_float8_current_scaling_exact.py
||
test_fail
"test_float8_current_scaling_exact.py_int8_tensorwise"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_gqa.xml
$TE_PATH
/tests/pytorch/test_gqa.py
||
test_fail
"test_gqa.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_gqa.xml
$TE_PATH
/tests/pytorch/test_gqa.py
||
test_fail
"test_gqa.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_optimizer.xml
$TE_PATH
/tests/pytorch/test_fused_optimizer.py
||
test_fail
"test_fused_optimizer.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_optimizer.xml
$TE_PATH
/tests/pytorch/test_fused_optimizer.py
||
test_fail
"test_fused_optimizer.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_multi_tensor.xml
$TE_PATH
/tests/pytorch/test_multi_tensor.py
||
test_fail
"test_multi_tensor.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_multi_tensor.xml
$TE_PATH
/tests/pytorch/test_multi_tensor.py
||
test_fail
"test_multi_tensor.py"
...
...
tests/pytorch/distributed/test_numerics.py
View file @
e4f5325e
...
@@ -51,11 +51,26 @@ def _run_test(quantization):
...
@@ -51,11 +51,26 @@ def _run_test(quantization):
all_boolean
=
[
True
,
False
]
all_boolean
=
[
True
,
False
]
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"quantization"
,
[
None
,
"fp8"
,
"mxfp8"
,
"fp8_cs"
,
"fp8_block_scaling"
,
"nvfp4"
]
"quantization"
,
[
None
,
"fp8"
,
"mxfp8"
,
"fp8_cs"
,
"fp8_block_scaling"
,
"nvfp4"
]
)
)
def
test_distributed
(
quantization
):
def
test_distributed
(
quantization
):
if
quantization
==
"fp8"
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
quantization
==
"fp8_cs"
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
quantization
==
"mxfp8"
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
quantization
==
"fp8_block_scaling"
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
quantization
==
"nvfp4"
and
not
nvfp4_available
:
pytest
.
skip
(
reason_for_no_nvfp4
)
_run_test
(
quantization
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
[
None
,
"fp8"
,
"mxfp8"
,
"fp8_cs"
,
"fp8_block_scaling"
,
"nvfp4"
]
)
def
test_int8_distributed
(
quantization
):
if
quantization
==
"fp8"
and
not
fp8_available
:
if
quantization
==
"fp8"
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
quantization
==
"fp8_cs"
and
not
fp8_available
:
if
quantization
==
"fp8_cs"
and
not
fp8_available
:
...
...
tests/pytorch/test_float8_blockwise_gemm_exact.py
View file @
e4f5325e
...
@@ -47,7 +47,7 @@ def cublas_gemm_fp8_blockwise_case(
...
@@ -47,7 +47,7 @@ def cublas_gemm_fp8_blockwise_case(
atol
:
float
=
0.0
,
atol
:
float
=
0.0
,
rtol
:
float
=
0.0
rtol
:
float
=
0.0
):
):
if
IS_HIP_EXTENSION
and
int8_simulation_fp8
:
if
IS_HIP_EXTENSION
:
if
use_bias
or
use_gelu
:
if
use_bias
or
use_gelu
:
pytest
.
skip
(
"Bias and GELU not supported in int8 simulation mode on ROCm."
)
pytest
.
skip
(
"Bias and GELU not supported in int8 simulation mode on ROCm."
)
if
not
((
not
x_columnwise
and
not
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
)
or
(
not
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
)
or
(
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
is_w_1d_scaled
)):
if
not
((
not
x_columnwise
and
not
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
)
or
(
not
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
)
or
(
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
is_w_1d_scaled
)):
...
@@ -249,7 +249,7 @@ def cublas_gemm_test_constraint_enforced(
...
@@ -249,7 +249,7 @@ def cublas_gemm_test_constraint_enforced(
expected_err_cls
=
RuntimeError
expected_err_cls
=
RuntimeError
):
):
if
IS_HIP_EXTENSION
:
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ROCm does not support cuBLAS
GEMM
. No need to test constraint enforcement."
)
pytest
.
skip
(
"ROCm does not support cuBLAS
blockwise FP8 gemm
. No need to test constraint enforcement."
)
if
not
fp8_blockwise_gemm_supported
():
if
not
fp8_blockwise_gemm_supported
():
pytest
.
skip
(
"CUDA version does not support blockwise FP8 gemm."
)
pytest
.
skip
(
"CUDA version does not support blockwise FP8 gemm."
)
# Setup device and random seed
# Setup device and random seed
...
...
tests/pytorch/test_float8_blockwise_scaling_exact.py
View file @
e4f5325e
...
@@ -9,7 +9,7 @@ import pathlib
...
@@ -9,7 +9,7 @@ import pathlib
import
pytest
import
pytest
import
torch
import
torch
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
)
from
transformer_engine.common.recipe
import
Float8BlockScaling
from
transformer_engine.common.recipe
import
Float8BlockScaling
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch
import
(
from
transformer_engine.pytorch
import
(
...
@@ -507,6 +507,9 @@ def test_quantization_block_tiling_extrema_versus_reference(
...
@@ -507,6 +507,9 @@ def test_quantization_block_tiling_extrema_versus_reference(
rtol
=
0.0
,
rtol
=
0.0
,
)
)
def
fp8_blockwise_scaling_supported
()
->
bool
:
supported
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
return
supported
# FP8 per tesnor current scaling
# FP8 per tesnor current scaling
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
...
@@ -541,12 +544,65 @@ class TestFP8BlockScalingRecipeLinear(TestFP8RecipeLinearBase):
...
@@ -541,12 +544,65 @@ class TestFP8BlockScalingRecipeLinear(TestFP8RecipeLinearBase):
out_size
,
out_size
,
dtype
,
dtype
,
use_bias
=
True
,
use_bias
=
True
,
):
if
not
fp8_blockwise_scaling_supported
():
pytest
.
skip
(
"CUDA version does not support blockwise FP8."
)
fp8_zero_tolerance_tensor_dumps_recipe2
=
None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map
=
self
.
_check_golden_tensor_dumps
(
TENSOR_DUMP_DIR
,
recipe2
,
(
batch_size
,
hidden_size
,
out_size
),
dtype
,
use_bias
)
if
tensor_map
is
not
None
:
fp8_zero_tolerance_tensor_dumps_recipe2
=
tensor_map
self
.
compare_recipe
(
recipe1
,
recipe2
,
batch_size
,
hidden_size
,
out_size
,
use_bias
,
seed
=
torch
.
initial_seed
(),
dtype
=
dtype
,
y_error
=
0.5
,
dgrad_error
=
1
,
wgrad_error
=
1
,
bgrad_error
=
0.5
,
recipe1_golden_tensors
=
None
,
recipe2_golden_tensors
=
fp8_zero_tolerance_tensor_dumps_recipe2
,
)
@
pytest
.
mark
.
parametrize
(
"batch_size, hidden_size, out_size"
,
[
(
16
,
256
,
128
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
],
ids
=
[
"bf16"
])
@
pytest
.
mark
.
parametrize
(
"recipe1, recipe2"
,
[
(
GetRecipes
.
none
,
GetRecipes
.
fp8_blockwise
),
],
)
def
test_int8_current_scaling_with_linear_module
(
self
,
recipe1
,
recipe2
,
batch_size
,
hidden_size
,
out_size
,
dtype
,
use_bias
=
True
,
):
):
if
IS_HIP_EXTENSION
:
if
IS_HIP_EXTENSION
:
import
importlib
import
importlib
ori_int8_sim_fp8
=
os
.
environ
.
get
(
"NVTE_INT8_SIM_FP8"
,
None
)
ori_int8_sim_fp8
=
os
.
environ
.
get
(
"NVTE_INT8_SIM_FP8"
,
None
)
os
.
environ
[
"NVTE_INT8_SIM_FP8"
]
=
"1"
os
.
environ
[
"NVTE_INT8_SIM_FP8"
]
=
"1"
importlib
.
reload
(
te
.
pytorch
.
fp8
)
importlib
.
reload
(
te
.
pytorch
.
fp8
)
if
not
fp8_blockwise_scaling_supported
():
pytest
.
skip
(
"CUDA version does not support blockwise FP8."
)
fp8_zero_tolerance_tensor_dumps_recipe2
=
None
fp8_zero_tolerance_tensor_dumps_recipe2
=
None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
# if we cannot get all four tensors, then still set the tensor dump to None
...
@@ -612,12 +668,71 @@ class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase)
...
@@ -612,12 +668,71 @@ class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase)
out_size
,
out_size
,
dtype
,
dtype
,
use_bias
=
True
,
use_bias
=
True
,
):
if
not
fp8_blockwise_scaling_supported
():
pytest
.
skip
(
"CUDA version does not support blockwise FP8."
)
fp8_zero_tolerance_tensor_dumps_recipe2
=
None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map
=
self
.
_check_golden_tensor_dumps
(
TENSOR_DUMP_DIR
,
recipe2
,
(
batch_size
,
hidden_size
,
out_size
),
dtype
,
use_bias
,
"LayerNorm"
,
)
if
tensor_map
is
not
None
:
fp8_zero_tolerance_tensor_dumps_recipe2
=
tensor_map
self
.
compare_recipe
(
recipe1
,
recipe2
,
batch_size
,
hidden_size
,
out_size
,
use_bias
,
seed
=
torch
.
initial_seed
(),
dtype
=
dtype
,
y_error
=
0.5
,
ln_out_error
=
0.5
,
dgrad_error
=
1.6
,
wgrad_error
=
1
,
bgrad_error
=
0.5
,
recipe1_golden_tensors
=
None
,
recipe2_golden_tensors
=
fp8_zero_tolerance_tensor_dumps_recipe2
,
)
@
pytest
.
mark
.
parametrize
(
"batch_size, hidden_size, out_size"
,
[
(
16
,
256
,
128
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
],
ids
=
[
"bf16"
])
@
pytest
.
mark
.
parametrize
(
"recipe1, recipe2"
,
[
(
GetRecipes
.
none
,
GetRecipes
.
fp8_blockwise
),
],
)
def
test_int8_current_scaling_with_layernorm_linear_module
(
self
,
recipe1
,
recipe2
,
batch_size
,
hidden_size
,
out_size
,
dtype
,
use_bias
=
True
,
):
):
if
IS_HIP_EXTENSION
:
if
IS_HIP_EXTENSION
:
import
importlib
import
importlib
ori_int8_sim_fp8
=
os
.
environ
.
get
(
"NVTE_INT8_SIM_FP8"
,
None
)
ori_int8_sim_fp8
=
os
.
environ
.
get
(
"NVTE_INT8_SIM_FP8"
,
None
)
os
.
environ
[
"NVTE_INT8_SIM_FP8"
]
=
"1"
os
.
environ
[
"NVTE_INT8_SIM_FP8"
]
=
"1"
importlib
.
reload
(
te
.
pytorch
.
fp8
)
importlib
.
reload
(
te
.
pytorch
.
fp8
)
if
not
fp8_blockwise_scaling_supported
():
pytest
.
skip
(
"CUDA version does not support blockwise FP8."
)
fp8_zero_tolerance_tensor_dumps_recipe2
=
None
fp8_zero_tolerance_tensor_dumps_recipe2
=
None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
# if we cannot get all four tensors, then still set the tensor dump to None
...
...
tests/pytorch/test_int8_channelwise_gemm_exact.py
deleted
100644 → 0
View file @
eebc98fc
# NVTE_INT8_SIM_FP8_TENSORWISE=1 python3 test_int8_channelwise_gemm_exact.py
from
collections.abc
import
Iterable
import
io
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
,
Optional
import
pytest
import
torch
import
transformer_engine
as
te
import
transformer_engine.common.recipe
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
Float8Quantizer
,
Float8Tensor
,
Float8CurrentScalingQuantizer
,
)
from
transformer_engine.pytorch.constants
import
TE_DType
,
TE_DType_To_Torch
from
transformer_engine.pytorch.utils
import
is_non_tn_fp8_gemm_supported
import
transformer_engine_torch
as
tex
from
references.ref_per_tensor_cs
import
ref_per_tensor_cs_cast
from
transformer_engine.pytorch.triton.per_token_group_quant
import
(
per_token_quant_fp8_to_int8
,
per_token_quant_fp8_to_int8_opt
,
channelwise_dequantize
,
channelwise_dequantize_transA
,
channelwise_dequantize_transB
,
tensorwise_dequantize
)
import
time
import
os
int8_simulation_fp8_tensorwise
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
,
"0"
)))
tensorwise_int8_check
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE_CHECK"
,
"0"
)))
def
dtype_tols
(
dtype
:
torch
.
dtype
)
->
Dict
[
str
,
float
]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
if
dtype
==
torch
.
float32
:
return
dict
(
rtol
=
1.3e-6
,
atol
=
1e-5
)
if
dtype
==
torch
.
float16
:
return
dict
(
rtol
=
1e-3
,
atol
=
1e-5
)
if
dtype
==
torch
.
bfloat16
:
return
dict
(
rtol
=
1.6e-2
,
atol
=
1e-5
)
raise
ValueError
(
f
"Unsuppored dtype (
{
dtype
}
)"
)
def
assert_allclose
(
l1
:
List
[
torch
.
Tensor
],
l2
:
List
[
torch
.
Tensor
],
atol
:
float
=
None
,
rtol
:
float
=
None
)
->
bool
:
"""Ensures two lists are equal."""
assert
len
(
l1
)
==
len
(
l2
),
"Unequal number of outputs."
for
i
,
(
t1
,
t2
)
in
enumerate
(
zip
(
l1
,
l2
)):
tols
=
dtype_tols
(
t2
.
dtype
)
if
rtol
is
not
None
:
tols
[
"rtol"
]
=
rtol
if
atol
is
not
None
:
tols
[
"atol"
]
=
atol
result
=
torch
.
allclose
(
t1
,
t2
,
**
tols
)
if
not
result
:
diff
=
torch
.
abs
(
t1
-
t2
)
tol
=
tols
[
"atol"
]
+
(
tols
[
"rtol"
]
*
torch
.
abs
(
t2
))
exceed_mask
=
diff
>
tol
if
exceed_mask
.
any
():
indices
=
torch
.
nonzero
(
exceed_mask
,
as_tuple
=
True
)
max_diff
=
diff
[
exceed_mask
].
max
()
max_idx
=
(
diff
[
exceed_mask
]
==
max_diff
).
nonzero
(
as_tuple
=
True
)[
0
][
0
]
max_location
=
[
idx
[
max_idx
].
item
()
for
idx
in
indices
]
msg
=
(
f
"Outputs not close enough in tensor at idx=
{
i
}
. "
f
"Maximum difference at location
{
max_location
}
"
f
"with
{
t1
[
exceed_mask
][
max_idx
].
item
()
}
vs
{
t2
[
exceed_mask
][
max_idx
].
item
()
}
"
f
"(diff
{
max_diff
.
item
()
}
)."
)
raise
AssertionError
(
msg
)
# TN
m
=
4096
k
=
4096
n
=
6144
seed
=
4096
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
device
=
"cuda"
out_dtype
=
torch
.
int32
# Allocate cuBLAS workspace
workspace_size
=
128
workspace
=
torch
.
empty
(
128
,
dtype
=
torch
.
uint8
,
device
=
device
)
out_quantizer
=
None
accumulate
=
False
use_gelu
=
False
use_bias
=
False
bias
=
None
use_grad
=
False
assert
not
(
use_gelu
and
use_bias
),
"Bias and GELU not supported by GEMM"
aux_tensor
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
device
)
if
use_gelu
else
None
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
device
)
if
accumulate
else
None
bias_dtype
=
TE_DType
[
torch
.
bfloat16
if
bias
is
None
else
bias
.
dtype
]
use_split_accumulator
=
False
# bf16 to int8
# transa = True
# transb = False
# x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
# w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
# bf16_out = torch.matmul(x_bf16, w_bf16.t())
# x_int8, x_scales = per_token_quant_int8(x_bf16)
# w_int8, w_scales = per_token_quant_int8(w_bf16)
# # print("x_int8: ", x_int8)
# # print("w_int8: ", w_int8)
# # cuBLAS GEMM
# # return type is out, bias_grad, gelu_input, extra_output
# # We are just capturing out.
# y_int32 = tex.generic_gemm(
# w_int8,
# transa,
# x_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# # y_int32 = torch._int_mm(x_int8, w_int8.t())
# # print("y_int32: ", y_int32)
# output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
# # print("out_scales.shape: ", out_scales.shape)
# # print("out_scales: ", out_scales)
# # print("bf16_out: ", bf16_out)
# # print("output: ", output)
# # NN
# transa = False
# transb = False
# dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
# w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
# bf16_dx = torch.matmul(dy_bf16, w_bf16)
# dy_int8, dy_scales = per_token_quant_int8(dy_bf16)
# w_int8, w_scales = per_token_quant_int8_v2(w_bf16)
# # print("dy_scales.shape: ", dy_scales.shape)
# # print("w_scales.shape: ", w_scales.shape)
# # print("dy_int8: ", dy_int8)
# # print("w_int8: ", w_int8)
# # cuBLAS GEMM
# # return type is out, bias_grad, gelu_input, extra_output
# # We are just capturing out.
# dx_int32 = tex.generic_gemm(
# w_int8,
# transa,
# dy_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# # dx_int32 = torch._int_mm(dy_int8, w_int8)
# # print("dx_int32: ", dx_int32)
# dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
# # print("dx_scales.shape: ", dx_scales.shape)
# # print("dx_scales: ", dx_scales)
# # print("bf16_dx: ", bf16_dx)
# # print("dx: ", dx)
# # NT
# transa = False
# transb = True
# dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
# x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
# bf16_dw = torch.matmul(dy_bf16.t(), x_bf16)
# dy_int8, dy_scales = per_token_quant_int8_v2(dy_bf16)
# x_int8, x_scales = per_token_quant_int8_v2(x_bf16)
# # cuBLAS GEMM
# # return type is out, bias_grad, gelu_input, extra_output
# # We are just capturing out.
# dw_int32 = tex.generic_gemm(
# x_int8,
# transa,
# dy_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
# # print("bf16_dw: ", bf16_dw)
# # print("dw: ", dw)
# fp8 to int8
quantizer_e5m2
=
Float8CurrentScalingQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
,
device
=
"cuda"
,
force_pow_2_scales
=
False
,
amax_epsilon
=
0.0
,
)
quantizer_e4m3
=
Float8CurrentScalingQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
,
force_pow_2_scales
=
False
,
amax_epsilon
=
0.0
,
)
# current scaling
def
to_float8_CS
(
tensor
:
torch
.
Tensor
,
fp8_dtype
:
tex
.
DType
=
tex
.
DType
.
kFloat8E5M2
,
return_transpose
:
bool
=
False
,
force_pow_2_scales
:
bool
=
False
,
amax_epsilon
:
float
=
0.0
,
)
->
Float8Tensor
:
"""Cast tensor to FP8"""
quantizer
=
quantizer_e5m2
if
fp8_dtype
==
tex
.
DType
.
kFloat8E5M2
else
quantizer_e4m3
if
return_transpose
:
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
else
:
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
return
quantizer
(
tensor
)
# TN
transa
=
True
transb
=
False
x_bf16
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
w_bf16
=
(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
output
=
(
torch
.
randn
((
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
20
):
bf16_out
=
torch
.
matmul
(
x_bf16
,
w_bf16
.
t
())
print
(
"bf16_out: "
,
bf16_out
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
i
in
range
(
20
):
bf16_out
=
torch
.
matmul
(
x_bf16
,
w_bf16
.
t
())
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
# x_int8, x_scales = per_token_quant_int8(x_bf16)
# w_int8, w_scales = per_token_quant_int8(w_bf16)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
# Cast to FP8 and back
x_fp8
=
to_float8_CS
(
x_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)
w_fp8
=
to_float8_CS
(
w_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)
# print("x_fp8: ", x_fp8._data.view(dtype=torch.float8_e4m3fn))
# print("w_fp8: ", w_fp8._data.view(dtype=torch.float8_e4m3fn))
if
int8_simulation_fp8_tensorwise
:
x_int8
,
x_scales
=
x_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
x_fp8
.
_scale_inv
w_int8
,
w_scales
=
w_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
w_fp8
.
_scale_inv
else
:
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
x_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
x_fp8
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
w_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
w_fp8
.
_scale_inv
,
False
)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
y_int32
=
tex
.
generic_gemm
(
w_int8
,
transa
,
x_int8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
# y_int32 = torch._int_mm(x_int8, w_int8.t())
# print("y_int32: ", y_int32)
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize
(
x_scales
,
w_scales
,
y_int32
,
output
)
else
:
output
=
channelwise_dequantize_transB
(
x_scales
,
w_scales
,
y_int32
)
print
(
"output: "
,
output
)
if
tensorwise_int8_check
:
lt_output
=
tex
.
generic_gemm
(
w_fp8
,
transa
,
x_fp8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
torch
.
bfloat16
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
True
,
)[
0
]
print
(
"lt_output: "
,
lt_output
)
assert_allclose
([
output
],
[
lt_output
])
# print("out_scales.shape: ", out_scales.shape)
# print("out_scales: ", out_scales)
# torch.cuda.synchronize()
# start = time.time()
# for i in range(20):
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
# # w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
# x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e4m3fn), x_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e4m3fn), w_fp8._scale_inv, False)
# y_int32 = tex.generic_gemm(
# w_int8,
# transa,
# x_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
# torch.cuda.synchronize()
# end = time.time()
# NN
# transa = True
transa
=
False
transb
=
False
dy_bf16
=
(
torch
.
randn
((
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
w_bf16
=
(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
dx
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
bf16_dx
=
torch
.
matmul
(
dy_bf16
,
w_bf16
)
print
(
"bf16_dx: "
,
bf16_dx
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
i
in
range
(
20
):
bf16_dx
=
torch
.
matmul
(
dy_bf16
,
w_bf16
)
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
# Cast to FP8 and back
dy_fp8
=
to_float8_CS
(
dy_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
w_fp8
=
to_float8_CS
(
w_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
if
int8_simulation_fp8_tensorwise
:
dy_int8
,
dy_scales
=
dy_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
dy_fp8
.
_scale_inv
w_int8
,
w_scales
=
w_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
w_fp8
.
_scale_inv
else
:
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8
(
dy_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
dy_fp8
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8_opt
(
w_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
w_fp8
.
_scale_inv
,
False
)
# w_int8, w_scales = per_token_quant_fp8_to_int8_v2(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._transpose.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# print("dy_scales.shape: ", dy_scales.shape)
# print("w_scales.shape: ", w_scales.shape)
# print("dy_int8: ", dy_int8)
# print("w_int8: ", w_int8)
# print("w_scales: ", w_scales)
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
dx_int32
=
tex
.
generic_gemm
(
w_int8
,
transa
,
dy_int8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
# dx_int32 = torch._int_mm(dy_int8, w_int8)
# print("dx_int32: ", dx_int32)
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
,
dx
)
else
:
dx
=
channelwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
)
# dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
print
(
"dx: "
,
dx
)
if
tensorwise_int8_check
:
lt_dx
=
tex
.
generic_gemm
(
w_fp8
,
transa
,
dy_fp8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
torch
.
bfloat16
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
True
,
)[
0
]
print
(
"lt_dx: "
,
lt_dx
)
assert_allclose
([
dx
],
[
lt_dx
])
# print("dx_scales.shape: ", dx_scales.shape)
# print("dx_scales: ", dx_scales)
# torch.cuda.synchronize()
# start = time.time()
# for i in range(20):
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# # w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# # w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8_opt(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# # w_int8, w_scales = per_token_quant_fp8_to_int8_v2(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# # w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._transpose.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# dx_int32 = tex.generic_gemm(
# w_int8,
# transa,
# dy_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
# # dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
# torch.cuda.synchronize()
# end = time.time()
# NT
# transa = True
# transb = False
transa
=
False
transb
=
True
dy_bf16
=
(
torch
.
randn
((
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
x_bf16
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
dw
=
(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
bf16_dw
=
torch
.
matmul
(
dy_bf16
.
t
(),
x_bf16
)
print
(
"bf16_dw: "
,
bf16_dw
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
i
in
range
(
20
):
bf16_dw
=
torch
.
matmul
(
dy_bf16
.
t
(),
x_bf16
)
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
# Cast to FP8 and back
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_fp8
=
to_float8_CS
(
dy_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
x_fp8
=
to_float8_CS
(
x_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
if
int8_simulation_fp8_tensorwise
:
dy_int8
,
dy_scales
=
dy_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
dy_fp8
.
_scale_inv
x_int8
,
x_scales
=
x_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
x_fp8
.
_scale_inv
else
:
# dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8_opt
(
dy_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
dy_fp8
.
_scale_inv
,
False
)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8_opt
(
x_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
x_fp8
.
_scale_inv
,
False
)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._transpose.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._transpose.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
dw_int32
=
tex
.
generic_gemm
(
x_int8
,
transa
,
dy_int8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize
(
dy_scales
,
x_scales
,
dw_int32
,
dw
)
else
:
dw
=
channelwise_dequantize_transA
(
dy_scales
,
x_scales
,
dw_int32
)
# dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
print
(
"dw: "
,
dw
)
if
tensorwise_int8_check
:
lt_dw
=
tex
.
generic_gemm
(
x_fp8
,
transa
,
dy_fp8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
torch
.
bfloat16
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
True
,
)[
0
]
print
(
"lt_dw: "
,
lt_dw
)
assert_allclose
([
dw
],
[
lt_dw
])
# torch.cuda.synchronize()
# start = time.time()
# for i in range(20):
# # dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# # x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# # dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# # x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# # dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# # x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8_opt(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# # dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._transpose.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# # x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._transpose.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# dw_int32 = tex.generic_gemm(
# x_int8,
# transa,
# dy_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
# # dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
# torch.cuda.synchronize()
# end = time.time()
# bacth gemm wgrad
m
=
1024
k
=
1024
n
=
1024
b
=
4
transa
=
False
transb
=
True
dy_int8
=
(
torch
.
randn
((
b
,
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
int8
)
x_int8
=
(
torch
.
randn
((
b
,
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
int8
)
int32_dw_list
=
[]
for
i
in
range
(
b
):
int32_dw
=
torch
.
_int_mm
(
dy_int8
[
i
].
t
(),
x_int8
[
i
])
# bf16_dw = torch.matmul(dy_int8[i].t(), x_int8[i])
int32_dw_list
.
append
(
int32_dw
)
batched_int32_dw
=
torch
.
stack
(
int32_dw_list
)
# print("batched_int32_dw.shape: ", batched_int32_dw.shape)
# print("batched_int32_dw: ", batched_int32_dw)
out_dtype
=
torch
.
int32
out
=
torch
.
empty
((
b
,
n
,
k
),
dtype
=
out_dtype
,
device
=
device
)
te_dw
=
tex
.
generic_batchgemm
(
x_int8
.
view
(
-
1
,
x_int8
.
size
(
-
1
)),
transa
,
dy_int8
.
view
(
-
1
,
dy_int8
.
size
(
-
1
)),
transb
,
out
.
view
(
-
1
,
out
.
size
(
-
1
)),
b
,
out_quantizer
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
# print("te_dw.shape: ", te_dw.view(b, -1, te_dw.size(-1)).shape)
# print("te_dw: ", te_dw.view(b, -1, te_dw.size(-1)))
torch
.
testing
.
assert_close
(
te_dw
.
view
(
b
,
-
1
,
te_dw
.
size
(
-
1
)),
batched_int32_dw
,
atol
=
0
,
rtol
=
0
)
# NT
b
=
4
transa
=
False
transb
=
True
dy_bf16
=
[(
torch
.
randn
((
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
b
)]
x_bf16
=
[(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
b
)]
dw_ref
=
[(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
b
)]
dw
=
[(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
b
)]
# Cast to FP8 and back
dy_fp8
=
[
to_float8_CS
(
dy_bf16
[
i
],
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
for
i
in
range
(
b
)]
x_fp8
=
[
to_float8_CS
(
x_bf16
[
i
],
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
for
i
in
range
(
b
)]
if
int8_simulation_fp8_tensorwise
:
dy_int8
,
dy_scales
=
[
dy_fp8
[
i
].
_data
.
view
(
dtype
=
torch
.
int8
)
for
i
in
range
(
b
)],
[
dy_fp8
[
i
].
_scale_inv
for
i
in
range
(
b
)]
x_int8
,
x_scales
=
[
x_fp8
[
i
].
_data
.
view
(
dtype
=
torch
.
int8
)
for
i
in
range
(
b
)],
[
x_fp8
[
i
].
_scale_inv
for
i
in
range
(
b
)]
else
:
dy_int8
,
dy_scales
=
[],
[]
x_int8
,
x_scales
=
[],
[]
assert
False
for
i
in
range
(
b
):
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
dw_int32
=
tex
.
generic_gemm
(
x_int8
[
i
],
transa
,
dy_int8
[
i
],
transb
,
None
,
None
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize
(
dy_scales
[
i
],
x_scales
[
i
],
dw_int32
,
dw_ref
[
i
])
else
:
assert
False
dw_ref_tensor
=
torch
.
stack
(
dw_ref
).
contiguous
().
view
(
-
1
,
dw_ref
[
0
].
size
(
-
1
))
# print("dw_ref_tensor: ", dw_ref_tensor)
torch
.
cuda
.
synchronize
()
dy_int8_tensor
=
torch
.
stack
(
dy_int8
).
contiguous
()
dy_scales_tensor
=
torch
.
stack
(
dy_scales
).
contiguous
()
x_int8_tensor
=
torch
.
stack
(
x_int8
).
contiguous
()
x_scales_tensor
=
torch
.
stack
(
x_scales
).
contiguous
()
dw_tensor
=
torch
.
stack
(
dw
).
contiguous
()
out_dtype
=
torch
.
bfloat16
dw_tensor
=
tex
.
tensorwise_int8_batchgemm
(
x_int8_tensor
.
view
(
-
1
,
x_int8_tensor
.
size
(
-
1
)),
transa
,
dy_int8_tensor
.
view
(
-
1
,
dy_int8_tensor
.
size
(
-
1
)),
transb
,
x_scales_tensor
,
dy_scales_tensor
,
dw_tensor
.
view
(
-
1
,
dw_tensor
.
size
(
-
1
)),
b
,
out_quantizer
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
# print("dw_tensor: ", dw_tensor)
torch
.
testing
.
assert_close
(
dw_ref_tensor
,
dw_tensor
,
atol
=
1e-5
,
rtol
=
1e-5
)
transformer_engine/pytorch/quantization.py
View file @
e4f5325e
...
@@ -15,6 +15,7 @@ from collections import deque
...
@@ -15,6 +15,7 @@ from collections import deque
from
typing
import
Callable
,
List
,
Optional
,
Dict
,
Any
,
Tuple
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Dict
,
Any
,
Tuple
,
Union
import
torch
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
(
from
transformer_engine.common.recipe
import
(
Recipe
,
Recipe
,
...
@@ -27,10 +28,8 @@ from transformer_engine.common.recipe import (
...
@@ -27,10 +28,8 @@ from transformer_engine.common.recipe import (
CustomRecipe
,
CustomRecipe
,
)
)
from
.constants
import
dist_group_type
from
.constants
import
dist_group_type
from
.utils
import
(
get_device_compute_capability
,
is_gfx928
,
is_gfx936
,
is_gfx938
)
from
.utils
import
get_device_compute_capability
from
.jit
import
jit_fuser
from
.jit
import
jit_fuser
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
int8_simulation_fp8
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8"
,
"0"
)))
int8_simulation_fp8
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8"
,
"0"
)))
int8_simulation_fp8_tensorwise
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
,
"0"
)))
int8_simulation_fp8_tensorwise
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
,
"0"
)))
blockwise_fp8_block_len
=
int
(
os
.
getenv
(
"NVTE_BLOCKWISE_FP8_BLOCK_LEN"
,
"128"
))
blockwise_fp8_block_len
=
int
(
os
.
getenv
(
"NVTE_BLOCKWISE_FP8_BLOCK_LEN"
,
"128"
))
...
@@ -45,18 +44,14 @@ __all__ = [
...
@@ -45,18 +44,14 @@ __all__ = [
"get_default_recipe"
,
"get_default_recipe"
,
]
]
if
IS_HIP_EXTENSION
:
from
transformer_engine.pytorch.utils
import
is_K100_AI
,
is_BW
@
functools
.
lru_cache
(
maxsize
=
None
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
check_fp8_support
()
->
Tuple
[
bool
,
str
]:
def
check_fp8_support
()
->
Tuple
[
bool
,
str
]:
"""Return if fp8 support is available"""
"""Return if fp8 support is available"""
if
IS_HIP_EXTENSION
:
if
IS_HIP_EXTENSION
:
if
(
is_K100_AI
()
or
is_BW
())
and
int8_simulation_fp8
:
if
is_gfx938
():
return
True
,
"DCU turn on fp8 simulation with int8"
return
True
,
""
else
:
if
(
is_gfx928
()
or
is_gfx936
())
and
int8_simulation_fp8
and
int8_simulation_fp8_tensorwise
:
return
False
,
"DCU not support fp8 for now"
return
True
,
""
else
:
if
get_device_compute_capability
()
>=
(
9
,
0
):
# hopper and above
if
get_device_compute_capability
()
>=
(
9
,
0
):
# hopper and above
return
True
,
""
return
True
,
""
if
get_device_compute_capability
()
<
(
8
,
9
):
# pre-ada
if
get_device_compute_capability
()
<
(
8
,
9
):
# pre-ada
...
@@ -71,6 +66,8 @@ def check_fp8_support() -> Tuple[bool, str]:
...
@@ -71,6 +66,8 @@ def check_fp8_support() -> Tuple[bool, str]:
@
functools
.
lru_cache
(
maxsize
=
None
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
check_mxfp8_support
()
->
Tuple
[
bool
,
str
]:
def
check_mxfp8_support
()
->
Tuple
[
bool
,
str
]:
"""Return if fp8 support is available"""
"""Return if fp8 support is available"""
if
IS_HIP_EXTENSION
:
return
False
,
"DCU not support mxfp8 for now"
if
get_device_compute_capability
()
>=
(
12
,
0
):
if
get_device_compute_capability
()
>=
(
12
,
0
):
return
False
,
"MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet."
return
False
,
"MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet."
if
get_device_compute_capability
()
>=
(
10
,
0
):
# blackwell and above
if
get_device_compute_capability
()
>=
(
10
,
0
):
# blackwell and above
...
@@ -83,7 +80,6 @@ def check_nvfp4_support() -> Tuple[bool, str]:
...
@@ -83,7 +80,6 @@ def check_nvfp4_support() -> Tuple[bool, str]:
"""Return if nvfp4 support is available"""
"""Return if nvfp4 support is available"""
if
IS_HIP_EXTENSION
:
if
IS_HIP_EXTENSION
:
return
False
,
"NVFP4 is not supported on rocm platform."
return
False
,
"NVFP4 is not supported on rocm platform."
else
:
if
get_device_compute_capability
()
>=
(
10
,
0
):
# blackwell and above
if
get_device_compute_capability
()
>=
(
10
,
0
):
# blackwell and above
return
True
,
""
return
True
,
""
return
False
,
"Device compute capability 10.0 or higher required for NVFP4 execution."
return
False
,
"Device compute capability 10.0 or higher required for NVFP4 execution."
...
@@ -93,9 +89,10 @@ def check_nvfp4_support() -> Tuple[bool, str]:
...
@@ -93,9 +89,10 @@ def check_nvfp4_support() -> Tuple[bool, str]:
def
check_fp8_block_scaling_support
()
->
Tuple
[
bool
,
str
]:
def
check_fp8_block_scaling_support
()
->
Tuple
[
bool
,
str
]:
"""Return if fp8 block scaling support is available"""
"""Return if fp8 block scaling support is available"""
if
IS_HIP_EXTENSION
:
if
IS_HIP_EXTENSION
:
if
is_K100_AI
()
or
is_BW
()
and
int8_simulation_fp8
:
if
is_gfx938
():
return
True
,
""
if
(
is_gfx928
()
or
is_gfx936
())
and
int8_simulation_fp8
:
return
True
,
""
return
True
,
""
else
:
return
False
,
"DCU not support block_scaling fp8 for now"
return
False
,
"DCU not support block_scaling fp8 for now"
if
get_device_compute_capability
()
>=
(
9
,
0
)
and
float
(
torch
.
version
.
cuda
)
>=
12.9
:
if
get_device_compute_capability
()
>=
(
9
,
0
)
and
float
(
torch
.
version
.
cuda
)
>=
12.9
:
return
True
,
""
return
True
,
""
...
...
transformer_engine/pytorch/utils.py
View file @
e4f5325e
...
@@ -11,11 +11,10 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
...
@@ -11,11 +11,10 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
.quantized_tensor
import
Quantizer
from
.torch_version
import
torch_version
from
.torch_version
import
torch_version
from
.quantized_tensor
import
Quantizer
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
__all__
=
[
"get_device_compute_capability"
,
"get_cudnn_version"
,
"is_bf16_available"
]
__all__
=
[
"get_device_compute_capability"
,
"get_cudnn_version"
,
"is_bf16_available"
]
...
@@ -447,20 +446,64 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None:
...
@@ -447,20 +446,64 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None:
)
)
if
IS_HIP_EXTENSION
:
if
IS_HIP_EXTENSION
:
def
is_mi200
():
@
functools
.
lru_cache
(
maxsize
=
None
)
"""check whether this machine is mi200/210/250"""
def
_get_gcn_arch_impl
(
device
:
torch
.
device
)
->
int
:
props
=
torch
.
cuda
.
get_device_properties
(
device
)
import
re
import
re
return
(
re
.
search
(
'AMD Instinct MI2.0'
,
torch
.
cuda
.
get_device_name
(
torch
.
cuda
.
current_device
()))
is
not
None
)
if
re
.
search
(
'gfx906'
,
props
.
gcnArchName
)
is
not
None
:
return
906
if
re
.
search
(
'gfx926'
,
props
.
gcnArchName
)
is
not
None
:
return
926
if
re
.
search
(
'gfx928'
,
props
.
gcnArchName
)
is
not
None
:
return
928
if
re
.
search
(
'gfx936'
,
props
.
gcnArchName
)
is
not
None
:
return
936
if
re
.
search
(
'gfx938'
,
props
.
gcnArchName
)
is
not
None
:
return
938
raise
RuntimeError
(
f
"Unsupported GCN Arch
{
props
.
gcnArchName
}
"
)
def
_get_gcn_arch
()
->
int
:
return
_get_gcn_arch_impl
(
torch
.
cuda
.
current_device
())
def
is_gfx906
()
->
bool
:
"""check whether this machine is gfx906"""
return
_get_gcn_arch
()
==
906
def
is_gfx926
()
->
bool
:
"""check whether this machine is gfx926"""
return
_get_gcn_arch
()
==
926
def
is_gfx928
()
->
bool
:
"""check whether this machine is gfx928"""
return
_get_gcn_arch
()
==
928
def
is_gfx936
()
->
bool
:
"""check whether this machine is gfx928"""
return
_get_gcn_arch
()
==
936
def
is_gfx938
()
->
bool
:
"""check whether this machine is gfx928"""
return
_get_gcn_arch
()
==
938
else
:
def
is_gfx906
()
->
bool
:
"""gfx906 is only available on ROCm"""
return
False
def
is_K100_AI
():
def
is_gfx926
()
->
bool
:
"""check whether this machine is K100_AI"""
"""gfx926 is only available on ROCm"""
import
re
return
False
return
(
re
.
search
(
'K100_AI'
,
torch
.
cuda
.
get_device_name
(
torch
.
cuda
.
current_device
()))
is
not
None
)
def
is_BW
():
def
is_gfx928
()
->
bool
:
"""check whether this machine is BW"""
"""gfx928 is only available on ROCm"""
import
re
return
False
return
(
re
.
search
(
'BW'
,
torch
.
cuda
.
get_device_name
(
torch
.
cuda
.
current_device
()))
is
not
None
)
def
is_gfx936
()
->
bool
:
"""gfx936 is only available on ROCm"""
return
False
def
is_gfx938
()
->
bool
:
"""gfx938 is only available on ROCm"""
return
False
def
assert_dim_for_all_gather
(
def
assert_dim_for_all_gather
(
tensor
:
torch
.
Tensor
,
with_all_gather
:
bool
,
quantizer
:
Quantizer
tensor
:
torch
.
Tensor
,
with_all_gather
:
bool
,
quantizer
:
Quantizer
...
@@ -477,12 +520,8 @@ def is_bf16_compatible() -> bool:
...
@@ -477,12 +520,8 @@ def is_bf16_compatible() -> bool:
check on device compute capability to enforce sm_80 or higher.
check on device compute capability to enforce sm_80 or higher.
"""
"""
if
IS_HIP_EXTENSION
:
if
IS_HIP_EXTENSION
:
# only MI200 and MI300 machines support bf16
# only these arch support bf16
if
get_device_compute_capability
()
>=
(
9
,
4
)
or
is_mi200
()
or
is_K100_AI
()
or
is_BW
():
return
is_gfx928
()
or
is_gfx936
()
or
is_gfx938
()
return
True
else
:
return
False
else
:
return
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
return
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
...
@@ -517,7 +556,6 @@ def is_non_tn_fp8_gemm_supported(is_blockwise: Optional[bool] = False) -> bool:
...
@@ -517,7 +556,6 @@ def is_non_tn_fp8_gemm_supported(is_blockwise: Optional[bool] = False) -> bool:
if
IS_HIP_EXTENSION
:
if
IS_HIP_EXTENSION
:
if
is_blockwise
:
if
is_blockwise
:
return
False
return
False
else
:
return
True
return
True
device_capability
=
torch
.
cuda
.
get_device_capability
()
device_capability
=
torch
.
cuda
.
get_device_capability
()
return
(
10
,
0
)
<=
device_capability
<
(
12
,
0
)
or
device_capability
>=
(
13
,
0
)
return
(
10
,
0
)
<=
device_capability
<
(
12
,
0
)
or
device_capability
>=
(
13
,
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment