Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
a68e5f87
Commit
a68e5f87
authored
Feb 24, 2026
by
wenjh
Browse files
Enable fp8 on nmz
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
99a1c744
Pipeline
#3434
failed with stages
in 0 seconds
Changes
7
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
223 additions
and
851 deletions
+223
-851
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+5
-3
tests/pytorch/distributed/test_numerics.py
tests/pytorch/distributed/test_numerics.py
+16
-1
tests/pytorch/test_float8_blockwise_gemm_exact.py
tests/pytorch/test_float8_blockwise_gemm_exact.py
+3
-3
tests/pytorch/test_float8_blockwise_scaling_exact.py
tests/pytorch/test_float8_blockwise_scaling_exact.py
+116
-1
tests/pytorch/test_int8_channelwise_gemm_exact.py
tests/pytorch/test_int8_channelwise_gemm_exact.py
+0
-796
transformer_engine/pytorch/quantization.py
transformer_engine/pytorch/quantization.py
+20
-22
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+63
-25
No files found.
qa/L0_pytorch_unittest/test.sh
View file @
a68e5f87
...
...
@@ -36,10 +36,12 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PA
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8tensor.xml
$TE_PATH
/tests/pytorch/test_float8tensor.py
||
test_fail
"test_float8tensor.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8blockwisetensor.xml
$TE_PATH
/tests/pytorch/test_float8blockwisetensor.py
||
test_fail
"test_float8blockwisetensor.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_scaling_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_scaling_exact.py
||
test_fail
"test_float8_blockwise_scaling_exact.py"
NVTE_INT8_SIM_FP8
=
1 python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_gemm_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_gemm_exact.py
||
test_fail
"test_float8_blockwise_gemm_exact.py"
NVTE_INT8_SIM_FP8
=
1 python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_gemm_exact_int8.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_gemm_exact.py
||
test_fail
"test_float8_blockwise_gemm_exact.py_int8"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_gemm_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_gemm_exact.py
||
test_fail
"test_float8_blockwise_gemm_exact.py"
# channelwise int8 test
NVTE_INT8_SIM_FP8
=
1 python3
-m
pytest
-v
-s
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_current_scaling_exact.xml
$TE_PATH
/tests/pytorch/test_float8_current_scaling_exact.py
NVTE_INT8_SIM_FP8
=
1
NVTE_INT8_SIM_FP8_TENSORWISE
=
1 python3
-m
pytest
-v
-s
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_current_scaling_exact.xml
$TE_PATH
/tests/pytorch/test_float8_current_scaling_exact.py
python3
-m
pytest
-v
-s
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_current_scaling_exact.xml
$TE_PATH
/tests/pytorch/test_float8_current_scaling_exact.py
||
test_fail
"test_float8_current_scaling_exact.py"
NVTE_INT8_SIM_FP8
=
1 python3
-m
pytest
-v
-s
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_current_scaling_exact_int8.xml
$TE_PATH
/tests/pytorch/test_float8_current_scaling_exact.py
||
test_fail
"test_float8_current_scaling_exact.py_int8"
NVTE_INT8_SIM_FP8
=
1
NVTE_INT8_SIM_FP8_TENSORWISE
=
1 python3
-m
pytest
-v
-s
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_current_scaling_exact_int8_tensorwise.xml
$TE_PATH
/tests/pytorch/test_float8_current_scaling_exact.py
||
test_fail
"test_float8_current_scaling_exact.py_int8_tensorwise"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_gqa.xml
$TE_PATH
/tests/pytorch/test_gqa.py
||
test_fail
"test_gqa.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_optimizer.xml
$TE_PATH
/tests/pytorch/test_fused_optimizer.py
||
test_fail
"test_fused_optimizer.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_multi_tensor.xml
$TE_PATH
/tests/pytorch/test_multi_tensor.py
||
test_fail
"test_multi_tensor.py"
...
...
tests/pytorch/distributed/test_numerics.py
View file @
a68e5f87
...
...
@@ -51,11 +51,26 @@ def _run_test(quantization):
all_boolean
=
[
True
,
False
]
@
pytest
.
mark
.
parametrize
(
"quantization"
,
[
None
,
"fp8"
,
"mxfp8"
,
"fp8_cs"
,
"fp8_block_scaling"
,
"nvfp4"
]
)
def
test_distributed
(
quantization
):
if
quantization
==
"fp8"
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
quantization
==
"fp8_cs"
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
quantization
==
"mxfp8"
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
quantization
==
"fp8_block_scaling"
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
quantization
==
"nvfp4"
and
not
nvfp4_available
:
pytest
.
skip
(
reason_for_no_nvfp4
)
_run_test
(
quantization
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
[
None
,
"fp8"
,
"mxfp8"
,
"fp8_cs"
,
"fp8_block_scaling"
,
"nvfp4"
]
)
def
test_int8_distributed
(
quantization
):
if
quantization
==
"fp8"
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
quantization
==
"fp8_cs"
and
not
fp8_available
:
...
...
tests/pytorch/test_float8_blockwise_gemm_exact.py
View file @
a68e5f87
...
...
@@ -47,7 +47,7 @@ def cublas_gemm_fp8_blockwise_case(
atol
:
float
=
0.0
,
rtol
:
float
=
0.0
):
if
IS_HIP_EXTENSION
and
int8_simulation_fp8
:
if
IS_HIP_EXTENSION
:
if
use_bias
or
use_gelu
:
pytest
.
skip
(
"Bias and GELU not supported in int8 simulation mode on ROCm."
)
if
not
((
not
x_columnwise
and
not
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
)
or
(
not
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
)
or
(
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
is_w_1d_scaled
)):
...
...
@@ -249,7 +249,7 @@ def cublas_gemm_test_constraint_enforced(
expected_err_cls
=
RuntimeError
):
if
IS_HIP_EXTENSION
:
pytest
.
skip
(
"ROCm does not support cuBLAS
GEMM
. No need to test constraint enforcement."
)
pytest
.
skip
(
"ROCm does not support cuBLAS
blockwise FP8 gemm
. No need to test constraint enforcement."
)
if
not
fp8_blockwise_gemm_supported
():
pytest
.
skip
(
"CUDA version does not support blockwise FP8 gemm."
)
# Setup device and random seed
...
...
tests/pytorch/test_float8_blockwise_scaling_exact.py
View file @
a68e5f87
...
...
@@ -9,7 +9,7 @@ import pathlib
import
pytest
import
torch
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
)
from
transformer_engine.common.recipe
import
Float8BlockScaling
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch
import
(
...
...
@@ -507,6 +507,9 @@ def test_quantization_block_tiling_extrema_versus_reference(
rtol
=
0.0
,
)
def
fp8_blockwise_scaling_supported
()
->
bool
:
supported
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
return
supported
# FP8 per tesnor current scaling
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
...
...
@@ -541,12 +544,65 @@ class TestFP8BlockScalingRecipeLinear(TestFP8RecipeLinearBase):
out_size
,
dtype
,
use_bias
=
True
,
):
if
not
fp8_blockwise_scaling_supported
():
pytest
.
skip
(
"CUDA version does not support blockwise FP8."
)
fp8_zero_tolerance_tensor_dumps_recipe2
=
None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map
=
self
.
_check_golden_tensor_dumps
(
TENSOR_DUMP_DIR
,
recipe2
,
(
batch_size
,
hidden_size
,
out_size
),
dtype
,
use_bias
)
if
tensor_map
is
not
None
:
fp8_zero_tolerance_tensor_dumps_recipe2
=
tensor_map
self
.
compare_recipe
(
recipe1
,
recipe2
,
batch_size
,
hidden_size
,
out_size
,
use_bias
,
seed
=
torch
.
initial_seed
(),
dtype
=
dtype
,
y_error
=
0.5
,
dgrad_error
=
1
,
wgrad_error
=
1
,
bgrad_error
=
0.5
,
recipe1_golden_tensors
=
None
,
recipe2_golden_tensors
=
fp8_zero_tolerance_tensor_dumps_recipe2
,
)
@
pytest
.
mark
.
parametrize
(
"batch_size, hidden_size, out_size"
,
[
(
16
,
256
,
128
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
],
ids
=
[
"bf16"
])
@
pytest
.
mark
.
parametrize
(
"recipe1, recipe2"
,
[
(
GetRecipes
.
none
,
GetRecipes
.
fp8_blockwise
),
],
)
def
test_int8_current_scaling_with_linear_module
(
self
,
recipe1
,
recipe2
,
batch_size
,
hidden_size
,
out_size
,
dtype
,
use_bias
=
True
,
):
if
IS_HIP_EXTENSION
:
import
importlib
ori_int8_sim_fp8
=
os
.
environ
.
get
(
"NVTE_INT8_SIM_FP8"
,
None
)
os
.
environ
[
"NVTE_INT8_SIM_FP8"
]
=
"1"
importlib
.
reload
(
te
.
pytorch
.
fp8
)
if
not
fp8_blockwise_scaling_supported
():
pytest
.
skip
(
"CUDA version does not support blockwise FP8."
)
fp8_zero_tolerance_tensor_dumps_recipe2
=
None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
...
...
@@ -612,12 +668,71 @@ class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase)
out_size
,
dtype
,
use_bias
=
True
,
):
if
not
fp8_blockwise_scaling_supported
():
pytest
.
skip
(
"CUDA version does not support blockwise FP8."
)
fp8_zero_tolerance_tensor_dumps_recipe2
=
None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map
=
self
.
_check_golden_tensor_dumps
(
TENSOR_DUMP_DIR
,
recipe2
,
(
batch_size
,
hidden_size
,
out_size
),
dtype
,
use_bias
,
"LayerNorm"
,
)
if
tensor_map
is
not
None
:
fp8_zero_tolerance_tensor_dumps_recipe2
=
tensor_map
self
.
compare_recipe
(
recipe1
,
recipe2
,
batch_size
,
hidden_size
,
out_size
,
use_bias
,
seed
=
torch
.
initial_seed
(),
dtype
=
dtype
,
y_error
=
0.5
,
ln_out_error
=
0.5
,
dgrad_error
=
1.6
,
wgrad_error
=
1
,
bgrad_error
=
0.5
,
recipe1_golden_tensors
=
None
,
recipe2_golden_tensors
=
fp8_zero_tolerance_tensor_dumps_recipe2
,
)
@
pytest
.
mark
.
parametrize
(
"batch_size, hidden_size, out_size"
,
[
(
16
,
256
,
128
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
],
ids
=
[
"bf16"
])
@
pytest
.
mark
.
parametrize
(
"recipe1, recipe2"
,
[
(
GetRecipes
.
none
,
GetRecipes
.
fp8_blockwise
),
],
)
def
test_int8_current_scaling_with_layernorm_linear_module
(
self
,
recipe1
,
recipe2
,
batch_size
,
hidden_size
,
out_size
,
dtype
,
use_bias
=
True
,
):
if
IS_HIP_EXTENSION
:
import
importlib
ori_int8_sim_fp8
=
os
.
environ
.
get
(
"NVTE_INT8_SIM_FP8"
,
None
)
os
.
environ
[
"NVTE_INT8_SIM_FP8"
]
=
"1"
importlib
.
reload
(
te
.
pytorch
.
fp8
)
if
not
fp8_blockwise_scaling_supported
():
pytest
.
skip
(
"CUDA version does not support blockwise FP8."
)
fp8_zero_tolerance_tensor_dumps_recipe2
=
None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
...
...
tests/pytorch/test_int8_channelwise_gemm_exact.py
deleted
100644 → 0
View file @
99a1c744
# NVTE_INT8_SIM_FP8_TENSORWISE=1 python3 test_int8_channelwise_gemm_exact.py
from
collections.abc
import
Iterable
import
io
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
,
Optional
import
pytest
import
torch
import
transformer_engine
as
te
import
transformer_engine.common.recipe
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
Float8Quantizer
,
Float8Tensor
,
Float8CurrentScalingQuantizer
,
)
from
transformer_engine.pytorch.constants
import
TE_DType
,
TE_DType_To_Torch
from
transformer_engine.pytorch.utils
import
is_non_tn_fp8_gemm_supported
import
transformer_engine_torch
as
tex
from
references.ref_per_tensor_cs
import
ref_per_tensor_cs_cast
from
transformer_engine.pytorch.triton.per_token_group_quant
import
(
per_token_quant_fp8_to_int8
,
per_token_quant_fp8_to_int8_opt
,
channelwise_dequantize
,
channelwise_dequantize_transA
,
channelwise_dequantize_transB
,
tensorwise_dequantize
)
import
time
import
os
int8_simulation_fp8_tensorwise
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
,
"0"
)))
tensorwise_int8_check
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE_CHECK"
,
"0"
)))
def
dtype_tols
(
dtype
:
torch
.
dtype
)
->
Dict
[
str
,
float
]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
if
dtype
==
torch
.
float32
:
return
dict
(
rtol
=
1.3e-6
,
atol
=
1e-5
)
if
dtype
==
torch
.
float16
:
return
dict
(
rtol
=
1e-3
,
atol
=
1e-5
)
if
dtype
==
torch
.
bfloat16
:
return
dict
(
rtol
=
1.6e-2
,
atol
=
1e-5
)
raise
ValueError
(
f
"Unsuppored dtype (
{
dtype
}
)"
)
def
assert_allclose
(
l1
:
List
[
torch
.
Tensor
],
l2
:
List
[
torch
.
Tensor
],
atol
:
float
=
None
,
rtol
:
float
=
None
)
->
bool
:
"""Ensures two lists are equal."""
assert
len
(
l1
)
==
len
(
l2
),
"Unequal number of outputs."
for
i
,
(
t1
,
t2
)
in
enumerate
(
zip
(
l1
,
l2
)):
tols
=
dtype_tols
(
t2
.
dtype
)
if
rtol
is
not
None
:
tols
[
"rtol"
]
=
rtol
if
atol
is
not
None
:
tols
[
"atol"
]
=
atol
result
=
torch
.
allclose
(
t1
,
t2
,
**
tols
)
if
not
result
:
diff
=
torch
.
abs
(
t1
-
t2
)
tol
=
tols
[
"atol"
]
+
(
tols
[
"rtol"
]
*
torch
.
abs
(
t2
))
exceed_mask
=
diff
>
tol
if
exceed_mask
.
any
():
indices
=
torch
.
nonzero
(
exceed_mask
,
as_tuple
=
True
)
max_diff
=
diff
[
exceed_mask
].
max
()
max_idx
=
(
diff
[
exceed_mask
]
==
max_diff
).
nonzero
(
as_tuple
=
True
)[
0
][
0
]
max_location
=
[
idx
[
max_idx
].
item
()
for
idx
in
indices
]
msg
=
(
f
"Outputs not close enough in tensor at idx=
{
i
}
. "
f
"Maximum difference at location
{
max_location
}
"
f
"with
{
t1
[
exceed_mask
][
max_idx
].
item
()
}
vs
{
t2
[
exceed_mask
][
max_idx
].
item
()
}
"
f
"(diff
{
max_diff
.
item
()
}
)."
)
raise
AssertionError
(
msg
)
# TN
m
=
4096
k
=
4096
n
=
6144
seed
=
4096
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
device
=
"cuda"
out_dtype
=
torch
.
int32
# Allocate cuBLAS workspace
workspace_size
=
128
workspace
=
torch
.
empty
(
128
,
dtype
=
torch
.
uint8
,
device
=
device
)
out_quantizer
=
None
accumulate
=
False
use_gelu
=
False
use_bias
=
False
bias
=
None
use_grad
=
False
assert
not
(
use_gelu
and
use_bias
),
"Bias and GELU not supported by GEMM"
aux_tensor
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
device
)
if
use_gelu
else
None
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
device
)
if
accumulate
else
None
bias_dtype
=
TE_DType
[
torch
.
bfloat16
if
bias
is
None
else
bias
.
dtype
]
use_split_accumulator
=
False
# bf16 to int8
# transa = True
# transb = False
# x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
# w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
# bf16_out = torch.matmul(x_bf16, w_bf16.t())
# x_int8, x_scales = per_token_quant_int8(x_bf16)
# w_int8, w_scales = per_token_quant_int8(w_bf16)
# # print("x_int8: ", x_int8)
# # print("w_int8: ", w_int8)
# # cuBLAS GEMM
# # return type is out, bias_grad, gelu_input, extra_output
# # We are just capturing out.
# y_int32 = tex.generic_gemm(
# w_int8,
# transa,
# x_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# # y_int32 = torch._int_mm(x_int8, w_int8.t())
# # print("y_int32: ", y_int32)
# output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
# # print("out_scales.shape: ", out_scales.shape)
# # print("out_scales: ", out_scales)
# # print("bf16_out: ", bf16_out)
# # print("output: ", output)
# # NN
# transa = False
# transb = False
# dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
# w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
# bf16_dx = torch.matmul(dy_bf16, w_bf16)
# dy_int8, dy_scales = per_token_quant_int8(dy_bf16)
# w_int8, w_scales = per_token_quant_int8_v2(w_bf16)
# # print("dy_scales.shape: ", dy_scales.shape)
# # print("w_scales.shape: ", w_scales.shape)
# # print("dy_int8: ", dy_int8)
# # print("w_int8: ", w_int8)
# # cuBLAS GEMM
# # return type is out, bias_grad, gelu_input, extra_output
# # We are just capturing out.
# dx_int32 = tex.generic_gemm(
# w_int8,
# transa,
# dy_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# # dx_int32 = torch._int_mm(dy_int8, w_int8)
# # print("dx_int32: ", dx_int32)
# dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
# # print("dx_scales.shape: ", dx_scales.shape)
# # print("dx_scales: ", dx_scales)
# # print("bf16_dx: ", bf16_dx)
# # print("dx: ", dx)
# # NT
# transa = False
# transb = True
# dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
# x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
# bf16_dw = torch.matmul(dy_bf16.t(), x_bf16)
# dy_int8, dy_scales = per_token_quant_int8_v2(dy_bf16)
# x_int8, x_scales = per_token_quant_int8_v2(x_bf16)
# # cuBLAS GEMM
# # return type is out, bias_grad, gelu_input, extra_output
# # We are just capturing out.
# dw_int32 = tex.generic_gemm(
# x_int8,
# transa,
# dy_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
# # print("bf16_dw: ", bf16_dw)
# # print("dw: ", dw)
# fp8 to int8
quantizer_e5m2
=
Float8CurrentScalingQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
,
device
=
"cuda"
,
force_pow_2_scales
=
False
,
amax_epsilon
=
0.0
,
)
quantizer_e4m3
=
Float8CurrentScalingQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
,
force_pow_2_scales
=
False
,
amax_epsilon
=
0.0
,
)
# current scaling
def
to_float8_CS
(
tensor
:
torch
.
Tensor
,
fp8_dtype
:
tex
.
DType
=
tex
.
DType
.
kFloat8E5M2
,
return_transpose
:
bool
=
False
,
force_pow_2_scales
:
bool
=
False
,
amax_epsilon
:
float
=
0.0
,
)
->
Float8Tensor
:
"""Cast tensor to FP8"""
quantizer
=
quantizer_e5m2
if
fp8_dtype
==
tex
.
DType
.
kFloat8E5M2
else
quantizer_e4m3
if
return_transpose
:
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
else
:
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
return
quantizer
(
tensor
)
# TN
transa
=
True
transb
=
False
x_bf16
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
w_bf16
=
(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
output
=
(
torch
.
randn
((
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
20
):
bf16_out
=
torch
.
matmul
(
x_bf16
,
w_bf16
.
t
())
print
(
"bf16_out: "
,
bf16_out
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
i
in
range
(
20
):
bf16_out
=
torch
.
matmul
(
x_bf16
,
w_bf16
.
t
())
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
# x_int8, x_scales = per_token_quant_int8(x_bf16)
# w_int8, w_scales = per_token_quant_int8(w_bf16)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
# Cast to FP8 and back
x_fp8
=
to_float8_CS
(
x_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)
w_fp8
=
to_float8_CS
(
w_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)
# print("x_fp8: ", x_fp8._data.view(dtype=torch.float8_e4m3fn))
# print("w_fp8: ", w_fp8._data.view(dtype=torch.float8_e4m3fn))
if
int8_simulation_fp8_tensorwise
:
x_int8
,
x_scales
=
x_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
x_fp8
.
_scale_inv
w_int8
,
w_scales
=
w_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
w_fp8
.
_scale_inv
else
:
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
x_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
x_fp8
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
w_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
w_fp8
.
_scale_inv
,
False
)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
y_int32
=
tex
.
generic_gemm
(
w_int8
,
transa
,
x_int8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
# y_int32 = torch._int_mm(x_int8, w_int8.t())
# print("y_int32: ", y_int32)
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize
(
x_scales
,
w_scales
,
y_int32
,
output
)
else
:
output
=
channelwise_dequantize_transB
(
x_scales
,
w_scales
,
y_int32
)
print
(
"output: "
,
output
)
if
tensorwise_int8_check
:
lt_output
=
tex
.
generic_gemm
(
w_fp8
,
transa
,
x_fp8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
torch
.
bfloat16
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
True
,
)[
0
]
print
(
"lt_output: "
,
lt_output
)
assert_allclose
([
output
],
[
lt_output
])
# print("out_scales.shape: ", out_scales.shape)
# print("out_scales: ", out_scales)
# torch.cuda.synchronize()
# start = time.time()
# for i in range(20):
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
# # w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
# x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e4m3fn), x_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e4m3fn), w_fp8._scale_inv, False)
# y_int32 = tex.generic_gemm(
# w_int8,
# transa,
# x_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
# torch.cuda.synchronize()
# end = time.time()
# NN
# transa = True
transa
=
False
transb
=
False
dy_bf16
=
(
torch
.
randn
((
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
w_bf16
=
(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
dx
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
bf16_dx
=
torch
.
matmul
(
dy_bf16
,
w_bf16
)
print
(
"bf16_dx: "
,
bf16_dx
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
i
in
range
(
20
):
bf16_dx
=
torch
.
matmul
(
dy_bf16
,
w_bf16
)
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
# Cast to FP8 and back
dy_fp8
=
to_float8_CS
(
dy_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
w_fp8
=
to_float8_CS
(
w_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
if
int8_simulation_fp8_tensorwise
:
dy_int8
,
dy_scales
=
dy_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
dy_fp8
.
_scale_inv
w_int8
,
w_scales
=
w_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
w_fp8
.
_scale_inv
else
:
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8
(
dy_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
dy_fp8
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8_opt
(
w_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
w_fp8
.
_scale_inv
,
False
)
# w_int8, w_scales = per_token_quant_fp8_to_int8_v2(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._transpose.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# print("dy_scales.shape: ", dy_scales.shape)
# print("w_scales.shape: ", w_scales.shape)
# print("dy_int8: ", dy_int8)
# print("w_int8: ", w_int8)
# print("w_scales: ", w_scales)
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
dx_int32
=
tex
.
generic_gemm
(
w_int8
,
transa
,
dy_int8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
# dx_int32 = torch._int_mm(dy_int8, w_int8)
# print("dx_int32: ", dx_int32)
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
,
dx
)
else
:
dx
=
channelwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
)
# dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
print
(
"dx: "
,
dx
)
if
tensorwise_int8_check
:
lt_dx
=
tex
.
generic_gemm
(
w_fp8
,
transa
,
dy_fp8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
torch
.
bfloat16
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
True
,
)[
0
]
print
(
"lt_dx: "
,
lt_dx
)
assert_allclose
([
dx
],
[
lt_dx
])
# print("dx_scales.shape: ", dx_scales.shape)
# print("dx_scales: ", dx_scales)
# torch.cuda.synchronize()
# start = time.time()
# for i in range(20):
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# # w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# # w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8_opt(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# # w_int8, w_scales = per_token_quant_fp8_to_int8_v2(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# # w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._transpose.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# dx_int32 = tex.generic_gemm(
# w_int8,
# transa,
# dy_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
# # dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
# torch.cuda.synchronize()
# end = time.time()
# NT
# transa = True
# transb = False
transa
=
False
transb
=
True
dy_bf16
=
(
torch
.
randn
((
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
x_bf16
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
dw
=
(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
bf16_dw
=
torch
.
matmul
(
dy_bf16
.
t
(),
x_bf16
)
print
(
"bf16_dw: "
,
bf16_dw
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
i
in
range
(
20
):
bf16_dw
=
torch
.
matmul
(
dy_bf16
.
t
(),
x_bf16
)
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
# Cast to FP8 and back
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_fp8
=
to_float8_CS
(
dy_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
x_fp8
=
to_float8_CS
(
x_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
if
int8_simulation_fp8_tensorwise
:
dy_int8
,
dy_scales
=
dy_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
dy_fp8
.
_scale_inv
x_int8
,
x_scales
=
x_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
x_fp8
.
_scale_inv
else
:
# dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8_opt
(
dy_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
dy_fp8
.
_scale_inv
,
False
)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8_opt
(
x_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
x_fp8
.
_scale_inv
,
False
)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._transpose.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._transpose.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
dw_int32
=
tex
.
generic_gemm
(
x_int8
,
transa
,
dy_int8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize
(
dy_scales
,
x_scales
,
dw_int32
,
dw
)
else
:
dw
=
channelwise_dequantize_transA
(
dy_scales
,
x_scales
,
dw_int32
)
# dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
print
(
"dw: "
,
dw
)
if
tensorwise_int8_check
:
lt_dw
=
tex
.
generic_gemm
(
x_fp8
,
transa
,
dy_fp8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
torch
.
bfloat16
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
True
,
)[
0
]
print
(
"lt_dw: "
,
lt_dw
)
assert_allclose
([
dw
],
[
lt_dw
])
# torch.cuda.synchronize()
# start = time.time()
# for i in range(20):
# # dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# # x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# # dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# # x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# # dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# # x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8_opt(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# # dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._transpose.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# # x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._transpose.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# dw_int32 = tex.generic_gemm(
# x_int8,
# transa,
# dy_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
# # dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
# torch.cuda.synchronize()
# end = time.time()
# bacth gemm wgrad
m
=
1024
k
=
1024
n
=
1024
b
=
4
transa
=
False
transb
=
True
dy_int8
=
(
torch
.
randn
((
b
,
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
int8
)
x_int8
=
(
torch
.
randn
((
b
,
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
int8
)
int32_dw_list
=
[]
for
i
in
range
(
b
):
int32_dw
=
torch
.
_int_mm
(
dy_int8
[
i
].
t
(),
x_int8
[
i
])
# bf16_dw = torch.matmul(dy_int8[i].t(), x_int8[i])
int32_dw_list
.
append
(
int32_dw
)
batched_int32_dw
=
torch
.
stack
(
int32_dw_list
)
# print("batched_int32_dw.shape: ", batched_int32_dw.shape)
# print("batched_int32_dw: ", batched_int32_dw)
out_dtype
=
torch
.
int32
out
=
torch
.
empty
((
b
,
n
,
k
),
dtype
=
out_dtype
,
device
=
device
)
te_dw
=
tex
.
generic_batchgemm
(
x_int8
.
view
(
-
1
,
x_int8
.
size
(
-
1
)),
transa
,
dy_int8
.
view
(
-
1
,
dy_int8
.
size
(
-
1
)),
transb
,
out
.
view
(
-
1
,
out
.
size
(
-
1
)),
b
,
out_quantizer
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
# print("te_dw.shape: ", te_dw.view(b, -1, te_dw.size(-1)).shape)
# print("te_dw: ", te_dw.view(b, -1, te_dw.size(-1)))
torch
.
testing
.
assert_close
(
te_dw
.
view
(
b
,
-
1
,
te_dw
.
size
(
-
1
)),
batched_int32_dw
,
atol
=
0
,
rtol
=
0
)
# NT
b
=
4
transa
=
False
transb
=
True
dy_bf16
=
[(
torch
.
randn
((
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
b
)]
x_bf16
=
[(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
b
)]
dw_ref
=
[(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
b
)]
dw
=
[(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
b
)]
# Cast to FP8 and back
dy_fp8
=
[
to_float8_CS
(
dy_bf16
[
i
],
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
for
i
in
range
(
b
)]
x_fp8
=
[
to_float8_CS
(
x_bf16
[
i
],
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
for
i
in
range
(
b
)]
if
int8_simulation_fp8_tensorwise
:
dy_int8
,
dy_scales
=
[
dy_fp8
[
i
].
_data
.
view
(
dtype
=
torch
.
int8
)
for
i
in
range
(
b
)],
[
dy_fp8
[
i
].
_scale_inv
for
i
in
range
(
b
)]
x_int8
,
x_scales
=
[
x_fp8
[
i
].
_data
.
view
(
dtype
=
torch
.
int8
)
for
i
in
range
(
b
)],
[
x_fp8
[
i
].
_scale_inv
for
i
in
range
(
b
)]
else
:
dy_int8
,
dy_scales
=
[],
[]
x_int8
,
x_scales
=
[],
[]
assert
False
for
i
in
range
(
b
):
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
dw_int32
=
tex
.
generic_gemm
(
x_int8
[
i
],
transa
,
dy_int8
[
i
],
transb
,
None
,
None
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize
(
dy_scales
[
i
],
x_scales
[
i
],
dw_int32
,
dw_ref
[
i
])
else
:
assert
False
dw_ref_tensor
=
torch
.
stack
(
dw_ref
).
contiguous
().
view
(
-
1
,
dw_ref
[
0
].
size
(
-
1
))
# print("dw_ref_tensor: ", dw_ref_tensor)
torch
.
cuda
.
synchronize
()
dy_int8_tensor
=
torch
.
stack
(
dy_int8
).
contiguous
()
dy_scales_tensor
=
torch
.
stack
(
dy_scales
).
contiguous
()
x_int8_tensor
=
torch
.
stack
(
x_int8
).
contiguous
()
x_scales_tensor
=
torch
.
stack
(
x_scales
).
contiguous
()
dw_tensor
=
torch
.
stack
(
dw
).
contiguous
()
out_dtype
=
torch
.
bfloat16
dw_tensor
=
tex
.
tensorwise_int8_batchgemm
(
x_int8_tensor
.
view
(
-
1
,
x_int8_tensor
.
size
(
-
1
)),
transa
,
dy_int8_tensor
.
view
(
-
1
,
dy_int8_tensor
.
size
(
-
1
)),
transb
,
x_scales_tensor
,
dy_scales_tensor
,
dw_tensor
.
view
(
-
1
,
dw_tensor
.
size
(
-
1
)),
b
,
out_quantizer
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
# print("dw_tensor: ", dw_tensor)
torch
.
testing
.
assert_close
(
dw_ref_tensor
,
dw_tensor
,
atol
=
1e-5
,
rtol
=
1e-5
)
transformer_engine/pytorch/quantization.py
View file @
a68e5f87
...
...
@@ -28,7 +28,7 @@ from transformer_engine.common.recipe import (
)
from
.constants
import
dist_group_type
from
.utils
import
get_device_compute_capability
from
.utils
import
(
get_device_compute_capability
,
is_gfx928
,
is_gfx936
,
is_gfx938
)
from
.jit
import
jit_fuser
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
int8_simulation_fp8
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8"
,
"0"
)))
...
...
@@ -45,18 +45,14 @@ __all__ = [
"get_default_recipe"
,
]
if
IS_HIP_EXTENSION
:
from
transformer_engine.pytorch.utils
import
is_K100_AI
,
is_BW
@
functools
.
lru_cache
(
maxsize
=
None
)
def
check_fp8_support
()
->
Tuple
[
bool
,
str
]:
"""Return if fp8 support is available"""
if
IS_HIP_EXTENSION
:
if
(
is_K100_AI
()
or
is_BW
())
and
int8_simulation_fp8
:
return
True
,
"DCU turn on fp8 simulation with int8"
else
:
return
False
,
"DCU not support fp8 for now"
else
:
if
is_gfx938
():
return
True
,
""
if
(
is_gfx928
()
or
is_gfx936
())
and
int8_simulation_fp8
and
int8_simulation_fp8_tensorwise
:
return
True
,
""
if
get_device_compute_capability
()
>=
(
9
,
0
):
# hopper and above
return
True
,
""
if
get_device_compute_capability
()
<
(
8
,
9
):
# pre-ada
...
...
@@ -71,6 +67,8 @@ def check_fp8_support() -> Tuple[bool, str]:
@
functools
.
lru_cache
(
maxsize
=
None
)
def
check_mxfp8_support
()
->
Tuple
[
bool
,
str
]:
"""Return if fp8 support is available"""
if
IS_HIP_EXTENSION
:
return
False
,
"DCU not support mxfp8 for now"
if
get_device_compute_capability
()
>=
(
12
,
0
):
return
False
,
"MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet."
if
get_device_compute_capability
()
>=
(
10
,
0
):
# blackwell and above
...
...
@@ -83,7 +81,6 @@ def check_nvfp4_support() -> Tuple[bool, str]:
"""Return if nvfp4 support is available"""
if
IS_HIP_EXTENSION
:
return
False
,
"NVFP4 is not supported on rocm platform."
else
:
if
get_device_compute_capability
()
>=
(
10
,
0
):
# blackwell and above
return
True
,
""
return
False
,
"Device compute capability 10.0 or higher required for NVFP4 execution."
...
...
@@ -93,9 +90,10 @@ def check_nvfp4_support() -> Tuple[bool, str]:
def
check_fp8_block_scaling_support
()
->
Tuple
[
bool
,
str
]:
"""Return if fp8 block scaling support is available"""
if
IS_HIP_EXTENSION
:
if
is_K100_AI
()
or
is_BW
()
and
int8_simulation_fp8
:
if
is_gfx938
():
return
True
,
""
if
(
is_gfx928
()
or
is_gfx936
())
and
int8_simulation_fp8
:
return
True
,
""
else
:
return
False
,
"DCU not support block_scaling fp8 for now"
if
get_device_compute_capability
()
>=
(
9
,
0
)
and
float
(
torch
.
version
.
cuda
)
>=
12.9
:
return
True
,
""
...
...
transformer_engine/pytorch/utils.py
View file @
a68e5f87
...
...
@@ -10,10 +10,9 @@ import os
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
numpy
as
np
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
.
import
torch_version
from
.quantized_tensor
import
Quantizer
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
__all__
=
[
"get_device_compute_capability"
,
"get_cudnn_version"
,
"is_bf16_available"
]
...
...
@@ -445,20 +444,64 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None:
)
if
IS_HIP_EXTENSION
:
def
is_mi200
():
"""check whether this machine is mi200/210/250"""
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_get_gcn_arch_impl
(
device
:
torch
.
device
)
->
int
:
props
=
torch
.
cuda
.
get_device_properties
(
device
)
import
re
return
(
re
.
search
(
'AMD Instinct MI2.0'
,
torch
.
cuda
.
get_device_name
(
torch
.
cuda
.
current_device
()))
is
not
None
)
if
re
.
search
(
'gfx906'
,
props
.
gcnArchName
)
is
not
None
:
return
906
if
re
.
search
(
'gfx926'
,
props
.
gcnArchName
)
is
not
None
:
return
926
if
re
.
search
(
'gfx928'
,
props
.
gcnArchName
)
is
not
None
:
return
928
if
re
.
search
(
'gfx936'
,
props
.
gcnArchName
)
is
not
None
:
return
936
if
re
.
search
(
'gfx938'
,
props
.
gcnArchName
)
is
not
None
:
return
938
raise
RuntimeError
(
f
"Unsupported GCN Arch
{
props
.
gcnArchName
}
"
)
def
_get_gcn_arch
()
->
int
:
return
_get_gcn_arch_impl
(
torch
.
cuda
.
current_device
())
def
is_gfx906
()
->
bool
:
"""check whether this machine is gfx906"""
return
_get_gcn_arch
()
==
906
def
is_gfx926
()
->
bool
:
"""check whether this machine is gfx926"""
return
_get_gcn_arch
()
==
926
def
is_gfx928
()
->
bool
:
"""check whether this machine is gfx928"""
return
_get_gcn_arch
()
==
928
def
is_gfx936
()
->
bool
:
"""check whether this machine is gfx928"""
return
_get_gcn_arch
()
==
936
def
is_gfx938
()
->
bool
:
"""check whether this machine is gfx928"""
return
_get_gcn_arch
()
==
938
else
:
def
is_gfx906
()
->
bool
:
"""gfx906 is only available on ROCm"""
return
False
def
is_K100_AI
():
"""check whether this machine is K100_AI"""
import
re
return
(
re
.
search
(
'K100_AI'
,
torch
.
cuda
.
get_device_name
(
torch
.
cuda
.
current_device
()))
is
not
None
)
def
is_gfx926
()
->
bool
:
"""gfx926 is only available on ROCm"""
return
False
def
is_BW
():
"""check whether this machine is BW"""
import
re
return
(
re
.
search
(
'BW'
,
torch
.
cuda
.
get_device_name
(
torch
.
cuda
.
current_device
()))
is
not
None
)
def
is_gfx928
()
->
bool
:
"""gfx928 is only available on ROCm"""
return
False
def
is_gfx936
()
->
bool
:
"""gfx936 is only available on ROCm"""
return
False
def
is_gfx938
()
->
bool
:
"""gfx938 is only available on ROCm"""
return
False
def
assert_dim_for_all_gather
(
tensor
:
torch
.
Tensor
,
with_all_gather
:
bool
,
quantizer
:
Quantizer
...
...
@@ -475,12 +518,8 @@ def is_bf16_compatible() -> bool:
check on device compute capability to enforce sm_80 or higher.
"""
if
IS_HIP_EXTENSION
:
# only MI200 and MI300 machines support bf16
if
get_device_compute_capability
()
>=
(
9
,
4
)
or
is_mi200
()
or
is_K100_AI
()
or
is_BW
():
return
True
else
:
return
False
else
:
# only these arch support bf16
return
is_gfx928
()
or
is_gfx936
()
or
is_gfx938
()
return
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
...
...
@@ -515,7 +554,6 @@ def is_non_tn_fp8_gemm_supported(is_blockwise: Optional[bool] = False) -> bool:
if
IS_HIP_EXTENSION
:
if
is_blockwise
:
return
False
else
:
return
True
device_capability
=
torch
.
cuda
.
get_device_capability
()
return
(
10
,
0
)
<=
device_capability
<
(
12
,
0
)
or
device_capability
>=
(
13
,
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment