Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
063ef88d
Commit
063ef88d
authored
Dec 03, 2025
by
wenjh
Browse files
Merge nv main up to v2.10.0.dev0
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
91670b05
5624dbb4
Changes
298
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1038 additions
and
129 deletions
+1038
-129
examples/jax/mnist/test_single_gpu_mnist.py
examples/jax/mnist/test_single_gpu_mnist.py
+7
-7
examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py
examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py
+2
-2
examples/pytorch/fsdp/README.md
examples/pytorch/fsdp/README.md
+1
-1
examples/pytorch/fsdp/fsdp.py
examples/pytorch/fsdp/fsdp.py
+4
-4
examples/pytorch/mnist/main.py
examples/pytorch/mnist/main.py
+3
-3
pyproject.toml
pyproject.toml
+10
-0
qa/L0_jax_distributed_unittest/test.sh
qa/L0_jax_distributed_unittest/test.sh
+4
-0
qa/L0_jax_unittest/test.sh
qa/L0_jax_unittest/test.sh
+1
-1
qa/L0_pytorch_debug_unittest/test.sh
qa/L0_pytorch_debug_unittest/test.sh
+10
-9
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+1
-0
qa/L1_jax_distributed_unittest/test.sh
qa/L1_jax_distributed_unittest/test.sh
+1
-1
qa/L1_pytorch_distributed_unittest/test.sh
qa/L1_pytorch_distributed_unittest/test.sh
+3
-2
qa/L1_pytorch_onnx_unittest/test.sh
qa/L1_pytorch_onnx_unittest/test.sh
+5
-3
setup.py
setup.py
+2
-1
tests/cpp/operator/CMakeLists.txt
tests/cpp/operator/CMakeLists.txt
+7
-0
tests/cpp/operator/test_cast_float8blockwise.cu
tests/cpp/operator/test_cast_float8blockwise.cu
+12
-0
tests/cpp/operator/test_cast_mxfp8.cu
tests/cpp/operator/test_cast_mxfp8.cu
+22
-20
tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
+28
-26
tests/cpp/operator/test_cast_nvfp4_transpose.cu
tests/cpp/operator/test_cast_nvfp4_transpose.cu
+741
-0
tests/cpp/test_common.cu
tests/cpp/test_common.cu
+174
-49
No files found.
examples/jax/mnist/test_single_gpu_mnist.py
View file @
063ef88d
...
...
@@ -18,11 +18,11 @@ from flax.training import train_state
import
transformer_engine.jax
as
te
import
transformer_engine.jax.flax
as
te_flax
from
transformer_engine.jax.quantize
import
is_
fp8_available
,
ScalingMode
from
transformer_engine.jax.quantize
import
is_
scaling_mode_supported
,
ScalingMode
DIR
=
str
(
Path
(
__file__
).
resolve
().
parents
[
1
])
sys
.
path
.
append
(
str
(
DIR
))
from
encoder.common
import
is_bf16_supported
,
get_
fp8
_recipe_from_name_string
from
encoder.common
import
is_bf16_supported
,
get_
quantization
_recipe_from_name_string
IMAGE_H
=
28
IMAGE_W
=
28
...
...
@@ -189,12 +189,12 @@ def train_and_evaluate(args):
label_shape
=
[
args
.
batch_size
]
if
args
.
use_fp8
:
fp8_recipe
=
get_
fp8
_recipe_from_name_string
(
args
.
fp8_recipe
)
fp8_recipe
=
get_
quantization
_recipe_from_name_string
(
args
.
fp8_recipe
)
else
:
fp8_recipe
=
None
with
te
.
fp8_
autocast
(
enabled
=
args
.
use_fp8
,
fp8_
recipe
=
fp8_recipe
,
mesh_resource
=
te
.
sharding
.
MeshResource
()
with
te
.
autocast
(
enabled
=
args
.
use_fp8
,
recipe
=
fp8_recipe
,
mesh_resource
=
te
.
sharding
.
MeshResource
()
):
cnn
=
Net
(
args
.
use_te
)
var_collect
=
cnn
.
init
(
init_rngs
,
jnp
.
empty
(
input_shape
,
dtype
=
jnp
.
bfloat16
))
...
...
@@ -308,8 +308,8 @@ def mnist_parser(args):
class
TestMNIST
(
unittest
.
TestCase
):
"""MNIST unittests"""
is_fp8_supported
,
fp8_reason
=
is_
fp8_available
(
ScalingMode
.
DELAYED_TENSOR_SCALING
)
is_mxfp8_supported
,
mxfp8_reason
=
is_
fp8_available
(
ScalingMode
.
MXFP8_1D_SCALING
)
is_fp8_supported
,
fp8_reason
=
is_
scaling_mode_supported
(
ScalingMode
.
DELAYED_TENSOR_SCALING
)
is_mxfp8_supported
,
mxfp8_reason
=
is_
scaling_mode_supported
(
ScalingMode
.
MXFP8_1D_SCALING
)
@
classmethod
def
setUpClass
(
cls
):
...
...
examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py
View file @
063ef88d
...
...
@@ -68,7 +68,7 @@ def _parse_args(argv=None, namespace=None):
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
1234
,
help
=
"RNG seed."
)
parser
.
add_argument
(
"--fp8"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Enables the te.
fp8_
autocast() context."
"--fp8"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Enables the te.autocast() context."
)
parser
.
add_argument
(
"--no-comm-overlap"
,
...
...
@@ -299,7 +299,7 @@ def _train(opts):
dist_print
(
" |-- Forward pass"
,
group
=
tp_group
,
debug
=
True
)
with
torch
.
amp
.
autocast
(
"cuda"
,
dtype
=
torch
.
bfloat16
):
with
te
.
fp8_
autocast
(
enabled
=
opts
.
fp8
,
fp8_
recipe
=
fp8_recipe
,
fp8
_group
=
nccl_world
):
with
te
.
autocast
(
enabled
=
opts
.
fp8
,
recipe
=
fp8_recipe
,
amax_reduction
_group
=
nccl_world
):
y
=
model
(
x
)
if
isinstance
(
y
,
tuple
):
out
,
*
_
=
y
...
...
examples/pytorch/fsdp/README.md
View file @
063ef88d
...
...
@@ -49,5 +49,5 @@ $ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsd
# ...
```
**NOTE:**
This example has
`
fp8_
autocast()`
enabled by default. To run on GPUs without Fp8 support
**NOTE:**
This example has
`autocast()`
enabled by default. To run on GPUs without Fp8 support
(e.g.: A100), add the
`--no-fp8`
option to the commands shown above.
examples/pytorch/fsdp/fsdp.py
View file @
063ef88d
...
...
@@ -173,7 +173,7 @@ def parse_fsdp_args():
"--no-fp8"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Disables the te.
fp8_
autocast() context."
,
help
=
"Disables the te.autocast() context."
,
)
parser
.
add_argument
(
"--no-defer-init"
,
...
...
@@ -284,11 +284,11 @@ def train(opts):
dtype
=
opts
.
dtype
,
device
=
"cuda"
,
)
#
fp8_
autocast needs to be given the FSDP process group for amax reductions
with
te
.
fp8_
autocast
(
enabled
=
not
opts
.
no_fp8
,
fp8_
recipe
=
fp8_recipe
,
fp8
_group
=
all_gpus
):
# autocast needs to be given the FSDP process group for amax reductions
with
te
.
autocast
(
enabled
=
not
opts
.
no_fp8
,
recipe
=
fp8_recipe
,
amax_reduction
_group
=
all_gpus
):
y
=
te_model
(
x
)
loss
=
y
.
sum
()
# calculate gradient and take training step outside the
fp8_
autocast context
# calculate gradient and take training step outside the autocast context
loss
.
backward
()
optim
.
step
()
optim
.
zero_grad
(
set_to_none
=
True
)
...
...
examples/pytorch/mnist/main.py
View file @
063ef88d
...
...
@@ -52,7 +52,7 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8):
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
train_loader
):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
optimizer
.
zero_grad
()
with
te
.
fp8_
autocast
(
enabled
=
use_fp8
):
with
te
.
autocast
(
enabled
=
use_fp8
):
output
=
model
(
data
)
loss
=
F
.
nll_loss
(
output
,
target
)
loss
.
backward
()
...
...
@@ -76,7 +76,7 @@ def calibrate(model, device, test_loader, fp8):
with
torch
.
no_grad
():
for
data
,
target
in
test_loader
:
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
with
te
.
fp8_
autocast
(
enabled
=
fp8
,
calibrating
=
True
):
with
te
.
autocast
(
enabled
=
fp8
,
calibrating
=
True
):
output
=
model
(
data
)
...
...
@@ -88,7 +88,7 @@ def test(model, device, test_loader, use_fp8):
with
torch
.
no_grad
():
for
data
,
target
in
test_loader
:
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
with
te
.
fp8_
autocast
(
enabled
=
use_fp8
):
with
te
.
autocast
(
enabled
=
use_fp8
):
output
=
model
(
data
)
test_loss
+=
F
.
nll_loss
(
output
,
target
,
reduction
=
"sum"
).
item
()
# sum up batch loss
pred
=
output
.
argmax
(
dim
=
1
,
keepdim
=
True
)
# get the index of the max log-probability
...
...
pyproject.toml
0 → 100755
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[build-system]
requires
=
[
"setuptools>=61.0"
,
"cmake>=3.21"
,
"wheel"
,
"pybind11[global]"
,
"ninja"
,
"nvidia-mathdx==25.1.1"
,
"pip"
,
"torch>=2.1"
,
"jax>=0.5.0"
,
"flax>=0.7.1"
]
# Use legacy backend to import local packages in setup.py
build-backend
=
"setuptools.build_meta:__legacy__"
qa/L0_jax_distributed_unittest/test.sh
View file @
063ef88d
...
...
@@ -29,6 +29,10 @@ wait
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_test_model_parallel_encoder.xml
$TE_PATH
/examples/jax/encoder/test_model_parallel_encoder.py
||
test_fail
"test_model_parallel_encoder.py"
wait
TE_PATH
=
$TE_PATH
bash
$TE_PATH
/examples/jax/encoder/run_test_multiprocessing_encoder.sh
||
test_fail
"run_test_multiprocessing_encoder.sh"
wait
TE_PATH
=
$TE_PATH
bash
$TE_PATH
/examples/jax/collective_gemm/run_test_cgemm.sh
||
test_fail
"run_test_cgemm.sh"
wait
if
[
$RET
-ne
0
]
;
then
echo
"Error: some sub-tests failed:
$FAILED_CASES
"
...
...
qa/L0_jax_unittest/test.sh
View file @
063ef88d
...
...
@@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_test_single_gpu_encoder.xml
$TE_PATH
/examples/jax/encoder/test_single_gpu_encoder.py
||
test_fail
"test_single_gpu_encoder.py"
# Test without custom calls
export
XLA_FLAGS
=
"
${
XLA_FLAGS
}
--xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS
=
"false"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_test_single_gpu_encoder.xml
$TE_PATH
/examples/jax/encoder/test_single_gpu_encoder.py
||
test_fail
"test_single_gpu_encoder.py without custom calls"
NVTE_JAX_CUSTOM_CALLS
=
"false"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest_test_single_gpu_encoder
_without_custom_call
.xml
$TE_PATH
/examples/jax/encoder/test_single_gpu_encoder.py
||
test_fail
"test_single_gpu_encoder.py without custom calls"
if
[
$RET
-ne
0
]
;
then
echo
"Error: some sub-tests failed:
$FAILED_CASES
"
...
...
qa/L0_pytorch_debug_unittest/test.sh
View file @
063ef88d
...
...
@@ -7,6 +7,8 @@
:
${
TE_PATH
:
=/opt/transformerengine
}
:
${
NVTE_TEST_NVINSPECT_FEATURE_DIRS
:
=
$TE_PATH
/transformer_engine/debug/features
}
:
${
NVTE_TEST_NVINSPECT_CONFIGS_DIR
:
=
$TE_PATH
/tests/pytorch/debug/test_configs/
}
:
${
XML_LOG_DIR
:
=/logs
}
mkdir
-p
"
$XML_LOG_DIR
"
# Config with the dummy feature which prevents nvinspect from being disabled.
# Nvinspect will be disabled if no feature is active.
...
...
@@ -20,17 +22,16 @@ pip uninstall -y nvdlfw-inspect
pip
install
git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git
pip
install
pytest
==
8.2.1
pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_sanity.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
FAIL
=
1
pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_config.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
FAIL
=
1
pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_numerics.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
FAIL
=
1
pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_log.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
--configs_dir
=
$NVTE_TEST_NVINSPECT_CONFIGS_DIR
||
FAIL
=
1
NVTE_TORCH_COMPILE
=
0 pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_api_features.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
--configs_dir
=
$NVTE_TEST_NVINSPECT_CONFIGS_DIR
||
FAIL
=
1
pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_log.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
--configs_dir
=
$NVTE_TEST_NVINSPECT_CONFIGS_DIR
||
FAIL
=
1
pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_perf.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
--configs_dir
=
$NVTE_TEST_NVINSPECT_CONFIGS_DIR
||
FAIL
=
1
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/test_sanity.xml
$TE_PATH
/tests/pytorch/debug/test_sanity.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
FAIL
=
1
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/test_config.xml
$TE_PATH
/tests/pytorch/debug/test_config.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
FAIL
=
1
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/test_numerics.xml
$TE_PATH
/tests/pytorch/debug/test_numerics.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
FAIL
=
1
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/test_log.xml
$TE_PATH
/tests/pytorch/debug/test_log.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
--configs_dir
=
$NVTE_TEST_NVINSPECT_CONFIGS_DIR
||
FAIL
=
1
NVTE_TORCH_COMPILE
=
0 pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/test_api_features.xml
$TE_PATH
/tests/pytorch/debug/test_api_features.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
--configs_dir
=
$NVTE_TEST_NVINSPECT_CONFIGS_DIR
||
FAIL
=
1
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/test_perf.xml
$TE_PATH
/tests/pytorch/debug/test_perf.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
--configs_dir
=
$NVTE_TEST_NVINSPECT_CONFIGS_DIR
||
FAIL
=
1
# standard sanity and numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED
=
1
NVTE_TEST_NVINSPECT_CONFIG_FILE
=
$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE
NVTE_TEST_NVINSPECT_FEATURE_DIRS
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 pytest
-v
-s
$TE_PATH
/tests/pytorch/test_sanity.py
||
FAIL
=
1
NVTE_TEST_NVINSPECT_ENABLED
=
1
NVTE_TEST_NVINSPECT_CONFIG_FILE
=
$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE
NVTE_TEST_NVINSPECT_FEATURE_DIRS
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 pytest
-v
-s
$TE_PATH
/tests/pytorch/test_numerics.py
||
FAIL
=
1
NVTE_TEST_NVINSPECT_ENABLED
=
1
NVTE_TEST_NVINSPECT_CONFIG_FILE
=
$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE
NVTE_TEST_NVINSPECT_FEATURE_DIRS
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/test_sanity_2.xml
$TE_PATH
/tests/pytorch/test_sanity.py
||
FAIL
=
1
NVTE_TEST_NVINSPECT_ENABLED
=
1
NVTE_TEST_NVINSPECT_CONFIG_FILE
=
$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE
NVTE_TEST_NVINSPECT_FEATURE_DIRS
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/test_numerics_2.xml
$TE_PATH
/tests/pytorch/test_numerics.py
||
FAIL
=
1
exit
$FAIL
qa/L0_pytorch_unittest/test.sh
View file @
063ef88d
...
...
@@ -31,6 +31,7 @@ ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0
ROCBLAS_ATOMICS_MOD
=
0
HIPBLASLT_ATOMICS_MOD
=
0
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_cuda_graphs.xml
$TE_PATH
/tests/pytorch/test_cuda_graphs.py
||
test_fail
"test_cuda_graphs.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_jit.xml
$TE_PATH
/tests/pytorch/test_jit.py
||
test_fail
"test_jit.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_rope.xml
$TE_PATH
/tests/pytorch/test_fused_rope.py
||
test_fail
"test_fused_rope.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_nvfp4.xml
$TE_PATH
/tests/pytorch/nvfp4
||
test_fail
"test_nvfp4"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8tensor.xml
$TE_PATH
/tests/pytorch/test_float8tensor.py
||
test_fail
"test_float8tensor.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8blockwisetensor.xml
$TE_PATH
/tests/pytorch/test_float8blockwisetensor.py
||
test_fail
"test_float8blockwisetensor.py"
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_scaling_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_scaling_exact.py
||
test_fail
"test_float8_blockwise_scaling_exact.py"
...
...
qa/L1_jax_distributed_unittest/test.sh
View file @
063ef88d
...
...
@@ -9,4 +9,4 @@ set -xe
mkdir
-p
"
$XML_LOG_DIR
"
NVTE_JAX_UNITTEST_LEVEL
=
"L1"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
--junitxml
=
$XML_LOG_DIR
/pytest.xml
$TE_PATH
/tests/jax/test_distributed_
*
SCRIPT_NAME
=
test_multi_process_distributed_grouped_gemm.py bash
$TE_PATH
/tests/jax/multi_process_launch.sh
SCRIPT_NAME
=
$TE_PATH
/tests/jax/
test_multi_process_distributed_grouped_gemm.py bash
$TE_PATH
/tests/jax/multi_process_launch.sh
qa/L1_pytorch_distributed_unittest/test.sh
View file @
063ef88d
...
...
@@ -30,6 +30,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_sanity.xml
$TE_PATH
/tests/pytorch/distributed/test_sanity.py
||
test_fail
"test_sanity.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_numerics.xml
$TE_PATH
/tests/pytorch/distributed/test_numerics.py
||
test_fail
"test_numerics.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_numerics_exact.xml
$TE_PATH
/tests/pytorch/distributed/test_numerics_exact.py
||
test_fail
"test_numerics_exact.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_fusible_ops.xml
$TE_PATH
/tests/pytorch/distributed/test_fusible_ops.py
||
test_fail
"test_fusible_ops.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_torch_fsdp2.xml
$TE_PATH
/tests/pytorch/distributed/test_torch_fsdp2.py
||
test_fail
"test_torch_fsdp2.py"
python3
-m
pytest
-v
-s
--log-cli-level
=
INFO
--junitxml
=
$XML_LOG_DIR
/pytest_test_comm_gemm_overlap.xml
$TE_PATH
/tests/pytorch/distributed/test_comm_gemm_overlap.py
||
test_fail
"test_comm_gemm_overlap.py"
...
...
@@ -47,9 +48,9 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_
:
${
NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE
:
=
$TE_PATH
/tests/pytorch/debug/test_configs/dummy_feature.yaml
}
:
${
NVTE_TEST_NVINSPECT_FEATURE_DIRS
:
=
$TE_PATH
/transformer_engine/debug/features
}
pytest
-v
-s
$TE_PATH
/tests/pytorch/debug/test_distributed.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
test_fail
"debug test_distributed.py"
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_distributed.xml
$TE_PATH
/tests/pytorch/debug/test_distributed.py
--feature_dirs
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
||
test_fail
"debug test_distributed.py"
# standard numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED
=
True
NVTE_TEST_NVINSPECT_CONFIG_FILE
=
$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE
NVTE_TEST_NVINSPECT_FEATURE_DIRS
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
pytest
-v
-s
$TE_PATH
/tests/pytorch/distributed/test_numerics.py
||
test_fail
"debug test_numerics.py"
NVTE_TEST_NVINSPECT_ENABLED
=
True
NVTE_TEST_NVINSPECT_CONFIG_FILE
=
$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE
NVTE_TEST_NVINSPECT_FEATURE_DIRS
=
$NVTE_TEST_NVINSPECT_FEATURE_DIRS
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_numerics_2.xml
$TE_PATH
/tests/pytorch/distributed/test_numerics.py
||
test_fail
"debug test_numerics.py"
if
[
"
$RET
"
-ne
0
]
;
then
echo
"Error in the following test cases:
$FAILED_CASES
"
...
...
qa/L1_pytorch_onnx_unittest/test.sh
View file @
063ef88d
...
...
@@ -3,9 +3,11 @@
# See LICENSE for license information.
pip3
install
onnxruntime
==
1.20.1
pip3
install
onnxruntime_extensions
==
0.13.0
pip3
install
onnxruntime
pip3
install
onnxruntime_extensions
:
${
TE_PATH
:
=/opt/transformerengine
}
:
${
XML_LOG_DIR
:
=/logs
}
mkdir
-p
"
$XML_LOG_DIR
"
python3
-m
pytest
--tb
=
auto
$TE_PATH
/tests/pytorch/test_onnx_export.py
python3
-m
pytest
--tb
=
auto
--junitxml
=
$XML_LOG_DIR
/test_onnx_export.xml
$TE_PATH
/tests/pytorch/test_onnx_export.py
setup.py
View file @
063ef88d
...
...
@@ -23,6 +23,7 @@ from build_tools.utils import (
cuda_version
,
get_frameworks
,
remove_dups
,
min_python_version_str
,
)
frameworks
=
get_frameworks
()
...
...
@@ -211,7 +212,7 @@ if __name__ == "__main__":
long_description_content_type
=
"text/x-rst"
,
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
CMakeBuildExtension
,
"bdist_wheel"
:
TimedBdist
},
python_requires
=
">=
3.8
"
,
python_requires
=
f
">=
{
min_python_version_str
()
}
"
,
classifiers
=
[
"Programming Language :: Python :: 3"
],
install_requires
=
install_requires
,
license_files
=
(
"LICENSE"
,),
...
...
tests/cpp/operator/CMakeLists.txt
View file @
063ef88d
...
...
@@ -66,6 +66,13 @@ else()
add_executable
(
test_operator
${
test_hip_sources
}
)
endif
()
# Add profiling and debug flags for CUDA compilation
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-lineinfo"
)
# Generate line info for device code
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-g"
)
# Add debug symbols for host code
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
--ptxas-options=-v"
)
# Add info about registers usage
# Note: Using -lineinfo instead of -G to avoid conflicts and get line mapping
# Find required packages
find_package
(
OpenMP REQUIRED
)
if
(
USE_CUDA
)
list
(
APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main
${
TE_LIB
}
CUDA::nvrtc CUDNN::cudnn
)
...
...
tests/cpp/operator/test_cast_float8blockwise.cu
View file @
063ef88d
...
...
@@ -529,6 +529,12 @@ TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) {
q_opts
.
amax_epsilon
=
eps
;
q_opts
.
block_scaling_dim
=
2u
;
// On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8,
// which requires using power of two scaling factors. Skip unsupported tests.
if
(
getDeviceComputeCapability
()
>=
blackwellComputeCapability
&&
!
force_pow_2
)
{
GTEST_SKIP
();
}
if
(
colwise
&&
matrix_size
.
size
()
<
2
)
{
// test_common Tensor initialization code does not
// handle this case.
...
...
@@ -580,6 +586,12 @@ TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) {
q_opts
.
amax_epsilon
=
eps
;
q_opts
.
block_scaling_dim
=
1u
;
// On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8,
// which requires using power of two scaling factors. Skip unsupported tests.
if
(
getDeviceComputeCapability
()
>=
blackwellComputeCapability
&&
!
force_pow_2
)
{
GTEST_SKIP
();
}
if
(
colwise
&&
matrix_size
.
size
()
<
2
)
{
// test_common Tensor initialization code does not
// handle this case.
...
...
tests/cpp/operator/test_cast_mxfp8.cu
View file @
063ef88d
...
...
@@ -81,6 +81,7 @@ void compute_ref(const ProcessingMethod processing_method,
// Cache computations
for
(
size_t
i
=
i_min
;
i
<
i_max
;
++
i
)
{
for
(
size_t
j
=
j_min
;
j
<
j_max
;
++
j
)
{
const
size_t
idx
=
i
*
cols
+
j
;
const
size_t
cache_idx
=
(
i
-
i_min
)
*
tile_size_X
+
(
j
-
j_min
);
...
...
@@ -310,7 +311,8 @@ void performTest_x1(const ProcessingMethod processing_method,
const
double
rel_tolerable_mismatches_limit
=
0.0
;
size_t
mismatches_scales
=
0
;
compare_e8m0_scaling_factors
(
"scales"
,
gpu_scales_ptr
,
ref_output_scales
.
get
(),
compare_scaling_factors
(
"scales"
,
gpu_scales_ptr
,
ref_output_scales
.
get
(),
unpadded_blocks_Y
,
unpadded_blocks_X
,
scales_stride
,
mismatches_scales
,
scale_diff_abs_tolerance
,
...
...
@@ -481,7 +483,7 @@ void performTest_x2(const ProcessingMethod processing_method,
const
double
rel_tolerable_mismatches_limit
=
0.0
;
size_t
mismatches_scales_rowwise
=
0
;
compare_
e8m0_
scaling_factors
(
"scales_rowwise"
,
output
.
rowwise_cpu_scale_inv_ptr
<
fp8e8m0
>
(),
compare_scaling_factors
(
"scales_rowwise"
,
output
.
rowwise_cpu_scale_inv_ptr
<
fp8e8m0
>
(),
ref_scales_rowwise
.
get
(),
unpadded_blocks_Y_rowwise
,
unpadded_blocks_X_rowwise
,
scales_stride_rowwise
,
mismatches_scales_rowwise
,
...
...
@@ -490,7 +492,7 @@ void performTest_x2(const ProcessingMethod processing_method,
rel_tolerable_mismatches_limit
);
size_t
mismatches_scales_colwise
=
0
;
compare_
e8m0_
scaling_factors
(
"scales_colwise"
,
output
.
columnwise_cpu_scale_inv_ptr
<
fp8e8m0
>
(),
compare_scaling_factors
(
"scales_colwise"
,
output
.
columnwise_cpu_scale_inv_ptr
<
fp8e8m0
>
(),
ref_scales_colwise
.
get
(),
unpadded_blocks_Y_colwise
,
unpadded_blocks_X_colwise
,
scales_stride_colwise
,
mismatches_scales_colwise
,
...
...
tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
View file @
063ef88d
...
...
@@ -267,19 +267,20 @@ void performTest_x1(const size_t rows,
?
output
.
rowwise_cpu_scale_inv_ptr
<
fp8e8m0
>
()
:
output
.
columnwise_cpu_scale_inv_ptr
<
fp8e8m0
>
();
if
(
rowwise
)
{
compare_
e8m0_
scaling_factors
(
"rowwise scales"
,
gpu_scales_ptr
,
ref_output_scales
.
get
(),
compare_scaling_factors
(
"rowwise scales"
,
gpu_scales_ptr
,
ref_output_scales
.
get
(),
unpadded_blocks_Y
,
unpadded_blocks_X
,
scales_stride
,
mismatches_scales
,
scale_diff_abs_tolerance
,
abs_tolerable_mismatches_limit
,
rel_tolerable_mismatches_limit
);
}
else
{
compare_
e8m0_
scaling_factors
(
"colwise scales"
,
gpu_scales_ptr
,
ref_output_scales
.
get
(),
compare_scaling_factors
(
"colwise scales"
,
gpu_scales_ptr
,
ref_output_scales
.
get
(),
unpadded_blocks_Y
,
unpadded_blocks_X
,
scales_stride
,
mismatches_scales
,
scale_diff_abs_tolerance
,
abs_tolerable_mismatches_limit
,
rel_tolerable_mismatches_limit
);
}
const
size_t
mismatches_elts
=
32
*
mismatches_scales
;
...
...
@@ -378,7 +379,7 @@ void performTest_x2(const size_t rows,
const
double
rel_tolerable_mismatches_limit
=
1.0e-4
;
size_t
mismatches_scales_rowwise
=
0
;
compare_
e8m0_
scaling_factors
(
"scales_rowwise"
,
output
.
rowwise_cpu_scale_inv_ptr
<
fp8e8m0
>
(),
compare_scaling_factors
(
"scales_rowwise"
,
output
.
rowwise_cpu_scale_inv_ptr
<
fp8e8m0
>
(),
ref_scales_rowwise
.
get
(),
unpadded_blocks_Y_rowwise
,
unpadded_blocks_X_rowwise
,
scales_stride_rowwise
,
mismatches_scales_rowwise
,
...
...
@@ -386,7 +387,7 @@ void performTest_x2(const size_t rows,
abs_tolerable_mismatches_limit
,
rel_tolerable_mismatches_limit
);
size_t
mismatches_scales_colwise
=
0
;
compare_
e8m0_
scaling_factors
(
"scales_colwise"
,
output
.
columnwise_cpu_scale_inv_ptr
<
fp8e8m0
>
(),
compare_scaling_factors
(
"scales_colwise"
,
output
.
columnwise_cpu_scale_inv_ptr
<
fp8e8m0
>
(),
ref_scales_colwise
.
get
(),
unpadded_blocks_Y_colwise
,
unpadded_blocks_X_colwise
,
scales_stride_colwise
,
mismatches_scales_colwise
,
...
...
@@ -394,6 +395,7 @@ void performTest_x2(const size_t rows,
abs_tolerable_mismatches_limit
,
rel_tolerable_mismatches_limit
);
const
size_t
mismatches_elts_rowwise
=
32
*
mismatches_scales_rowwise
;
const
size_t
mismatches_elts_colwise
=
32
*
mismatches_scales_colwise
;
...
...
tests/cpp/operator/test_cast_nvfp4_transpose.cu
0 → 100644
View file @
063ef88d
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_fp4.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/activation.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
#include <fstream>
using
namespace
transformer_engine
;
using
namespace
test
;
namespace
{
enum
ActivationType
{
Identity
,
GeLU
,
SiLU
,
ReLU
,
QGeLU
,
SReLU
};
double2
cvt_fp4x2_to_double2
(
fp4e2m1x2
fp4_pair
)
{
const
__half2_raw
raw_truncated_to_fp4e2m1_pair
=
__nv_cvt_fp4x2_to_halfraw2
(
*
reinterpret_cast
<
__nv_fp4x2_storage_t
*>
(
&
fp4_pair
),
__NV_E2M1
);
const
__half2
truncated_to_fp4e2m1_pair
(
raw_truncated_to_fp4e2m1_pair
);
const
double
truncated_to_fp4e2m1_x
=
static_cast
<
double
>
(
truncated_to_fp4e2m1_pair
.
x
);
const
double
truncated_to_fp4e2m1_y
=
static_cast
<
double
>
(
truncated_to_fp4e2m1_pair
.
y
);
return
{
truncated_to_fp4e2m1_x
,
truncated_to_fp4e2m1_y
};
}
template
<
typename
InputType
>
std
::
vector
<
InputType
>
create_transpose
(
const
InputType
*
const
input
,
const
size_t
rows
,
size_t
cols
)
{
std
::
vector
<
InputType
>
input_t
(
cols
*
rows
);
for
(
size_t
i
=
0
;
i
<
rows
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
cols
;
++
j
)
{
const
size_t
idx
=
i
*
cols
+
j
;
const
size_t
idx_t
=
j
*
rows
+
i
;
input_t
[
idx_t
]
=
input
[
idx
];
}
}
return
input_t
;
}
// Compute the global encode scale factor for a given global amax
float
compute_global_encode_scaling_factor_FP4
(
const
float
global_amax
)
{
constexpr
float
fp8_max
=
448.0
f
;
// 448.0f;
constexpr
float
fp4_max
=
6.0
f
;
// 6.0f;
float
global_encode_scale
=
fp8_max
*
fp4_max
/
global_amax
;
// If scale is infinity, return max value of float32
global_encode_scale
=
fminf
(
global_encode_scale
,
Numeric_Traits
<
float
>::
maxNorm
);
// If global amax is 0 or infinity, return 1
if
(
global_amax
==
0.0
f
||
global_encode_scale
==
0.0
f
)
{
return
1.0
f
;
}
return
global_encode_scale
;
}
// 1D Scaling: Original implementation with 1x16 blocks
template
<
typename
InputType
>
void
quantize_nvfp4_1d
(
float
(
*
OP
)(
const
float
),
const
InputType
*
const
input
,
fp4e2m1x2
*
const
output
,
fp8e4m3
*
const
scales
,
const
size_t
rows
,
const
size_t
cols
,
const
size_t
scales_stride
,
const
float
global_amax
)
{
// Compute a global encoding/decoding scaling factor for all S_dec_b
const
float
S_enc
=
compute_global_encode_scaling_factor_FP4
(
global_amax
);
constexpr
size_t
block_size_X
=
16
;
const
size_t
blocks_X
=
divide_round_up
(
cols
,
block_size_X
);
std
::
array
<
float
,
block_size_X
>
cache_buffer
;
for
(
size_t
i
=
0
;
i
<
block_size_X
;
++
i
)
{
cache_buffer
[
i
]
=
0.0
f
;
}
for
(
size_t
i
=
0
;
i
<
rows
;
++
i
)
{
for
(
size_t
block_X
=
0
;
block_X
<
blocks_X
;
++
block_X
)
{
const
size_t
j_min
=
block_X
*
block_size_X
;
const
size_t
j_max
=
j_min
+
block_size_X
;
// Find block amax
float
block_amax
=
0.0
f
;
for
(
size_t
j
=
j_min
;
j
<
j_max
;
++
j
)
{
const
size_t
idx
=
i
*
cols
+
j
;
const
size_t
cache_idx
=
j
-
j_min
;
const
float
input_elt
=
static_cast
<
float
>
(
input
[
idx
]);
const
float
act_elt
=
OP
(
input_elt
);
// Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32
const
float
elt
=
static_cast
<
float
>
(
static_cast
<
InputType
>
(
act_elt
));
cache_buffer
[
cache_idx
]
=
elt
;
block_amax
=
std
::
max
(
block_amax
,
std
::
abs
(
elt
));
}
// 2. Compute E4M3 scaling factor
// Compute per-block encoding/decoding scaling factor
const
float
S_dec_b
=
block_amax
/
6.0
f
;
// Scale & Store per-block decoding scaling factor
const
float
S_dec_b_fp8
=
S_dec_b
*
S_enc
;
// Compute "correct" per-block encoding scaling factor
const
float
S_enc_b_fp8
=
S_dec_b_fp8
==
0
?
0.
f
:
S_enc
/
S_dec_b_fp8
;
const
size_t
scale_idx
=
i
*
scales_stride
+
block_X
;
scales
[
scale_idx
]
=
static_cast
<
fp8e4m3
>
(
S_dec_b_fp8
);
const
float
scale_reciprocal
=
S_enc_b_fp8
;
for
(
size_t
j
=
j_min
;
j
<
j_max
;
j
+=
2
)
{
const
int
idx_pair
=
(
i
*
cols
+
j
)
/
2
;
const
int
cache_idx_x
=
j
-
j_min
;
const
int
cache_idx_y
=
cache_idx_x
+
1
;
const
float
cached_x
=
cache_buffer
[
cache_idx_x
];
const
float
cached_y
=
cache_buffer
[
cache_idx_y
];
const
float
scaled_elt_x
=
cached_x
*
scale_reciprocal
;
const
float
scaled_elt_y
=
cached_y
*
scale_reciprocal
;
const
float2
scaled_elt_pair
=
{
scaled_elt_x
,
scaled_elt_y
};
fp4e2m1x2
casted_to_e2m1_pair
(
scaled_elt_pair
);
output
[
idx_pair
]
=
casted_to_e2m1_pair
;
// const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair);
}
}
}
}
// Compute 2D mathematical scaling factors (8x8 for 128x128 input)
template
<
typename
InputType
>
void
compute_2d_mathematical_scales
(
float
(
*
OP
)(
const
float
),
const
InputType
*
const
input
,
const
size_t
rows
,
const
size_t
cols
,
const
float
global_amax
,
std
::
vector
<
std
::
vector
<
fp8e4m3
>>&
math_scales
)
{
const
float
S_enc
=
compute_global_encode_scaling_factor_FP4
(
global_amax
);
constexpr
size_t
block_size_Y
=
16
;
constexpr
size_t
block_size_X
=
16
;
const
size_t
blocks_Y
=
divide_round_up
(
rows
,
block_size_Y
);
const
size_t
blocks_X
=
divide_round_up
(
cols
,
block_size_X
);
math_scales
.
resize
(
blocks_Y
,
std
::
vector
<
fp8e4m3
>
(
blocks_X
));
for
(
size_t
block_Y
=
0
;
block_Y
<
blocks_Y
;
++
block_Y
)
{
for
(
size_t
block_X
=
0
;
block_X
<
blocks_X
;
++
block_X
)
{
const
size_t
i_min
=
block_Y
*
block_size_Y
;
const
size_t
i_max
=
std
::
min
(
i_min
+
block_size_Y
,
rows
);
const
size_t
j_min
=
block_X
*
block_size_X
;
const
size_t
j_max
=
std
::
min
(
j_min
+
block_size_X
,
cols
);
// Find 2D block amax over entire 16x16 region
float
block_amax
=
0.0
f
;
for
(
size_t
i
=
i_min
;
i
<
i_max
;
++
i
)
{
for
(
size_t
j
=
j_min
;
j
<
j_max
;
++
j
)
{
const
size_t
idx
=
i
*
cols
+
j
;
const
float
input_elt
=
static_cast
<
float
>
(
input
[
idx
]);
const
float
act_elt
=
OP
(
input_elt
);
const
float
elt
=
static_cast
<
float
>
(
static_cast
<
InputType
>
(
act_elt
));
block_amax
=
std
::
max
(
block_amax
,
std
::
abs
(
elt
));
}
}
// Compute E4M3 scaling factor for this 16x16 block
const
float
S_dec_b
=
block_amax
/
6.0
f
;
const
fp8e4m3
S_dec_b_fp8
=
static_cast
<
fp8e4m3
>
(
S_dec_b
*
S_enc
);
math_scales
[
block_Y
][
block_X
]
=
S_dec_b_fp8
;
}
}
}
// 2D Scaling: NEW implementation with proper replication
template
<
typename
InputType
>
void
quantize_nvfp4_2d
(
float
(
*
OP
)(
const
float
),
const
InputType
*
const
input
,
fp4e2m1x2
*
const
output
,
fp8e4m3
*
const
scales
,
const
size_t
rows
,
const
size_t
cols
,
const
size_t
scales_stride
,
const
float
global_amax
)
{
// Step 1: Compute mathematical 8x8 scaling factors
std
::
vector
<
std
::
vector
<
fp8e4m3
>>
math_scales
;
compute_2d_mathematical_scales
(
OP
,
input
,
rows
,
cols
,
global_amax
,
math_scales
);
const
float
S_enc
=
compute_global_encode_scaling_factor_FP4
(
global_amax
);
constexpr
size_t
block_size_Y
=
16
;
constexpr
size_t
block_size_X
=
16
;
const
size_t
blocks_Y
=
divide_round_up
(
rows
,
block_size_Y
);
const
size_t
blocks_X
=
divide_round_up
(
cols
,
block_size_X
);
// Step 2: Replicate scaling factors row-wise (128×8 storage) - only if scales is not nullptr
if
(
scales
!=
nullptr
)
{
// Each of the 128 rows gets scaling factors from its corresponding 16×16 block
for
(
size_t
i
=
0
;
i
<
rows
;
++
i
)
{
const
size_t
block_Y
=
i
/
block_size_Y
;
for
(
size_t
block_X
=
0
;
block_X
<
blocks_X
;
++
block_X
)
{
const
size_t
scale_idx
=
i
*
scales_stride
+
block_X
;
scales
[
scale_idx
]
=
math_scales
[
block_Y
][
block_X
];
}
}
}
// Step 3: Apply quantization using the mathematical scaling factors
std
::
array
<
std
::
array
<
float
,
block_size_X
>
,
block_size_Y
>
cache_buffer
;
for
(
size_t
block_Y
=
0
;
block_Y
<
blocks_Y
;
++
block_Y
)
{
for
(
size_t
block_X
=
0
;
block_X
<
blocks_X
;
++
block_X
)
{
const
size_t
i_min
=
block_Y
*
block_size_Y
;
const
size_t
i_max
=
std
::
min
(
i_min
+
block_size_Y
,
rows
);
const
size_t
j_min
=
block_X
*
block_size_X
;
const
size_t
j_max
=
std
::
min
(
j_min
+
block_size_X
,
cols
);
// Get the scaling factor for this block
const
float
S_dec_b_fp8
=
static_cast
<
float
>
(
math_scales
[
block_Y
][
block_X
]);
const
float
S_enc_b_fp8
=
S_dec_b_fp8
==
0
?
0.
f
:
S_enc
/
S_dec_b_fp8
;
const
float
scale_reciprocal
=
S_enc_b_fp8
;
// Process and cache data for this 16x16 block
for
(
size_t
i
=
i_min
;
i
<
i_max
;
++
i
)
{
for
(
size_t
j
=
j_min
;
j
<
j_max
;
++
j
)
{
const
size_t
idx
=
i
*
cols
+
j
;
const
size_t
cache_idx_y
=
i
-
i_min
;
const
size_t
cache_idx_x
=
j
-
j_min
;
const
float
input_elt
=
static_cast
<
float
>
(
input
[
idx
]);
const
float
act_elt
=
OP
(
input_elt
);
const
float
elt
=
static_cast
<
float
>
(
static_cast
<
InputType
>
(
act_elt
));
cache_buffer
[
cache_idx_y
][
cache_idx_x
]
=
elt
;
}
}
// Apply scaling to all elements in this 16x16 block
for
(
size_t
i
=
i_min
;
i
<
i_max
;
++
i
)
{
for
(
size_t
j
=
j_min
;
j
<
j_max
;
j
+=
2
)
{
const
int
idx_pair
=
(
i
*
cols
+
j
)
/
2
;
const
size_t
cache_idx_y
=
i
-
i_min
;
const
size_t
cache_idx_x1
=
j
-
j_min
;
const
size_t
cache_idx_x2
=
std
::
min
(
cache_idx_x1
+
1
,
block_size_X
-
1
);
const
float
cached_x
=
cache_buffer
[
cache_idx_y
][
cache_idx_x1
];
const
float
cached_y
=
((
j
+
1
)
<
j_max
&&
cache_idx_x2
<
block_size_X
)
?
cache_buffer
[
cache_idx_y
][
cache_idx_x2
]
:
0.0
f
;
const
float
scaled_elt_x
=
cached_x
*
scale_reciprocal
;
const
float
scaled_elt_y
=
cached_y
*
scale_reciprocal
;
const
float2
scaled_elt_pair
=
{
scaled_elt_x
,
scaled_elt_y
};
fp4e2m1x2
casted_to_e2m1_pair
(
scaled_elt_pair
);
output
[
idx_pair
]
=
casted_to_e2m1_pair
;
}
}
}
}
}
// Wrapper function that calls appropriate implementation based on 2D flag
template
<
typename
InputType
>
void
quantize_nvfp4
(
float
(
*
OP
)(
const
float
),
const
InputType
*
const
input
,
fp4e2m1x2
*
const
output
,
fp8e4m3
*
const
scales
,
const
size_t
rows
,
const
size_t
cols
,
const
size_t
scales_stride
,
const
float
global_amax
,
const
bool
use_2d_quantization
=
false
)
{
if
(
use_2d_quantization
)
{
quantize_nvfp4_2d
(
OP
,
input
,
output
,
scales
,
rows
,
cols
,
scales_stride
,
global_amax
);
}
else
{
quantize_nvfp4_1d
(
OP
,
input
,
output
,
scales
,
rows
,
cols
,
scales_stride
,
global_amax
);
}
}
template
<
typename
InputType
>
void
compute_ref
(
float
(
*
OP
)(
const
float
),
const
InputType
*
input
,
fp4e2m1x2
*
output
,
fp4e2m1x2
*
output_t
,
fp8e4m3
*
scales
,
fp8e4m3
*
scales_t
,
const
float
global_amax
,
const
size_t
rows
,
const
size_t
cols
,
const
size_t
scales_stride
,
const
size_t
scales_stride_t
,
const
bool
use_2d_quantization
=
false
)
{
std
::
vector
<
InputType
>
input_t
=
create_transpose
(
input
,
rows
,
cols
);
if
(
use_2d_quantization
)
{
// Step 1: Compute mathematical 8×8 scaling factors
std
::
vector
<
std
::
vector
<
fp8e4m3
>>
math_scales
;
compute_2d_mathematical_scales
(
OP
,
input
,
rows
,
cols
,
global_amax
,
math_scales
);
constexpr
size_t
block_size_Y
=
16
;
constexpr
size_t
block_size_X
=
16
;
const
size_t
blocks_Y
=
divide_round_up
(
rows
,
block_size_Y
);
const
size_t
blocks_X
=
divide_round_up
(
cols
,
block_size_X
);
// Step 2: Generate scales (128×8) by replicating row-wise
for
(
size_t
i
=
0
;
i
<
rows
;
++
i
)
{
const
size_t
block_Y
=
i
/
block_size_Y
;
for
(
size_t
block_X
=
0
;
block_X
<
blocks_X
;
++
block_X
)
{
const
size_t
scale_idx
=
i
*
scales_stride
+
block_X
;
scales
[
scale_idx
]
=
math_scales
[
block_Y
][
block_X
];
}
}
// Step 3: Generate scales_t (128×8) with proper transposed block mapping
for
(
size_t
i
=
0
;
i
<
cols
;
++
i
)
{
// cols = 128, which becomes rows of transposed data
const
size_t
block_X_orig
=
i
/
block_size_X
;
// i was column index in original, so maps to block_X
for
(
size_t
block_Y_new
=
0
;
block_Y_new
<
blocks_Y
;
++
block_Y_new
)
{
// block in transposed coordinate
const
size_t
scale_idx
=
i
*
scales_stride_t
+
block_Y_new
;
scales_t
[
scale_idx
]
=
math_scales
[
block_Y_new
][
block_X_orig
];
}
}
// Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d
// (This part processes the actual FP4 data using the mathematical scaling factors)
quantize_nvfp4_2d
(
OP
,
input
,
output
,
nullptr
,
rows
,
cols
,
scales_stride
,
global_amax
);
// scales already filled
quantize_nvfp4_2d
(
OP
,
input_t
.
data
(),
output_t
,
nullptr
,
cols
,
rows
,
scales_stride_t
,
global_amax
);
// scales_t already filled
}
else
{
quantize_nvfp4
(
OP
,
input
,
output
,
scales
,
rows
,
cols
,
scales_stride
,
global_amax
,
use_2d_quantization
);
quantize_nvfp4
(
OP
,
input_t
.
data
(),
output_t
,
scales_t
,
cols
,
rows
,
scales_stride_t
,
global_amax
,
use_2d_quantization
);
}
}
void
compare_nvfp4_tensors
(
const
std
::
string
&
name
,
const
fp4e2m1
*
test_data
,
const
fp4e2m1
*
ref_data
,
const
int
rows
,
const
int
cols
,
double
atol
=
1e-5
,
double
rtol
=
1e-8
)
{
std
::
vector
<
std
::
string
>
mismatch_messages
;
size_t
total_mismatches
=
0
;
for
(
int
i
=
0
;
i
<
rows
;
++
i
)
{
for
(
int
j
=
0
;
j
<
cols
;
j
+=
2
)
{
const
int
idx
=
i
*
cols
+
j
;
double2
test_data_pair
=
cvt_fp4x2_to_double2
(
*
reinterpret_cast
<
const
fp4e2m1x2
*>
(
&
test_data
[
idx
/
2
]));
double2
ref_data_pair
=
cvt_fp4x2_to_double2
(
*
reinterpret_cast
<
const
fp4e2m1x2
*>
(
&
ref_data
[
idx
/
2
]));
for
(
int
k
=
0
;
k
<
2
;
++
k
)
{
const
double
t
=
(
k
==
0
?
test_data_pair
.
x
:
test_data_pair
.
y
);
const
double
r
=
(
k
==
0
?
ref_data_pair
.
x
:
ref_data_pair
.
y
);
bool
mismatch
=
fabs
(
t
-
r
)
>
atol
&&
(
r
==
0
||
fabs
((
t
-
r
)
/
r
)
>
rtol
);
/* For Float32 the floating point comparison is enough to error out */
bool
assertion
=
false
;
if
(
mismatch
&&
!
assertion
)
{
/* Check if it is just a failure of round to nearest choosing different
side of the real value */
const
double
mean
=
(
t
+
r
)
/
2
;
const
double
mean_p
=
mean
>=
0
?
mean
*
(
1
+
1e-6
)
:
mean
*
(
1
-
1e-6
);
const
double
mean_m
=
mean
>=
0
?
mean
*
(
1
-
1e-6
)
:
mean
*
(
1
+
1e-6
);
const
double
cast_mean_p
=
static_cast
<
double
>
(
static_cast
<
fp4e2m1
>
(
mean_p
));
const
double
cast_mean_m
=
static_cast
<
double
>
(
static_cast
<
fp4e2m1
>
(
mean_m
));
assertion
=
!
(
cast_mean_m
==
std
::
min
(
t
,
r
)
&&
cast_mean_p
==
std
::
max
(
t
,
r
));
}
if
(
assertion
)
{
total_mismatches
++
;
std
::
string
msg
=
"Mismatch at place ("
+
std
::
to_string
(
idx
+
k
)
+
"): "
+
std
::
to_string
(
t
)
+
" vs "
+
std
::
to_string
(
r
)
+
" (abs_diff: "
+
std
::
to_string
(
fabs
(
t
-
r
))
+
", rel_diff: "
+
std
::
to_string
(
r
==
0
?
0.0
:
fabs
((
t
-
r
)
/
r
))
+
")"
;
mismatch_messages
.
push_back
(
msg
);
// Optional: limit number of detailed messages to avoid overwhelming output
if
(
mismatch_messages
.
size
()
<=
100
)
{
std
::
cout
<<
"Error in tensor "
<<
name
<<
": "
<<
msg
<<
std
::
endl
;
}
}
}
}
}
// Always report summary - either success or failure
std
::
cout
<<
"=== SUMMARY for tensor "
<<
name
<<
" ==="
<<
std
::
endl
;
std
::
cout
<<
"Total elements checked: "
<<
(
rows
*
cols
)
<<
std
::
endl
;
if
(
total_mismatches
>
0
)
{
std
::
cout
<<
"STATUS: FAILED for output"
<<
std
::
endl
;
std
::
cout
<<
"Total mismatches found: "
<<
total_mismatches
<<
std
::
endl
;
std
::
cout
<<
"Mismatch rate: "
<<
(
100.0
*
total_mismatches
)
/
(
rows
*
cols
)
<<
"%"
<<
std
::
endl
;
if
(
mismatch_messages
.
size
()
>
100
)
{
std
::
cout
<<
"... and "
<<
(
mismatch_messages
.
size
()
-
100
)
<<
" more mismatches (showing first 100)"
<<
std
::
endl
;
}
std
::
cout
<<
"============================"
<<
std
::
endl
;
GTEST_FAIL
()
<<
"Found "
<<
total_mismatches
<<
" mismatches in tensor "
<<
name
;
}
else
{
std
::
cout
<<
"STATUS: PASSED for output"
<<
std
::
endl
;
std
::
cout
<<
"All elements match within tolerance!"
<<
std
::
endl
;
std
::
cout
<<
"Tensor "
<<
name
<<
" is IDENTICAL to reference"
<<
std
::
endl
;
std
::
cout
<<
"============================"
<<
std
::
endl
;
}
}
// Optional: Function to dump tensor data to files for detailed analysis
void
dump_nvfp4_tensor_data
(
const
std
::
string
&
prefix
,
const
fp4e2m1
*
test_data
,
const
fp4e2m1
*
ref_data
,
const
int
rows
,
const
int
cols
)
{
std
::
string
test_file
=
prefix
+
"_test.txt"
;
std
::
string
ref_file
=
prefix
+
"_ref.txt"
;
std
::
string
diff_file
=
prefix
+
"_diff.txt"
;
std
::
ofstream
test_out
(
test_file
);
std
::
ofstream
ref_out
(
ref_file
);
std
::
ofstream
diff_out
(
diff_file
);
if
(
test_out
.
is_open
()
&&
ref_out
.
is_open
()
&&
diff_out
.
is_open
())
{
for
(
int
i
=
0
;
i
<
rows
;
++
i
)
{
for
(
int
j
=
0
;
j
<
cols
;
j
+=
2
)
{
const
int
idx
=
i
*
cols
+
j
;
double2
test_data_pair
=
cvt_fp4x2_to_double2
(
*
reinterpret_cast
<
const
fp4e2m1x2
*>
(
&
test_data
[
idx
/
2
]));
double2
ref_data_pair
=
cvt_fp4x2_to_double2
(
*
reinterpret_cast
<
const
fp4e2m1x2
*>
(
&
ref_data
[
idx
/
2
]));
for
(
int
k
=
0
;
k
<
2
;
++
k
)
{
const
double
t
=
(
k
==
0
?
test_data_pair
.
x
:
test_data_pair
.
y
);
const
double
r
=
(
k
==
0
?
ref_data_pair
.
x
:
ref_data_pair
.
y
);
const
int
pos
=
idx
+
k
;
test_out
<<
"pos["
<<
pos
<<
"] = "
<<
t
<<
std
::
endl
;
ref_out
<<
"pos["
<<
pos
<<
"] = "
<<
r
<<
std
::
endl
;
diff_out
<<
"pos["
<<
pos
<<
"] test="
<<
t
<<
" ref="
<<
r
<<
" abs_diff="
<<
fabs
(
t
-
r
)
<<
" rel_diff="
<<
(
r
==
0
?
0.0
:
fabs
((
t
-
r
)
/
r
))
<<
std
::
endl
;
}
}
}
std
::
cout
<<
"DEBUG: Dumped tensor data to files: "
<<
test_file
<<
", "
<<
ref_file
<<
", "
<<
diff_file
<<
std
::
endl
;
}
else
{
std
::
cout
<<
"WARNING: Could not open files for tensor data dump"
<<
std
::
endl
;
}
}
void
print_detailed_tensor_comparison
(
const
std
::
string
&
name
,
const
fp4e2m1
*
test_data
,
const
fp4e2m1
*
ref_data
,
const
int
rows
,
const
int
cols
)
{
printf
(
"
\n
=== DETAILED COMPARISON for %s (%d×%d = %d elements) ===
\n
"
,
name
.
c_str
(),
rows
,
cols
,
rows
*
cols
);
const
int
total_elements
=
rows
*
cols
;
const
int
check_count
=
128
;
printf
(
"--- FIRST %d ELEMENTS ---
\n
"
,
check_count
);
printf
(
"Index | Test_Value | Ref_Value | Match
\n
"
);
printf
(
"------|---------------|---------------|-------
\n
"
);
for
(
int
i
=
0
;
i
<
std
::
min
(
check_count
,
total_elements
);
++
i
)
{
double2
test_pair
=
cvt_fp4x2_to_double2
(
*
reinterpret_cast
<
const
fp4e2m1x2
*>
(
&
test_data
[
i
/
2
]));
double2
ref_pair
=
cvt_fp4x2_to_double2
(
*
reinterpret_cast
<
const
fp4e2m1x2
*>
(
&
ref_data
[
i
/
2
]));
double
t
=
(
i
%
2
==
0
)
?
test_pair
.
x
:
test_pair
.
y
;
double
r
=
(
i
%
2
==
0
)
?
ref_pair
.
x
:
ref_pair
.
y
;
bool
match
=
(
fabs
(
t
-
r
)
<
1e-6
);
printf
(
"%5d | %13.6f | %13.6f | %s
\n
"
,
i
,
t
,
r
,
match
?
"✓"
:
"✗"
);
}
if
(
total_elements
>
2
*
check_count
)
{
printf
(
"
\n
--- LAST %d ELEMENTS ---
\n
"
,
check_count
);
printf
(
"Index | Test_Value | Ref_Value | Match
\n
"
);
printf
(
"------|---------------|---------------|-------
\n
"
);
for
(
int
i
=
total_elements
-
check_count
;
i
<
total_elements
;
++
i
)
{
double2
test_pair
=
cvt_fp4x2_to_double2
(
*
reinterpret_cast
<
const
fp4e2m1x2
*>
(
&
test_data
[
i
/
2
]));
double2
ref_pair
=
cvt_fp4x2_to_double2
(
*
reinterpret_cast
<
const
fp4e2m1x2
*>
(
&
ref_data
[
i
/
2
]));
double
t
=
(
i
%
2
==
0
)
?
test_pair
.
x
:
test_pair
.
y
;
double
r
=
(
i
%
2
==
0
)
?
ref_pair
.
x
:
ref_pair
.
y
;
bool
match
=
(
fabs
(
t
-
r
)
<
1e-6
);
printf
(
"%5d | %13.6f | %13.6f | %s
\n
"
,
i
,
t
,
r
,
match
?
"✓"
:
"✗"
);
}
}
printf
(
"==================================
\n
"
);
}
void
compareResults_nvfp4
(
const
Tensor
&
test
,
const
void
*
ref
,
const
void
*
ref_t
,
const
int
rows
,
const
int
cols
,
double
atol
=
1e-5
,
double
rtol
=
1e-8
,
bool
if_on_gpus
=
true
,
bool
dump_data
=
false
)
{
if
(
if_on_gpus
)
test
.
to_cpu
();
const
fp4e2m1
*
test_data
=
test
.
rowwise_cpu_dptr
<
fp4e2m1
>
();
const
fp4e2m1
*
test_data_t
=
test
.
columnwise_cpu_dptr
<
fp4e2m1
>
();
const
fp4e2m1
*
ref_data
=
reinterpret_cast
<
const
fp4e2m1
*>
(
ref
);
const
fp4e2m1
*
ref_data_t
=
reinterpret_cast
<
const
fp4e2m1
*>
(
ref_t
);
// Print detailed element-by-element comparison
// print_detailed_tensor_comparison("output", test_data, ref_data, rows, cols);
// print_detailed_tensor_comparison("output_t", test_data_t, ref_data_t, cols, rows);
// Optionally dump tensor data to files for detailed analysis
if
(
dump_data
)
{
dump_nvfp4_tensor_data
(
"output"
,
test_data
,
ref_data
,
rows
,
cols
);
dump_nvfp4_tensor_data
(
"output_t"
,
test_data_t
,
ref_data_t
,
cols
,
rows
);
}
compare_nvfp4_tensors
(
"output"
,
test_data
,
ref_data
,
rows
,
cols
,
atol
,
rtol
);
compare_nvfp4_tensors
(
"output_t"
,
test_data_t
,
ref_data_t
,
cols
,
rows
,
atol
,
rtol
);
}
template
<
typename
InputType
>
void
performTest
(
float
(
*
OP
)(
const
float
),
const
std
::
vector
<
size_t
>&
shape
)
{
using
namespace
test
;
DType
itype
=
TypeInfo
<
InputType
>::
dtype
;
DType
otype
=
DType
::
kFloat4E2M1
;
const
size_t
rows
=
first_dimension
(
shape
);
const
size_t
cols
=
last_dimension
(
shape
);
// Use get_scale_tensor_dims for NVFP4 scale tensor dimensions
// Now that CheckScaleTensorShape is fixed, this should work correctly
const
std
::
array
<
size_t
,
4
>
scale_dims
=
get_scale_tensor_dims
(
rows
,
cols
,
1
,
16
);
const
std
::
array
<
size_t
,
4
>
scale_dims_t
=
get_scale_tensor_dims
(
cols
,
rows
,
1
,
16
);
const
size_t
unpadded_blocks_Y
=
scale_dims
[
0
];
const
size_t
unpadded_blocks_X
=
scale_dims
[
1
];
const
size_t
blocks_Y
=
scale_dims
[
2
];
const
size_t
blocks_X
=
scale_dims
[
3
];
const
size_t
scales_stride
=
blocks_X
;
const
size_t
unpadded_blocks_Y_t
=
scale_dims_t
[
0
];
const
size_t
unpadded_blocks_X_t
=
scale_dims_t
[
1
];
const
size_t
blocks_Y_t
=
scale_dims_t
[
2
];
const
size_t
blocks_X_t
=
scale_dims_t
[
3
];
const
size_t
scales_stride_t
=
blocks_X_t
;
Tensor
input
(
"input"
,
shape
,
itype
);
Tensor
output
(
"output"
,
shape
,
otype
,
true
,
true
,
NVTE_NVFP4_1D_SCALING
);
std
::
unique_ptr
<
fp4e2m1x2
[]
>
ref_output
=
std
::
make_unique
<
fp4e2m1x2
[]
>
(
rows
*
(
cols
/
2
));
std
::
unique_ptr
<
fp4e2m1x2
[]
>
ref_output_t
=
std
::
make_unique
<
fp4e2m1x2
[]
>
(
cols
*
(
rows
/
2
));
std
::
unique_ptr
<
fp8e4m3
[]
>
ref_scales
=
std
::
make_unique
<
fp8e4m3
[]
>
(
blocks_Y
*
blocks_X
);
std
::
unique_ptr
<
fp8e4m3
[]
>
ref_scales_t
=
std
::
make_unique
<
fp8e4m3
[]
>
(
blocks_Y_t
*
blocks_X_t
);
fillCase
<
fp32
>
(
&
input
,
InputsFillCase
::
uniform
);
// Find global amax
float
amax
=
0.0
f
;
const
InputType
*
input_dptr
=
input
.
rowwise_cpu_dptr
<
InputType
>
();
for
(
size_t
i
=
0
;
i
<
rows
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
cols
;
++
j
)
{
const
size_t
idx
=
i
*
cols
+
j
;
amax
=
fmaxf
(
amax
,
static_cast
<
float
>
(
input_dptr
[
idx
]));
}
}
// Set 2nd stage NVFP4 scaling factor
output
.
set_scale
(
amax
);
bool
use_2d_quantization
=
false
;
compute_ref
<
InputType
>
(
OP
,
input
.
rowwise_cpu_dptr
<
InputType
>
(),
ref_output
.
get
(),
ref_output_t
.
get
(),
ref_scales
.
get
(),
ref_scales_t
.
get
(),
output
.
scale
(),
rows
,
cols
,
scales_stride
,
scales_stride_t
,
use_2d_quantization
);
QuantizationConfigWrapper
quant_config
;
// Initialize stochastic rounding
Tensor
rng_state
(
"rng_state"
,
std
::
vector
<
size_t
>
{
2
},
DType
::
kInt64
);
rng_state
.
rowwise_cpu_dptr
<
int64_t
>
()[
0
]
=
123
;
// rng_seed
rng_state
.
rowwise_cpu_dptr
<
int64_t
>
()[
1
]
=
321
;
// rng_sequence
rng_state
.
from_cpu
();
quant_config
.
set_stochastic_rounding
(
false
);
quant_config
.
set_rng_state
(
rng_state
.
data
());
// Set 2D quantization based on compile-time flag
quant_config
.
set_nvfp4_2d_quantization
(
use_2d_quantization
);
// Call appropriate function based on operation type
// Activation functions take 3 parameters (input, output, stream)
// nvte_quantize_v2 takes 4 parameters (input, output, quant_config, stream)
if
(
OP
==
&
gelu
)
{
nvte_gelu
(
input
.
data
(),
output
.
data
(),
0
);
}
else
if
(
OP
==
&
silu
)
{
nvte_silu
(
input
.
data
(),
output
.
data
(),
0
);
}
else
if
(
OP
==
&
relu
)
{
nvte_relu
(
input
.
data
(),
output
.
data
(),
0
);
}
else
if
(
OP
==
&
qgelu
)
{
nvte_qgelu
(
input
.
data
(),
output
.
data
(),
0
);
}
else
if
(
OP
==
&
srelu
)
{
nvte_srelu
(
input
.
data
(),
output
.
data
(),
0
);
}
else
{
nvte_quantize_v2
(
input
.
data
(),
output
.
data
(),
quant_config
,
0
);
}
cudaDeviceSynchronize
();
auto
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
printf
(
"DEBUG: CUDA error detected: %s
\n
"
,
cudaGetErrorString
(
err
));
}
ASSERT_EQ
(
err
,
cudaSuccess
)
<<
cudaGetErrorString
(
err
);
const
double
atol
=
0.05
;
const
double
rtol
=
0.1
;
// Set dump_data=true to enable dumping tensor data to files for analysis
compareResults_nvfp4
(
output
,
ref_output
.
get
(),
ref_output_t
.
get
(),
rows
,
cols
,
atol
,
rtol
,
true
,
false
);
const
fp8e4m3
*
kernel_scales
=
output
.
rowwise_cpu_scale_inv_ptr
<
fp8e4m3
>
();
const
fp8e4m3
*
ref_scales_ptr
=
ref_scales
.
get
();
const
fp8e4m3
*
kernel_scales_t
=
output
.
columnwise_cpu_scale_inv_ptr
<
fp8e4m3
>
();
const
fp8e4m3
*
ref_scales_t_ptr
=
ref_scales_t
.
get
();
size_t
scale_mismatches_num
=
0
;
compare_scaling_factors
<
fp8e4m3
>
(
"scales"
,
output
.
rowwise_cpu_scale_inv_ptr
<
fp8e4m3
>
(),
ref_scales
.
get
(),
unpadded_blocks_Y
,
unpadded_blocks_X
,
scales_stride
,
scale_mismatches_num
);
compare_scaling_factors
<
fp8e4m3
>
(
"scales_t"
,
output
.
columnwise_cpu_scale_inv_ptr
<
fp8e4m3
>
(),
ref_scales_t
.
get
(),
unpadded_blocks_Y_t
,
unpadded_blocks_X_t
,
scales_stride_t
,
scale_mismatches_num
);
}
std
::
vector
<
std
::
vector
<
size_t
>>
tensor_dims
=
{
{
32
,
32
},
{
32
,
64
},
{
64
,
32
},
{
64
,
96
},
{
128
,
128
},
{
256
,
256
},
{
512
,
512
},
{
1024
,
1024
},
{
2048
,
2048
},
{
128
,
256
},
{
8192
,
128
},
{
2048
,
160
},
{
8
,
32
,
1024
},
{
16
,
8
,
4
,
512
},
{
1024
,
16384
},
{
4096
,
13312
},
};
// Only GeLU activation tests are supported
std
::
vector
<
ActivationType
>
Activation_types
=
{
ActivationType
::
Identity
,
ActivationType
::
GeLU
,
ActivationType
::
SiLU
,
ActivationType
::
ReLU
,
ActivationType
::
QGeLU
,
ActivationType
::
SReLU
,
};
}
// namespace
class
FusedCastTransposeNVFP4TestSuite
:
public
::
testing
::
TestWithParam
<
std
::
tuple
<
ActivationType
,
std
::
vector
<
size_t
>
,
transformer_engine
::
DType
>>
{};
TEST_P
(
FusedCastTransposeNVFP4TestSuite
,
TestFusedCastTransposeNVFP4
)
{
// Skip tests for pre-Blackwell architectures
if
(
getDeviceComputeCapability
()
<
blackwellComputeCapability
)
{
GTEST_SKIP
();
}
using
namespace
transformer_engine
;
using
namespace
test
;
const
ActivationType
Act_type
=
std
::
get
<
0
>
(
GetParam
());
const
auto
tensor_dims
=
std
::
get
<
1
>
(
GetParam
());
const
DType
input_type
=
std
::
get
<
2
>
(
GetParam
());
// Skip tests if the input tensor is 1D
if
(
tensor_dims
.
size
()
<
2
)
{
GTEST_SKIP
();
}
// Forward activations
auto
OP
=
&
identity
;
switch
(
Act_type
)
{
case
ActivationType
::
GeLU
:
OP
=
&
gelu
;
break
;
case
ActivationType
::
SiLU
:
OP
=
&
silu
;
break
;
case
ActivationType
::
ReLU
:
OP
=
&
relu
;
break
;
case
ActivationType
::
QGeLU
:
OP
=
&
qgelu
;
break
;
case
ActivationType
::
SReLU
:
OP
=
&
srelu
;
break
;
}
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY
(
input_type
,
InputType
,
performTest
<
InputType
>
(
OP
,
tensor_dims
);
);
}
std
::
string
to_string
(
const
ActivationType
Act_type
)
{
switch
(
Act_type
)
{
case
ActivationType
::
Identity
:
return
"CAST_ONLY"
;
case
ActivationType
::
GeLU
:
return
"GeLU"
;
case
ActivationType
::
SiLU
:
return
"SiLU"
;
case
ActivationType
::
ReLU
:
return
"ReLU"
;
case
ActivationType
::
QGeLU
:
return
"QGeLU"
;
case
ActivationType
::
SReLU
:
return
"SReLU"
;
default:
return
""
;
}
}
INSTANTIATE_TEST_SUITE_P
(
OperatorTest
,
FusedCastTransposeNVFP4TestSuite
,
::
testing
::
Combine
(
::
testing
::
ValuesIn
(
Activation_types
),
::
testing
::
ValuesIn
(
tensor_dims
),
::
testing
::
Values
(
DType
::
kBFloat16
)),
[](
const
testing
::
TestParamInfo
<
FusedCastTransposeNVFP4TestSuite
::
ParamType
>&
info
)
{
std
::
string
name
=
to_string
(
std
::
get
<
0
>
(
info
.
param
));
const
auto
&
shape
=
std
::
get
<
1
>
(
info
.
param
);
for
(
const
auto
&
s
:
shape
)
{
name
+=
"X"
+
std
::
to_string
(
s
);
}
name
+=
"X"
+
test
::
typeName
(
std
::
get
<
2
>
(
info
.
param
));
return
name
;
});
tests/cpp/test_common.cu
View file @
063ef88d
...
...
@@ -111,6 +111,10 @@ size_t DIVUP(const size_t &x, const size_t &y){
return
(((
x
)
+
((
y
)
-
1
))
/
(
y
));
}
size_t
DIVUP_TO_MULTIPLE
(
const
size_t
&
x
,
const
size_t
&
y
){
return
DIVUP
(
x
,
y
)
*
y
;
}
struct
scale_inv_meta
{
std
::
vector
<
size_t
>
shape
;
DType
type
;
...
...
@@ -147,21 +151,71 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta
ret_rowwise
,
ret_colwise
;
auto
block_alignment
=
std
::
vector
<
size_t
>
{
128ul
,
4ul
};
{
auto
alignment
=
block_alignment
[
0
];
auto
scale_dim_0
=
DIVUP
(
DIVUP
(
first_dim
,
static_cast
<
size_t
>
(
1
)),
alignment
)
*
alignment
;
alignment
=
block_alignment
[
1
];
auto
scale_dim_1
=
DIVUP
(
DIVUP
(
last_dim
,
static_cast
<
size_t
>
(
32
)),
alignment
)
*
alignment
;
ret_rowwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
const
size_t
block_size_X_rowwise
=
32
;
size_t
scale_dim_Y_rowwise
=
DIVUP_TO_MULTIPLE
(
first_dim
,
scale_tensor_alignment_Y_rowwise
);
size_t
scale_dim_X_rowwise
=
DIVUP_TO_MULTIPLE
(
DIVUP
(
last_dim
,
block_size_X_rowwise
),
scale_tensor_alignment_X_rowwise
);
ret_rowwise
.
shape
=
{
scale_dim_Y_rowwise
,
scale_dim_X_rowwise
};
const
size_t
block_size_Y_colwise
=
32
;
size_t
scale_dim_Y_colwise
=
DIVUP_TO_MULTIPLE
(
DIVUP
(
first_dim
,
block_size_Y_colwise
),
scale_tensor_alignment_Y_colwise
);
size_t
scale_dim_X_colwise
=
DIVUP_TO_MULTIPLE
(
last_dim
,
scale_tensor_alignment_X_colwise
);
ret_colwise
.
shape
=
{
scale_dim_Y_colwise
,
scale_dim_X_colwise
};
ret_rowwise
.
type
=
DType
::
kFloat8E8M0
;
ret_rowwise
.
type_size_bits
=
typeToNumBits
(
DType
::
kFloat8E8M0
);
ret_colwise
.
type
=
DType
::
kFloat8E8M0
;
ret_colwise
.
type_size_bits
=
typeToNumBits
(
DType
::
kFloat8E8M0
);
return
{
ret_rowwise
,
ret_colwise
};
}
{
auto
alignment
=
block_alignment
[
1
];
auto
scale_dim_0
=
DIVUP
(
DIVUP
(
first_dim
,
static_cast
<
size_t
>
(
32
)),
alignment
)
*
alignment
;
alignment
=
block_alignment
[
0
];
auto
scale_dim_1
=
DIVUP
(
DIVUP
(
last_dim
,
static_cast
<
size_t
>
(
1
)),
alignment
)
*
alignment
;
ret_colwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
if
(
scaling_mode
==
NVTE_NVFP4_1D_SCALING
)
{
std
::
vector
<
size_t
>
shape_vec
;
for
(
size_t
i
=
0
;
i
<
shape
.
ndim
;
++
i
)
{
shape_vec
.
push_back
(
shape
.
data
[
i
]);
}
size_t
first_dim
=
first_dimension
(
shape_vec
);
size_t
last_dim
=
last_dimension
(
shape_vec
);
NVTE_CHECK
(
last_dim
%
32
==
0
);
NVTE_CHECK
(
first_dim
%
32
==
0
);
scale_inv_meta
ret_rowwise
,
ret_colwise
;
size_t
scale_dim_Y
=
DIVUP_TO_MULTIPLE
(
first_dim
,
scale_tensor_alignment_Y_rowwise
);
size_t
scale_dim_X
=
DIVUP_TO_MULTIPLE
(
DIVUP
(
last_dim
,
16lu
),
scale_tensor_alignment_X_rowwise
);
ret_rowwise
.
shape
=
{
scale_dim_Y
,
scale_dim_X
};
size_t
scale_dim_Y_t
=
DIVUP_TO_MULTIPLE
(
last_dim
,
scale_tensor_alignment_Y_rowwise
);
size_t
scale_dim_X_t
=
DIVUP_TO_MULTIPLE
(
DIVUP
(
first_dim
,
16lu
),
scale_tensor_alignment_X_rowwise
);
ret_colwise
.
shape
=
{
scale_dim_Y_t
,
scale_dim_X_t
};
ret_rowwise
.
type
=
DType
::
kFloat8E4M3
;
ret_rowwise
.
type_size_bits
=
typeToNumBits
(
DType
::
kFloat8E4M3
);
ret_colwise
.
type
=
DType
::
kFloat8E4M3
;
ret_colwise
.
type_size_bits
=
typeToNumBits
(
DType
::
kFloat8E4M3
);
return
{
ret_rowwise
,
ret_colwise
};
}
if
(
scaling_mode
==
NVTE_MXFP8_1D_SCALING
)
{
std
::
vector
<
size_t
>
shape_vec
;
for
(
size_t
i
=
0
;
i
<
shape
.
ndim
;
++
i
)
{
shape_vec
.
push_back
(
shape
.
data
[
i
]);
}
size_t
first_dim
=
first_dimension
(
shape_vec
);
size_t
last_dim
=
last_dimension
(
shape_vec
);
scale_inv_meta
ret_rowwise
,
ret_colwise
;
const
size_t
block_size_X_rowwise
=
32
;
size_t
scale_dim_Y_rowwise
=
DIVUP_TO_MULTIPLE
(
first_dim
,
scale_tensor_alignment_Y_rowwise
);
size_t
scale_dim_X_rowwise
=
DIVUP_TO_MULTIPLE
(
DIVUP
(
last_dim
,
block_size_X_rowwise
),
scale_tensor_alignment_X_rowwise
);
ret_rowwise
.
shape
=
{
scale_dim_Y_rowwise
,
scale_dim_X_rowwise
};
const
size_t
block_size_Y_colwise
=
32
;
size_t
scale_dim_Y_colwise
=
DIVUP_TO_MULTIPLE
(
DIVUP
(
first_dim
,
block_size_Y_colwise
),
scale_tensor_alignment_Y_colwise
);
size_t
scale_dim_X_colwise
=
DIVUP_TO_MULTIPLE
(
last_dim
,
scale_tensor_alignment_X_colwise
);
ret_colwise
.
shape
=
{
scale_dim_Y_colwise
,
scale_dim_X_colwise
};
ret_rowwise
.
type
=
DType
::
kFloat8E8M0
;
ret_colwise
.
type
=
DType
::
kFloat8E8M0
;
ret_rowwise
.
type_size_bits
=
typeToNumBits
(
DType
::
kFloat8E8M0
);
...
...
@@ -254,14 +308,15 @@ Tensor::Tensor(const std::string& name,
NVTEShape
columnwise_shape
=
{};
std
::
vector
<
size_t
>
columnwise_shape_vec
;
if
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
||
scaling_mode
==
NVTE_BLOCK_SCALING_1D
||
scaling_mode
==
NVTE_BLOCK_SCALING_2D
)
{
if
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
||
scaling_mode
==
NVTE_BLOCK_SCALING_1D
||
scaling_mode
==
NVTE_BLOCK_SCALING_2D
)
{
// Transpose when tensor scaling
columnwise_shape_vec
.
emplace_back
(
shape
.
data
[
shape
.
ndim
-
1
]);
for
(
size_t
i
=
0
;
i
<
shape
.
ndim
-
1
;
++
i
)
{
columnwise_shape_vec
.
emplace_back
(
shape
.
data
[
i
]);
}
}
else
{
// Same shape for MX
// Same shape for MX
and NVFP4
for
(
size_t
i
=
0
;
i
<
shape
.
ndim
;
++
i
)
{
columnwise_shape_vec
.
emplace_back
(
shape
.
data
[
i
]);
}
...
...
@@ -287,10 +342,13 @@ Tensor::Tensor(const std::string& name,
std
::
fill_n
(
cpu_data_columnwise_
.
get
(),
total_size
,
0
);
}
}
tensor_
.
set_rowwise_data
(
dptr_rowwise
,
type
,
shape
);
tensor_
.
set_columnwise_data
(
dptr_columnwise
,
type
,
columnwise_shape
);
if
(
isFp8Type
(
type
))
{
const
DType
rowwise_type
=
(
scaling_mode
==
NVTE_NVFP4_1D_SCALING
)
?
DType
::
kFloat4E2M1
:
type
;
const
DType
colwise_type
=
(
scaling_mode
==
NVTE_NVFP4_1D_SCALING
)
?
DType
::
kFloat4E2M1
:
type
;
tensor_
.
set_rowwise_data
(
dptr_rowwise
,
rowwise_type
,
shape
);
tensor_
.
set_columnwise_data
(
dptr_columnwise
,
colwise_type
,
columnwise_shape
);
if
(
isFp8Type
(
type
)
||
isFp4Type
(
type
))
{
if
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
)
{
cudaMalloc
((
void
**
)
&
amax
,
sizeof
(
float
));
// NOLINT(*)
cudaMemset
(
amax
,
0
,
sizeof
(
float
));
...
...
@@ -314,8 +372,14 @@ Tensor::Tensor(const std::string& name,
std
::
fill_n
(
columnwise_scale_inv_cpu_data_
.
get
(),
sizeof
(
float
),
0
);
}
}
else
{
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
normalized_shape
,
tensor_
.
scaling_mode
());
if
(
scaling_mode
==
NVTE_NVFP4_1D_SCALING
)
{
// Used for NVFP4 second stage scaling
cudaMalloc
((
void
**
)
&
scale
,
sizeof
(
float
));
// NOLINT(*)
cudaMemset
(
scale
,
0
,
sizeof
(
float
));
scale_cpu_data_
=
std
::
make_shared
<
float
>
(
0
);
tensor_
.
set_scale
(
scale
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
normalized_shape
,
tensor_
.
scaling_mode
());
auto
rowwise_scale_size
=
rowwise_scale_meta
.
bytes
();
auto
columnwise_scale_size
=
colwise_scale_meta
.
bytes
();
auto
scale_shape
=
rowwise_scale_meta
.
shape
;
...
...
@@ -350,13 +414,16 @@ void Tensor::to_cpu() const {
cudaMemcpyDeviceToHost
);
}
if
(
columnwise_
)
{
const
DType
colwise_type
=
tensor_
.
dtype
();
const
size_t
colwise_size
=
bytes
(
s
,
colwise_type
);
cudaMemcpy
(
cpu_data_columnwise_
.
get
(),
tensor_
.
get_columnwise_data
().
data_ptr
,
size
,
colwise_
size
,
cudaMemcpyDeviceToHost
);
}
if
(
isFp8Type
(
dtype
()))
{
if
(
tensor_
.
scaling_mode
()
==
NVTE_DELAYED_TENSOR_SCALING
)
{
if
(
isFp8Type
(
dtype
())
||
isFp4Type
(
dtype
())
)
{
if
(
(
tensor_
.
scaling_mode
()
==
NVTE_DELAYED_TENSOR_SCALING
)
)
{
if
(
tensor_
.
amax
()
!=
nullptr
){
cudaMemcpy
(
amax_cpu_data_
.
get
(),
tensor_
.
amax
(),
...
...
@@ -368,8 +435,7 @@ void Tensor::to_cpu() const {
sizeof
(
float
),
cudaMemcpyDeviceToHost
);
}
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
s
,
tensor_
.
scaling_mode
());
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
s
,
tensor_
.
scaling_mode
());
if
(
rowwise_
)
{
auto
scale_size
=
rowwise_scale_meta
.
bytes
();
cudaMemcpy
(
rowwise_scale_inv_cpu_data_
.
get
(),
...
...
@@ -398,15 +464,15 @@ void Tensor::from_cpu() const {
cudaMemcpy
(
tensor_
.
get_columnwise_data
().
data_ptr
,
cpu_data_columnwise_
.
get
(),
size
,
cudaMemcpyHostToDevice
);
}
if
(
isFp8Type
(
dtype
()))
{
if
(
tensor_
.
scaling_mode
()
==
NVTE_DELAYED_TENSOR_SCALING
)
{
if
(
isFp8Type
(
dtype
())
||
isFp4Type
(
dtype
()))
{
if
((
tensor_
.
scaling_mode
()
==
NVTE_DELAYED_TENSOR_SCALING
)
||
(
tensor_
.
scaling_mode
()
==
NVTE_NVFP4_1D_SCALING
))
{
if
(
tensor_
.
amax
()
!=
nullptr
){
cudaMemcpy
(
tensor_
.
amax
(),
amax_cpu_data_
.
get
(),
sizeof
(
float
),
cudaMemcpyHostToDevice
);
}
cudaMemcpy
(
tensor_
.
scale
(),
scale_cpu_data_
.
get
(),
sizeof
(
float
),
cudaMemcpyHostToDevice
);
}
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
s
,
tensor_
.
scaling_mode
());
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
s
,
tensor_
.
scaling_mode
());
if
(
rowwise_
)
{
auto
scale_size
=
rowwise_scale_meta
.
bytes
();
cudaMemcpy
(
tensor_
.
get_rowwise_scale_inv
().
data_ptr
,
...
...
@@ -423,7 +489,7 @@ void Tensor::from_cpu() const {
}
void
Tensor
::
set_scale
(
float
scale
)
{
if
(
isFp8Type
(
dtype
()))
{
if
(
isFp8Type
(
dtype
())
||
isFp4Type
(
dtype
())
)
{
NVTE_CHECK
(
scale_cpu_data_
);
if
(
tensor_
.
scaling_mode
()
==
NVTE_DELAYED_TENSOR_SCALING
)
{
*
scale_cpu_data_
=
scale
;
...
...
@@ -433,7 +499,7 @@ void Tensor::set_scale(float scale) {
}
void
Tensor
::
set_scale_inv
(
float
scale_inv
)
{
if
(
isFp8Type
(
dtype
()))
{
if
(
isFp8Type
(
dtype
())
||
isFp4Type
(
dtype
())
)
{
if
(
rowwise_
)
{
NVTE_CHECK
(
rowwise_scale_inv_cpu_data_
);
}
...
...
@@ -441,8 +507,7 @@ void Tensor::set_scale_inv(float scale_inv) {
NVTE_CHECK
(
columnwise_scale_inv_cpu_data_
);
}
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
tensor_
.
shape
(),
tensor_
.
scaling_mode
());
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
tensor_
.
shape
(),
tensor_
.
scaling_mode
());
if
(
rowwise_
)
{
auto
num_scales
=
product
(
rowwise_scale_meta
.
shape
);
if
(
num_scales
==
1
)
{
...
...
@@ -472,7 +537,8 @@ void Tensor::set_scale_inv(float scale_inv) {
}
void
Tensor
::
shareFP8Meta
(
const
Tensor
&
other
)
{
if
(
isFp8Type
(
dtype
())
&&
isFp8Type
(
other
.
dtype
()))
{
if
((
isFp8Type
(
dtype
())
&&
isFp8Type
(
other
.
dtype
()))
||
isFp4Type
(
dtype
())
&&
isFp4Type
(
other
.
dtype
()))
{
auto
new_tensor
=
TensorWrapper
(
other
.
tensor_
.
scaling_mode
());
auto
my_rowwise_data
=
tensor_
.
get_rowwise_data
();
new_tensor
.
set_rowwise_data
(
my_rowwise_data
.
data_ptr
,
static_cast
<
DType
>
(
my_rowwise_data
.
dtype
),
...
...
@@ -724,12 +790,30 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t
}
}
void
compare_e8m0_scaling_factors
(
const
std
::
string
&
name
,
const
uint8_t
*
test
,
const
uint8_t
*
ref
,
template
<
typename
T
>
struct
CastToType
;
template
<
>
struct
CastToType
<
uint8_t
>
{
using
type
=
int
;
};
template
<
>
struct
CastToType
<
fp8e4m3
>
{
using
type
=
float
;
};
template
<
typename
T
>
void
compare_scaling_factors
(
const
std
::
string
&
name
,
const
T
*
test
,
const
T
*
ref
,
const
size_t
row_blocks
,
const
size_t
col_blocks
,
const
size_t
stride
,
size_t
&
mismatches_num
,
const
size_t
atol
,
const
double
abs_tolerable_mismatches_limit
,
const
double
rel_tolerable_mismatches_limit
)
{
using
UpcastType
=
typename
CastToType
<
T
>::
type
;
auto
[
atol_fp8e4m3
,
rtol_fp8e4m3
]
=
getTolerances
(
DType
::
kFloat8E4M3
);
const
size_t
N
=
row_blocks
*
col_blocks
;
const
size_t
tolerable_mismatches_limit
=
std
::
min
(
abs_tolerable_mismatches_limit
,
std
::
floor
(
N
*
rel_tolerable_mismatches_limit
));
...
...
@@ -739,11 +823,31 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
for
(
int
i
=
0
;
i
<
row_blocks
;
++
i
)
{
for
(
int
j
=
0
;
j
<
col_blocks
;
++
j
)
{
const
int
idx
=
i
*
stride
+
j
;
const
int
test_val
=
static_cast
<
int
>
(
test
[
idx
])
;
const
int
ref_val
=
static_cast
<
int
>
(
ref
[
idx
]);
const
int
abs_delta
=
std
::
abs
(
test_val
-
ref_val
)
;
float
t
,
r
;
bool
assertion
=
false
;
if
(
abs_delta
>
atol
)
{
if
(
std
::
is_same
<
T
,
uint8_t
>::
value
)
{
t
=
static_cast
<
float
>
(
test
[
idx
]);
r
=
static_cast
<
float
>
(
ref
[
idx
]);
assertion
=
std
::
abs
(
t
-
r
)
>
atol
;
}
else
{
t
=
static_cast
<
float
>
(
*
reinterpret_cast
<
const
fp8e4m3
*>
(
&
test
[
idx
]));
r
=
static_cast
<
float
>
(
*
reinterpret_cast
<
const
fp8e4m3
*>
(
&
ref
[
idx
]));
const
bool
mismatch
=
(
fabs
(
t
-
r
)
>
atol_fp8e4m3
)
&&
(
r
==
0
||
fabs
((
t
-
r
)
/
r
)
>
rtol_fp8e4m3
);
if
(
mismatch
)
{
/* Check if it is just a failure of round to nearest choosing different
side of the real value */
const
double
mean
=
(
t
+
r
)
/
2
;
const
double
mean_p
=
mean
>=
0
?
mean
*
(
1
+
1e-6
)
:
mean
*
(
1
-
1e-6
);
const
double
mean_m
=
mean
>=
0
?
mean
*
(
1
-
1e-6
)
:
mean
*
(
1
+
1e-6
);
const
double
cast_mean_p
=
static_cast
<
double
>
(
static_cast
<
T
>
(
mean_p
));
const
double
cast_mean_m
=
static_cast
<
double
>
(
static_cast
<
T
>
(
mean_m
));
assertion
=
!
(
cast_mean_m
==
std
::
min
(
t
,
r
)
&&
cast_mean_p
==
std
::
max
(
t
,
r
));
}
}
if
(
assertion
)
{
mismatches_num
++
;
mismatch_indices
.
push_back
(
idx
);
}
...
...
@@ -751,8 +855,8 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
std
::
cout
<<
"Error in "
<<
name
<<
std
::
endl
;
for
(
const
int
index
:
mismatch_indices
)
{
std
::
cout
<<
"Mismatch at ("
<<
index
<<
"):"
<<
static_cast
<
int
>
(
test
[
index
])
<<
" vs "
<<
static_cast
<
int
>
(
ref
[
index
])
<<
std
::
endl
;
<<
static_cast
<
UpcastType
>
(
test
[
index
])
<<
" vs "
<<
static_cast
<
UpcastType
>
(
ref
[
index
])
<<
std
::
endl
;
}
GTEST_FAIL
()
<<
mismatches_num
<<
" mismatche(s) which is more than tolerable mismatch limit of "
<<
tolerable_mismatches_limit
<<
"."
;
...
...
@@ -761,6 +865,22 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
}
}
// Instantiate templates
template
void
compare_scaling_factors
<
uint8_t
>(
const
std
::
string
&
name
,
const
uint8_t
*
test
,
const
uint8_t
*
ref
,
const
size_t
row_blocks
,
const
size_t
col_blocks
,
const
size_t
stride
,
size_t
&
mismatches_num
,
const
size_t
atol
,
const
double
abs_tolerable_mismatches_limit
,
const
double
rel_tolerable_mismatches_limit
);
template
void
compare_scaling_factors
<
fp8e4m3
>(
const
std
::
string
&
name
,
const
fp8e4m3
*
test
,
const
fp8e4m3
*
ref
,
const
size_t
row_blocks
,
const
size_t
col_blocks
,
const
size_t
stride
,
size_t
&
mismatches_num
,
const
size_t
atol
,
const
double
abs_tolerable_mismatches_limit
,
const
double
rel_tolerable_mismatches_limit
);
std
::
pair
<
double
,
double
>
getTolerances
(
const
DType
type
)
{
switch
(
type
)
{
case
DType
::
kFloat32
:
...
...
@@ -920,6 +1040,10 @@ bool isFp8Type(DType type) {
return
type
==
DType
::
kFloat8E4M3
||
type
==
DType
::
kFloat8E5M2
||
type
==
DType
::
kFloat8E8M0
;
}
bool
isFp4Type
(
DType
type
)
{
return
type
==
DType
::
kFloat4E2M1
;
}
int32_t
getDeviceComputeCapability
()
{
cudaDeviceProp
deviceProp
;
cudaGetDeviceProperties
(
&
deviceProp
,
0
);
...
...
@@ -941,7 +1065,8 @@ std::array<size_t, 4> get_scale_tensor_dims(const size_t rows,
const
size_t
cols
,
const
size_t
block_size_rows
,
const
size_t
block_size_cols
)
{
const
bool
is_rowwise
=
(
block_size_rows
==
1
)
&&
(
block_size_cols
==
32
);
const
bool
is_rowwise
=
(
block_size_rows
==
1
)
&&
((
block_size_cols
==
32
)
||
(
block_size_cols
==
16
));
const
size_t
alignment_Y
=
is_rowwise
?
scale_tensor_alignment_Y_rowwise
...
...
Prev
1
2
3
4
5
6
7
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment