Commit 9df0c4a3 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents 0d874a4e f122b07d
...@@ -23,8 +23,6 @@ Additionally, we offer [Jupyter notebook tutorials](https://github.com/NVIDIA/Tr ...@@ -23,8 +23,6 @@ Additionally, we offer [Jupyter notebook tutorials](https://github.com/NVIDIA/Tr
- **FP8 Weight Caching**: Avoiding redundant FP8 casting during multiple gradient accumulation steps to improve efficiency. - **FP8 Weight Caching**: Avoiding redundant FP8 casting during multiple gradient accumulation steps to improve efficiency.
- [Introduction to FP8](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/fp8_primer.ipynb) - [Introduction to FP8](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/fp8_primer.ipynb)
- Overview of FP8 datatypes (E4M3, E5M2), mixed precision training, delayed scaling strategies, and code examples for FP8 configuration and usage. - Overview of FP8 datatypes (E4M3, E5M2), mixed precision training, delayed scaling strategies, and code examples for FP8 configuration and usage.
- [TE Quickstart](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/quickstart.ipynb)
- Introduction to TE, building a Transformer Layer using PyTorch, and instructions on integrating TE modules like Linear and LayerNorm.
- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/mnist) - [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/mnist)
# JAX # JAX
...@@ -34,7 +32,9 @@ Additionally, we offer [Jupyter notebook tutorials](https://github.com/NVIDIA/Tr ...@@ -34,7 +32,9 @@ Additionally, we offer [Jupyter notebook tutorials](https://github.com/NVIDIA/Tr
- Model Parallelism: Divide a model across multiple GPUs for parallel training. - Model Parallelism: Divide a model across multiple GPUs for parallel training.
- Multiprocessing with Model Parallelism: Multiprocessing for model parallelism, including multi-node support and hardware affinity setup. - Multiprocessing with Model Parallelism: Multiprocessing for model parallelism, including multi-node support and hardware affinity setup.
- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/jax/mnist) - [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/jax/mnist)
- [TE JAX Integration Tutorial](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_jax_integration.ipynb)
- Introduction to integrating TE into an existing JAX model framework, building a Transformer Layer, and instructions on integrating TE modules like Linear and LayerNorm.
# Third party # Third party
- [Hugging Face Accelerate + TE](https://github.com/huggingface/accelerate/tree/main/benchmarks/fp8/transformer_engine) - [Hugging Face Accelerate + TE](https://github.com/huggingface/accelerate/tree/main/benchmarks/fp8/transformer_engine)
- Scripts for training with Accelerate and TE. Supports single GPU, and multi-GPU via DDP, FSDP, and DeepSpeed ZeRO 1-3. - Scripts for training with Accelerate and TE. Supports single GPU, and multi-GPU via DDP, FSDP, and DeepSpeed ZeRO 1-3.
# Datasets used by TE encoder tests. Pull these to pre-emptively cache datasets
ylecun/mnist
nyu-mll/glue
\ No newline at end of file
...@@ -219,11 +219,11 @@ def get_datasets(max_seq_len): ...@@ -219,11 +219,11 @@ def get_datasets(max_seq_len):
vocab = {} vocab = {}
word_id = 0 word_id = 0
train_ds = load_dataset("glue", "cola", split="train") train_ds = load_dataset("nyu-mll/glue", "cola", split="train")
train_ds.set_format(type="np") train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset("glue", "cola", split="validation") test_ds = load_dataset("nyu-mll/glue", "cola", split="validation")
test_ds.set_format(type="np") test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id return train_ds, test_ds, word_id
......
...@@ -197,11 +197,11 @@ def get_datasets(max_seq_len): ...@@ -197,11 +197,11 @@ def get_datasets(max_seq_len):
vocab = {} vocab = {}
word_id = 0 word_id = 0
train_ds = load_dataset("glue", "cola", split="train") train_ds = load_dataset("nyu-mll/glue", "cola", split="train")
train_ds.set_format(type="np") train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset("glue", "cola", split="validation") test_ds = load_dataset("nyu-mll/glue", "cola", split="validation")
test_ds.set_format(type="np") test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id return train_ds, test_ds, word_id
......
...@@ -307,11 +307,11 @@ def get_datasets(max_seq_len): ...@@ -307,11 +307,11 @@ def get_datasets(max_seq_len):
vocab = {} vocab = {}
word_id = 0 word_id = 0
train_ds = load_dataset("glue", "cola", split="train") train_ds = load_dataset("nyu-mll/glue", "cola", split="train")
train_ds.set_format(type="np") train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset("glue", "cola", split="validation") test_ds = load_dataset("nyu-mll/glue", "cola", split="validation")
test_ds.set_format(type="np") test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id return train_ds, test_ds, word_id
......
...@@ -195,11 +195,11 @@ def get_datasets(max_seq_len): ...@@ -195,11 +195,11 @@ def get_datasets(max_seq_len):
vocab = {} vocab = {}
word_id = 0 word_id = 0
train_ds = load_dataset("glue", "cola", split="train") train_ds = load_dataset("nyu-mll/glue", "cola", split="train")
train_ds.set_format(type="np") train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset("glue", "cola", split="validation") test_ds = load_dataset("nyu-mll/glue", "cola", split="validation")
test_ds.set_format(type="np") test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id return train_ds, test_ds, word_id
......
...@@ -146,7 +146,7 @@ def eval_model(state, test_ds, batch_size, var_collect): ...@@ -146,7 +146,7 @@ def eval_model(state, test_ds, batch_size, var_collect):
def get_datasets(): def get_datasets():
"""Load MNIST train and test datasets into memory.""" """Load MNIST train and test datasets into memory."""
train_ds = load_dataset("mnist", split="train", trust_remote_code=True) train_ds = load_dataset("ylecun/mnist", split="train", trust_remote_code=True)
train_ds.set_format(type="np") train_ds.set_format(type="np")
batch_size = train_ds["image"].shape[0] batch_size = train_ds["image"].shape[0]
shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C) shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
...@@ -154,7 +154,7 @@ def get_datasets(): ...@@ -154,7 +154,7 @@ def get_datasets():
"image": train_ds["image"].astype(np.float32).reshape(shape) / 255.0, "image": train_ds["image"].astype(np.float32).reshape(shape) / 255.0,
"label": train_ds["label"], "label": train_ds["label"],
} }
test_ds = load_dataset("mnist", split="test", trust_remote_code=True) test_ds = load_dataset("ylecun/mnist", split="test", trust_remote_code=True)
test_ds.set_format(type="np") test_ds.set_format(type="np")
batch_size = test_ds["image"].shape[0] batch_size = test_ds["image"].shape[0]
shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C) shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
......
...@@ -29,6 +29,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" ...@@ -29,6 +29,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_with_determinism.xml $TE_PATH/tests/jax/test_fused_attn.py -k "TestFusedAttnWithDeterminism" || test_fail "tests/jax/test_fused_attn.py"
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
......
...@@ -51,11 +51,14 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_e ...@@ -51,11 +51,14 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_e
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
mkdir -p $TE_PATH/artifacts/tests/pytorch/test_checkpoint && python $TE_PATH/tests/pytorch/test_checkpoint.py --save-checkpoint all --checkpoint-dir $TE_PATH/artifacts/tests/pytorch/test_checkpoint/ export NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" if [ ! -d "$NVTE_TEST_CHECKPOINT_ARTIFACT_PATH" ]; then
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_batched_linear.xml $TE_PATH/tests/pytorch/test_batched_linear.py || test_fail "test_batched_linear.py" python3 $TE_PATH/tests/pytorch/test_checkpoint.py --save-checkpoint all || error_exit "Failed to generate checkpoint files"
fi
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py"
......
...@@ -28,11 +28,11 @@ WHL_BASE="transformer_engine-${VERSION}" ...@@ -28,11 +28,11 @@ WHL_BASE="transformer_engine-${VERSION}"
# Core wheel. # Core wheel.
NVTE_RELEASE_BUILD=1 pip3 wheel --no-build-isolation -vvv --wheel-dir ./dist . || error_exit "Failed to setup bdist_wheel" NVTE_RELEASE_BUILD=1 pip3 wheel --no-build-isolation -vvv --wheel-dir ./dist . || error_exit "Failed to setup bdist_wheel"
wheel unpack dist/${WHL_BASE}-* || error_exit "Failed to unpack dist/${WHL_BASE}-*.whl" python3 -m wheel unpack dist/${WHL_BASE}-* || error_exit "Failed to unpack dist/${WHL_BASE}-*.whl"
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info" mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}" python3 -m wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}"
rm dist/*.whl || error_exit "Failed to remove dist/*.whl" rm dist/*.whl || error_exit "Failed to remove dist/*.whl"
mv *.whl dist/ || error_exit "Failed to move *.whl to dist/" mv *.whl dist/ || error_exit "Failed to move *.whl to dist/"
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 pip3 wheel --no-build-isolation --no-deps -vvv --wheel-dir ./dist . || error_exit "Failed to setup metapackage" NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 pip3 wheel --no-build-isolation --no-deps -vvv --wheel-dir ./dist . || error_exit "Failed to setup metapackage"
......
...@@ -6,4 +6,5 @@ ...@@ -6,4 +6,5 @@
: ${XML_LOG_DIR:=/logs} : ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py # NVTE_UnfusedDPA_Emulate_FP8=1 enables FP8 attention emulation when no native backend is available
NVTE_UnfusedDPA_Emulate_FP8=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py
...@@ -106,11 +106,6 @@ def setup_common_extension() -> CMakeExtension: ...@@ -106,11 +106,6 @@ def setup_common_extension() -> CMakeExtension:
f"nvidia-cublasmp-cu{cuda_version()[0]}" f"nvidia-cublasmp-cu{cuda_version()[0]}"
).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}") ).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}")
cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}") cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}")
nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution(
f"nvidia-nvshmem-cu{cuda_version()[0]}"
).locate_file("nvidia/nvshmem")
cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")
print("CMAKE_FLAGS:", cmake_flags[-2:])
# Add custom CMake arguments from environment variable # Add custom CMake arguments from environment variable
nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
......
...@@ -11,6 +11,7 @@ list(APPEND test_cuda_sources ...@@ -11,6 +11,7 @@ list(APPEND test_cuda_sources
test_cast_mxfp8_gated_swiglu.cu test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu test_qdq.cu
test_cast_mxfp8.cu test_cast_mxfp8.cu
test_cast_mxfp8_grouped.cu
test_cast_float8blockwise.cu test_cast_float8blockwise.cu
test_dequantize_mxfp8.cu test_dequantize_mxfp8.cu
test_transpose.cu test_transpose.cu
...@@ -29,6 +30,7 @@ list(APPEND test_cuda_sources ...@@ -29,6 +30,7 @@ list(APPEND test_cuda_sources
test_causal_softmax.cu test_causal_softmax.cu
test_swizzle.cu test_swizzle.cu
test_swap_first_dims.cu test_swap_first_dims.cu
test_grouped_gemm.cu
../test_common.cu) ../test_common.cu)
if(USE_ROCM) if(USE_ROCM)
list(APPEND test_cuda_sources list(APPEND test_cuda_sources
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/activation.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
using namespace transformer_engine;
using namespace test;
namespace {
enum ProcessingMethod {
CAST_ONLY,
CAST_DBIAS,
CAST_DBIAS_DACT,
CAST_DACT,
CAST_ACT
};
enum ActivationKind {
Identity,
GeLU,
SiLU,
ReLU,
QGeLU,
SReLU
};
enum ShapeRepresentation {
SAME_BOTH_DIMS = 0,
VARYING_FIRST_DIM = 1,
VARYING_LAST_DIM = 2,
VARYING_BOTH_DIMS = 3
};
template <typename InputType, typename OutputType>
void compute_ref(const ProcessingMethod processing_method,
float (*OP)(const float),
const bool rowwise,
const bool colwise,
const InputType* input,
const InputType* grad,
OutputType* output_rowwise,
OutputType* output_colwise,
fp8e8m0* output_scales_rowwise,
fp8e8m0* output_scales_colwise,
InputType* output_dbias,
const size_t rows,
const size_t cols,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise,
const bool is_single_tensor)
{
const size_t tile_size_Y = 32;
const size_t tile_size_X = 32;
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
std::vector<float> output_dbias_fp32(cols, 0);
#pragma omp parallel proc_bind(spread)
{
// Buffers to cache intermediate computations
std::vector<float> cache_buffer(tile_size_Y * tile_size_X);
std::vector<float> thread_dbias(cols, 0);
#pragma omp for schedule(static)
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
const size_t tile_Y = t / tiles_num_X;
const size_t tile_X = t % tiles_num_X;
const size_t tile_offset_Y = tile_Y * tile_size_Y;
const size_t tile_offset_X = tile_X * tile_size_X;
const size_t i_min = tile_offset_Y;
const size_t i_max = std::min(i_min + tile_size_Y, rows);
const size_t j_min = tile_offset_X;
const size_t j_max = std::min(j_min + tile_size_X, cols);
// Cache computations
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
float elt = static_cast<float>(input[idx]);
if (processing_method == ProcessingMethod::CAST_DBIAS) {
// grad is the input
elt = static_cast<float>(grad[idx]);
}
if (processing_method != ProcessingMethod::CAST_ONLY
&& processing_method != ProcessingMethod::CAST_DBIAS) {
elt = OP(elt);
}
if (processing_method == ProcessingMethod::CAST_DACT ||
processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
elt *= static_cast<float>(grad[idx]);
}
thread_dbias[j] += elt;
// Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32
elt = static_cast<float>(static_cast<InputType>(elt));
cache_buffer[cache_idx] = elt;
if (isinf(elt) || isnan(elt)) {
continue;
}
}
}
if (rowwise) {
for (size_t i = i_min; i < i_max; ++i) {
float block_amax = 0.0f;
for (size_t j = j_min; j < j_max; ++j) {
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
}
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const size_t scale_idx = i * scales_stride_rowwise + tile_X;
output_scales_rowwise[scale_idx] = biased_exponent;
const float scale_reciprocal = exp2f_rcp(biased_exponent);
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
output_rowwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
}
}
}
if (colwise) {
for (size_t j = j_min; j < j_max; ++j) {
float block_amax = 0.0f;
for (size_t i = i_min; i < i_max; ++i) {
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
}
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const size_t scale_idx = tile_Y * scales_stride_colwise + j;
output_scales_colwise[scale_idx] = biased_exponent;
const float scale_reciprocal = exp2f_rcp(biased_exponent);
for (size_t i = i_min; i < i_max; ++i) {
const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
output_colwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
}
}
}
}
#pragma omp critical
{
for (size_t j = 0; j < cols; ++j) {
output_dbias_fp32[j] += thread_dbias[j];
}
}
}
if (is_single_tensor) {
for (size_t j = 0; j < cols; ++j) {
output_dbias[j] = static_cast<InputType>(output_dbias_fp32[j]);
}
}
}
template <typename T>
void compare_scaled_elts(const std::string &name,
const T* ref_data,
const T* test_data,
const size_t rows,
const size_t cols,
const bool rowwise,
const size_t tolerable_mismatches_limit = 0,
const double atol = 1e-5,
const double rtol = 1e-8) {
size_t mismatches_num = 0;
int first_mismatch_idx = -1;
for (size_t i = 0; i < rows * cols; ++i) {
double t = static_cast<double>(test_data[i]);
double r = static_cast<double>(ref_data[i]);
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
/* For Float32 the floating point comparison is enough to error out */
bool assertion = false;
if (mismatch && !assertion) {
/* Check if it is just a failure of round to nearest choosing different
side of the real value */
const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
}
std::string direction = rowwise ? "rowwise" : "columnwise";
if (assertion) {
mismatches_num++;
if (first_mismatch_idx == -1) {
first_mismatch_idx = i;
}
}
if (mismatches_num > tolerable_mismatches_limit) {
const double first_mismatch_t = static_cast<double>(test_data[first_mismatch_idx]);
const double first_mismatch_r = static_cast<double>(ref_data[first_mismatch_idx]);
GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
<< tolerable_mismatches_limit << "." << std::endl
<< "Error in tensor " << name << " in "
<< direction << " direction." << std::endl
<< "First mismatch at place " << first_mismatch_idx
<< " (" << std::to_string(first_mismatch_idx) << "): "
<< first_mismatch_t << " vs " << first_mismatch_r;
}
}
}
/**
* Scaling along single dimension (either rows or columns)
* Produces one set of output data and the corresponding data of the fused operation (dbias):
* 1) Scaled rows + row-wise scaling factors
* OR
* 2) Scaled columns + column-wise scaling factors
*/
template <typename InputType, typename OutputType>
void performTest(const ProcessingMethod processing_method,
float (*OP)(const float),
const ShapeRepresentation shape_rep,
const size_t num_tensors,
const std::vector<size_t>& logical_shape_vec,
const std::vector<size_t>& first_dims_h,
const std::vector<size_t>& last_dims_h,
const std::vector<size_t>& offsets_h,
const bool rowwise,
const bool colwise) {
using namespace test;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
const size_t rows = logical_shape_vec[0];
const size_t cols = logical_shape_vec[1];
size_t elts_num = 0;
size_t rowwise_sfs_num = 0;
size_t colwise_sfs_num = 0;
std::vector<size_t> rowwise_scales_first_dim(num_tensors, 0);
std::vector<size_t> rowwise_scales_last_dim(num_tensors, 0);
std::vector<size_t> rowwise_scales_offset(num_tensors + 1, 0);
std::vector<size_t> colwise_scales_first_dim(num_tensors, 0);
std::vector<size_t> colwise_scales_last_dim(num_tensors, 0);
std::vector<size_t> colwise_scales_offset(num_tensors + 1, 0);
for (size_t t = 0; t < num_tensors; ++t) {
const size_t M = first_dims_h[t];
const size_t K = last_dims_h[t];
const size_t elts = M * K;
elts_num += elts;
const size_t unpadded_rowwise_blocks_Y = M;
const size_t unpadded_rowwise_blocks_X = divide_round_up(K, 32);
const size_t unpadded_colwise_blocks_Y = divide_round_up(M, 32);
const size_t unpadded_colwise_blocks_X = K;
rowwise_scales_first_dim[t] = round_up_to_nearest_multiple(unpadded_rowwise_blocks_Y, 128);
rowwise_scales_last_dim[t] = round_up_to_nearest_multiple(unpadded_rowwise_blocks_X, 4);
colwise_scales_first_dim[t] = round_up_to_nearest_multiple(unpadded_colwise_blocks_Y, 4);
colwise_scales_last_dim[t] = round_up_to_nearest_multiple(unpadded_colwise_blocks_X, 128);
const size_t rowwise_sfs = rowwise_scales_first_dim[t] * rowwise_scales_last_dim[t];
const size_t colwise_sfs = colwise_scales_first_dim[t] * colwise_scales_last_dim[t];
rowwise_sfs_num += rowwise_sfs;
colwise_sfs_num += colwise_sfs;
rowwise_scales_offset[t+1] = rowwise_sfs_num;
colwise_scales_offset[t+1] = colwise_sfs_num;
}
const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM);
std::vector<size_t> scales_rowwise_shape = {rowwise_sfs_num};
std::vector<size_t> scales_colwise_shape = {colwise_sfs_num};
std::mt19937 gen;
std::uniform_real_distribution<> dis(-2.0, 1.0);
std::vector<InputType> in_data(elts_num);
std::vector<InputType> grad_data(elts_num);
std::vector<OutputType> out_data_rowwise_h(rowwise ? elts_num : 0);
std::vector<OutputType> out_data_colwise_h(colwise ? elts_num : 0);
std::vector<fp8e8m0> out_scales_rowwise_h(rowwise ? rowwise_sfs_num : 0);
std::vector<fp8e8m0> out_scales_colwise_h(colwise ? colwise_sfs_num : 0);
std::vector<OutputType> out_data_rowwise_ref(rowwise ? elts_num : 0);
std::vector<OutputType> out_data_colwise_ref(colwise ? elts_num : 0);
std::vector<fp8e8m0> out_scales_rowwise_ref(rowwise ? rowwise_sfs_num : 0);
std::vector<fp8e8m0> out_scales_colwise_ref(colwise ? colwise_sfs_num : 0);
std::vector<InputType> ref_output_dbias(is_single_tensor ? cols : 0);
for (size_t i = 0; i < elts_num; ++i) {
const float val = dis(gen);
grad_data[i] = static_cast<InputType>(val);
in_data[i] = static_cast<InputType>(val);
}
const OutputType zero_elt = static_cast<OutputType>(0.0f);
const fp8e8m0 zero_SF = static_cast<fp8e8m0>(0.0f);
if (rowwise) {
std::fill(out_data_rowwise_h.begin(), out_data_rowwise_h.end(), zero_elt);
std::fill(out_data_rowwise_ref.begin(), out_data_rowwise_ref.end(), zero_elt);
std::fill(out_scales_rowwise_h.begin(), out_scales_rowwise_h.end(), zero_SF);
std::fill(out_scales_rowwise_ref.begin(), out_scales_rowwise_ref.end(), zero_SF);
}
if (colwise) {
std::fill(out_data_colwise_h.begin(), out_data_colwise_h.end(), zero_elt);
std::fill(out_data_colwise_ref.begin(), out_data_colwise_ref.end(), zero_elt);
std::fill(out_scales_colwise_h.begin(), out_scales_colwise_h.end(), zero_SF);
std::fill(out_scales_colwise_ref.begin(), out_scales_colwise_ref.end(), zero_SF);
}
const size_t in_data_size = elts_num * sizeof(InputType);
const size_t out_data_size = elts_num * sizeof(OutputType);
const size_t rowwise_scales_size = rowwise_sfs_num * sizeof(fp8e8m0);
const size_t colwise_scales_size = colwise_sfs_num * sizeof(fp8e8m0);
const size_t first_dims_size = num_tensors * sizeof(size_t);
const size_t last_dims_size = num_tensors * sizeof(size_t);
const size_t offsets_size = (num_tensors + 1) * sizeof(size_t);
InputType* grad_data_d;
InputType* in_data_d;
OutputType* out_data_rowwise_d;
OutputType* out_data_colwise_d;
fp8e8m0* out_scales_rowwise_d;
fp8e8m0* out_scales_colwise_d;
size_t* first_dims_d;
size_t* last_dims_d;
size_t* offsets_d;
cudaMalloc((void**)&grad_data_d, in_data_size);
cudaMalloc((void**)&in_data_d, in_data_size);
cudaMalloc((void**)&first_dims_d, first_dims_size);
cudaMalloc((void**)&last_dims_d, last_dims_size);
cudaMalloc((void**)&offsets_d, offsets_size);
cudaMemcpy(grad_data_d, grad_data.data(), in_data_size, cudaMemcpyHostToDevice);
cudaMemcpy(in_data_d, in_data.data(), in_data_size, cudaMemcpyHostToDevice);
cudaMemcpy(first_dims_d, first_dims_h.data(), first_dims_size, cudaMemcpyHostToDevice);
cudaMemcpy(last_dims_d, last_dims_h.data(), last_dims_size, cudaMemcpyHostToDevice);
cudaMemcpy(offsets_d, offsets_h.data(), offsets_size, cudaMemcpyHostToDevice);
NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size());
NVTEShape first_dims_shape_;
NVTEShape last_dims_shape_;
NVTEShape offsets_shape_;
first_dims_shape_.ndim = 1;
last_dims_shape_.ndim = 1;
offsets_shape_.ndim = 1;
first_dims_shape_.data[0] = num_tensors;
last_dims_shape_.data[0] = num_tensors;
offsets_shape_.data[0] = num_tensors + 1;
NVTEGroupedTensor grad_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_);
NVTEGroupedTensor in_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_);
NVTEGroupedTensor out_group_tensor = nvte_create_grouped_tensor(NVTE_MXFP8_1D_SCALING, num_tensors, logical_shape_);
NVTEBasicTensor grad_data_tensor = {grad_data_d, static_cast<NVTEDType>(itype), logical_shape_};
NVTEBasicTensor in_data_tensor = {in_data_d, static_cast<NVTEDType>(itype), logical_shape_};
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor);
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &grad_data_tensor);
if ((shape_rep == VARYING_FIRST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) {
NVTEBasicTensor first_dims_tensor = {first_dims_d, kNVTEInt64, first_dims_shape_};
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor);
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor);
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor);
}
if ((shape_rep == VARYING_LAST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) {
NVTEBasicTensor last_dims_tensor = {last_dims_d, kNVTEInt64, last_dims_shape_};
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor);
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor);
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor);
}
if (shape_rep != SAME_BOTH_DIMS) {
NVTEBasicTensor offsets_tensor = {offsets_d, kNVTEInt64, offsets_shape_};
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor);
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor);
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor);
}
if (rowwise) {
cudaMalloc((void**)&out_data_rowwise_d, out_data_size);
cudaMalloc((void**)&out_scales_rowwise_d, rowwise_scales_size);
cudaMemset(out_data_rowwise_d, 0, out_data_size);
cudaMemset(out_scales_rowwise_d, 0, rowwise_scales_size);
NVTEBasicTensor out_data_rowwise_tensor = {out_data_rowwise_d, static_cast<NVTEDType>(otype), logical_shape_};
NVTEShape scales_rowwise_shape_ = nvte_make_shape(scales_rowwise_shape.data(), scales_rowwise_shape.size());
NVTEBasicTensor out_scales_rowwise_tensor = {out_scales_rowwise_d, NVTEDType::kNVTEFloat8E8M0, scales_rowwise_shape_};
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &out_data_rowwise_tensor);
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseScaleInv, &out_scales_rowwise_tensor);
}
if (colwise) {
cudaMalloc((void**)&out_data_colwise_d, out_data_size);
cudaMalloc((void**)&out_scales_colwise_d, colwise_scales_size);
cudaMemset(out_data_colwise_d, 0, out_data_size);
cudaMemset(out_scales_colwise_d, 0, colwise_scales_size);
NVTEBasicTensor out_data_colwise_tensor = {out_data_colwise_d, static_cast<NVTEDType>(otype), logical_shape_};
NVTEShape scales_colwise_shape_ = nvte_make_shape(scales_colwise_shape.data(), scales_colwise_shape.size());
NVTEBasicTensor out_scales_colwise_tensor = {out_scales_colwise_d, NVTEDType::kNVTEFloat8E8M0, scales_colwise_shape_};
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseData, &out_data_colwise_tensor);
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv, &out_scales_colwise_tensor);
}
Tensor output_dbias("output_dbias", std::vector<size_t>{ cols }, itype);
// Reference (CPU)
if (is_single_tensor) {
const size_t unpadded_rowwise_blocks_X = divide_round_up(cols, 32);
const size_t unpadded_colwise_blocks_X = cols;
const size_t scales_stride_rowwise = round_up_to_nearest_multiple(unpadded_rowwise_blocks_X, 4);
const size_t scales_stride_colwise = round_up_to_nearest_multiple(unpadded_colwise_blocks_X, 128);
compute_ref<InputType, OutputType>(
processing_method, OP, rowwise, colwise, in_data.data(), grad_data.data(),
out_data_rowwise_ref.data(), out_data_colwise_ref.data(),
out_scales_rowwise_ref.data(), out_scales_colwise_ref.data(),
ref_output_dbias.data(), rows, cols,
scales_stride_rowwise,
scales_stride_colwise,
is_single_tensor);
} else {
for (size_t t = 0; t < num_tensors; ++t) {
const size_t M = first_dims_h[t];
const size_t K = last_dims_h[t];
const size_t scales_stride_rowwise = rowwise_scales_last_dim[t];
const size_t scales_stride_colwise = colwise_scales_last_dim[t];
const size_t data_offset = offsets_h[t];
const size_t rowwise_sfs_offset = rowwise_scales_offset[t];
const size_t colwise_sfs_offset = colwise_scales_offset[t];
const InputType* const grad_ptr = grad_data.data() + data_offset;
const InputType* const in_ptr = in_data.data() + data_offset;
OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset;
OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset;
fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + rowwise_sfs_offset;
fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + colwise_sfs_offset;
compute_ref<InputType, OutputType>(
processing_method, OP, rowwise, colwise, in_ptr, grad_ptr,
out_data_rowwise_ptr, out_data_colwise_ptr,
out_scales_rowwise_ptr, out_scales_colwise_ptr,
ref_output_dbias.data(), M, K,
scales_stride_rowwise,
scales_stride_colwise,
is_single_tensor);
}
}
// GPU
Tensor workspace;
switch (processing_method) {
case ProcessingMethod::CAST_ONLY: {
nvte_group_quantize(in_group_tensor, out_group_tensor, 0);
break;
}
case ProcessingMethod::CAST_DBIAS: {
nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0);
break;
}
case ProcessingMethod::CAST_DBIAS_DACT: {
auto nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dgelu;
if (OP == &dsilu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsilu; }
else if (OP == &drelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_drelu; }
else if (OP == &dqgelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dqgelu; }
else if (OP == &dsrelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsrelu; }
nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor,
output_dbias.data(), workspace.data(), 0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor,
output_dbias.data(), workspace.data(), 0);
break;
}
case ProcessingMethod::CAST_ACT: {
auto nvte_group_act = &nvte_group_gelu;
if (OP == &silu) { nvte_group_act = &nvte_group_silu; }
else if (OP == &relu) { nvte_group_act = &nvte_group_relu; }
else if (OP == &qgelu) { nvte_group_act = &nvte_group_qgelu; }
else if (OP == &srelu) { nvte_group_act = &nvte_group_srelu; }
nvte_group_act(in_group_tensor, out_group_tensor, 0);
break;
}
case ProcessingMethod::CAST_DACT: {
auto nvte_group_dact = &nvte_group_dgelu;
if (OP == &dsilu) { nvte_group_dact = &nvte_group_dsilu; }
else if (OP == &drelu) { nvte_group_dact = &nvte_group_drelu; }
else if (OP == &dqgelu) { nvte_group_dact = &nvte_group_dqgelu; }
else if (OP == &dsrelu) { nvte_group_dact = &nvte_group_dsrelu; }
nvte_group_dact(grad_group_tensor, in_group_tensor, out_group_tensor, 0);
break;
}
}
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol, rtol] = getTolerances(otype);
const size_t scale_diff_abs_tolerance = 0;
const double abs_tolerable_mismatches_limit = 0.0;
const double rel_tolerable_mismatches_limit = 0.0;
if (rowwise) {
cudaMemcpy(out_data_rowwise_h.data(), out_data_rowwise_d, out_data_size, cudaMemcpyDeviceToHost);
cudaMemcpy(out_scales_rowwise_h.data(), out_scales_rowwise_d, rowwise_scales_size, cudaMemcpyDeviceToHost);
size_t mismatches_scales = 0;
compare_scaling_factors("rowwise_scales", out_scales_rowwise_h.data(), out_scales_rowwise_ref.data(),
1, rowwise_sfs_num, rowwise_sfs_num, mismatches_scales, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit);
const size_t mismatches_elts = 32 * mismatches_scales;
compare_scaled_elts<OutputType>("rowwise_output", out_data_rowwise_ref.data(),
out_data_rowwise_h.data(), rows, cols, true, mismatches_elts);
}
if (colwise) {
cudaMemcpy(out_data_colwise_h.data(), out_data_colwise_d, out_data_size, cudaMemcpyDeviceToHost);
cudaMemcpy(out_scales_colwise_h.data(), out_scales_colwise_d, colwise_scales_size, cudaMemcpyDeviceToHost);
size_t mismatches_scales = 0;
compare_scaling_factors("colwise_scales", out_scales_colwise_h.data(), out_scales_colwise_ref.data(),
1, colwise_sfs_num, colwise_sfs_num, mismatches_scales, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit);
const size_t mismatches_elts = 32 * mismatches_scales;
compare_scaled_elts<OutputType>("colwise_output", out_data_colwise_ref.data(),
out_data_colwise_h.data(), rows, cols, false, mismatches_elts);
}
if (processing_method == ProcessingMethod::CAST_DBIAS
|| processing_method == ProcessingMethod::CAST_DBIAS_DACT)
{
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
if (itype == DType::kFloat32) {
atol_dbias = 1e-4;
rtol_dbias *= sqrt(static_cast<double>(rows)) ;
} else {
rtol_dbias *= 4;
}
compareResults("output_dbias", output_dbias, ref_output_dbias.data(), true, atol_dbias, rtol_dbias);
}
cudaFree(grad_data_d);
cudaFree(in_data_d);
cudaFree(first_dims_d);
cudaFree(last_dims_d);
cudaFree(offsets_d);
if (rowwise) {
cudaFree(out_data_rowwise_d);
cudaFree(out_scales_rowwise_d);
}
if (colwise) {
cudaFree(out_data_colwise_d);
cudaFree(out_scales_colwise_d);
}
}
std::vector<ProcessingMethod> processing_methods = {
ProcessingMethod::CAST_ONLY,
ProcessingMethod::CAST_DBIAS,
ProcessingMethod::CAST_DBIAS_DACT,
ProcessingMethod::CAST_DACT,
ProcessingMethod::CAST_ACT,
};
std::vector<ActivationKind> activation_kinds = {
ActivationKind::Identity,
ActivationKind::GeLU,
// ActivationKind::SiLU,
// ActivationKind::ReLU,
// ActivationKind::QGeLU,
// ActivationKind::SReLU,
};
enum ScalingDirection {
ROWWISE = 0,
COLWISE = 1,
BOTH = 2
};
std::vector<ScalingDirection> scaling_directions = {
ScalingDirection::ROWWISE,
ScalingDirection::COLWISE,
ScalingDirection::BOTH,
};
// {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]}
std::vector<std::vector<size_t>> input_config = {
{SAME_BOTH_DIMS, 1, 128,128},
{SAME_BOTH_DIMS, 2, 256,128},
{VARYING_FIRST_DIM, 2, 512,128, 128,384},
{VARYING_FIRST_DIM, 2, 384,160, 128,256},
{VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304},
{VARYING_LAST_DIM, 3, 256,896, 128,256,512},
{VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256},
{VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640},
};
} // namespace
class GroupedFusedCastMXFP8TestSuite : public ::testing::TestWithParam
<std::tuple<ProcessingMethod,
ActivationKind,
ScalingDirection,
std::vector<size_t>, // Config
transformer_engine::DType, // InputType
transformer_engine::DType // OutputType
>> {};
TEST_P(GroupedFusedCastMXFP8TestSuite, Test) {
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
using namespace transformer_engine;
using namespace test;
const ProcessingMethod processing_method = std::get<0>(GetParam());
const ActivationKind activation = std::get<1>(GetParam());
const ScalingDirection scaling_direction = std::get<2>(GetParam());
const std::vector<size_t> input_config = std::get<3>(GetParam());
const DType input_type = std::get<4>(GetParam());
const DType output_type = std::get<5>(GetParam());
const ShapeRepresentation shape_rep = static_cast<ShapeRepresentation>(input_config[0]);
const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM);
const size_t num_tensors = input_config[1];
const std::vector<size_t> logical_shape = {input_config[2], input_config[3]};
std::vector<size_t> first_dims(num_tensors);
std::vector<size_t> last_dims(num_tensors);
std::vector<size_t> offsets(num_tensors + 1, 0);
for (size_t t = 0; t < num_tensors; ++t) {
switch (shape_rep) {
case SAME_BOTH_DIMS: {
first_dims[t] = logical_shape[0] / num_tensors;
last_dims[t] = logical_shape[1];
break;
}
case VARYING_FIRST_DIM: {
first_dims[t] = input_config[t + 4];
last_dims[t] = logical_shape[1];
break;
}
case VARYING_LAST_DIM: {
first_dims[t] = logical_shape[0];
last_dims[t] = input_config[t + 4];
break;
}
case VARYING_BOTH_DIMS: {
first_dims[t] = input_config[t + 4];
last_dims[t] = input_config[t + (4 + num_tensors)];
break;
}
}
offsets[t+1] = offsets[t] + first_dims[t] * last_dims[t];
// Skips tests if tensor shape is not as required by the kernel
if ((first_dims[t] % 128 != 0) || (last_dims[t] % 32 != 0)) {
GTEST_SKIP();
}
}
// Skips DBias tests if last dimension of tensors variates
if ((processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT)
&& !is_single_tensor) {
GTEST_SKIP();
}
// Skips non Act tests if the Activation type is not an identity
if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS)
&& activation != ActivationKind::Identity) {
GTEST_SKIP();
}
// Skips Act tests if the Activation is an identity
if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT
|| processing_method == ProcessingMethod::CAST_DACT
|| processing_method == ProcessingMethod::CAST_ACT) && (activation == ActivationKind::Identity)) {
GTEST_SKIP();
}
bool rowwise = false;
bool colwise = false;
switch (scaling_direction) {
case ScalingDirection::ROWWISE: rowwise = true; break;
case ScalingDirection::COLWISE: colwise = true; break;
case ScalingDirection::BOTH: rowwise = true; colwise = true; break;
}
auto OP = &identity;
if (processing_method == ProcessingMethod::CAST_ACT) {
switch (activation) {
case ActivationKind::GeLU: OP = &gelu; break;
case ActivationKind::SiLU: OP = &silu; break;
case ActivationKind::ReLU: OP = &relu; break;
case ActivationKind::QGeLU: OP = &qgelu; break;
case ActivationKind::SReLU: OP = &srelu; break;
}
} else if (processing_method == ProcessingMethod::CAST_DACT
|| processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
switch (activation) {
case ActivationKind::GeLU: OP = &dgelu; break;
case ActivationKind::SiLU: OP = &dsilu; break;
case ActivationKind::ReLU: OP = &drelu; break;
case ActivationKind::QGeLU: OP = &dqgelu; break;
case ActivationKind::SReLU: OP = &dsrelu; break;
}
}
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
performTest<InputType, OutputType>(processing_method, OP, shape_rep, num_tensors,
logical_shape, first_dims, last_dims, offsets,
rowwise, colwise);
);
);
}
std::string to_string(const ProcessingMethod method) {
switch (method) {
case ProcessingMethod::CAST_ONLY: return "CAST_ONLY";
case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS";
case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT";
case ProcessingMethod::CAST_DACT: return "CAST_DACT";
case ProcessingMethod::CAST_ACT: return "CAST_ACT";
default: return "";
}
}
std::string to_string(const ActivationKind activation) {
switch (activation) {
case ActivationKind::Identity: return "Identity";
case ActivationKind::GeLU: return "GeLU";
case ActivationKind::SiLU: return "SiLU";
case ActivationKind::ReLU: return "ReLU";
case ActivationKind::QGeLU: return "QGeLU";
case ActivationKind::SReLU: return "SReLU";
default: return "";
}
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
GroupedFusedCastMXFP8TestSuite,
::testing::Combine(
::testing::ValuesIn(processing_methods),
::testing::ValuesIn(activation_kinds),
::testing::ValuesIn(scaling_directions),
::testing::ValuesIn(input_config),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)),
[](const testing::TestParamInfo<GroupedFusedCastMXFP8TestSuite::ParamType>& info) {
const ProcessingMethod method = std::get<0>(info.param);
std::string name = to_string(method);
name += "X" + to_string(std::get<1>(info.param));
switch (std::get<2>(info.param)) {
case ScalingDirection::ROWWISE: name += "_ROWWISE_"; break;
case ScalingDirection::COLWISE: name += "_COLWISE_"; break;
case ScalingDirection::BOTH: name += "_BIDIMENSIONAL_"; break;
}
const std::vector<size_t> input = std::get<3>(info.param);
switch(static_cast<ShapeRepresentation>(input[0])) {
case ShapeRepresentation::SAME_BOTH_DIMS: name += "SAME_BOTH_DIMS"; break;
case ShapeRepresentation::VARYING_FIRST_DIM: name += "VARYING_FIRST_DIM"; break;
case ShapeRepresentation::VARYING_LAST_DIM: name += "VARYING_LAST_DIM"; break;
case ShapeRepresentation::VARYING_BOTH_DIMS: name += "VARYING_BOTH_DIMS"; break;
};
name += "_N_" + std::to_string(input[1]);
name += "_SHAPE_" +
std::to_string(input[2]) +
"X" + std::to_string(input[3]);
name += "_" + test::typeName(std::get<4>(info.param)) +
"_" + test::typeName(std::get<5>(info.param));
return name;
});
...@@ -54,12 +54,16 @@ std::vector<InputType> create_transpose(const InputType* const input, const size ...@@ -54,12 +54,16 @@ std::vector<InputType> create_transpose(const InputType* const input, const size
} }
// Compute the global encode scale factor for a given global amax // Compute the global encode scale factor for a given global amax
float compute_global_encode_scaling_factor_FP4(const float global_amax) { float compute_global_encode_scaling_factor_FP4(const float global_amax, const bool use_fast_math) {
constexpr float fp8_max = 448.0f; // 448.0f; constexpr float fp8_max = 448.0f; // 448.0f;
constexpr float fp4_max = 6.0f; // 6.0f; constexpr float fp4_max = 6.0f; // 6.0f;
float global_encode_scale = fp8_max * fp4_max / global_amax; float global_encode_scale = fp8_max * fp4_max / global_amax;
// If scale is infinity, return max value of float32 // If scale is infinity, return the max normalized value
global_encode_scale = fminf(global_encode_scale, Numeric_Traits<float>::maxNorm); const float max_norm_clamp = use_fast_math
? Numeric_Traits<bf16>::maxNorm
: Numeric_Traits<float>::maxNorm;
global_encode_scale = fminf(global_encode_scale, max_norm_clamp);
// If global amax is 0 or infinity, return 1 // If global amax is 0 or infinity, return 1
if (global_amax == 0.0f || global_encode_scale == 0.0f) { if (global_amax == 0.0f || global_encode_scale == 0.0f) {
return 1.0f; return 1.0f;
...@@ -76,10 +80,11 @@ void quantize_nvfp4_1d(float (*OP)(const float), ...@@ -76,10 +80,11 @@ void quantize_nvfp4_1d(float (*OP)(const float),
const size_t rows, const size_t rows,
const size_t cols, const size_t cols,
const size_t scales_stride, const size_t scales_stride,
const float global_amax) { const float global_amax,
const bool use_fast_math) {
// Compute a global encoding/decoding scaling factor for all S_dec_b // Compute a global encoding/decoding scaling factor for all S_dec_b
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math);
constexpr size_t block_size_X = 16; constexpr size_t block_size_X = 16;
const size_t blocks_X = divide_round_up(cols, block_size_X); const size_t blocks_X = divide_round_up(cols, block_size_X);
...@@ -114,14 +119,20 @@ void quantize_nvfp4_1d(float (*OP)(const float), ...@@ -114,14 +119,20 @@ void quantize_nvfp4_1d(float (*OP)(const float),
const float S_dec_b = block_amax / 6.0f; const float S_dec_b = block_amax / 6.0f;
// Scale & Store per-block decoding scaling factor // Scale & Store per-block decoding scaling factor
const float S_dec_b_fp8 = S_dec_b * S_enc; const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b * S_enc);
const float S_dec_b_fp32 = static_cast<float>(S_dec_b_fp8);
// Compute "correct" per-block encoding scaling factor // Compute "correct" per-block encoding scaling factor
const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8; const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32;
const size_t scale_idx = i * scales_stride + block_X; const size_t scale_idx = i * scales_stride + block_X;
scales[scale_idx] = static_cast<fp8e4m3>(S_dec_b_fp8); scales[scale_idx] = S_dec_b_fp8;
const float scale_reciprocal = S_enc_b_fp8;
float scale_reciprocal = S_enc_b_fp8;
if (use_fast_math) {
// Numerical truncation to match GPU implementation, if mixed precision FMA instruction is used
scale_reciprocal = static_cast<float>(static_cast<bf16>(scale_reciprocal));
}
for (size_t j = j_min; j < j_max; j += 2) { for (size_t j = j_min; j < j_max; j += 2) {
const int idx_pair = (i * cols + j) / 2; const int idx_pair = (i * cols + j) / 2;
...@@ -136,7 +147,7 @@ void quantize_nvfp4_1d(float (*OP)(const float), ...@@ -136,7 +147,7 @@ void quantize_nvfp4_1d(float (*OP)(const float),
fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair); fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair);
output[idx_pair] = casted_to_e2m1_pair; output[idx_pair] = casted_to_e2m1_pair;
// const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair); const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair);
} }
} }
} }
...@@ -149,9 +160,10 @@ void compute_2d_mathematical_scales(float (*OP)(const float), ...@@ -149,9 +160,10 @@ void compute_2d_mathematical_scales(float (*OP)(const float),
const size_t rows, const size_t rows,
const size_t cols, const size_t cols,
const float global_amax, const float global_amax,
std::vector<std::vector<fp8e4m3>>& math_scales) { std::vector<std::vector<fp8e4m3>>& math_scales,
const bool use_fast_math) {
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math);
constexpr size_t block_size_Y = 16; constexpr size_t block_size_Y = 16;
constexpr size_t block_size_X = 16; constexpr size_t block_size_X = 16;
const size_t blocks_Y = divide_round_up(rows, block_size_Y); const size_t blocks_Y = divide_round_up(rows, block_size_Y);
...@@ -195,13 +207,14 @@ void quantize_nvfp4_2d(float (*OP)(const float), ...@@ -195,13 +207,14 @@ void quantize_nvfp4_2d(float (*OP)(const float),
const size_t rows, const size_t rows,
const size_t cols, const size_t cols,
const size_t scales_stride, const size_t scales_stride,
const float global_amax) { const float global_amax,
const bool use_fast_math) {
// Step 1: Compute mathematical 8x8 scaling factors // Step 1: Compute mathematical 8x8 scaling factors
std::vector<std::vector<fp8e4m3>> math_scales; std::vector<std::vector<fp8e4m3>> math_scales;
compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales); compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math);
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math);
constexpr size_t block_size_Y = 16; constexpr size_t block_size_Y = 16;
constexpr size_t block_size_X = 16; constexpr size_t block_size_X = 16;
const size_t blocks_Y = divide_round_up(rows, block_size_Y); const size_t blocks_Y = divide_round_up(rows, block_size_Y);
...@@ -282,11 +295,12 @@ void quantize_nvfp4(float (*OP)(const float), ...@@ -282,11 +295,12 @@ void quantize_nvfp4(float (*OP)(const float),
const size_t cols, const size_t cols,
const size_t scales_stride, const size_t scales_stride,
const float global_amax, const float global_amax,
const bool use_fast_math,
const bool use_2d_quantization = false) { const bool use_2d_quantization = false) {
if (use_2d_quantization) { if (use_2d_quantization) {
quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax); quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math);
} else { } else {
quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax); quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math);
} }
} }
...@@ -302,6 +316,7 @@ void compute_ref(float (*OP)(const float), ...@@ -302,6 +316,7 @@ void compute_ref(float (*OP)(const float),
const size_t cols, const size_t cols,
const size_t scales_stride, const size_t scales_stride,
const size_t scales_stride_t, const size_t scales_stride_t,
const bool use_fast_math,
const bool use_2d_quantization = false) const bool use_2d_quantization = false)
{ {
std::vector<InputType> input_t = create_transpose(input, rows, cols); std::vector<InputType> input_t = create_transpose(input, rows, cols);
...@@ -309,7 +324,7 @@ void compute_ref(float (*OP)(const float), ...@@ -309,7 +324,7 @@ void compute_ref(float (*OP)(const float),
if (use_2d_quantization) { if (use_2d_quantization) {
// Step 1: Compute mathematical 8×8 scaling factors // Step 1: Compute mathematical 8×8 scaling factors
std::vector<std::vector<fp8e4m3>> math_scales; std::vector<std::vector<fp8e4m3>> math_scales;
compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales); compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math);
constexpr size_t block_size_Y = 16; constexpr size_t block_size_Y = 16;
constexpr size_t block_size_X = 16; constexpr size_t block_size_X = 16;
...@@ -336,12 +351,16 @@ void compute_ref(float (*OP)(const float), ...@@ -336,12 +351,16 @@ void compute_ref(float (*OP)(const float),
// Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d // Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d
// (This part processes the actual FP4 data using the mathematical scaling factors) // (This part processes the actual FP4 data using the mathematical scaling factors)
quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax); // scales already filled quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax,
quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax); // scales_t already filled use_fast_math); // scales already filled
quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax,
use_fast_math); // scales_t already filled
} else { } else {
quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_2d_quantization); quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax,
quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, use_2d_quantization); use_fast_math, use_2d_quantization);
quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax,
use_fast_math, use_2d_quantization);
} }
} }
...@@ -349,6 +368,8 @@ void compare_nvfp4_tensors(const std::string& name, ...@@ -349,6 +368,8 @@ void compare_nvfp4_tensors(const std::string& name,
const fp4e2m1 *test_data, const fp4e2m1 *ref_data, const fp4e2m1 *test_data, const fp4e2m1 *ref_data,
const int rows, const int cols, const int rows, const int cols,
double atol = 1e-5, double rtol = 1e-8) { double atol = 1e-5, double rtol = 1e-8) {
constexpr int max_mismatches_to_print = 3;
std::vector<std::string> mismatch_messages; std::vector<std::string> mismatch_messages;
size_t total_mismatches = 0; size_t total_mismatches = 0;
...@@ -362,29 +383,16 @@ void compare_nvfp4_tensors(const std::string& name, ...@@ -362,29 +383,16 @@ void compare_nvfp4_tensors(const std::string& name,
const double t = (k == 0 ? test_data_pair.x : test_data_pair.y); const double t = (k == 0 ? test_data_pair.x : test_data_pair.y);
const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y); const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y);
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); const bool mismatch = fabs(t - r) > (atol + fabs(r) * rtol);
/* For Float32 the floating point comparison is enough to error out */ if (mismatch) {
bool assertion = false;
if (mismatch && !assertion) {
/* Check if it is just a failure of round to nearest choosing different
side of the real value */
const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
const double cast_mean_p = static_cast<double>(static_cast<fp4e2m1>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<fp4e2m1>(mean_m));
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
}
if (assertion) {
total_mismatches++; total_mismatches++;
std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " +
std::to_string(t) + " vs " + std::to_string(r) +
" (abs_diff: " + std::to_string(fabs(t - r)) +
", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")";
mismatch_messages.push_back(msg);
// Optional: limit number of detailed messages to avoid overwhelming output // Optional: limit number of detailed messages to avoid overwhelming output
if (mismatch_messages.size() <= 100) { if (total_mismatches <= max_mismatches_to_print) {
std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " +
std::to_string(t) + " vs " + std::to_string(r) +
" (abs_diff: " + std::to_string(fabs(t - r)) +
", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")";
mismatch_messages.push_back(msg);
std::cout << "Error in tensor " << name << ": " << msg << std::endl; std::cout << "Error in tensor " << name << ": " << msg << std::endl;
} }
} }
...@@ -400,8 +408,9 @@ void compare_nvfp4_tensors(const std::string& name, ...@@ -400,8 +408,9 @@ void compare_nvfp4_tensors(const std::string& name,
std::cout << "STATUS: FAILED for output" << std::endl; std::cout << "STATUS: FAILED for output" << std::endl;
std::cout << "Total mismatches found: " << total_mismatches << std::endl; std::cout << "Total mismatches found: " << total_mismatches << std::endl;
std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl; std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl;
if (mismatch_messages.size() > 100) { if (mismatch_messages.size() > max_mismatches_to_print) {
std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl; std::cout << "... and " << (mismatch_messages.size() - max_mismatches_to_print)
<< " more mismatches (showing first " << max_mismatches_to_print << ")" << std::endl;
} }
std::cout << "============================" << std::endl; std::cout << "============================" << std::endl;
...@@ -519,7 +528,8 @@ void compareResults_nvfp4(const Tensor &test, ...@@ -519,7 +528,8 @@ void compareResults_nvfp4(const Tensor &test,
template <typename InputType> template <typename InputType>
void performTest(float (*OP)(const float), void performTest(float (*OP)(const float),
const std::vector<size_t>& shape) { const std::vector<size_t>& shape,
const bool use_fast_math) {
using namespace test; using namespace test;
DType itype = TypeInfo<InputType>::dtype; DType itype = TypeInfo<InputType>::dtype;
...@@ -580,15 +590,16 @@ void performTest(float (*OP)(const float), ...@@ -580,15 +590,16 @@ void performTest(float (*OP)(const float),
cols, cols,
scales_stride, scales_stride,
scales_stride_t, scales_stride_t,
use_fast_math,
use_2d_quantization); use_2d_quantization);
QuantizationConfigWrapper quant_config;
// Initialize stochastic rounding // Initialize stochastic rounding
Tensor rng_state("rng_state", std::vector<size_t>{2}, DType::kInt64); Tensor rng_state("rng_state", std::vector<size_t>{2}, DType::kInt64);
rng_state.rowwise_cpu_dptr<int64_t>()[0] = 123; // rng_seed rng_state.rowwise_cpu_dptr<int64_t>()[0] = 123; // rng_seed
rng_state.rowwise_cpu_dptr<int64_t>()[1] = 321; // rng_sequence rng_state.rowwise_cpu_dptr<int64_t>()[1] = 321; // rng_sequence
rng_state.from_cpu(); rng_state.from_cpu();
QuantizationConfigWrapper quant_config;
quant_config.set_use_fast_math(use_fast_math);
quant_config.set_stochastic_rounding(false); quant_config.set_stochastic_rounding(false);
quant_config.set_rng_state(rng_state.data()); quant_config.set_rng_state(rng_state.data());
...@@ -619,8 +630,8 @@ void performTest(float (*OP)(const float), ...@@ -619,8 +630,8 @@ void performTest(float (*OP)(const float),
} }
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
const double atol = 0.05; const double atol = 1.0E-6;
const double rtol = 0.1; const double rtol = 1.0E-6;
// Set dump_data=true to enable dumping tensor data to files for analysis // Set dump_data=true to enable dumping tensor data to files for analysis
compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false); compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false);
...@@ -671,7 +682,8 @@ std::vector<ActivationType> Activation_types = { ...@@ -671,7 +682,8 @@ std::vector<ActivationType> Activation_types = {
class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam
<std::tuple<ActivationType, <std::tuple<ActivationType,
std::vector<size_t>, std::vector<size_t>,
transformer_engine::DType>> {}; transformer_engine::DType,
bool>> {};
TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
// Skip tests for pre-Blackwell architectures // Skip tests for pre-Blackwell architectures
...@@ -685,6 +697,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { ...@@ -685,6 +697,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
const ActivationType Act_type = std::get<0>(GetParam()); const ActivationType Act_type = std::get<0>(GetParam());
const auto tensor_dims = std::get<1>(GetParam()); const auto tensor_dims = std::get<1>(GetParam());
const DType input_type = std::get<2>(GetParam()); const DType input_type = std::get<2>(GetParam());
const bool use_fast_math = std::get<3>(GetParam());
// Skip tests if the input tensor is 1D // Skip tests if the input tensor is 1D
if (tensor_dims.size() < 2) { if (tensor_dims.size() < 2) {
...@@ -702,7 +715,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { ...@@ -702,7 +715,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
} }
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
performTest<InputType>(OP, tensor_dims); performTest<InputType>(OP, tensor_dims, use_fast_math);
); );
} }
...@@ -724,7 +737,8 @@ INSTANTIATE_TEST_SUITE_P( ...@@ -724,7 +737,8 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Combine( ::testing::Combine(
::testing::ValuesIn(Activation_types), ::testing::ValuesIn(Activation_types),
::testing::ValuesIn(tensor_dims), ::testing::ValuesIn(tensor_dims),
::testing::Values(DType::kBFloat16)), ::testing::Values(DType::kBFloat16),
::testing::Values(false)),
[](const testing::TestParamInfo<FusedCastTransposeNVFP4TestSuite::ParamType>& info) { [](const testing::TestParamInfo<FusedCastTransposeNVFP4TestSuite::ParamType>& info) {
std::string name = to_string(std::get<0>(info.param)); std::string name = to_string(std::get<0>(info.param));
const auto& shape = std::get<1>(info.param); const auto& shape = std::get<1>(info.param);
...@@ -732,5 +746,8 @@ INSTANTIATE_TEST_SUITE_P( ...@@ -732,5 +746,8 @@ INSTANTIATE_TEST_SUITE_P(
name += "X" + std::to_string(s); name += "X" + std::to_string(s);
} }
name += "X" + test::typeName(std::get<2>(info.param)); name += "X" + test::typeName(std::get<2>(info.param));
if (std::get<3>(info.param)) {
name += "X_FAST_SCALING";
}
return name; return name;
}); });
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cublasLt.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
#include <numeric>
#include <optional>
#include <random>
#include <tuple>
#include <vector>
#include <transformer_engine/cast.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/recipe.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
using namespace transformer_engine;
using namespace test;
namespace {
enum class InputCase {
kFP8Current,
kBF16,
};
enum class ShapeCase {
kAllSame,
kSameFirst,
kSameLast,
kAllDifferent,
};
size_t grouped_setup_workspace_size(const size_t num_tensors) {
const size_t ptr_bytes = num_tensors * sizeof(void*);
const size_t int_bytes = num_tensors * sizeof(int);
// Layout: 6 pointer arrays (A, B, C, D, alpha, beta) + 6 int arrays (a_rows, a_cols, b_rows, b_cols, d_rows, d_cols)
size_t size = 6 * ptr_bytes + 6 * int_bytes;
const size_t alignment = 256;
size = ((size + alignment - 1) / alignment) * alignment;
return size;
}
Tensor make_fp8_operand(const std::string& name, const std::vector<size_t>& shape) {
Tensor input_fp32(name + "_fp32", shape, DType::kFloat32);
fillUniform(&input_fp32);
Tensor fp8(name, shape, TypeInfo<fp8e4m3>::dtype, true, true, NVTE_DELAYED_TENSOR_SCALING);
nvte_compute_amax(input_fp32.data(), fp8.data(), 0);
QuantizationConfigWrapper config;
nvte_compute_scale_from_amax(fp8.data(), config, 0);
nvte_quantize(input_fp32.data(), fp8.data(), 0);
return fp8;
}
Tensor make_bf16_operand(const std::string& name, const std::vector<size_t>& shape) {
Tensor t(name, shape, DType::kBFloat16);
const size_t numel = shape[0] * shape[1];
std::vector<__nv_bfloat16> ones(numel, __float2bfloat16(1.0f));
NVTE_CHECK_CUDA(cudaMemcpy(t.rowwise_dptr(), ones.data(),
numel * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice));
return t;
}
struct TestParams {
InputCase input_case;
bool transa;
bool transb;
ShapeCase shape_case;
bool use_null_c = false; // When true, pass nullptr for C (valid when beta=0)
};
// Returns a vector of (M, N, K) tuples for each GEMM in the group.
// M - number of rows in output D
// N - number of columns in output D
// K - reduction dimension shared between A and B
std::vector<std::tuple<size_t, size_t, size_t>> make_shapes(ShapeCase scase) {
switch (scase) {
case ShapeCase::kAllSame:
return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}};
case ShapeCase::kSameFirst:
// Same M (first dim), varying N and K
return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}};
case ShapeCase::kSameLast:
// Same N (last dim), varying M and K
return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}};
case ShapeCase::kAllDifferent:
default:
return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}};
}
}
void run_grouped_gemm_case(const TestParams& params) {
#if CUBLAS_VERSION < 130200
GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.2+, but compile-time cuBLAS version is "
<< CUBLAS_VERSION << ".";
#else
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer.";
}
const std::vector<std::tuple<size_t, size_t, size_t>> shapes = make_shapes(params.shape_case);
const size_t num_gemms = shapes.size();
std::vector<Tensor> A_tensors;
std::vector<Tensor> B_tensors;
std::vector<Tensor> D_multi;
A_tensors.reserve(num_gemms);
B_tensors.reserve(num_gemms);
D_multi.reserve(num_gemms);
for (size_t i = 0; i < num_gemms; ++i) {
const auto [M, N, K] = shapes[i];
const std::vector<size_t> a_shape = params.transa ? std::vector<size_t>{M, K}
: std::vector<size_t>{K, M};
const std::vector<size_t> b_shape = params.transb ? std::vector<size_t>{K, N}
: std::vector<size_t>{N, K};
switch (params.input_case) {
case InputCase::kFP8Current: {
A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape));
B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape));
break;
}
case InputCase::kBF16: {
A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape));
B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape));
break;
}
}
D_multi.emplace_back(Tensor("D_multi" + std::to_string(i),
std::vector<size_t>{M, N},
DType::kBFloat16));
}
std::vector<NVTETensor> A_ptrs(num_gemms);
std::vector<NVTETensor> B_ptrs(num_gemms);
std::vector<NVTETensor> D_ptrs(num_gemms);
std::vector<Tensor> workspaces(num_gemms);
std::vector<NVTETensor> workspace_ptrs(num_gemms, nullptr);
std::vector<Tensor*> A_views;
std::vector<Tensor*> B_views;
A_views.reserve(num_gemms);
B_views.reserve(num_gemms);
// Empty bias/gelu arrays for nvte_multi_tensor_gemm (no epilogues)
std::vector<NVTETensor> bias_ptrs(num_gemms, nullptr);
std::vector<NVTETensor> gelu_ptrs(num_gemms, nullptr);
const size_t cublas_ws_bytes = 32ull * 1024 * 1024;
for (size_t i = 0; i < num_gemms; ++i) {
A_ptrs[i] = A_tensors[i].data();
B_ptrs[i] = B_tensors[i].data();
D_ptrs[i] = D_multi[i].data();
workspaces[i] = Tensor("workspace" + std::to_string(i), std::vector<size_t>{cublas_ws_bytes}, DType::kByte);
workspace_ptrs[i] = workspaces[i].data();
A_views.push_back(&A_tensors[i]);
B_views.push_back(&B_tensors[i]);
}
nvte_multi_tensor_gemm(A_ptrs.data(),
B_ptrs.data(),
D_ptrs.data(),
bias_ptrs.data(),
gelu_ptrs.data(),
static_cast<int>(num_gemms),
params.transa,
params.transb,
false, // grad
workspace_ptrs.data(),
false, // accumulate
false, // use_split_accumulator
0, // sm_count
0);
GroupedBuffers grouped_A = build_grouped_tensor(A_views, A_tensors[0].scaling_mode());
GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode());
std::vector<Tensor> C_tensors;
std::vector<Tensor> D_group_tensors;
C_tensors.reserve(num_gemms);
D_group_tensors.reserve(num_gemms);
for (size_t i = 0; i < num_gemms; ++i) {
const auto [M, N, K] = shapes[i];
(void)K;
if (!params.use_null_c) {
C_tensors.emplace_back(Tensor("C" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(M), static_cast<size_t>(N)},
DType::kBFloat16));
}
D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(M), static_cast<size_t>(N)},
DType::kBFloat16));
NVTE_CHECK_CUDA(cudaMemset(D_group_tensors.back().rowwise_dptr(), 0, bytes(D_group_tensors.back().rowwise_shape(), D_group_tensors.back().dtype())));
}
std::vector<Tensor*> C_views, D_views;
for (size_t i = 0; i < num_gemms; ++i) {
if (!params.use_null_c) {
C_views.push_back(&C_tensors[i]);
}
D_views.push_back(&D_group_tensors[i]);
}
std::optional<GroupedBuffers> grouped_C;
if (!params.use_null_c) {
grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING);
}
GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING);
// Per-matrix alpha/beta (all 1.0 and 0.0 respectively)
Tensor alpha_tensor("alpha", std::vector<size_t>{num_gemms}, DType::kFloat32);
Tensor beta_tensor("beta", std::vector<size_t>{num_gemms}, DType::kFloat32);
std::vector<float> alpha_vals(num_gemms, 1.f);
std::vector<float> beta_vals(num_gemms, 0.f);
NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(),
num_gemms * sizeof(float), cudaMemcpyHostToDevice));
NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(),
num_gemms * sizeof(float), cudaMemcpyHostToDevice));
const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms);
Tensor setup_ws("setup_ws", std::vector<size_t>{setup_ws_bytes}, DType::kByte);
Tensor cublas_ws("cublas_ws", std::vector<size_t>{cublas_ws_bytes}, DType::kByte);
nvte_grouped_gemm(grouped_A.get_handle(),
params.transa,
grouped_B.get_handle(),
params.transb,
params.use_null_c ? nullptr : grouped_C->get_handle(),
grouped_D.get_handle(),
alpha_tensor.data(),
beta_tensor.data(),
setup_ws.data(),
cublas_ws.data(),
nullptr, // config (use defaults)
0);
for (size_t i = 0; i < num_gemms; ++i) {
Tensor grouped_split("grouped_D" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(std::get<0>(shapes[i])),
static_cast<size_t>(std::get<1>(shapes[i]))},
D_multi[i].dtype());
const size_t offset_bytes = static_cast<size_t>(grouped_D.offsets_host[i]) * grouped_D.elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(),
static_cast<char*>(grouped_D.get_data()) + offset_bytes,
grouped_D.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
grouped_split.to_cpu();
D_multi[i].to_cpu();
auto [atol, rtol] = getTolerances(D_multi[i].dtype());
compareResults("grouped_vs_multi",
grouped_split,
D_multi[i].rowwise_cpu_dptr<bf16>(),
true,
atol,
rtol);
}
#endif // CUBLAS_VERSION >= 130200
}
class GroupedGemmTest : public ::testing::TestWithParam<TestParams> {};
TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) {
run_grouped_gemm_case(GetParam());
}
std::string MakeGroupedGemmTestName(const testing::TestParamInfo<GroupedGemmTest::ParamType>& info) {
constexpr const char* kInputNames[] = {"FP8Current", "BF16"};
constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"};
const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") +
"tb" + (info.param.transb ? "T" : "N");
const std::string null_c = info.param.use_null_c ? "_NullC" : "";
return std::string(kInputNames[static_cast<int>(info.param.input_case)]) + "_" +
kShapeNames[static_cast<int>(info.param.shape_case)] + "_" + layout + null_c;
}
// TestParams: {input_case, transa, transb, shape_case, use_null_c}
const std::vector<TestParams> kTestParams = {
// Basic tests
{InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false},
{InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false},
{InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false},
{InputCase::kBF16, true, false, ShapeCase::kSameFirst, false},
{InputCase::kBF16, false, true, ShapeCase::kSameLast, false},
{InputCase::kBF16, false, false, ShapeCase::kAllSame, false},
{InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false},
// Test NULL C (valid when beta=0)
{InputCase::kBF16, false, false, ShapeCase::kAllSame, true},
};
INSTANTIATE_TEST_SUITE_P(OperatorTest,
GroupedGemmTest,
::testing::ValuesIn(kTestParams),
MakeGroupedGemmTestName);
} // namespace
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <numeric>
#include <random> #include <random>
#include <iostream> #include <iostream>
#include <cassert> #include <cassert>
...@@ -1116,4 +1117,166 @@ std::array<size_t, 4> get_scale_tensor_dims(const size_t rows, ...@@ -1116,4 +1117,166 @@ std::array<size_t, 4> get_scale_tensor_dims(const size_t rows,
return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X}; return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X};
} }
GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
const NVTEScalingMode scaling_mode) {
NVTE_CHECK(!tensors.empty(), "No tensors provided for grouped tensor build.");
const NVTEShape shape = tensors[0]->rowwise_shape();
const DType dtype = tensors[0]->dtype();
const size_t num_tensors = tensors.size();
const size_t elem_size = typeToNumBits(dtype) / 8;
GroupedBuffers grouped;
grouped.elem_size = elem_size;
grouped.num_tensors = num_tensors;
grouped.dtype = dtype;
grouped.scaling_mode = scaling_mode;
grouped.tensor_bytes.resize(num_tensors);
grouped.offsets_host.resize(num_tensors, 0);
std::vector<int64_t> first_dims(num_tensors);
std::vector<int64_t> last_dims(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
const auto s = tensors[i]->rowwise_shape();
NVTE_CHECK(s.ndim == 2, "Grouped tensor build expects 2D tensors.");
first_dims[i] = static_cast<int64_t>(s.data[0]);
last_dims[i] = static_cast<int64_t>(s.data[1]);
grouped.tensor_bytes[i] = bytes(s, dtype);
}
const bool same_first = std::all_of(first_dims.begin(), first_dims.end(),
[&](int64_t v) { return v == first_dims[0]; });
const bool same_last = std::all_of(last_dims.begin(), last_dims.end(),
[&](int64_t v) { return v == last_dims[0]; });
std::vector<int64_t> offsets(num_tensors, 0);
auto random_padding = [&]() -> int64_t {
// Random padding ensuring 16-byte alignment regardless of element size
// cuBLAS requires aligned pointers for vectorized loads
static std::mt19937 gen(12345);
std::uniform_int_distribution<int64_t> dist(0, 3);
// Calculate elements needed for 16-byte alignment in bytes, rounded up
const size_t align_elements =
std::max<size_t>(1, (16 + elem_size - 1) / elem_size); // 16 bytes / element_size
return dist(gen) * static_cast<int64_t>(align_elements);
};
auto numel = [&](size_t idx) -> int64_t {
return first_dims[idx] * last_dims[idx];
};
const bool need_offsets = !same_first || !same_last;
if (need_offsets) {
offsets[0] = 0;
for (size_t i = 1; i < num_tensors; ++i) {
offsets[i] = offsets[i - 1] + numel(i - 1) + random_padding();
}
} else {
for (size_t i = 0; i < num_tensors; ++i) {
offsets[i] = static_cast<int64_t>(i) * numel(0);
}
}
grouped.offsets_host = offsets;
int64_t logical_first = 0;
int64_t logical_last = 0;
if (same_first && same_last) {
logical_first = first_dims[0] * static_cast<int64_t>(num_tensors);
logical_last = last_dims[0];
} else if (same_first && !same_last) {
logical_first = first_dims[0];
logical_last = std::accumulate(last_dims.begin(), last_dims.end(), int64_t{0});
} else if (!same_first && same_last) {
logical_first = std::accumulate(first_dims.begin(), first_dims.end(), int64_t{0});
logical_last = last_dims[0];
} else {
logical_first = 1;
logical_last = 0;
for (size_t i = 0; i < num_tensors; ++i) {
logical_last += first_dims[i] * last_dims[i];
}
}
size_t logical_data[2] = {static_cast<size_t>(logical_first),
static_cast<size_t>(logical_last)};
grouped.logical_shape = nvte_make_shape(logical_data, 2);
grouped.handle.reset(nvte_create_grouped_tensor(scaling_mode, num_tensors, grouped.logical_shape));
const int64_t last_idx = static_cast<int64_t>(num_tensors - 1);
const int64_t total_elems = need_offsets
? (offsets[last_idx] + numel(last_idx))
: (logical_first * logical_last);
const size_t total_bytes = static_cast<size_t>(total_elems) * elem_size;
grouped.data = cuda_alloc(total_bytes);
for (size_t i = 0; i < num_tensors; ++i) {
const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(static_cast<char*>(grouped.data.get()) + offset_bytes,
tensors[i]->rowwise_dptr(),
grouped.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
}
NVTEBasicTensor data_tensor{grouped.data.get(), static_cast<NVTEDType>(dtype), grouped.logical_shape};
NVTEGroupedTensor h = grouped.handle.get();
nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseData, &data_tensor);
const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype);
if (include_columnwise) {
grouped.columnwise_data = cuda_alloc(total_bytes);
for (size_t i = 0; i < num_tensors; ++i) {
const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(static_cast<char*>(grouped.columnwise_data.get()) + offset_bytes,
tensors[i]->columnwise_dptr(),
grouped.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
}
NVTEBasicTensor col_tensor{grouped.columnwise_data.get(),
static_cast<NVTEDType>(dtype),
grouped.logical_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseData, &col_tensor);
}
if (!same_first) {
grouped.first_dims_dev = cuda_alloc<int64_t>(num_tensors * sizeof(int64_t));
NVTE_CHECK_CUDA(cudaMemcpy(grouped.first_dims_dev.get(), first_dims.data(),
num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice));
NVTEShape fd_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor fd_tensor{grouped.first_dims_dev.get(), kNVTEInt64, fd_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedFirstDims, &fd_tensor);
}
if (!same_last) {
grouped.last_dims_dev = cuda_alloc<int64_t>(num_tensors * sizeof(int64_t));
NVTE_CHECK_CUDA(cudaMemcpy(grouped.last_dims_dev.get(), last_dims.data(),
num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice));
NVTEShape ld_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor ld_tensor{grouped.last_dims_dev.get(), kNVTEInt64, ld_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedLastDims, &ld_tensor);
}
if (!same_first || !same_last) {
grouped.offsets_dev = cuda_alloc<int64_t>(num_tensors * sizeof(int64_t));
NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev.get(), offsets.data(),
num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice));
NVTEShape off_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor off_tensor{grouped.offsets_dev.get(), kNVTEInt64, off_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedTensorOffsets, &off_tensor);
}
if (isFp8Type(dtype)) {
std::vector<float> scale_inv_cpu(num_tensors, 1.f);
for (size_t i = 0; i < num_tensors; ++i) {
tensors[i]->to_cpu();
scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr<float>()[0];
}
grouped.scale_inv = cuda_alloc(sizeof(float) * num_tensors);
NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv.get(), scale_inv_cpu.data(),
sizeof(float) * num_tensors, cudaMemcpyHostToDevice));
NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor scale_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseScaleInv, &scale_tensor);
nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor);
}
return grouped;
}
} // namespace test } // namespace test
...@@ -446,10 +446,14 @@ inline fp8e8m0 float_to_e8m0(float val) { ...@@ -446,10 +446,14 @@ inline fp8e8m0 float_to_e8m0(float val) {
} }
inline float exp2f_rcp(fp8e8m0 biased_exp) { inline float exp2f_rcp(fp8e8m0 biased_exp) {
if (biased_exp == 0) { int32_t int_val = 0;
return 1.0f; if (biased_exp == 255) {
int_val = 0x7fffffff;
} else if (biased_exp == 254) {
int_val = 0x00400000;
} else {
int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127)
} }
int32_t int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127)
float fp32_val = *reinterpret_cast<float*>(&int_val); float fp32_val = *reinterpret_cast<float*>(&int_val);
return fp32_val; return fp32_val;
} }
...@@ -525,6 +529,60 @@ int32_t getDeviceComputeCapability(); ...@@ -525,6 +529,60 @@ int32_t getDeviceComputeCapability();
constexpr int32_t hopperComputeCapability = 90; constexpr int32_t hopperComputeCapability = 90;
constexpr int32_t blackwellComputeCapability = 100; constexpr int32_t blackwellComputeCapability = 100;
// Custom deleters for RAII
struct CudaDeleter {
void operator()(void* p) const { if (p) cudaFree(p); }
};
struct GroupedTensorDeleter {
void operator()(NVTEGroupedTensor h) const { if (h) nvte_destroy_grouped_tensor(h); }
};
template <typename T = void>
using CudaPtr = std::unique_ptr<T, CudaDeleter>;
using GroupedTensorHandle = std::unique_ptr<std::remove_pointer_t<NVTEGroupedTensor>, GroupedTensorDeleter>;
// Helper to allocate CUDA memory into a CudaPtr
template <typename T = void>
CudaPtr<T> cuda_alloc(size_t bytes) {
void* ptr = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&ptr, bytes));
return CudaPtr<T>(static_cast<T*>(ptr));
}
// Helper owning GPU buffers that back NVTEGroupedTensor.
// NVTEGroupedTensor does not own memory; data/offsets/scales
// must be allocated and freed by the test.
struct GroupedBuffers {
GroupedTensorHandle handle;
CudaPtr<> data;
CudaPtr<> scale_inv;
CudaPtr<int64_t> first_dims_dev;
CudaPtr<int64_t> last_dims_dev;
CudaPtr<int64_t> offsets_dev;
CudaPtr<> columnwise_data;
NVTEShape logical_shape{};
std::vector<int64_t> offsets_host;
std::vector<size_t> tensor_bytes;
size_t num_tensors{0};
size_t elem_size{0};
DType dtype{DType::kFloat32};
NVTEScalingMode scaling_mode{NVTE_DELAYED_TENSOR_SCALING};
GroupedBuffers() = default;
GroupedBuffers(const GroupedBuffers&) = delete;
GroupedBuffers& operator=(const GroupedBuffers&) = delete;
GroupedBuffers(GroupedBuffers&&) = default;
GroupedBuffers& operator=(GroupedBuffers&&) = default;
~GroupedBuffers() = default;
// Convenience accessors for raw pointers
NVTEGroupedTensor get_handle() const { return handle.get(); }
void* get_data() const { return data.get(); }
};
GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
const NVTEScalingMode scaling_mode);
} // namespace test } // namespace test
#if FP4_TYPE_SUPPORTED #if FP4_TYPE_SUPPORTED
......
...@@ -1921,3 +1921,37 @@ class TestGroupedDense: ...@@ -1921,3 +1921,37 @@ class TestGroupedDense:
assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype) assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype)
assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype)
assert_allclose(prim_dbias, ref_dbias, dtype=dtype) assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
class TestDebugInspectFFI:
@pytest_parametrize_wrapper("shape", [(256, 128)])
@pytest_parametrize_wrapper(
"dtype",
[
jnp.float32,
jnp.bfloat16,
jnp.float16,
# Note: fp4 currently doesn't work
# jnp.float4_e2m1fn
]
+ ([jnp.float8_e4m3fn, jnp.float8_e5m2] if is_fp8_supported else []),
)
def test_debug_inspect_ffi(self, shape, dtype):
from transformer_engine.jax.debug.experimental import inspect_array, load_array_dump
def f(x):
x = x + 1
x = inspect_array(x, "my_array")
x = x + 1
return x
key = jax.random.PRNGKey(0)
x = jax.random.uniform(key, shape, jnp.float32)
x = x.astype(dtype)
_ = jax.jit(f)(x)
expected = x + 1
actual = load_array_dump("my_tensor_gpu0.bin", shape, dtype)
assert_allclose(actual, expected, dtype=dtype)
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Tests for fused attention""" """Tests for fused attention"""
import os
from enum import Enum, auto from enum import Enum, auto
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import partial from functools import partial
...@@ -49,6 +50,9 @@ from transformer_engine_jax import ( ...@@ -49,6 +50,9 @@ from transformer_engine_jax import (
from distributed_test_base import assert_equal_collectives from distributed_test_base import assert_equal_collectives
from utils import assert_allclose, print_debug_tensor_stats from utils import assert_allclose, print_debug_tensor_stats
# Get determinism
_deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
def init(): def init():
...@@ -413,15 +417,25 @@ class FusedAttnRunner: ...@@ -413,15 +417,25 @@ class FusedAttnRunner:
pytest.skip( pytest.skip(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
) )
# TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support
if ( if get_device_compute_capability(0) >= 100 and self.is_training:
get_device_compute_capability(0) >= 100 if FusedAttnHelper.is_non_deterministic_allowed() and (
and self.dropout_prob == 0.1 (self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS)
and self.attn_bias_type is not AttnBiasType.NO_BIAS or get_cudnn_version() < 90700
): ):
pytest.skip( pytest.skip(
"For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" "For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with"
) " dropout"
)
if not FusedAttnHelper.is_non_deterministic_allowed() and (
self.dropout_prob != 0.0
or self.attn_bias_type != AttnBiasType.NO_BIAS
or get_cudnn_version() < 91801
):
pytest.skip(
"For sm100+, deterministic bprop (cuDNN 9.18.1+) does not support bias or"
" dropout"
)
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate(): if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():
...@@ -1269,6 +1283,7 @@ class FusedAttnRunner: ...@@ -1269,6 +1283,7 @@ class FusedAttnRunner:
pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"), pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"),
], ],
) )
@pytest.mark.skipif(_deterministic, reason="Test non-determinism only")
class TestFusedAttn: class TestFusedAttn:
""" """
Fused attention tester Fused attention tester
...@@ -1392,3 +1407,182 @@ class TestFusedAttn: ...@@ -1392,3 +1407,182 @@ class TestFusedAttn:
seq_desc_format, seq_desc_format,
) )
runner.test_backward() runner.test_backward()
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
pytest.param(
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT"
),
],
)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
],
)
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout",
[
# large data size + fp16 + cross attn + gqa + diff hidden v dim + qkv separate
pytest.param(
2,
1024,
2048,
12,
6,
128,
64,
jnp.bfloat16,
QKVLayout.BSHD_BSHD_BSHD,
id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-SEPARATE",
),
pytest.param(
2,
1024,
2048,
12,
6,
128,
64,
jnp.bfloat16,
QKVLayout.THD_THD_THD,
id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-RAGGED_SEPARATE",
),
],
)
@pytest.mark.parametrize(
"dropout_prob",
[
pytest.param(0.0, id="DROP_0.0"),
],
)
@pytest.mark.parametrize(
"swa",
[
pytest.param(False, id="NO_SWA"),
],
)
@pytest.mark.parametrize(
"seq_desc_format",
[
pytest.param(SeqDescFormat.Seqlens, id="Seqlens"),
],
)
@pytest.mark.skipif(not _deterministic, reason="Test determinism only")
class TestFusedAttnWithDeterminism:
"""
Fused attention tester with determinism
"""
@staticmethod
@pytest.mark.parametrize(
"is_training",
[
pytest.param(True, id="TRAINING"),
],
)
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
def _test_forward(
b,
s_q,
s_kv,
h_q,
h_kv,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
):
"""
Test forward with parameterized configs
This test is not intended to run automatically during CI as it is time-consuming
It is kept for development and debugging
"""
TestFusedAttn._test_forward(
b,
s_q,
s_kv,
h_q,
h_kv,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
)
@staticmethod
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
def test_backward(
b,
s_q,
s_kv,
h_q,
h_kv,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
):
"""
Test backward with parameterized configs
"""
TestFusedAttn.test_backward(
b,
s_q,
s_kv,
h_q,
h_kv,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment