Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
......@@ -18,11 +18,11 @@ from flax.training import train_state
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
DIR = str(Path(__file__).resolve().parents[1])
sys.path.append(str(DIR))
from encoder.common import is_bf16_supported, get_fp8_recipe_from_name_string
from encoder.common import is_bf16_supported, get_quantization_recipe_from_name_string
IMAGE_H = 28
IMAGE_W = 28
......@@ -189,12 +189,12 @@ def train_and_evaluate(args):
label_shape = [args.batch_size]
if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe)
fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
else:
fp8_recipe = None
with te.fp8_autocast(
enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
with te.autocast(
enabled=args.use_fp8, recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
):
cnn = Net(args.use_te)
var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16))
......@@ -308,8 +308,8 @@ def mnist_parser(args):
class TestMNIST(unittest.TestCase):
"""MNIST unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
@classmethod
def setUpClass(cls):
......
......@@ -68,7 +68,7 @@ def _parse_args(argv=None, namespace=None):
)
parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
"--fp8", action="store_true", default=False, help="Enables the te.autocast() context."
)
parser.add_argument(
"--no-comm-overlap",
......@@ -299,7 +299,7 @@ def _train(opts):
dist_print(" |-- Forward pass", group=tp_group, debug=True)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
with te.autocast(enabled=opts.fp8, recipe=fp8_recipe, amax_reduction_group=nccl_world):
y = model(x)
if isinstance(y, tuple):
out, *_ = y
......
......@@ -49,5 +49,5 @@ $ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsd
# ...
```
**NOTE:** This example has `fp8_autocast()` enabled by default. To run on GPUs without Fp8 support
**NOTE:** This example has `autocast()` enabled by default. To run on GPUs without Fp8 support
(e.g.: A100), add the `--no-fp8` option to the commands shown above.
......@@ -173,7 +173,7 @@ def parse_fsdp_args():
"--no-fp8",
action="store_true",
default=False,
help="Disables the te.fp8_autocast() context.",
help="Disables the te.autocast() context.",
)
parser.add_argument(
"--no-defer-init",
......@@ -284,11 +284,11 @@ def train(opts):
dtype=opts.dtype,
device="cuda",
)
# fp8_autocast needs to be given the FSDP process group for amax reductions
with te.fp8_autocast(enabled=not opts.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus):
# autocast needs to be given the FSDP process group for amax reductions
with te.autocast(enabled=not opts.no_fp8, recipe=fp8_recipe, amax_reduction_group=all_gpus):
y = te_model(x)
loss = y.sum()
# calculate gradient and take training step outside the fp8_autocast context
# calculate gradient and take training step outside the autocast context
loss.backward()
optim.step()
optim.zero_grad(set_to_none=True)
......
......@@ -52,7 +52,7 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8):
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
with te.fp8_autocast(enabled=use_fp8):
with te.autocast(enabled=use_fp8):
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
......@@ -76,7 +76,7 @@ def calibrate(model, device, test_loader, fp8):
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=fp8, calibrating=True):
with te.autocast(enabled=fp8, calibrating=True):
output = model(data)
......@@ -88,7 +88,7 @@ def test(model, device, test_loader, use_fp8):
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=use_fp8):
with te.autocast(enabled=use_fp8):
output = model(data)
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[build-system]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "nvidia-mathdx==25.1.1", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"]
# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
......@@ -29,6 +29,10 @@ wait
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
wait
TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh"
wait
TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/collective_gemm/run_test_cgemm.sh || test_fail "run_test_cgemm.sh"
wait
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
......
......@@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
# Test without custom calls
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls"
NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder_without_custom_call.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls"
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
......
......@@ -7,6 +7,8 @@
: ${TE_PATH:=/opt/transformerengine}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
: ${NVTE_TEST_NVINSPECT_CONFIGS_DIR:=$TE_PATH/tests/pytorch/debug/test_configs/}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
# Config with the dummy feature which prevents nvinspect from being disabled.
# Nvinspect will be disabled if no feature is active.
......@@ -20,17 +22,16 @@ pip uninstall -y nvdlfw-inspect
pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git
pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
# standard sanity and numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1
exit $FAIL
......@@ -31,6 +31,7 @@ ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0
ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
......
......@@ -9,4 +9,4 @@ set -xe
mkdir -p "$XML_LOG_DIR"
NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
SCRIPT_NAME=test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh
SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh
......@@ -30,6 +30,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --log-cli-level=INFO --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
......@@ -47,9 +48,9 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_
: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py"
pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_distributed.xml $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py"
# standard numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py"
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py"
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
......
......@@ -3,9 +3,11 @@
# See LICENSE for license information.
pip3 install onnxruntime==1.20.1
pip3 install onnxruntime_extensions==0.13.0
pip3 install onnxruntime
pip3 install onnxruntime_extensions
: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
python3 -m pytest --tb=auto $TE_PATH/tests/pytorch/test_onnx_export.py
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py
......@@ -23,6 +23,7 @@ from build_tools.utils import (
cuda_version,
get_frameworks,
remove_dups,
min_python_version_str,
)
frameworks = get_frameworks()
......@@ -211,7 +212,7 @@ if __name__ == "__main__":
long_description_content_type="text/x-rst",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
python_requires=">=3.8",
python_requires=f">={min_python_version_str()}",
classifiers=["Programming Language :: Python :: 3"],
install_requires=install_requires,
license_files=("LICENSE",),
......
......@@ -66,6 +66,13 @@ else()
add_executable(test_operator ${test_hip_sources})
endif()
# Add profiling and debug flags for CUDA compilation
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -lineinfo") # Generate line info for device code
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g") # Add debug symbols for host code
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --ptxas-options=-v") # Add info about registers usage
# Note: Using -lineinfo instead of -G to avoid conflicts and get line mapping
# Find required packages
find_package(OpenMP REQUIRED)
if(USE_CUDA)
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn)
......
......@@ -529,6 +529,12 @@ TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) {
q_opts.amax_epsilon = eps;
q_opts.block_scaling_dim = 2u;
// On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8,
// which requires using power of two scaling factors. Skip unsupported tests.
if (getDeviceComputeCapability() >= blackwellComputeCapability && !force_pow_2) {
GTEST_SKIP();
}
if (colwise && matrix_size.size() < 2) {
// test_common Tensor initialization code does not
// handle this case.
......@@ -580,6 +586,12 @@ TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) {
q_opts.amax_epsilon = eps;
q_opts.block_scaling_dim = 1u;
// On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8,
// which requires using power of two scaling factors. Skip unsupported tests.
if (getDeviceComputeCapability() >= blackwellComputeCapability && !force_pow_2) {
GTEST_SKIP();
}
if (colwise && matrix_size.size() < 2) {
// test_common Tensor initialization code does not
// handle this case.
......
......@@ -81,6 +81,7 @@ void compute_ref(const ProcessingMethod processing_method,
// Cache computations
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
......@@ -310,12 +311,13 @@ void performTest_x1(const ProcessingMethod processing_method,
const double rel_tolerable_mismatches_limit = 0.0;
size_t mismatches_scales = 0;
compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
compare_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
const size_t mismatches_elts = 32 * mismatches_scales;
auto [atol, rtol] = getTolerances(otype);
......@@ -481,22 +483,22 @@ void performTest_x2(const ProcessingMethod processing_method,
const double rel_tolerable_mismatches_limit = 0.0;
size_t mismatches_scales_rowwise = 0;
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise,
mismatches_scales_rowwise,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise,
mismatches_scales_rowwise,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
size_t mismatches_scales_colwise = 0;
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise,
mismatches_scales_colwise,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise,
mismatches_scales_colwise,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise;
const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise;
......
......@@ -267,19 +267,20 @@ void performTest_x1(const size_t rows,
? output.rowwise_cpu_scale_inv_ptr<fp8e8m0>()
: output.columnwise_cpu_scale_inv_ptr<fp8e8m0>();
if (rowwise) {
compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
compare_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
} else {
compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
compare_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
}
const size_t mismatches_elts = 32 * mismatches_scales;
......@@ -378,21 +379,22 @@ void performTest_x2(const size_t rows,
const double rel_tolerable_mismatches_limit = 1.0e-4;
size_t mismatches_scales_rowwise = 0;
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise,
mismatches_scales_rowwise,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise,
mismatches_scales_rowwise,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
size_t mismatches_scales_colwise = 0;
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise,
mismatches_scales_colwise,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise,
mismatches_scales_colwise,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise;
const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise;
......
This diff is collapsed.
......@@ -111,6 +111,10 @@ size_t DIVUP(const size_t &x, const size_t &y){
return (((x) + ((y)-1)) / (y));
}
size_t DIVUP_TO_MULTIPLE(const size_t &x, const size_t &y){
return DIVUP(x, y) * y;
}
struct scale_inv_meta {
std::vector<size_t> shape;
DType type;
......@@ -147,21 +151,71 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret_rowwise, ret_colwise;
auto block_alignment = std::vector<size_t>{128ul, 4ul};
{
auto alignment = block_alignment[0];
auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(1)), alignment) * alignment;
alignment = block_alignment[1];
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(32)), alignment) * alignment;
ret_rowwise.shape = {scale_dim_0, scale_dim_1};
const size_t block_size_X_rowwise = 32;
size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise);
ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise};
const size_t block_size_Y_colwise = 32;
size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise);
size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise);
ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise};
ret_rowwise.type = DType::kFloat8E8M0;
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
ret_colwise.type = DType::kFloat8E8M0;
ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
return {ret_rowwise, ret_colwise};
}
if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; ++i) {
shape_vec.push_back(shape.data[i]);
}
{
auto alignment = block_alignment[1];
auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(32)), alignment) * alignment;
alignment = block_alignment[0];
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(1)), alignment) * alignment;
ret_colwise.shape = {scale_dim_0, scale_dim_1};
size_t first_dim = first_dimension(shape_vec);
size_t last_dim = last_dimension(shape_vec);
NVTE_CHECK(last_dim % 32 == 0);
NVTE_CHECK(first_dim % 32 == 0);
scale_inv_meta ret_rowwise, ret_colwise;
size_t scale_dim_Y = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
size_t scale_dim_X = DIVUP_TO_MULTIPLE(DIVUP(last_dim, 16lu), scale_tensor_alignment_X_rowwise);
ret_rowwise.shape = {scale_dim_Y, scale_dim_X};
size_t scale_dim_Y_t = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_Y_rowwise);
size_t scale_dim_X_t = DIVUP_TO_MULTIPLE(DIVUP(first_dim, 16lu), scale_tensor_alignment_X_rowwise);
ret_colwise.shape = {scale_dim_Y_t, scale_dim_X_t};
ret_rowwise.type = DType::kFloat8E4M3;
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3);
ret_colwise.type = DType::kFloat8E4M3;
ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3);
return {ret_rowwise, ret_colwise};
}
if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; ++i) {
shape_vec.push_back(shape.data[i]);
}
size_t first_dim = first_dimension(shape_vec);
size_t last_dim = last_dimension(shape_vec);
scale_inv_meta ret_rowwise, ret_colwise;
const size_t block_size_X_rowwise = 32;
size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise);
ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise};
const size_t block_size_Y_colwise = 32;
size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise);
size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise);
ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise};
ret_rowwise.type = DType::kFloat8E8M0;
ret_colwise.type = DType::kFloat8E8M0;
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
......@@ -254,14 +308,15 @@ Tensor::Tensor(const std::string& name,
NVTEShape columnwise_shape = {};
std::vector<size_t> columnwise_shape_vec;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING
|| scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
// Transpose when tensor scaling
columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
for (size_t i = 0; i < shape.ndim - 1; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
}
} else {
// Same shape for MX
// Same shape for MX and NVFP4
for (size_t i = 0; i < shape.ndim; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
}
......@@ -287,10 +342,13 @@ Tensor::Tensor(const std::string& name,
std::fill_n(cpu_data_columnwise_.get(), total_size, 0);
}
}
tensor_.set_rowwise_data(dptr_rowwise, type, shape);
tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape);
if (isFp8Type(type)) {
const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape);
tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape);
if (isFp8Type(type) || isFp4Type(type)) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*)
cudaMemset(amax, 0, sizeof(float));
......@@ -309,13 +367,19 @@ Tensor::Tensor(const std::string& name,
}
if (columnwise) {
tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32,
std::vector<size_t>{1});
std::vector<size_t>{1});
columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(sizeof(float));
std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
}
} else {
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(normalized_shape, tensor_.scaling_mode());
if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
// Used for NVFP4 second stage scaling
cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*)
cudaMemset(scale, 0, sizeof(float));
scale_cpu_data_ = std::make_shared<float>(0);
tensor_.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
}
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, tensor_.scaling_mode());
auto rowwise_scale_size = rowwise_scale_meta.bytes();
auto columnwise_scale_size = colwise_scale_meta.bytes();
auto scale_shape = rowwise_scale_meta.shape;
......@@ -350,13 +414,16 @@ void Tensor::to_cpu() const {
cudaMemcpyDeviceToHost);
}
if (columnwise_) {
const DType colwise_type = tensor_.dtype();
const size_t colwise_size = bytes(s, colwise_type);
cudaMemcpy(cpu_data_columnwise_.get(),
tensor_.get_columnwise_data().data_ptr,
size,
cudaMemcpyDeviceToHost);
tensor_.get_columnwise_data().data_ptr,
colwise_size,
cudaMemcpyDeviceToHost);
}
if (isFp8Type(dtype())) {
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
if (isFp8Type(dtype()) || isFp4Type(dtype())) {
if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)) {
if (tensor_.amax() != nullptr){
cudaMemcpy(amax_cpu_data_.get(),
tensor_.amax(),
......@@ -368,8 +435,7 @@ void Tensor::to_cpu() const {
sizeof(float),
cudaMemcpyDeviceToHost);
}
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode());
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
if (rowwise_) {
auto scale_size = rowwise_scale_meta.bytes();
cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
......@@ -398,15 +464,15 @@ void Tensor::from_cpu() const {
cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size,
cudaMemcpyHostToDevice);
}
if (isFp8Type(dtype())) {
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
if (isFp8Type(dtype()) || isFp4Type(dtype())) {
if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)
|| (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING)) {
if (tensor_.amax() != nullptr){
cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
}
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
}
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode());
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
if (rowwise_) {
auto scale_size = rowwise_scale_meta.bytes();
cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
......@@ -423,7 +489,7 @@ void Tensor::from_cpu() const {
}
void Tensor::set_scale(float scale) {
if (isFp8Type(dtype())) {
if (isFp8Type(dtype()) || isFp4Type(dtype())) {
NVTE_CHECK(scale_cpu_data_);
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
*scale_cpu_data_ = scale;
......@@ -433,7 +499,7 @@ void Tensor::set_scale(float scale) {
}
void Tensor::set_scale_inv(float scale_inv) {
if (isFp8Type(dtype())) {
if (isFp8Type(dtype()) || isFp4Type(dtype())) {
if (rowwise_) {
NVTE_CHECK(rowwise_scale_inv_cpu_data_);
}
......@@ -441,8 +507,7 @@ void Tensor::set_scale_inv(float scale_inv) {
NVTE_CHECK(columnwise_scale_inv_cpu_data_);
}
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(tensor_.shape(), tensor_.scaling_mode());
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode());
if (rowwise_) {
auto num_scales = product(rowwise_scale_meta.shape);
if (num_scales == 1) {
......@@ -472,7 +537,8 @@ void Tensor::set_scale_inv(float scale_inv) {
}
void Tensor::shareFP8Meta(const Tensor &other) {
if (isFp8Type(dtype()) && isFp8Type(other.dtype())) {
if ((isFp8Type(dtype()) && isFp8Type(other.dtype()))
|| isFp4Type(dtype()) && isFp4Type(other.dtype())) {
auto new_tensor = TensorWrapper(other.tensor_.scaling_mode());
auto my_rowwise_data = tensor_.get_rowwise_data();
new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast<DType>(my_rowwise_data.dtype),
......@@ -724,12 +790,30 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t
}
}
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit,
const double rel_tolerable_mismatches_limit)
template <typename T>
struct CastToType;
template <>
struct CastToType<uint8_t> {
using type = int;
};
template <>
struct CastToType<fp8e4m3> {
using type = float;
};
template <typename T>
void compare_scaling_factors(const std::string &name, const T *test, const T *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit,
const double rel_tolerable_mismatches_limit)
{
using UpcastType = typename CastToType<T>::type;
auto [atol_fp8e4m3, rtol_fp8e4m3] = getTolerances(DType::kFloat8E4M3);
const size_t N = row_blocks * col_blocks;
const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit,
std::floor(N * rel_tolerable_mismatches_limit));
......@@ -739,11 +823,31 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
for (int i = 0; i < row_blocks; ++i) {
for (int j = 0; j < col_blocks; ++j) {
const int idx = i * stride + j;
const int test_val = static_cast<int>(test[idx]);
const int ref_val = static_cast<int>(ref[idx]);
const int abs_delta = std::abs(test_val - ref_val);
float t, r;
bool assertion = false;
if (abs_delta > atol) {
if (std::is_same<T, uint8_t>::value) {
t = static_cast<float>(test[idx]);
r = static_cast<float>(ref[idx]);
assertion = std::abs(t - r) > atol;
} else {
t = static_cast<float>(*reinterpret_cast<const fp8e4m3*>(&test[idx]));
r = static_cast<float>(*reinterpret_cast<const fp8e4m3*>(&ref[idx]));
const bool mismatch = (fabs(t - r) > atol_fp8e4m3)
&& (r == 0 || fabs((t - r) / r) > rtol_fp8e4m3);
if (mismatch) {
/* Check if it is just a failure of round to nearest choosing different
side of the real value */
const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
}
}
if (assertion) {
mismatches_num++;
mismatch_indices.push_back(idx);
}
......@@ -751,8 +855,8 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
std::cout << "Error in " << name << std::endl;
for (const int index : mismatch_indices) {
std::cout << "Mismatch at (" << index << "):"
<< static_cast<int>(test[index]) << " vs "
<< static_cast<int>(ref[index]) << std::endl;
<< static_cast<UpcastType>(test[index]) << " vs "
<< static_cast<UpcastType>(ref[index]) << std::endl;
}
GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
<< tolerable_mismatches_limit << ".";
......@@ -761,6 +865,22 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
}
}
// Instantiate templates
template
void compare_scaling_factors<uint8_t>(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit,
const double rel_tolerable_mismatches_limit);
template
void compare_scaling_factors<fp8e4m3>(const std::string &name, const fp8e4m3 *test, const fp8e4m3 *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit,
const double rel_tolerable_mismatches_limit);
std::pair<double, double> getTolerances(const DType type) {
switch(type) {
case DType::kFloat32:
......@@ -920,6 +1040,10 @@ bool isFp8Type(DType type) {
return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0;
}
bool isFp4Type(DType type) {
return type == DType::kFloat4E2M1;
}
int32_t getDeviceComputeCapability() {
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0);
......@@ -941,7 +1065,8 @@ std::array<size_t, 4> get_scale_tensor_dims(const size_t rows,
const size_t cols,
const size_t block_size_rows,
const size_t block_size_cols) {
const bool is_rowwise = (block_size_rows == 1) && (block_size_cols == 32);
const bool is_rowwise = (block_size_rows == 1)
&& ((block_size_cols == 32) || (block_size_cols == 16));
const size_t alignment_Y = is_rowwise
? scale_tensor_alignment_Y_rowwise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment