Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
...@@ -18,11 +18,11 @@ from flax.training import train_state ...@@ -18,11 +18,11 @@ from flax.training import train_state
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
DIR = str(Path(__file__).resolve().parents[1]) DIR = str(Path(__file__).resolve().parents[1])
sys.path.append(str(DIR)) sys.path.append(str(DIR))
from encoder.common import is_bf16_supported, get_fp8_recipe_from_name_string from encoder.common import is_bf16_supported, get_quantization_recipe_from_name_string
IMAGE_H = 28 IMAGE_H = 28
IMAGE_W = 28 IMAGE_W = 28
...@@ -189,12 +189,12 @@ def train_and_evaluate(args): ...@@ -189,12 +189,12 @@ def train_and_evaluate(args):
label_shape = [args.batch_size] label_shape = [args.batch_size]
if args.use_fp8: if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
else: else:
fp8_recipe = None fp8_recipe = None
with te.fp8_autocast( with te.autocast(
enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource() enabled=args.use_fp8, recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
): ):
cnn = Net(args.use_te) cnn = Net(args.use_te)
var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16)) var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16))
...@@ -308,8 +308,8 @@ def mnist_parser(args): ...@@ -308,8 +308,8 @@ def mnist_parser(args):
class TestMNIST(unittest.TestCase): class TestMNIST(unittest.TestCase):
"""MNIST unittests""" """MNIST unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
......
...@@ -68,7 +68,7 @@ def _parse_args(argv=None, namespace=None): ...@@ -68,7 +68,7 @@ def _parse_args(argv=None, namespace=None):
) )
parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
parser.add_argument( parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." "--fp8", action="store_true", default=False, help="Enables the te.autocast() context."
) )
parser.add_argument( parser.add_argument(
"--no-comm-overlap", "--no-comm-overlap",
...@@ -299,7 +299,7 @@ def _train(opts): ...@@ -299,7 +299,7 @@ def _train(opts):
dist_print(" |-- Forward pass", group=tp_group, debug=True) dist_print(" |-- Forward pass", group=tp_group, debug=True)
with torch.amp.autocast("cuda", dtype=torch.bfloat16): with torch.amp.autocast("cuda", dtype=torch.bfloat16):
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): with te.autocast(enabled=opts.fp8, recipe=fp8_recipe, amax_reduction_group=nccl_world):
y = model(x) y = model(x)
if isinstance(y, tuple): if isinstance(y, tuple):
out, *_ = y out, *_ = y
......
...@@ -49,5 +49,5 @@ $ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsd ...@@ -49,5 +49,5 @@ $ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsd
# ... # ...
``` ```
**NOTE:** This example has `fp8_autocast()` enabled by default. To run on GPUs without Fp8 support **NOTE:** This example has `autocast()` enabled by default. To run on GPUs without Fp8 support
(e.g.: A100), add the `--no-fp8` option to the commands shown above. (e.g.: A100), add the `--no-fp8` option to the commands shown above.
...@@ -173,7 +173,7 @@ def parse_fsdp_args(): ...@@ -173,7 +173,7 @@ def parse_fsdp_args():
"--no-fp8", "--no-fp8",
action="store_true", action="store_true",
default=False, default=False,
help="Disables the te.fp8_autocast() context.", help="Disables the te.autocast() context.",
) )
parser.add_argument( parser.add_argument(
"--no-defer-init", "--no-defer-init",
...@@ -284,11 +284,11 @@ def train(opts): ...@@ -284,11 +284,11 @@ def train(opts):
dtype=opts.dtype, dtype=opts.dtype,
device="cuda", device="cuda",
) )
# fp8_autocast needs to be given the FSDP process group for amax reductions # autocast needs to be given the FSDP process group for amax reductions
with te.fp8_autocast(enabled=not opts.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus): with te.autocast(enabled=not opts.no_fp8, recipe=fp8_recipe, amax_reduction_group=all_gpus):
y = te_model(x) y = te_model(x)
loss = y.sum() loss = y.sum()
# calculate gradient and take training step outside the fp8_autocast context # calculate gradient and take training step outside the autocast context
loss.backward() loss.backward()
optim.step() optim.step()
optim.zero_grad(set_to_none=True) optim.zero_grad(set_to_none=True)
......
...@@ -52,7 +52,7 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8): ...@@ -52,7 +52,7 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8):
for batch_idx, (data, target) in enumerate(train_loader): for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
optimizer.zero_grad() optimizer.zero_grad()
with te.fp8_autocast(enabled=use_fp8): with te.autocast(enabled=use_fp8):
output = model(data) output = model(data)
loss = F.nll_loss(output, target) loss = F.nll_loss(output, target)
loss.backward() loss.backward()
...@@ -76,7 +76,7 @@ def calibrate(model, device, test_loader, fp8): ...@@ -76,7 +76,7 @@ def calibrate(model, device, test_loader, fp8):
with torch.no_grad(): with torch.no_grad():
for data, target in test_loader: for data, target in test_loader:
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=fp8, calibrating=True): with te.autocast(enabled=fp8, calibrating=True):
output = model(data) output = model(data)
...@@ -88,7 +88,7 @@ def test(model, device, test_loader, use_fp8): ...@@ -88,7 +88,7 @@ def test(model, device, test_loader, use_fp8):
with torch.no_grad(): with torch.no_grad():
for data, target in test_loader: for data, target in test_loader:
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=use_fp8): with te.autocast(enabled=use_fp8):
output = model(data) output = model(data)
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[build-system]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "nvidia-mathdx==25.1.1", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"]
# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
...@@ -29,6 +29,10 @@ wait ...@@ -29,6 +29,10 @@ wait
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
wait wait
TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh"
wait
TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/collective_gemm/run_test_cgemm.sh || test_fail "run_test_cgemm.sh"
wait
if [ $RET -ne 0 ]; then if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES" echo "Error: some sub-tests failed: $FAILED_CASES"
......
...@@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" ...@@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
# Test without custom calls # Test without custom calls
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder_without_custom_call.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls"
if [ $RET -ne 0 ]; then if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES" echo "Error: some sub-tests failed: $FAILED_CASES"
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
: ${NVTE_TEST_NVINSPECT_CONFIGS_DIR:=$TE_PATH/tests/pytorch/debug/test_configs/} : ${NVTE_TEST_NVINSPECT_CONFIGS_DIR:=$TE_PATH/tests/pytorch/debug/test_configs/}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
# Config with the dummy feature which prevents nvinspect from being disabled. # Config with the dummy feature which prevents nvinspect from being disabled.
# Nvinspect will be disabled if no feature is active. # Nvinspect will be disabled if no feature is active.
...@@ -20,17 +22,16 @@ pip uninstall -y nvdlfw-inspect ...@@ -20,17 +22,16 @@ pip uninstall -y nvdlfw-inspect
pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git
pip install pytest==8.2.1 pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
# standard sanity and numerics tests with initialized debug # standard sanity and numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1
exit $FAIL exit $FAIL
...@@ -31,6 +31,7 @@ ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 ...@@ -31,6 +31,7 @@ ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0
ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
......
...@@ -9,4 +9,4 @@ set -xe ...@@ -9,4 +9,4 @@ set -xe
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
SCRIPT_NAME=test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh
...@@ -30,6 +30,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" ...@@ -30,6 +30,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --log-cli-level=INFO --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --log-cli-level=INFO --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
...@@ -47,9 +48,9 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_ ...@@ -47,9 +48,9 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_
: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} : ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_distributed.xml $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py"
# standard numerics tests with initialized debug # standard numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py"
if [ "$RET" -ne 0 ]; then if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES" echo "Error in the following test cases:$FAILED_CASES"
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
# See LICENSE for license information. # See LICENSE for license information.
pip3 install onnxruntime==1.20.1 pip3 install onnxruntime
pip3 install onnxruntime_extensions==0.13.0 pip3 install onnxruntime_extensions
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
python3 -m pytest --tb=auto $TE_PATH/tests/pytorch/test_onnx_export.py python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py
...@@ -23,6 +23,7 @@ from build_tools.utils import ( ...@@ -23,6 +23,7 @@ from build_tools.utils import (
cuda_version, cuda_version,
get_frameworks, get_frameworks,
remove_dups, remove_dups,
min_python_version_str,
) )
frameworks = get_frameworks() frameworks = get_frameworks()
...@@ -211,7 +212,7 @@ if __name__ == "__main__": ...@@ -211,7 +212,7 @@ if __name__ == "__main__":
long_description_content_type="text/x-rst", long_description_content_type="text/x-rst",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
python_requires=">=3.8", python_requires=f">={min_python_version_str()}",
classifiers=["Programming Language :: Python :: 3"], classifiers=["Programming Language :: Python :: 3"],
install_requires=install_requires, install_requires=install_requires,
license_files=("LICENSE",), license_files=("LICENSE",),
......
...@@ -66,6 +66,13 @@ else() ...@@ -66,6 +66,13 @@ else()
add_executable(test_operator ${test_hip_sources}) add_executable(test_operator ${test_hip_sources})
endif() endif()
# Add profiling and debug flags for CUDA compilation
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -lineinfo") # Generate line info for device code
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g") # Add debug symbols for host code
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --ptxas-options=-v") # Add info about registers usage
# Note: Using -lineinfo instead of -G to avoid conflicts and get line mapping
# Find required packages
find_package(OpenMP REQUIRED) find_package(OpenMP REQUIRED)
if(USE_CUDA) if(USE_CUDA)
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn)
......
...@@ -529,6 +529,12 @@ TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) { ...@@ -529,6 +529,12 @@ TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) {
q_opts.amax_epsilon = eps; q_opts.amax_epsilon = eps;
q_opts.block_scaling_dim = 2u; q_opts.block_scaling_dim = 2u;
// On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8,
// which requires using power of two scaling factors. Skip unsupported tests.
if (getDeviceComputeCapability() >= blackwellComputeCapability && !force_pow_2) {
GTEST_SKIP();
}
if (colwise && matrix_size.size() < 2) { if (colwise && matrix_size.size() < 2) {
// test_common Tensor initialization code does not // test_common Tensor initialization code does not
// handle this case. // handle this case.
...@@ -580,6 +586,12 @@ TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) { ...@@ -580,6 +586,12 @@ TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) {
q_opts.amax_epsilon = eps; q_opts.amax_epsilon = eps;
q_opts.block_scaling_dim = 1u; q_opts.block_scaling_dim = 1u;
// On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8,
// which requires using power of two scaling factors. Skip unsupported tests.
if (getDeviceComputeCapability() >= blackwellComputeCapability && !force_pow_2) {
GTEST_SKIP();
}
if (colwise && matrix_size.size() < 2) { if (colwise && matrix_size.size() < 2) {
// test_common Tensor initialization code does not // test_common Tensor initialization code does not
// handle this case. // handle this case.
......
...@@ -81,6 +81,7 @@ void compute_ref(const ProcessingMethod processing_method, ...@@ -81,6 +81,7 @@ void compute_ref(const ProcessingMethod processing_method,
// Cache computations // Cache computations
for (size_t i = i_min; i < i_max; ++i) { for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) { for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j; const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
...@@ -310,12 +311,13 @@ void performTest_x1(const ProcessingMethod processing_method, ...@@ -310,12 +311,13 @@ void performTest_x1(const ProcessingMethod processing_method,
const double rel_tolerable_mismatches_limit = 0.0; const double rel_tolerable_mismatches_limit = 0.0;
size_t mismatches_scales = 0; size_t mismatches_scales = 0;
compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride, compare_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
mismatches_scales, unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
scale_diff_abs_tolerance, mismatches_scales,
abs_tolerable_mismatches_limit, scale_diff_abs_tolerance,
rel_tolerable_mismatches_limit); abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
const size_t mismatches_elts = 32 * mismatches_scales; const size_t mismatches_elts = 32 * mismatches_scales;
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
...@@ -481,22 +483,22 @@ void performTest_x2(const ProcessingMethod processing_method, ...@@ -481,22 +483,22 @@ void performTest_x2(const ProcessingMethod processing_method,
const double rel_tolerable_mismatches_limit = 0.0; const double rel_tolerable_mismatches_limit = 0.0;
size_t mismatches_scales_rowwise = 0; size_t mismatches_scales_rowwise = 0;
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(), compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise, unpadded_blocks_X_rowwise, scales_stride_rowwise,
mismatches_scales_rowwise, mismatches_scales_rowwise,
scale_diff_abs_tolerance, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit); rel_tolerable_mismatches_limit);
size_t mismatches_scales_colwise = 0; size_t mismatches_scales_colwise = 0;
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(), compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise, ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise, unpadded_blocks_X_colwise, scales_stride_colwise,
mismatches_scales_colwise, mismatches_scales_colwise,
scale_diff_abs_tolerance, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit); rel_tolerable_mismatches_limit);
const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise; const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise;
const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise; const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise;
......
...@@ -267,19 +267,20 @@ void performTest_x1(const size_t rows, ...@@ -267,19 +267,20 @@ void performTest_x1(const size_t rows,
? output.rowwise_cpu_scale_inv_ptr<fp8e8m0>() ? output.rowwise_cpu_scale_inv_ptr<fp8e8m0>()
: output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(); : output.columnwise_cpu_scale_inv_ptr<fp8e8m0>();
if (rowwise) { if (rowwise) {
compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), compare_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride, unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales, mismatches_scales,
scale_diff_abs_tolerance, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit); rel_tolerable_mismatches_limit);
} else { } else {
compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), compare_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride, unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales, mismatches_scales,
scale_diff_abs_tolerance, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit); rel_tolerable_mismatches_limit);
} }
const size_t mismatches_elts = 32 * mismatches_scales; const size_t mismatches_elts = 32 * mismatches_scales;
...@@ -378,21 +379,22 @@ void performTest_x2(const size_t rows, ...@@ -378,21 +379,22 @@ void performTest_x2(const size_t rows,
const double rel_tolerable_mismatches_limit = 1.0e-4; const double rel_tolerable_mismatches_limit = 1.0e-4;
size_t mismatches_scales_rowwise = 0; size_t mismatches_scales_rowwise = 0;
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(), compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise, unpadded_blocks_X_rowwise, scales_stride_rowwise,
mismatches_scales_rowwise, mismatches_scales_rowwise,
scale_diff_abs_tolerance, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit); rel_tolerable_mismatches_limit);
size_t mismatches_scales_colwise = 0; size_t mismatches_scales_colwise = 0;
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(), compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise, ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise, unpadded_blocks_X_colwise, scales_stride_colwise,
mismatches_scales_colwise, mismatches_scales_colwise,
scale_diff_abs_tolerance, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit); rel_tolerable_mismatches_limit);
const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise; const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise;
const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise; const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise;
......
This diff is collapsed.
...@@ -111,6 +111,10 @@ size_t DIVUP(const size_t &x, const size_t &y){ ...@@ -111,6 +111,10 @@ size_t DIVUP(const size_t &x, const size_t &y){
return (((x) + ((y)-1)) / (y)); return (((x) + ((y)-1)) / (y));
} }
size_t DIVUP_TO_MULTIPLE(const size_t &x, const size_t &y){
return DIVUP(x, y) * y;
}
struct scale_inv_meta { struct scale_inv_meta {
std::vector<size_t> shape; std::vector<size_t> shape;
DType type; DType type;
...@@ -147,21 +151,71 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -147,21 +151,71 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret_rowwise, ret_colwise; scale_inv_meta ret_rowwise, ret_colwise;
auto block_alignment = std::vector<size_t>{128ul, 4ul}; const size_t block_size_X_rowwise = 32;
{ size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
auto alignment = block_alignment[0]; size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise);
auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(1)), alignment) * alignment; ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise};
alignment = block_alignment[1];
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(32)), alignment) * alignment; const size_t block_size_Y_colwise = 32;
ret_rowwise.shape = {scale_dim_0, scale_dim_1}; size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise);
size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise);
ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise};
ret_rowwise.type = DType::kFloat8E8M0;
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
ret_colwise.type = DType::kFloat8E8M0;
ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
return {ret_rowwise, ret_colwise};
}
if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; ++i) {
shape_vec.push_back(shape.data[i]);
} }
{ size_t first_dim = first_dimension(shape_vec);
auto alignment = block_alignment[1]; size_t last_dim = last_dimension(shape_vec);
auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(32)), alignment) * alignment;
alignment = block_alignment[0]; NVTE_CHECK(last_dim % 32 == 0);
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(1)), alignment) * alignment; NVTE_CHECK(first_dim % 32 == 0);
ret_colwise.shape = {scale_dim_0, scale_dim_1};
scale_inv_meta ret_rowwise, ret_colwise;
size_t scale_dim_Y = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
size_t scale_dim_X = DIVUP_TO_MULTIPLE(DIVUP(last_dim, 16lu), scale_tensor_alignment_X_rowwise);
ret_rowwise.shape = {scale_dim_Y, scale_dim_X};
size_t scale_dim_Y_t = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_Y_rowwise);
size_t scale_dim_X_t = DIVUP_TO_MULTIPLE(DIVUP(first_dim, 16lu), scale_tensor_alignment_X_rowwise);
ret_colwise.shape = {scale_dim_Y_t, scale_dim_X_t};
ret_rowwise.type = DType::kFloat8E4M3;
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3);
ret_colwise.type = DType::kFloat8E4M3;
ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3);
return {ret_rowwise, ret_colwise};
}
if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; ++i) {
shape_vec.push_back(shape.data[i]);
} }
size_t first_dim = first_dimension(shape_vec);
size_t last_dim = last_dimension(shape_vec);
scale_inv_meta ret_rowwise, ret_colwise;
const size_t block_size_X_rowwise = 32;
size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise);
ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise};
const size_t block_size_Y_colwise = 32;
size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise);
size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise);
ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise};
ret_rowwise.type = DType::kFloat8E8M0; ret_rowwise.type = DType::kFloat8E8M0;
ret_colwise.type = DType::kFloat8E8M0; ret_colwise.type = DType::kFloat8E8M0;
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
...@@ -254,14 +308,15 @@ Tensor::Tensor(const std::string& name, ...@@ -254,14 +308,15 @@ Tensor::Tensor(const std::string& name,
NVTEShape columnwise_shape = {}; NVTEShape columnwise_shape = {};
std::vector<size_t> columnwise_shape_vec; std::vector<size_t> columnwise_shape_vec;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING
|| scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
// Transpose when tensor scaling // Transpose when tensor scaling
columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]); columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
for (size_t i = 0; i < shape.ndim - 1; ++i) { for (size_t i = 0; i < shape.ndim - 1; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]); columnwise_shape_vec.emplace_back(shape.data[i]);
} }
} else { } else {
// Same shape for MX // Same shape for MX and NVFP4
for (size_t i = 0; i < shape.ndim; ++i) { for (size_t i = 0; i < shape.ndim; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]); columnwise_shape_vec.emplace_back(shape.data[i]);
} }
...@@ -287,10 +342,13 @@ Tensor::Tensor(const std::string& name, ...@@ -287,10 +342,13 @@ Tensor::Tensor(const std::string& name,
std::fill_n(cpu_data_columnwise_.get(), total_size, 0); std::fill_n(cpu_data_columnwise_.get(), total_size, 0);
} }
} }
tensor_.set_rowwise_data(dptr_rowwise, type, shape);
tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape);
if (isFp8Type(type)) { const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape);
tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape);
if (isFp8Type(type) || isFp4Type(type)) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*) cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*)
cudaMemset(amax, 0, sizeof(float)); cudaMemset(amax, 0, sizeof(float));
...@@ -309,13 +367,19 @@ Tensor::Tensor(const std::string& name, ...@@ -309,13 +367,19 @@ Tensor::Tensor(const std::string& name,
} }
if (columnwise) { if (columnwise) {
tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32, tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32,
std::vector<size_t>{1}); std::vector<size_t>{1});
columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(sizeof(float)); columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(sizeof(float));
std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0); std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
} }
} else { } else {
auto [rowwise_scale_meta, colwise_scale_meta] = if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
get_scales(normalized_shape, tensor_.scaling_mode()); // Used for NVFP4 second stage scaling
cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*)
cudaMemset(scale, 0, sizeof(float));
scale_cpu_data_ = std::make_shared<float>(0);
tensor_.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
}
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, tensor_.scaling_mode());
auto rowwise_scale_size = rowwise_scale_meta.bytes(); auto rowwise_scale_size = rowwise_scale_meta.bytes();
auto columnwise_scale_size = colwise_scale_meta.bytes(); auto columnwise_scale_size = colwise_scale_meta.bytes();
auto scale_shape = rowwise_scale_meta.shape; auto scale_shape = rowwise_scale_meta.shape;
...@@ -350,13 +414,16 @@ void Tensor::to_cpu() const { ...@@ -350,13 +414,16 @@ void Tensor::to_cpu() const {
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
} }
if (columnwise_) { if (columnwise_) {
const DType colwise_type = tensor_.dtype();
const size_t colwise_size = bytes(s, colwise_type);
cudaMemcpy(cpu_data_columnwise_.get(), cudaMemcpy(cpu_data_columnwise_.get(),
tensor_.get_columnwise_data().data_ptr, tensor_.get_columnwise_data().data_ptr,
size, colwise_size,
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
} }
if (isFp8Type(dtype())) { if (isFp8Type(dtype()) || isFp4Type(dtype())) {
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)) {
if (tensor_.amax() != nullptr){ if (tensor_.amax() != nullptr){
cudaMemcpy(amax_cpu_data_.get(), cudaMemcpy(amax_cpu_data_.get(),
tensor_.amax(), tensor_.amax(),
...@@ -368,8 +435,7 @@ void Tensor::to_cpu() const { ...@@ -368,8 +435,7 @@ void Tensor::to_cpu() const {
sizeof(float), sizeof(float),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
} }
auto [rowwise_scale_meta, colwise_scale_meta] = auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
get_scales(s, tensor_.scaling_mode());
if (rowwise_) { if (rowwise_) {
auto scale_size = rowwise_scale_meta.bytes(); auto scale_size = rowwise_scale_meta.bytes();
cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
...@@ -398,15 +464,15 @@ void Tensor::from_cpu() const { ...@@ -398,15 +464,15 @@ void Tensor::from_cpu() const {
cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size, cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size,
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
} }
if (isFp8Type(dtype())) { if (isFp8Type(dtype()) || isFp4Type(dtype())) {
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)
|| (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING)) {
if (tensor_.amax() != nullptr){ if (tensor_.amax() != nullptr){
cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
} }
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
} }
auto [rowwise_scale_meta, colwise_scale_meta] = auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
get_scales(s, tensor_.scaling_mode());
if (rowwise_) { if (rowwise_) {
auto scale_size = rowwise_scale_meta.bytes(); auto scale_size = rowwise_scale_meta.bytes();
cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
...@@ -423,7 +489,7 @@ void Tensor::from_cpu() const { ...@@ -423,7 +489,7 @@ void Tensor::from_cpu() const {
} }
void Tensor::set_scale(float scale) { void Tensor::set_scale(float scale) {
if (isFp8Type(dtype())) { if (isFp8Type(dtype()) || isFp4Type(dtype())) {
NVTE_CHECK(scale_cpu_data_); NVTE_CHECK(scale_cpu_data_);
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
*scale_cpu_data_ = scale; *scale_cpu_data_ = scale;
...@@ -433,7 +499,7 @@ void Tensor::set_scale(float scale) { ...@@ -433,7 +499,7 @@ void Tensor::set_scale(float scale) {
} }
void Tensor::set_scale_inv(float scale_inv) { void Tensor::set_scale_inv(float scale_inv) {
if (isFp8Type(dtype())) { if (isFp8Type(dtype()) || isFp4Type(dtype())) {
if (rowwise_) { if (rowwise_) {
NVTE_CHECK(rowwise_scale_inv_cpu_data_); NVTE_CHECK(rowwise_scale_inv_cpu_data_);
} }
...@@ -441,8 +507,7 @@ void Tensor::set_scale_inv(float scale_inv) { ...@@ -441,8 +507,7 @@ void Tensor::set_scale_inv(float scale_inv) {
NVTE_CHECK(columnwise_scale_inv_cpu_data_); NVTE_CHECK(columnwise_scale_inv_cpu_data_);
} }
auto [rowwise_scale_meta, colwise_scale_meta] = auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode());
get_scales(tensor_.shape(), tensor_.scaling_mode());
if (rowwise_) { if (rowwise_) {
auto num_scales = product(rowwise_scale_meta.shape); auto num_scales = product(rowwise_scale_meta.shape);
if (num_scales == 1) { if (num_scales == 1) {
...@@ -472,7 +537,8 @@ void Tensor::set_scale_inv(float scale_inv) { ...@@ -472,7 +537,8 @@ void Tensor::set_scale_inv(float scale_inv) {
} }
void Tensor::shareFP8Meta(const Tensor &other) { void Tensor::shareFP8Meta(const Tensor &other) {
if (isFp8Type(dtype()) && isFp8Type(other.dtype())) { if ((isFp8Type(dtype()) && isFp8Type(other.dtype()))
|| isFp4Type(dtype()) && isFp4Type(other.dtype())) {
auto new_tensor = TensorWrapper(other.tensor_.scaling_mode()); auto new_tensor = TensorWrapper(other.tensor_.scaling_mode());
auto my_rowwise_data = tensor_.get_rowwise_data(); auto my_rowwise_data = tensor_.get_rowwise_data();
new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast<DType>(my_rowwise_data.dtype), new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast<DType>(my_rowwise_data.dtype),
...@@ -724,12 +790,30 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t ...@@ -724,12 +790,30 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t
} }
} }
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, template <typename T>
const size_t row_blocks, const size_t col_blocks, const size_t stride, struct CastToType;
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit, template <>
const double rel_tolerable_mismatches_limit) struct CastToType<uint8_t> {
using type = int;
};
template <>
struct CastToType<fp8e4m3> {
using type = float;
};
template <typename T>
void compare_scaling_factors(const std::string &name, const T *test, const T *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit,
const double rel_tolerable_mismatches_limit)
{ {
using UpcastType = typename CastToType<T>::type;
auto [atol_fp8e4m3, rtol_fp8e4m3] = getTolerances(DType::kFloat8E4M3);
const size_t N = row_blocks * col_blocks; const size_t N = row_blocks * col_blocks;
const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit, const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit,
std::floor(N * rel_tolerable_mismatches_limit)); std::floor(N * rel_tolerable_mismatches_limit));
...@@ -739,11 +823,31 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, ...@@ -739,11 +823,31 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
for (int i = 0; i < row_blocks; ++i) { for (int i = 0; i < row_blocks; ++i) {
for (int j = 0; j < col_blocks; ++j) { for (int j = 0; j < col_blocks; ++j) {
const int idx = i * stride + j; const int idx = i * stride + j;
const int test_val = static_cast<int>(test[idx]); float t, r;
const int ref_val = static_cast<int>(ref[idx]);
const int abs_delta = std::abs(test_val - ref_val); bool assertion = false;
if (abs_delta > atol) { if (std::is_same<T, uint8_t>::value) {
t = static_cast<float>(test[idx]);
r = static_cast<float>(ref[idx]);
assertion = std::abs(t - r) > atol;
} else {
t = static_cast<float>(*reinterpret_cast<const fp8e4m3*>(&test[idx]));
r = static_cast<float>(*reinterpret_cast<const fp8e4m3*>(&ref[idx]));
const bool mismatch = (fabs(t - r) > atol_fp8e4m3)
&& (r == 0 || fabs((t - r) / r) > rtol_fp8e4m3);
if (mismatch) {
/* Check if it is just a failure of round to nearest choosing different
side of the real value */
const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
}
}
if (assertion) {
mismatches_num++; mismatches_num++;
mismatch_indices.push_back(idx); mismatch_indices.push_back(idx);
} }
...@@ -751,8 +855,8 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, ...@@ -751,8 +855,8 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
std::cout << "Error in " << name << std::endl; std::cout << "Error in " << name << std::endl;
for (const int index : mismatch_indices) { for (const int index : mismatch_indices) {
std::cout << "Mismatch at (" << index << "):" std::cout << "Mismatch at (" << index << "):"
<< static_cast<int>(test[index]) << " vs " << static_cast<UpcastType>(test[index]) << " vs "
<< static_cast<int>(ref[index]) << std::endl; << static_cast<UpcastType>(ref[index]) << std::endl;
} }
GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
<< tolerable_mismatches_limit << "."; << tolerable_mismatches_limit << ".";
...@@ -761,6 +865,22 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, ...@@ -761,6 +865,22 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
} }
} }
// Instantiate templates
template
void compare_scaling_factors<uint8_t>(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit,
const double rel_tolerable_mismatches_limit);
template
void compare_scaling_factors<fp8e4m3>(const std::string &name, const fp8e4m3 *test, const fp8e4m3 *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit,
const double rel_tolerable_mismatches_limit);
std::pair<double, double> getTolerances(const DType type) { std::pair<double, double> getTolerances(const DType type) {
switch(type) { switch(type) {
case DType::kFloat32: case DType::kFloat32:
...@@ -920,6 +1040,10 @@ bool isFp8Type(DType type) { ...@@ -920,6 +1040,10 @@ bool isFp8Type(DType type) {
return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0;
} }
bool isFp4Type(DType type) {
return type == DType::kFloat4E2M1;
}
int32_t getDeviceComputeCapability() { int32_t getDeviceComputeCapability() {
cudaDeviceProp deviceProp; cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0); cudaGetDeviceProperties(&deviceProp, 0);
...@@ -941,7 +1065,8 @@ std::array<size_t, 4> get_scale_tensor_dims(const size_t rows, ...@@ -941,7 +1065,8 @@ std::array<size_t, 4> get_scale_tensor_dims(const size_t rows,
const size_t cols, const size_t cols,
const size_t block_size_rows, const size_t block_size_rows,
const size_t block_size_cols) { const size_t block_size_cols) {
const bool is_rowwise = (block_size_rows == 1) && (block_size_cols == 32); const bool is_rowwise = (block_size_rows == 1)
&& ((block_size_cols == 32) || (block_size_cols == 16));
const size_t alignment_Y = is_rowwise const size_t alignment_Y = is_rowwise
? scale_tensor_alignment_Y_rowwise ? scale_tensor_alignment_Y_rowwise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment