Commit 9df0c4a3 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents 0d874a4e f122b07d
......@@ -23,8 +23,6 @@ Additionally, we offer [Jupyter notebook tutorials](https://github.com/NVIDIA/Tr
- **FP8 Weight Caching**: Avoiding redundant FP8 casting during multiple gradient accumulation steps to improve efficiency.
- [Introduction to FP8](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/fp8_primer.ipynb)
- Overview of FP8 datatypes (E4M3, E5M2), mixed precision training, delayed scaling strategies, and code examples for FP8 configuration and usage.
- [TE Quickstart](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/quickstart.ipynb)
- Introduction to TE, building a Transformer Layer using PyTorch, and instructions on integrating TE modules like Linear and LayerNorm.
- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/mnist)
# JAX
......@@ -34,6 +32,8 @@ Additionally, we offer [Jupyter notebook tutorials](https://github.com/NVIDIA/Tr
- Model Parallelism: Divide a model across multiple GPUs for parallel training.
- Multiprocessing with Model Parallelism: Multiprocessing for model parallelism, including multi-node support and hardware affinity setup.
- [Basic MNIST Example](https://github.com/NVIDIA/TransformerEngine/tree/main/examples/jax/mnist)
- [TE JAX Integration Tutorial](https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_jax_integration.ipynb)
- Introduction to integrating TE into an existing JAX model framework, building a Transformer Layer, and instructions on integrating TE modules like Linear and LayerNorm.
# Third party
- [Hugging Face Accelerate + TE](https://github.com/huggingface/accelerate/tree/main/benchmarks/fp8/transformer_engine)
......
# Datasets used by TE encoder tests. Pull these to pre-emptively cache datasets
ylecun/mnist
nyu-mll/glue
\ No newline at end of file
......@@ -219,11 +219,11 @@ def get_datasets(max_seq_len):
vocab = {}
word_id = 0
train_ds = load_dataset("glue", "cola", split="train")
train_ds = load_dataset("nyu-mll/glue", "cola", split="train")
train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset("glue", "cola", split="validation")
test_ds = load_dataset("nyu-mll/glue", "cola", split="validation")
test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
......
......@@ -197,11 +197,11 @@ def get_datasets(max_seq_len):
vocab = {}
word_id = 0
train_ds = load_dataset("glue", "cola", split="train")
train_ds = load_dataset("nyu-mll/glue", "cola", split="train")
train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset("glue", "cola", split="validation")
test_ds = load_dataset("nyu-mll/glue", "cola", split="validation")
test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
......
......@@ -307,11 +307,11 @@ def get_datasets(max_seq_len):
vocab = {}
word_id = 0
train_ds = load_dataset("glue", "cola", split="train")
train_ds = load_dataset("nyu-mll/glue", "cola", split="train")
train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset("glue", "cola", split="validation")
test_ds = load_dataset("nyu-mll/glue", "cola", split="validation")
test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
......
......@@ -195,11 +195,11 @@ def get_datasets(max_seq_len):
vocab = {}
word_id = 0
train_ds = load_dataset("glue", "cola", split="train")
train_ds = load_dataset("nyu-mll/glue", "cola", split="train")
train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset("glue", "cola", split="validation")
test_ds = load_dataset("nyu-mll/glue", "cola", split="validation")
test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
......
......@@ -146,7 +146,7 @@ def eval_model(state, test_ds, batch_size, var_collect):
def get_datasets():
"""Load MNIST train and test datasets into memory."""
train_ds = load_dataset("mnist", split="train", trust_remote_code=True)
train_ds = load_dataset("ylecun/mnist", split="train", trust_remote_code=True)
train_ds.set_format(type="np")
batch_size = train_ds["image"].shape[0]
shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
......@@ -154,7 +154,7 @@ def get_datasets():
"image": train_ds["image"].astype(np.float32).reshape(shape) / 255.0,
"label": train_ds["label"],
}
test_ds = load_dataset("mnist", split="test", trust_remote_code=True)
test_ds = load_dataset("ylecun/mnist", split="test", trust_remote_code=True)
test_ds.set_format(type="np")
batch_size = test_ds["image"].shape[0]
shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
......
......@@ -29,6 +29,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
mkdir -p "$XML_LOG_DIR"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_with_determinism.xml $TE_PATH/tests/jax/test_fused_attn.py -k "TestFusedAttnWithDeterminism" || test_fail "tests/jax/test_fused_attn.py"
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
......
......@@ -51,11 +51,14 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_e
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
mkdir -p $TE_PATH/artifacts/tests/pytorch/test_checkpoint && python $TE_PATH/tests/pytorch/test_checkpoint.py --save-checkpoint all --checkpoint-dir $TE_PATH/artifacts/tests/pytorch/test_checkpoint/
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_batched_linear.xml $TE_PATH/tests/pytorch/test_batched_linear.py || test_fail "test_batched_linear.py"
export NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint
if [ ! -d "$NVTE_TEST_CHECKPOINT_ARTIFACT_PATH" ]; then
python3 $TE_PATH/tests/pytorch/test_checkpoint.py --save-checkpoint all || error_exit "Failed to generate checkpoint files"
fi
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py"
......
......@@ -28,11 +28,11 @@ WHL_BASE="transformer_engine-${VERSION}"
# Core wheel.
NVTE_RELEASE_BUILD=1 pip3 wheel --no-build-isolation -vvv --wheel-dir ./dist . || error_exit "Failed to setup bdist_wheel"
wheel unpack dist/${WHL_BASE}-* || error_exit "Failed to unpack dist/${WHL_BASE}-*.whl"
python3 -m wheel unpack dist/${WHL_BASE}-* || error_exit "Failed to unpack dist/${WHL_BASE}-*.whl"
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}"
python3 -m wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}"
rm dist/*.whl || error_exit "Failed to remove dist/*.whl"
mv *.whl dist/ || error_exit "Failed to move *.whl to dist/"
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 pip3 wheel --no-build-isolation --no-deps -vvv --wheel-dir ./dist . || error_exit "Failed to setup metapackage"
......
......@@ -6,4 +6,5 @@
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py
# NVTE_UnfusedDPA_Emulate_FP8=1 enables FP8 attention emulation when no native backend is available
NVTE_UnfusedDPA_Emulate_FP8=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py
......@@ -106,11 +106,6 @@ def setup_common_extension() -> CMakeExtension:
f"nvidia-cublasmp-cu{cuda_version()[0]}"
).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}")
cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}")
nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution(
f"nvidia-nvshmem-cu{cuda_version()[0]}"
).locate_file("nvidia/nvshmem")
cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")
print("CMAKE_FLAGS:", cmake_flags[-2:])
# Add custom CMake arguments from environment variable
nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
......
......@@ -11,6 +11,7 @@ list(APPEND test_cuda_sources
test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu
test_cast_mxfp8.cu
test_cast_mxfp8_grouped.cu
test_cast_float8blockwise.cu
test_dequantize_mxfp8.cu
test_transpose.cu
......@@ -29,6 +30,7 @@ list(APPEND test_cuda_sources
test_causal_softmax.cu
test_swizzle.cu
test_swap_first_dims.cu
test_grouped_gemm.cu
../test_common.cu)
if(USE_ROCM)
list(APPEND test_cuda_sources
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/activation.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
using namespace transformer_engine;
using namespace test;
namespace {
enum ProcessingMethod {
CAST_ONLY,
CAST_DBIAS,
CAST_DBIAS_DACT,
CAST_DACT,
CAST_ACT
};
enum ActivationKind {
Identity,
GeLU,
SiLU,
ReLU,
QGeLU,
SReLU
};
enum ShapeRepresentation {
SAME_BOTH_DIMS = 0,
VARYING_FIRST_DIM = 1,
VARYING_LAST_DIM = 2,
VARYING_BOTH_DIMS = 3
};
template <typename InputType, typename OutputType>
void compute_ref(const ProcessingMethod processing_method,
float (*OP)(const float),
const bool rowwise,
const bool colwise,
const InputType* input,
const InputType* grad,
OutputType* output_rowwise,
OutputType* output_colwise,
fp8e8m0* output_scales_rowwise,
fp8e8m0* output_scales_colwise,
InputType* output_dbias,
const size_t rows,
const size_t cols,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise,
const bool is_single_tensor)
{
const size_t tile_size_Y = 32;
const size_t tile_size_X = 32;
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
std::vector<float> output_dbias_fp32(cols, 0);
#pragma omp parallel proc_bind(spread)
{
// Buffers to cache intermediate computations
std::vector<float> cache_buffer(tile_size_Y * tile_size_X);
std::vector<float> thread_dbias(cols, 0);
#pragma omp for schedule(static)
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
const size_t tile_Y = t / tiles_num_X;
const size_t tile_X = t % tiles_num_X;
const size_t tile_offset_Y = tile_Y * tile_size_Y;
const size_t tile_offset_X = tile_X * tile_size_X;
const size_t i_min = tile_offset_Y;
const size_t i_max = std::min(i_min + tile_size_Y, rows);
const size_t j_min = tile_offset_X;
const size_t j_max = std::min(j_min + tile_size_X, cols);
// Cache computations
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
float elt = static_cast<float>(input[idx]);
if (processing_method == ProcessingMethod::CAST_DBIAS) {
// grad is the input
elt = static_cast<float>(grad[idx]);
}
if (processing_method != ProcessingMethod::CAST_ONLY
&& processing_method != ProcessingMethod::CAST_DBIAS) {
elt = OP(elt);
}
if (processing_method == ProcessingMethod::CAST_DACT ||
processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
elt *= static_cast<float>(grad[idx]);
}
thread_dbias[j] += elt;
// Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32
elt = static_cast<float>(static_cast<InputType>(elt));
cache_buffer[cache_idx] = elt;
if (isinf(elt) || isnan(elt)) {
continue;
}
}
}
if (rowwise) {
for (size_t i = i_min; i < i_max; ++i) {
float block_amax = 0.0f;
for (size_t j = j_min; j < j_max; ++j) {
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
}
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const size_t scale_idx = i * scales_stride_rowwise + tile_X;
output_scales_rowwise[scale_idx] = biased_exponent;
const float scale_reciprocal = exp2f_rcp(biased_exponent);
for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
output_rowwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
}
}
}
if (colwise) {
for (size_t j = j_min; j < j_max; ++j) {
float block_amax = 0.0f;
for (size_t i = i_min; i < i_max; ++i) {
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
}
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const size_t scale_idx = tile_Y * scales_stride_colwise + j;
output_scales_colwise[scale_idx] = biased_exponent;
const float scale_reciprocal = exp2f_rcp(biased_exponent);
for (size_t i = i_min; i < i_max; ++i) {
const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
output_colwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
}
}
}
}
#pragma omp critical
{
for (size_t j = 0; j < cols; ++j) {
output_dbias_fp32[j] += thread_dbias[j];
}
}
}
if (is_single_tensor) {
for (size_t j = 0; j < cols; ++j) {
output_dbias[j] = static_cast<InputType>(output_dbias_fp32[j]);
}
}
}
template <typename T>
void compare_scaled_elts(const std::string &name,
const T* ref_data,
const T* test_data,
const size_t rows,
const size_t cols,
const bool rowwise,
const size_t tolerable_mismatches_limit = 0,
const double atol = 1e-5,
const double rtol = 1e-8) {
size_t mismatches_num = 0;
int first_mismatch_idx = -1;
for (size_t i = 0; i < rows * cols; ++i) {
double t = static_cast<double>(test_data[i]);
double r = static_cast<double>(ref_data[i]);
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
/* For Float32 the floating point comparison is enough to error out */
bool assertion = false;
if (mismatch && !assertion) {
/* Check if it is just a failure of round to nearest choosing different
side of the real value */
const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
}
std::string direction = rowwise ? "rowwise" : "columnwise";
if (assertion) {
mismatches_num++;
if (first_mismatch_idx == -1) {
first_mismatch_idx = i;
}
}
if (mismatches_num > tolerable_mismatches_limit) {
const double first_mismatch_t = static_cast<double>(test_data[first_mismatch_idx]);
const double first_mismatch_r = static_cast<double>(ref_data[first_mismatch_idx]);
GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
<< tolerable_mismatches_limit << "." << std::endl
<< "Error in tensor " << name << " in "
<< direction << " direction." << std::endl
<< "First mismatch at place " << first_mismatch_idx
<< " (" << std::to_string(first_mismatch_idx) << "): "
<< first_mismatch_t << " vs " << first_mismatch_r;
}
}
}
/**
* Scaling along single dimension (either rows or columns)
* Produces one set of output data and the corresponding data of the fused operation (dbias):
* 1) Scaled rows + row-wise scaling factors
* OR
* 2) Scaled columns + column-wise scaling factors
*/
template <typename InputType, typename OutputType>
void performTest(const ProcessingMethod processing_method,
float (*OP)(const float),
const ShapeRepresentation shape_rep,
const size_t num_tensors,
const std::vector<size_t>& logical_shape_vec,
const std::vector<size_t>& first_dims_h,
const std::vector<size_t>& last_dims_h,
const std::vector<size_t>& offsets_h,
const bool rowwise,
const bool colwise) {
using namespace test;
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
const size_t rows = logical_shape_vec[0];
const size_t cols = logical_shape_vec[1];
size_t elts_num = 0;
size_t rowwise_sfs_num = 0;
size_t colwise_sfs_num = 0;
std::vector<size_t> rowwise_scales_first_dim(num_tensors, 0);
std::vector<size_t> rowwise_scales_last_dim(num_tensors, 0);
std::vector<size_t> rowwise_scales_offset(num_tensors + 1, 0);
std::vector<size_t> colwise_scales_first_dim(num_tensors, 0);
std::vector<size_t> colwise_scales_last_dim(num_tensors, 0);
std::vector<size_t> colwise_scales_offset(num_tensors + 1, 0);
for (size_t t = 0; t < num_tensors; ++t) {
const size_t M = first_dims_h[t];
const size_t K = last_dims_h[t];
const size_t elts = M * K;
elts_num += elts;
const size_t unpadded_rowwise_blocks_Y = M;
const size_t unpadded_rowwise_blocks_X = divide_round_up(K, 32);
const size_t unpadded_colwise_blocks_Y = divide_round_up(M, 32);
const size_t unpadded_colwise_blocks_X = K;
rowwise_scales_first_dim[t] = round_up_to_nearest_multiple(unpadded_rowwise_blocks_Y, 128);
rowwise_scales_last_dim[t] = round_up_to_nearest_multiple(unpadded_rowwise_blocks_X, 4);
colwise_scales_first_dim[t] = round_up_to_nearest_multiple(unpadded_colwise_blocks_Y, 4);
colwise_scales_last_dim[t] = round_up_to_nearest_multiple(unpadded_colwise_blocks_X, 128);
const size_t rowwise_sfs = rowwise_scales_first_dim[t] * rowwise_scales_last_dim[t];
const size_t colwise_sfs = colwise_scales_first_dim[t] * colwise_scales_last_dim[t];
rowwise_sfs_num += rowwise_sfs;
colwise_sfs_num += colwise_sfs;
rowwise_scales_offset[t+1] = rowwise_sfs_num;
colwise_scales_offset[t+1] = colwise_sfs_num;
}
const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM);
std::vector<size_t> scales_rowwise_shape = {rowwise_sfs_num};
std::vector<size_t> scales_colwise_shape = {colwise_sfs_num};
std::mt19937 gen;
std::uniform_real_distribution<> dis(-2.0, 1.0);
std::vector<InputType> in_data(elts_num);
std::vector<InputType> grad_data(elts_num);
std::vector<OutputType> out_data_rowwise_h(rowwise ? elts_num : 0);
std::vector<OutputType> out_data_colwise_h(colwise ? elts_num : 0);
std::vector<fp8e8m0> out_scales_rowwise_h(rowwise ? rowwise_sfs_num : 0);
std::vector<fp8e8m0> out_scales_colwise_h(colwise ? colwise_sfs_num : 0);
std::vector<OutputType> out_data_rowwise_ref(rowwise ? elts_num : 0);
std::vector<OutputType> out_data_colwise_ref(colwise ? elts_num : 0);
std::vector<fp8e8m0> out_scales_rowwise_ref(rowwise ? rowwise_sfs_num : 0);
std::vector<fp8e8m0> out_scales_colwise_ref(colwise ? colwise_sfs_num : 0);
std::vector<InputType> ref_output_dbias(is_single_tensor ? cols : 0);
for (size_t i = 0; i < elts_num; ++i) {
const float val = dis(gen);
grad_data[i] = static_cast<InputType>(val);
in_data[i] = static_cast<InputType>(val);
}
const OutputType zero_elt = static_cast<OutputType>(0.0f);
const fp8e8m0 zero_SF = static_cast<fp8e8m0>(0.0f);
if (rowwise) {
std::fill(out_data_rowwise_h.begin(), out_data_rowwise_h.end(), zero_elt);
std::fill(out_data_rowwise_ref.begin(), out_data_rowwise_ref.end(), zero_elt);
std::fill(out_scales_rowwise_h.begin(), out_scales_rowwise_h.end(), zero_SF);
std::fill(out_scales_rowwise_ref.begin(), out_scales_rowwise_ref.end(), zero_SF);
}
if (colwise) {
std::fill(out_data_colwise_h.begin(), out_data_colwise_h.end(), zero_elt);
std::fill(out_data_colwise_ref.begin(), out_data_colwise_ref.end(), zero_elt);
std::fill(out_scales_colwise_h.begin(), out_scales_colwise_h.end(), zero_SF);
std::fill(out_scales_colwise_ref.begin(), out_scales_colwise_ref.end(), zero_SF);
}
const size_t in_data_size = elts_num * sizeof(InputType);
const size_t out_data_size = elts_num * sizeof(OutputType);
const size_t rowwise_scales_size = rowwise_sfs_num * sizeof(fp8e8m0);
const size_t colwise_scales_size = colwise_sfs_num * sizeof(fp8e8m0);
const size_t first_dims_size = num_tensors * sizeof(size_t);
const size_t last_dims_size = num_tensors * sizeof(size_t);
const size_t offsets_size = (num_tensors + 1) * sizeof(size_t);
InputType* grad_data_d;
InputType* in_data_d;
OutputType* out_data_rowwise_d;
OutputType* out_data_colwise_d;
fp8e8m0* out_scales_rowwise_d;
fp8e8m0* out_scales_colwise_d;
size_t* first_dims_d;
size_t* last_dims_d;
size_t* offsets_d;
cudaMalloc((void**)&grad_data_d, in_data_size);
cudaMalloc((void**)&in_data_d, in_data_size);
cudaMalloc((void**)&first_dims_d, first_dims_size);
cudaMalloc((void**)&last_dims_d, last_dims_size);
cudaMalloc((void**)&offsets_d, offsets_size);
cudaMemcpy(grad_data_d, grad_data.data(), in_data_size, cudaMemcpyHostToDevice);
cudaMemcpy(in_data_d, in_data.data(), in_data_size, cudaMemcpyHostToDevice);
cudaMemcpy(first_dims_d, first_dims_h.data(), first_dims_size, cudaMemcpyHostToDevice);
cudaMemcpy(last_dims_d, last_dims_h.data(), last_dims_size, cudaMemcpyHostToDevice);
cudaMemcpy(offsets_d, offsets_h.data(), offsets_size, cudaMemcpyHostToDevice);
NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size());
NVTEShape first_dims_shape_;
NVTEShape last_dims_shape_;
NVTEShape offsets_shape_;
first_dims_shape_.ndim = 1;
last_dims_shape_.ndim = 1;
offsets_shape_.ndim = 1;
first_dims_shape_.data[0] = num_tensors;
last_dims_shape_.data[0] = num_tensors;
offsets_shape_.data[0] = num_tensors + 1;
NVTEGroupedTensor grad_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_);
NVTEGroupedTensor in_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_);
NVTEGroupedTensor out_group_tensor = nvte_create_grouped_tensor(NVTE_MXFP8_1D_SCALING, num_tensors, logical_shape_);
NVTEBasicTensor grad_data_tensor = {grad_data_d, static_cast<NVTEDType>(itype), logical_shape_};
NVTEBasicTensor in_data_tensor = {in_data_d, static_cast<NVTEDType>(itype), logical_shape_};
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor);
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &grad_data_tensor);
if ((shape_rep == VARYING_FIRST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) {
NVTEBasicTensor first_dims_tensor = {first_dims_d, kNVTEInt64, first_dims_shape_};
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor);
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor);
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor);
}
if ((shape_rep == VARYING_LAST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) {
NVTEBasicTensor last_dims_tensor = {last_dims_d, kNVTEInt64, last_dims_shape_};
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor);
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor);
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor);
}
if (shape_rep != SAME_BOTH_DIMS) {
NVTEBasicTensor offsets_tensor = {offsets_d, kNVTEInt64, offsets_shape_};
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor);
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor);
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor);
}
if (rowwise) {
cudaMalloc((void**)&out_data_rowwise_d, out_data_size);
cudaMalloc((void**)&out_scales_rowwise_d, rowwise_scales_size);
cudaMemset(out_data_rowwise_d, 0, out_data_size);
cudaMemset(out_scales_rowwise_d, 0, rowwise_scales_size);
NVTEBasicTensor out_data_rowwise_tensor = {out_data_rowwise_d, static_cast<NVTEDType>(otype), logical_shape_};
NVTEShape scales_rowwise_shape_ = nvte_make_shape(scales_rowwise_shape.data(), scales_rowwise_shape.size());
NVTEBasicTensor out_scales_rowwise_tensor = {out_scales_rowwise_d, NVTEDType::kNVTEFloat8E8M0, scales_rowwise_shape_};
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &out_data_rowwise_tensor);
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseScaleInv, &out_scales_rowwise_tensor);
}
if (colwise) {
cudaMalloc((void**)&out_data_colwise_d, out_data_size);
cudaMalloc((void**)&out_scales_colwise_d, colwise_scales_size);
cudaMemset(out_data_colwise_d, 0, out_data_size);
cudaMemset(out_scales_colwise_d, 0, colwise_scales_size);
NVTEBasicTensor out_data_colwise_tensor = {out_data_colwise_d, static_cast<NVTEDType>(otype), logical_shape_};
NVTEShape scales_colwise_shape_ = nvte_make_shape(scales_colwise_shape.data(), scales_colwise_shape.size());
NVTEBasicTensor out_scales_colwise_tensor = {out_scales_colwise_d, NVTEDType::kNVTEFloat8E8M0, scales_colwise_shape_};
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseData, &out_data_colwise_tensor);
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv, &out_scales_colwise_tensor);
}
Tensor output_dbias("output_dbias", std::vector<size_t>{ cols }, itype);
// Reference (CPU)
if (is_single_tensor) {
const size_t unpadded_rowwise_blocks_X = divide_round_up(cols, 32);
const size_t unpadded_colwise_blocks_X = cols;
const size_t scales_stride_rowwise = round_up_to_nearest_multiple(unpadded_rowwise_blocks_X, 4);
const size_t scales_stride_colwise = round_up_to_nearest_multiple(unpadded_colwise_blocks_X, 128);
compute_ref<InputType, OutputType>(
processing_method, OP, rowwise, colwise, in_data.data(), grad_data.data(),
out_data_rowwise_ref.data(), out_data_colwise_ref.data(),
out_scales_rowwise_ref.data(), out_scales_colwise_ref.data(),
ref_output_dbias.data(), rows, cols,
scales_stride_rowwise,
scales_stride_colwise,
is_single_tensor);
} else {
for (size_t t = 0; t < num_tensors; ++t) {
const size_t M = first_dims_h[t];
const size_t K = last_dims_h[t];
const size_t scales_stride_rowwise = rowwise_scales_last_dim[t];
const size_t scales_stride_colwise = colwise_scales_last_dim[t];
const size_t data_offset = offsets_h[t];
const size_t rowwise_sfs_offset = rowwise_scales_offset[t];
const size_t colwise_sfs_offset = colwise_scales_offset[t];
const InputType* const grad_ptr = grad_data.data() + data_offset;
const InputType* const in_ptr = in_data.data() + data_offset;
OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset;
OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset;
fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + rowwise_sfs_offset;
fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + colwise_sfs_offset;
compute_ref<InputType, OutputType>(
processing_method, OP, rowwise, colwise, in_ptr, grad_ptr,
out_data_rowwise_ptr, out_data_colwise_ptr,
out_scales_rowwise_ptr, out_scales_colwise_ptr,
ref_output_dbias.data(), M, K,
scales_stride_rowwise,
scales_stride_colwise,
is_single_tensor);
}
}
// GPU
Tensor workspace;
switch (processing_method) {
case ProcessingMethod::CAST_ONLY: {
nvte_group_quantize(in_group_tensor, out_group_tensor, 0);
break;
}
case ProcessingMethod::CAST_DBIAS: {
nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0);
break;
}
case ProcessingMethod::CAST_DBIAS_DACT: {
auto nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dgelu;
if (OP == &dsilu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsilu; }
else if (OP == &drelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_drelu; }
else if (OP == &dqgelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dqgelu; }
else if (OP == &dsrelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsrelu; }
nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor,
output_dbias.data(), workspace.data(), 0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor,
output_dbias.data(), workspace.data(), 0);
break;
}
case ProcessingMethod::CAST_ACT: {
auto nvte_group_act = &nvte_group_gelu;
if (OP == &silu) { nvte_group_act = &nvte_group_silu; }
else if (OP == &relu) { nvte_group_act = &nvte_group_relu; }
else if (OP == &qgelu) { nvte_group_act = &nvte_group_qgelu; }
else if (OP == &srelu) { nvte_group_act = &nvte_group_srelu; }
nvte_group_act(in_group_tensor, out_group_tensor, 0);
break;
}
case ProcessingMethod::CAST_DACT: {
auto nvte_group_dact = &nvte_group_dgelu;
if (OP == &dsilu) { nvte_group_dact = &nvte_group_dsilu; }
else if (OP == &drelu) { nvte_group_dact = &nvte_group_drelu; }
else if (OP == &dqgelu) { nvte_group_dact = &nvte_group_dqgelu; }
else if (OP == &dsrelu) { nvte_group_dact = &nvte_group_dsrelu; }
nvte_group_dact(grad_group_tensor, in_group_tensor, out_group_tensor, 0);
break;
}
}
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol, rtol] = getTolerances(otype);
const size_t scale_diff_abs_tolerance = 0;
const double abs_tolerable_mismatches_limit = 0.0;
const double rel_tolerable_mismatches_limit = 0.0;
if (rowwise) {
cudaMemcpy(out_data_rowwise_h.data(), out_data_rowwise_d, out_data_size, cudaMemcpyDeviceToHost);
cudaMemcpy(out_scales_rowwise_h.data(), out_scales_rowwise_d, rowwise_scales_size, cudaMemcpyDeviceToHost);
size_t mismatches_scales = 0;
compare_scaling_factors("rowwise_scales", out_scales_rowwise_h.data(), out_scales_rowwise_ref.data(),
1, rowwise_sfs_num, rowwise_sfs_num, mismatches_scales, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit);
const size_t mismatches_elts = 32 * mismatches_scales;
compare_scaled_elts<OutputType>("rowwise_output", out_data_rowwise_ref.data(),
out_data_rowwise_h.data(), rows, cols, true, mismatches_elts);
}
if (colwise) {
cudaMemcpy(out_data_colwise_h.data(), out_data_colwise_d, out_data_size, cudaMemcpyDeviceToHost);
cudaMemcpy(out_scales_colwise_h.data(), out_scales_colwise_d, colwise_scales_size, cudaMemcpyDeviceToHost);
size_t mismatches_scales = 0;
compare_scaling_factors("colwise_scales", out_scales_colwise_h.data(), out_scales_colwise_ref.data(),
1, colwise_sfs_num, colwise_sfs_num, mismatches_scales, scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit);
const size_t mismatches_elts = 32 * mismatches_scales;
compare_scaled_elts<OutputType>("colwise_output", out_data_colwise_ref.data(),
out_data_colwise_h.data(), rows, cols, false, mismatches_elts);
}
if (processing_method == ProcessingMethod::CAST_DBIAS
|| processing_method == ProcessingMethod::CAST_DBIAS_DACT)
{
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
if (itype == DType::kFloat32) {
atol_dbias = 1e-4;
rtol_dbias *= sqrt(static_cast<double>(rows)) ;
} else {
rtol_dbias *= 4;
}
compareResults("output_dbias", output_dbias, ref_output_dbias.data(), true, atol_dbias, rtol_dbias);
}
cudaFree(grad_data_d);
cudaFree(in_data_d);
cudaFree(first_dims_d);
cudaFree(last_dims_d);
cudaFree(offsets_d);
if (rowwise) {
cudaFree(out_data_rowwise_d);
cudaFree(out_scales_rowwise_d);
}
if (colwise) {
cudaFree(out_data_colwise_d);
cudaFree(out_scales_colwise_d);
}
}
std::vector<ProcessingMethod> processing_methods = {
ProcessingMethod::CAST_ONLY,
ProcessingMethod::CAST_DBIAS,
ProcessingMethod::CAST_DBIAS_DACT,
ProcessingMethod::CAST_DACT,
ProcessingMethod::CAST_ACT,
};
std::vector<ActivationKind> activation_kinds = {
ActivationKind::Identity,
ActivationKind::GeLU,
// ActivationKind::SiLU,
// ActivationKind::ReLU,
// ActivationKind::QGeLU,
// ActivationKind::SReLU,
};
enum ScalingDirection {
ROWWISE = 0,
COLWISE = 1,
BOTH = 2
};
std::vector<ScalingDirection> scaling_directions = {
ScalingDirection::ROWWISE,
ScalingDirection::COLWISE,
ScalingDirection::BOTH,
};
// {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]}
std::vector<std::vector<size_t>> input_config = {
{SAME_BOTH_DIMS, 1, 128,128},
{SAME_BOTH_DIMS, 2, 256,128},
{VARYING_FIRST_DIM, 2, 512,128, 128,384},
{VARYING_FIRST_DIM, 2, 384,160, 128,256},
{VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304},
{VARYING_LAST_DIM, 3, 256,896, 128,256,512},
{VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256},
{VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640},
};
} // namespace
class GroupedFusedCastMXFP8TestSuite : public ::testing::TestWithParam
<std::tuple<ProcessingMethod,
ActivationKind,
ScalingDirection,
std::vector<size_t>, // Config
transformer_engine::DType, // InputType
transformer_engine::DType // OutputType
>> {};
TEST_P(GroupedFusedCastMXFP8TestSuite, Test) {
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
using namespace transformer_engine;
using namespace test;
const ProcessingMethod processing_method = std::get<0>(GetParam());
const ActivationKind activation = std::get<1>(GetParam());
const ScalingDirection scaling_direction = std::get<2>(GetParam());
const std::vector<size_t> input_config = std::get<3>(GetParam());
const DType input_type = std::get<4>(GetParam());
const DType output_type = std::get<5>(GetParam());
const ShapeRepresentation shape_rep = static_cast<ShapeRepresentation>(input_config[0]);
const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM);
const size_t num_tensors = input_config[1];
const std::vector<size_t> logical_shape = {input_config[2], input_config[3]};
std::vector<size_t> first_dims(num_tensors);
std::vector<size_t> last_dims(num_tensors);
std::vector<size_t> offsets(num_tensors + 1, 0);
for (size_t t = 0; t < num_tensors; ++t) {
switch (shape_rep) {
case SAME_BOTH_DIMS: {
first_dims[t] = logical_shape[0] / num_tensors;
last_dims[t] = logical_shape[1];
break;
}
case VARYING_FIRST_DIM: {
first_dims[t] = input_config[t + 4];
last_dims[t] = logical_shape[1];
break;
}
case VARYING_LAST_DIM: {
first_dims[t] = logical_shape[0];
last_dims[t] = input_config[t + 4];
break;
}
case VARYING_BOTH_DIMS: {
first_dims[t] = input_config[t + 4];
last_dims[t] = input_config[t + (4 + num_tensors)];
break;
}
}
offsets[t+1] = offsets[t] + first_dims[t] * last_dims[t];
// Skips tests if tensor shape is not as required by the kernel
if ((first_dims[t] % 128 != 0) || (last_dims[t] % 32 != 0)) {
GTEST_SKIP();
}
}
// Skips DBias tests if last dimension of tensors variates
if ((processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT)
&& !is_single_tensor) {
GTEST_SKIP();
}
// Skips non Act tests if the Activation type is not an identity
if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS)
&& activation != ActivationKind::Identity) {
GTEST_SKIP();
}
// Skips Act tests if the Activation is an identity
if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT
|| processing_method == ProcessingMethod::CAST_DACT
|| processing_method == ProcessingMethod::CAST_ACT) && (activation == ActivationKind::Identity)) {
GTEST_SKIP();
}
bool rowwise = false;
bool colwise = false;
switch (scaling_direction) {
case ScalingDirection::ROWWISE: rowwise = true; break;
case ScalingDirection::COLWISE: colwise = true; break;
case ScalingDirection::BOTH: rowwise = true; colwise = true; break;
}
auto OP = &identity;
if (processing_method == ProcessingMethod::CAST_ACT) {
switch (activation) {
case ActivationKind::GeLU: OP = &gelu; break;
case ActivationKind::SiLU: OP = &silu; break;
case ActivationKind::ReLU: OP = &relu; break;
case ActivationKind::QGeLU: OP = &qgelu; break;
case ActivationKind::SReLU: OP = &srelu; break;
}
} else if (processing_method == ProcessingMethod::CAST_DACT
|| processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
switch (activation) {
case ActivationKind::GeLU: OP = &dgelu; break;
case ActivationKind::SiLU: OP = &dsilu; break;
case ActivationKind::ReLU: OP = &drelu; break;
case ActivationKind::QGeLU: OP = &dqgelu; break;
case ActivationKind::SReLU: OP = &dsrelu; break;
}
}
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
performTest<InputType, OutputType>(processing_method, OP, shape_rep, num_tensors,
logical_shape, first_dims, last_dims, offsets,
rowwise, colwise);
);
);
}
std::string to_string(const ProcessingMethod method) {
switch (method) {
case ProcessingMethod::CAST_ONLY: return "CAST_ONLY";
case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS";
case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT";
case ProcessingMethod::CAST_DACT: return "CAST_DACT";
case ProcessingMethod::CAST_ACT: return "CAST_ACT";
default: return "";
}
}
std::string to_string(const ActivationKind activation) {
switch (activation) {
case ActivationKind::Identity: return "Identity";
case ActivationKind::GeLU: return "GeLU";
case ActivationKind::SiLU: return "SiLU";
case ActivationKind::ReLU: return "ReLU";
case ActivationKind::QGeLU: return "QGeLU";
case ActivationKind::SReLU: return "SReLU";
default: return "";
}
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
GroupedFusedCastMXFP8TestSuite,
::testing::Combine(
::testing::ValuesIn(processing_methods),
::testing::ValuesIn(activation_kinds),
::testing::ValuesIn(scaling_directions),
::testing::ValuesIn(input_config),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)),
[](const testing::TestParamInfo<GroupedFusedCastMXFP8TestSuite::ParamType>& info) {
const ProcessingMethod method = std::get<0>(info.param);
std::string name = to_string(method);
name += "X" + to_string(std::get<1>(info.param));
switch (std::get<2>(info.param)) {
case ScalingDirection::ROWWISE: name += "_ROWWISE_"; break;
case ScalingDirection::COLWISE: name += "_COLWISE_"; break;
case ScalingDirection::BOTH: name += "_BIDIMENSIONAL_"; break;
}
const std::vector<size_t> input = std::get<3>(info.param);
switch(static_cast<ShapeRepresentation>(input[0])) {
case ShapeRepresentation::SAME_BOTH_DIMS: name += "SAME_BOTH_DIMS"; break;
case ShapeRepresentation::VARYING_FIRST_DIM: name += "VARYING_FIRST_DIM"; break;
case ShapeRepresentation::VARYING_LAST_DIM: name += "VARYING_LAST_DIM"; break;
case ShapeRepresentation::VARYING_BOTH_DIMS: name += "VARYING_BOTH_DIMS"; break;
};
name += "_N_" + std::to_string(input[1]);
name += "_SHAPE_" +
std::to_string(input[2]) +
"X" + std::to_string(input[3]);
name += "_" + test::typeName(std::get<4>(info.param)) +
"_" + test::typeName(std::get<5>(info.param));
return name;
});
......@@ -54,12 +54,16 @@ std::vector<InputType> create_transpose(const InputType* const input, const size
}
// Compute the global encode scale factor for a given global amax
float compute_global_encode_scaling_factor_FP4(const float global_amax) {
float compute_global_encode_scaling_factor_FP4(const float global_amax, const bool use_fast_math) {
constexpr float fp8_max = 448.0f; // 448.0f;
constexpr float fp4_max = 6.0f; // 6.0f;
float global_encode_scale = fp8_max * fp4_max / global_amax;
// If scale is infinity, return max value of float32
global_encode_scale = fminf(global_encode_scale, Numeric_Traits<float>::maxNorm);
// If scale is infinity, return the max normalized value
const float max_norm_clamp = use_fast_math
? Numeric_Traits<bf16>::maxNorm
: Numeric_Traits<float>::maxNorm;
global_encode_scale = fminf(global_encode_scale, max_norm_clamp);
// If global amax is 0 or infinity, return 1
if (global_amax == 0.0f || global_encode_scale == 0.0f) {
return 1.0f;
......@@ -76,10 +80,11 @@ void quantize_nvfp4_1d(float (*OP)(const float),
const size_t rows,
const size_t cols,
const size_t scales_stride,
const float global_amax) {
const float global_amax,
const bool use_fast_math) {
// Compute a global encoding/decoding scaling factor for all S_dec_b
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax);
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math);
constexpr size_t block_size_X = 16;
const size_t blocks_X = divide_round_up(cols, block_size_X);
......@@ -114,14 +119,20 @@ void quantize_nvfp4_1d(float (*OP)(const float),
const float S_dec_b = block_amax / 6.0f;
// Scale & Store per-block decoding scaling factor
const float S_dec_b_fp8 = S_dec_b * S_enc;
const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b * S_enc);
const float S_dec_b_fp32 = static_cast<float>(S_dec_b_fp8);
// Compute "correct" per-block encoding scaling factor
const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8;
const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32;
const size_t scale_idx = i * scales_stride + block_X;
scales[scale_idx] = static_cast<fp8e4m3>(S_dec_b_fp8);
const float scale_reciprocal = S_enc_b_fp8;
scales[scale_idx] = S_dec_b_fp8;
float scale_reciprocal = S_enc_b_fp8;
if (use_fast_math) {
// Numerical truncation to match GPU implementation, if mixed precision FMA instruction is used
scale_reciprocal = static_cast<float>(static_cast<bf16>(scale_reciprocal));
}
for (size_t j = j_min; j < j_max; j += 2) {
const int idx_pair = (i * cols + j) / 2;
......@@ -136,7 +147,7 @@ void quantize_nvfp4_1d(float (*OP)(const float),
fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair);
output[idx_pair] = casted_to_e2m1_pair;
// const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair);
const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair);
}
}
}
......@@ -149,9 +160,10 @@ void compute_2d_mathematical_scales(float (*OP)(const float),
const size_t rows,
const size_t cols,
const float global_amax,
std::vector<std::vector<fp8e4m3>>& math_scales) {
std::vector<std::vector<fp8e4m3>>& math_scales,
const bool use_fast_math) {
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax);
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math);
constexpr size_t block_size_Y = 16;
constexpr size_t block_size_X = 16;
const size_t blocks_Y = divide_round_up(rows, block_size_Y);
......@@ -195,13 +207,14 @@ void quantize_nvfp4_2d(float (*OP)(const float),
const size_t rows,
const size_t cols,
const size_t scales_stride,
const float global_amax) {
const float global_amax,
const bool use_fast_math) {
// Step 1: Compute mathematical 8x8 scaling factors
std::vector<std::vector<fp8e4m3>> math_scales;
compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales);
compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math);
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax);
const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math);
constexpr size_t block_size_Y = 16;
constexpr size_t block_size_X = 16;
const size_t blocks_Y = divide_round_up(rows, block_size_Y);
......@@ -282,11 +295,12 @@ void quantize_nvfp4(float (*OP)(const float),
const size_t cols,
const size_t scales_stride,
const float global_amax,
const bool use_fast_math,
const bool use_2d_quantization = false) {
if (use_2d_quantization) {
quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax);
quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math);
} else {
quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax);
quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math);
}
}
......@@ -302,6 +316,7 @@ void compute_ref(float (*OP)(const float),
const size_t cols,
const size_t scales_stride,
const size_t scales_stride_t,
const bool use_fast_math,
const bool use_2d_quantization = false)
{
std::vector<InputType> input_t = create_transpose(input, rows, cols);
......@@ -309,7 +324,7 @@ void compute_ref(float (*OP)(const float),
if (use_2d_quantization) {
// Step 1: Compute mathematical 8×8 scaling factors
std::vector<std::vector<fp8e4m3>> math_scales;
compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales);
compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math);
constexpr size_t block_size_Y = 16;
constexpr size_t block_size_X = 16;
......@@ -336,12 +351,16 @@ void compute_ref(float (*OP)(const float),
// Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d
// (This part processes the actual FP4 data using the mathematical scaling factors)
quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax); // scales already filled
quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax); // scales_t already filled
quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax,
use_fast_math); // scales already filled
quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax,
use_fast_math); // scales_t already filled
} else {
quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_2d_quantization);
quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, use_2d_quantization);
quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax,
use_fast_math, use_2d_quantization);
quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax,
use_fast_math, use_2d_quantization);
}
}
......@@ -349,6 +368,8 @@ void compare_nvfp4_tensors(const std::string& name,
const fp4e2m1 *test_data, const fp4e2m1 *ref_data,
const int rows, const int cols,
double atol = 1e-5, double rtol = 1e-8) {
constexpr int max_mismatches_to_print = 3;
std::vector<std::string> mismatch_messages;
size_t total_mismatches = 0;
......@@ -362,29 +383,16 @@ void compare_nvfp4_tensors(const std::string& name,
const double t = (k == 0 ? test_data_pair.x : test_data_pair.y);
const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y);
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
/* For Float32 the floating point comparison is enough to error out */
bool assertion = false;
if (mismatch && !assertion) {
/* Check if it is just a failure of round to nearest choosing different
side of the real value */
const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
const double cast_mean_p = static_cast<double>(static_cast<fp4e2m1>(mean_p));
const double cast_mean_m = static_cast<double>(static_cast<fp4e2m1>(mean_m));
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
}
if (assertion) {
const bool mismatch = fabs(t - r) > (atol + fabs(r) * rtol);
if (mismatch) {
total_mismatches++;
// Optional: limit number of detailed messages to avoid overwhelming output
if (total_mismatches <= max_mismatches_to_print) {
std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " +
std::to_string(t) + " vs " + std::to_string(r) +
" (abs_diff: " + std::to_string(fabs(t - r)) +
", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")";
mismatch_messages.push_back(msg);
// Optional: limit number of detailed messages to avoid overwhelming output
if (mismatch_messages.size() <= 100) {
std::cout << "Error in tensor " << name << ": " << msg << std::endl;
}
}
......@@ -400,8 +408,9 @@ void compare_nvfp4_tensors(const std::string& name,
std::cout << "STATUS: FAILED for output" << std::endl;
std::cout << "Total mismatches found: " << total_mismatches << std::endl;
std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl;
if (mismatch_messages.size() > 100) {
std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl;
if (mismatch_messages.size() > max_mismatches_to_print) {
std::cout << "... and " << (mismatch_messages.size() - max_mismatches_to_print)
<< " more mismatches (showing first " << max_mismatches_to_print << ")" << std::endl;
}
std::cout << "============================" << std::endl;
......@@ -519,7 +528,8 @@ void compareResults_nvfp4(const Tensor &test,
template <typename InputType>
void performTest(float (*OP)(const float),
const std::vector<size_t>& shape) {
const std::vector<size_t>& shape,
const bool use_fast_math) {
using namespace test;
DType itype = TypeInfo<InputType>::dtype;
......@@ -580,15 +590,16 @@ void performTest(float (*OP)(const float),
cols,
scales_stride,
scales_stride_t,
use_fast_math,
use_2d_quantization);
QuantizationConfigWrapper quant_config;
// Initialize stochastic rounding
Tensor rng_state("rng_state", std::vector<size_t>{2}, DType::kInt64);
rng_state.rowwise_cpu_dptr<int64_t>()[0] = 123; // rng_seed
rng_state.rowwise_cpu_dptr<int64_t>()[1] = 321; // rng_sequence
rng_state.from_cpu();
QuantizationConfigWrapper quant_config;
quant_config.set_use_fast_math(use_fast_math);
quant_config.set_stochastic_rounding(false);
quant_config.set_rng_state(rng_state.data());
......@@ -619,8 +630,8 @@ void performTest(float (*OP)(const float),
}
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
const double atol = 0.05;
const double rtol = 0.1;
const double atol = 1.0E-6;
const double rtol = 1.0E-6;
// Set dump_data=true to enable dumping tensor data to files for analysis
compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false);
......@@ -671,7 +682,8 @@ std::vector<ActivationType> Activation_types = {
class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam
<std::tuple<ActivationType,
std::vector<size_t>,
transformer_engine::DType>> {};
transformer_engine::DType,
bool>> {};
TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
// Skip tests for pre-Blackwell architectures
......@@ -685,6 +697,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
const ActivationType Act_type = std::get<0>(GetParam());
const auto tensor_dims = std::get<1>(GetParam());
const DType input_type = std::get<2>(GetParam());
const bool use_fast_math = std::get<3>(GetParam());
// Skip tests if the input tensor is 1D
if (tensor_dims.size() < 2) {
......@@ -702,7 +715,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
}
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
performTest<InputType>(OP, tensor_dims);
performTest<InputType>(OP, tensor_dims, use_fast_math);
);
}
......@@ -724,7 +737,8 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Combine(
::testing::ValuesIn(Activation_types),
::testing::ValuesIn(tensor_dims),
::testing::Values(DType::kBFloat16)),
::testing::Values(DType::kBFloat16),
::testing::Values(false)),
[](const testing::TestParamInfo<FusedCastTransposeNVFP4TestSuite::ParamType>& info) {
std::string name = to_string(std::get<0>(info.param));
const auto& shape = std::get<1>(info.param);
......@@ -732,5 +746,8 @@ INSTANTIATE_TEST_SUITE_P(
name += "X" + std::to_string(s);
}
name += "X" + test::typeName(std::get<2>(info.param));
if (std::get<3>(info.param)) {
name += "X_FAST_SCALING";
}
return name;
});
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cublasLt.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
#include <numeric>
#include <optional>
#include <random>
#include <tuple>
#include <vector>
#include <transformer_engine/cast.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/recipe.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
using namespace transformer_engine;
using namespace test;
namespace {
enum class InputCase {
kFP8Current,
kBF16,
};
enum class ShapeCase {
kAllSame,
kSameFirst,
kSameLast,
kAllDifferent,
};
size_t grouped_setup_workspace_size(const size_t num_tensors) {
const size_t ptr_bytes = num_tensors * sizeof(void*);
const size_t int_bytes = num_tensors * sizeof(int);
// Layout: 6 pointer arrays (A, B, C, D, alpha, beta) + 6 int arrays (a_rows, a_cols, b_rows, b_cols, d_rows, d_cols)
size_t size = 6 * ptr_bytes + 6 * int_bytes;
const size_t alignment = 256;
size = ((size + alignment - 1) / alignment) * alignment;
return size;
}
Tensor make_fp8_operand(const std::string& name, const std::vector<size_t>& shape) {
Tensor input_fp32(name + "_fp32", shape, DType::kFloat32);
fillUniform(&input_fp32);
Tensor fp8(name, shape, TypeInfo<fp8e4m3>::dtype, true, true, NVTE_DELAYED_TENSOR_SCALING);
nvte_compute_amax(input_fp32.data(), fp8.data(), 0);
QuantizationConfigWrapper config;
nvte_compute_scale_from_amax(fp8.data(), config, 0);
nvte_quantize(input_fp32.data(), fp8.data(), 0);
return fp8;
}
Tensor make_bf16_operand(const std::string& name, const std::vector<size_t>& shape) {
Tensor t(name, shape, DType::kBFloat16);
const size_t numel = shape[0] * shape[1];
std::vector<__nv_bfloat16> ones(numel, __float2bfloat16(1.0f));
NVTE_CHECK_CUDA(cudaMemcpy(t.rowwise_dptr(), ones.data(),
numel * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice));
return t;
}
struct TestParams {
InputCase input_case;
bool transa;
bool transb;
ShapeCase shape_case;
bool use_null_c = false; // When true, pass nullptr for C (valid when beta=0)
};
// Returns a vector of (M, N, K) tuples for each GEMM in the group.
// M - number of rows in output D
// N - number of columns in output D
// K - reduction dimension shared between A and B
std::vector<std::tuple<size_t, size_t, size_t>> make_shapes(ShapeCase scase) {
switch (scase) {
case ShapeCase::kAllSame:
return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}};
case ShapeCase::kSameFirst:
// Same M (first dim), varying N and K
return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}};
case ShapeCase::kSameLast:
// Same N (last dim), varying M and K
return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}};
case ShapeCase::kAllDifferent:
default:
return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}};
}
}
void run_grouped_gemm_case(const TestParams& params) {
#if CUBLAS_VERSION < 130200
GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.2+, but compile-time cuBLAS version is "
<< CUBLAS_VERSION << ".";
#else
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer.";
}
const std::vector<std::tuple<size_t, size_t, size_t>> shapes = make_shapes(params.shape_case);
const size_t num_gemms = shapes.size();
std::vector<Tensor> A_tensors;
std::vector<Tensor> B_tensors;
std::vector<Tensor> D_multi;
A_tensors.reserve(num_gemms);
B_tensors.reserve(num_gemms);
D_multi.reserve(num_gemms);
for (size_t i = 0; i < num_gemms; ++i) {
const auto [M, N, K] = shapes[i];
const std::vector<size_t> a_shape = params.transa ? std::vector<size_t>{M, K}
: std::vector<size_t>{K, M};
const std::vector<size_t> b_shape = params.transb ? std::vector<size_t>{K, N}
: std::vector<size_t>{N, K};
switch (params.input_case) {
case InputCase::kFP8Current: {
A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape));
B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape));
break;
}
case InputCase::kBF16: {
A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape));
B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape));
break;
}
}
D_multi.emplace_back(Tensor("D_multi" + std::to_string(i),
std::vector<size_t>{M, N},
DType::kBFloat16));
}
std::vector<NVTETensor> A_ptrs(num_gemms);
std::vector<NVTETensor> B_ptrs(num_gemms);
std::vector<NVTETensor> D_ptrs(num_gemms);
std::vector<Tensor> workspaces(num_gemms);
std::vector<NVTETensor> workspace_ptrs(num_gemms, nullptr);
std::vector<Tensor*> A_views;
std::vector<Tensor*> B_views;
A_views.reserve(num_gemms);
B_views.reserve(num_gemms);
// Empty bias/gelu arrays for nvte_multi_tensor_gemm (no epilogues)
std::vector<NVTETensor> bias_ptrs(num_gemms, nullptr);
std::vector<NVTETensor> gelu_ptrs(num_gemms, nullptr);
const size_t cublas_ws_bytes = 32ull * 1024 * 1024;
for (size_t i = 0; i < num_gemms; ++i) {
A_ptrs[i] = A_tensors[i].data();
B_ptrs[i] = B_tensors[i].data();
D_ptrs[i] = D_multi[i].data();
workspaces[i] = Tensor("workspace" + std::to_string(i), std::vector<size_t>{cublas_ws_bytes}, DType::kByte);
workspace_ptrs[i] = workspaces[i].data();
A_views.push_back(&A_tensors[i]);
B_views.push_back(&B_tensors[i]);
}
nvte_multi_tensor_gemm(A_ptrs.data(),
B_ptrs.data(),
D_ptrs.data(),
bias_ptrs.data(),
gelu_ptrs.data(),
static_cast<int>(num_gemms),
params.transa,
params.transb,
false, // grad
workspace_ptrs.data(),
false, // accumulate
false, // use_split_accumulator
0, // sm_count
0);
GroupedBuffers grouped_A = build_grouped_tensor(A_views, A_tensors[0].scaling_mode());
GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode());
std::vector<Tensor> C_tensors;
std::vector<Tensor> D_group_tensors;
C_tensors.reserve(num_gemms);
D_group_tensors.reserve(num_gemms);
for (size_t i = 0; i < num_gemms; ++i) {
const auto [M, N, K] = shapes[i];
(void)K;
if (!params.use_null_c) {
C_tensors.emplace_back(Tensor("C" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(M), static_cast<size_t>(N)},
DType::kBFloat16));
}
D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(M), static_cast<size_t>(N)},
DType::kBFloat16));
NVTE_CHECK_CUDA(cudaMemset(D_group_tensors.back().rowwise_dptr(), 0, bytes(D_group_tensors.back().rowwise_shape(), D_group_tensors.back().dtype())));
}
std::vector<Tensor*> C_views, D_views;
for (size_t i = 0; i < num_gemms; ++i) {
if (!params.use_null_c) {
C_views.push_back(&C_tensors[i]);
}
D_views.push_back(&D_group_tensors[i]);
}
std::optional<GroupedBuffers> grouped_C;
if (!params.use_null_c) {
grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING);
}
GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING);
// Per-matrix alpha/beta (all 1.0 and 0.0 respectively)
Tensor alpha_tensor("alpha", std::vector<size_t>{num_gemms}, DType::kFloat32);
Tensor beta_tensor("beta", std::vector<size_t>{num_gemms}, DType::kFloat32);
std::vector<float> alpha_vals(num_gemms, 1.f);
std::vector<float> beta_vals(num_gemms, 0.f);
NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(),
num_gemms * sizeof(float), cudaMemcpyHostToDevice));
NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(),
num_gemms * sizeof(float), cudaMemcpyHostToDevice));
const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms);
Tensor setup_ws("setup_ws", std::vector<size_t>{setup_ws_bytes}, DType::kByte);
Tensor cublas_ws("cublas_ws", std::vector<size_t>{cublas_ws_bytes}, DType::kByte);
nvte_grouped_gemm(grouped_A.get_handle(),
params.transa,
grouped_B.get_handle(),
params.transb,
params.use_null_c ? nullptr : grouped_C->get_handle(),
grouped_D.get_handle(),
alpha_tensor.data(),
beta_tensor.data(),
setup_ws.data(),
cublas_ws.data(),
nullptr, // config (use defaults)
0);
for (size_t i = 0; i < num_gemms; ++i) {
Tensor grouped_split("grouped_D" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(std::get<0>(shapes[i])),
static_cast<size_t>(std::get<1>(shapes[i]))},
D_multi[i].dtype());
const size_t offset_bytes = static_cast<size_t>(grouped_D.offsets_host[i]) * grouped_D.elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(),
static_cast<char*>(grouped_D.get_data()) + offset_bytes,
grouped_D.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
grouped_split.to_cpu();
D_multi[i].to_cpu();
auto [atol, rtol] = getTolerances(D_multi[i].dtype());
compareResults("grouped_vs_multi",
grouped_split,
D_multi[i].rowwise_cpu_dptr<bf16>(),
true,
atol,
rtol);
}
#endif // CUBLAS_VERSION >= 130200
}
class GroupedGemmTest : public ::testing::TestWithParam<TestParams> {};
TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) {
run_grouped_gemm_case(GetParam());
}
std::string MakeGroupedGemmTestName(const testing::TestParamInfo<GroupedGemmTest::ParamType>& info) {
constexpr const char* kInputNames[] = {"FP8Current", "BF16"};
constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"};
const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") +
"tb" + (info.param.transb ? "T" : "N");
const std::string null_c = info.param.use_null_c ? "_NullC" : "";
return std::string(kInputNames[static_cast<int>(info.param.input_case)]) + "_" +
kShapeNames[static_cast<int>(info.param.shape_case)] + "_" + layout + null_c;
}
// TestParams: {input_case, transa, transb, shape_case, use_null_c}
const std::vector<TestParams> kTestParams = {
// Basic tests
{InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false},
{InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false},
{InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false},
{InputCase::kBF16, true, false, ShapeCase::kSameFirst, false},
{InputCase::kBF16, false, true, ShapeCase::kSameLast, false},
{InputCase::kBF16, false, false, ShapeCase::kAllSame, false},
{InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false},
// Test NULL C (valid when beta=0)
{InputCase::kBF16, false, false, ShapeCase::kAllSame, true},
};
INSTANTIATE_TEST_SUITE_P(OperatorTest,
GroupedGemmTest,
::testing::ValuesIn(kTestParams),
MakeGroupedGemmTestName);
} // namespace
......@@ -9,6 +9,7 @@
#include <algorithm>
#include <memory>
#include <numeric>
#include <random>
#include <iostream>
#include <cassert>
......@@ -1116,4 +1117,166 @@ std::array<size_t, 4> get_scale_tensor_dims(const size_t rows,
return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X};
}
GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
const NVTEScalingMode scaling_mode) {
NVTE_CHECK(!tensors.empty(), "No tensors provided for grouped tensor build.");
const NVTEShape shape = tensors[0]->rowwise_shape();
const DType dtype = tensors[0]->dtype();
const size_t num_tensors = tensors.size();
const size_t elem_size = typeToNumBits(dtype) / 8;
GroupedBuffers grouped;
grouped.elem_size = elem_size;
grouped.num_tensors = num_tensors;
grouped.dtype = dtype;
grouped.scaling_mode = scaling_mode;
grouped.tensor_bytes.resize(num_tensors);
grouped.offsets_host.resize(num_tensors, 0);
std::vector<int64_t> first_dims(num_tensors);
std::vector<int64_t> last_dims(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
const auto s = tensors[i]->rowwise_shape();
NVTE_CHECK(s.ndim == 2, "Grouped tensor build expects 2D tensors.");
first_dims[i] = static_cast<int64_t>(s.data[0]);
last_dims[i] = static_cast<int64_t>(s.data[1]);
grouped.tensor_bytes[i] = bytes(s, dtype);
}
const bool same_first = std::all_of(first_dims.begin(), first_dims.end(),
[&](int64_t v) { return v == first_dims[0]; });
const bool same_last = std::all_of(last_dims.begin(), last_dims.end(),
[&](int64_t v) { return v == last_dims[0]; });
std::vector<int64_t> offsets(num_tensors, 0);
auto random_padding = [&]() -> int64_t {
// Random padding ensuring 16-byte alignment regardless of element size
// cuBLAS requires aligned pointers for vectorized loads
static std::mt19937 gen(12345);
std::uniform_int_distribution<int64_t> dist(0, 3);
// Calculate elements needed for 16-byte alignment in bytes, rounded up
const size_t align_elements =
std::max<size_t>(1, (16 + elem_size - 1) / elem_size); // 16 bytes / element_size
return dist(gen) * static_cast<int64_t>(align_elements);
};
auto numel = [&](size_t idx) -> int64_t {
return first_dims[idx] * last_dims[idx];
};
const bool need_offsets = !same_first || !same_last;
if (need_offsets) {
offsets[0] = 0;
for (size_t i = 1; i < num_tensors; ++i) {
offsets[i] = offsets[i - 1] + numel(i - 1) + random_padding();
}
} else {
for (size_t i = 0; i < num_tensors; ++i) {
offsets[i] = static_cast<int64_t>(i) * numel(0);
}
}
grouped.offsets_host = offsets;
int64_t logical_first = 0;
int64_t logical_last = 0;
if (same_first && same_last) {
logical_first = first_dims[0] * static_cast<int64_t>(num_tensors);
logical_last = last_dims[0];
} else if (same_first && !same_last) {
logical_first = first_dims[0];
logical_last = std::accumulate(last_dims.begin(), last_dims.end(), int64_t{0});
} else if (!same_first && same_last) {
logical_first = std::accumulate(first_dims.begin(), first_dims.end(), int64_t{0});
logical_last = last_dims[0];
} else {
logical_first = 1;
logical_last = 0;
for (size_t i = 0; i < num_tensors; ++i) {
logical_last += first_dims[i] * last_dims[i];
}
}
size_t logical_data[2] = {static_cast<size_t>(logical_first),
static_cast<size_t>(logical_last)};
grouped.logical_shape = nvte_make_shape(logical_data, 2);
grouped.handle.reset(nvte_create_grouped_tensor(scaling_mode, num_tensors, grouped.logical_shape));
const int64_t last_idx = static_cast<int64_t>(num_tensors - 1);
const int64_t total_elems = need_offsets
? (offsets[last_idx] + numel(last_idx))
: (logical_first * logical_last);
const size_t total_bytes = static_cast<size_t>(total_elems) * elem_size;
grouped.data = cuda_alloc(total_bytes);
for (size_t i = 0; i < num_tensors; ++i) {
const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(static_cast<char*>(grouped.data.get()) + offset_bytes,
tensors[i]->rowwise_dptr(),
grouped.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
}
NVTEBasicTensor data_tensor{grouped.data.get(), static_cast<NVTEDType>(dtype), grouped.logical_shape};
NVTEGroupedTensor h = grouped.handle.get();
nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseData, &data_tensor);
const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype);
if (include_columnwise) {
grouped.columnwise_data = cuda_alloc(total_bytes);
for (size_t i = 0; i < num_tensors; ++i) {
const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(static_cast<char*>(grouped.columnwise_data.get()) + offset_bytes,
tensors[i]->columnwise_dptr(),
grouped.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
}
NVTEBasicTensor col_tensor{grouped.columnwise_data.get(),
static_cast<NVTEDType>(dtype),
grouped.logical_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseData, &col_tensor);
}
if (!same_first) {
grouped.first_dims_dev = cuda_alloc<int64_t>(num_tensors * sizeof(int64_t));
NVTE_CHECK_CUDA(cudaMemcpy(grouped.first_dims_dev.get(), first_dims.data(),
num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice));
NVTEShape fd_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor fd_tensor{grouped.first_dims_dev.get(), kNVTEInt64, fd_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedFirstDims, &fd_tensor);
}
if (!same_last) {
grouped.last_dims_dev = cuda_alloc<int64_t>(num_tensors * sizeof(int64_t));
NVTE_CHECK_CUDA(cudaMemcpy(grouped.last_dims_dev.get(), last_dims.data(),
num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice));
NVTEShape ld_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor ld_tensor{grouped.last_dims_dev.get(), kNVTEInt64, ld_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedLastDims, &ld_tensor);
}
if (!same_first || !same_last) {
grouped.offsets_dev = cuda_alloc<int64_t>(num_tensors * sizeof(int64_t));
NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev.get(), offsets.data(),
num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice));
NVTEShape off_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor off_tensor{grouped.offsets_dev.get(), kNVTEInt64, off_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedTensorOffsets, &off_tensor);
}
if (isFp8Type(dtype)) {
std::vector<float> scale_inv_cpu(num_tensors, 1.f);
for (size_t i = 0; i < num_tensors; ++i) {
tensors[i]->to_cpu();
scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr<float>()[0];
}
grouped.scale_inv = cuda_alloc(sizeof(float) * num_tensors);
NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv.get(), scale_inv_cpu.data(),
sizeof(float) * num_tensors, cudaMemcpyHostToDevice));
NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor scale_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseScaleInv, &scale_tensor);
nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor);
}
return grouped;
}
} // namespace test
......@@ -446,10 +446,14 @@ inline fp8e8m0 float_to_e8m0(float val) {
}
inline float exp2f_rcp(fp8e8m0 biased_exp) {
if (biased_exp == 0) {
return 1.0f;
int32_t int_val = 0;
if (biased_exp == 255) {
int_val = 0x7fffffff;
} else if (biased_exp == 254) {
int_val = 0x00400000;
} else {
int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127)
}
int32_t int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127)
float fp32_val = *reinterpret_cast<float*>(&int_val);
return fp32_val;
}
......@@ -525,6 +529,60 @@ int32_t getDeviceComputeCapability();
constexpr int32_t hopperComputeCapability = 90;
constexpr int32_t blackwellComputeCapability = 100;
// Custom deleters for RAII
struct CudaDeleter {
void operator()(void* p) const { if (p) cudaFree(p); }
};
struct GroupedTensorDeleter {
void operator()(NVTEGroupedTensor h) const { if (h) nvte_destroy_grouped_tensor(h); }
};
template <typename T = void>
using CudaPtr = std::unique_ptr<T, CudaDeleter>;
using GroupedTensorHandle = std::unique_ptr<std::remove_pointer_t<NVTEGroupedTensor>, GroupedTensorDeleter>;
// Helper to allocate CUDA memory into a CudaPtr
template <typename T = void>
CudaPtr<T> cuda_alloc(size_t bytes) {
void* ptr = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&ptr, bytes));
return CudaPtr<T>(static_cast<T*>(ptr));
}
// Helper owning GPU buffers that back NVTEGroupedTensor.
// NVTEGroupedTensor does not own memory; data/offsets/scales
// must be allocated and freed by the test.
struct GroupedBuffers {
GroupedTensorHandle handle;
CudaPtr<> data;
CudaPtr<> scale_inv;
CudaPtr<int64_t> first_dims_dev;
CudaPtr<int64_t> last_dims_dev;
CudaPtr<int64_t> offsets_dev;
CudaPtr<> columnwise_data;
NVTEShape logical_shape{};
std::vector<int64_t> offsets_host;
std::vector<size_t> tensor_bytes;
size_t num_tensors{0};
size_t elem_size{0};
DType dtype{DType::kFloat32};
NVTEScalingMode scaling_mode{NVTE_DELAYED_TENSOR_SCALING};
GroupedBuffers() = default;
GroupedBuffers(const GroupedBuffers&) = delete;
GroupedBuffers& operator=(const GroupedBuffers&) = delete;
GroupedBuffers(GroupedBuffers&&) = default;
GroupedBuffers& operator=(GroupedBuffers&&) = default;
~GroupedBuffers() = default;
// Convenience accessors for raw pointers
NVTEGroupedTensor get_handle() const { return handle.get(); }
void* get_data() const { return data.get(); }
};
GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
const NVTEScalingMode scaling_mode);
} // namespace test
#if FP4_TYPE_SUPPORTED
......
......@@ -1921,3 +1921,37 @@ class TestGroupedDense:
assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype)
assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype)
assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
class TestDebugInspectFFI:
@pytest_parametrize_wrapper("shape", [(256, 128)])
@pytest_parametrize_wrapper(
"dtype",
[
jnp.float32,
jnp.bfloat16,
jnp.float16,
# Note: fp4 currently doesn't work
# jnp.float4_e2m1fn
]
+ ([jnp.float8_e4m3fn, jnp.float8_e5m2] if is_fp8_supported else []),
)
def test_debug_inspect_ffi(self, shape, dtype):
from transformer_engine.jax.debug.experimental import inspect_array, load_array_dump
def f(x):
x = x + 1
x = inspect_array(x, "my_array")
x = x + 1
return x
key = jax.random.PRNGKey(0)
x = jax.random.uniform(key, shape, jnp.float32)
x = x.astype(dtype)
_ = jax.jit(f)(x)
expected = x + 1
actual = load_array_dump("my_tensor_gpu0.bin", shape, dtype)
assert_allclose(actual, expected, dtype=dtype)
......@@ -2,6 +2,7 @@
#
# See LICENSE for license information.
"""Tests for fused attention"""
import os
from enum import Enum, auto
from dataclasses import dataclass, field
from functools import partial
......@@ -49,6 +50,9 @@ from transformer_engine_jax import (
from distributed_test_base import assert_equal_collectives
from utils import assert_allclose, print_debug_tensor_stats
# Get determinism
_deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
@pytest.fixture(autouse=True, scope="module")
def init():
......@@ -413,14 +417,24 @@ class FusedAttnRunner:
pytest.skip(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
# TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support
if (
get_device_compute_capability(0) >= 100
and self.dropout_prob == 0.1
and self.attn_bias_type is not AttnBiasType.NO_BIAS
if get_device_compute_capability(0) >= 100 and self.is_training:
if FusedAttnHelper.is_non_deterministic_allowed() and (
(self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS)
or get_cudnn_version() < 90700
):
pytest.skip(
"For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with"
" dropout"
)
if not FusedAttnHelper.is_non_deterministic_allowed() and (
self.dropout_prob != 0.0
or self.attn_bias_type != AttnBiasType.NO_BIAS
or get_cudnn_version() < 91801
):
pytest.skip(
"For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
"For sm100+, deterministic bprop (cuDNN 9.18.1+) does not support bias or"
" dropout"
)
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
......@@ -1269,6 +1283,7 @@ class FusedAttnRunner:
pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"),
],
)
@pytest.mark.skipif(_deterministic, reason="Test non-determinism only")
class TestFusedAttn:
"""
Fused attention tester
......@@ -1392,3 +1407,182 @@ class TestFusedAttn:
seq_desc_format,
)
runner.test_backward()
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
pytest.param(
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT"
),
],
)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
],
)
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout",
[
# large data size + fp16 + cross attn + gqa + diff hidden v dim + qkv separate
pytest.param(
2,
1024,
2048,
12,
6,
128,
64,
jnp.bfloat16,
QKVLayout.BSHD_BSHD_BSHD,
id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-SEPARATE",
),
pytest.param(
2,
1024,
2048,
12,
6,
128,
64,
jnp.bfloat16,
QKVLayout.THD_THD_THD,
id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-RAGGED_SEPARATE",
),
],
)
@pytest.mark.parametrize(
"dropout_prob",
[
pytest.param(0.0, id="DROP_0.0"),
],
)
@pytest.mark.parametrize(
"swa",
[
pytest.param(False, id="NO_SWA"),
],
)
@pytest.mark.parametrize(
"seq_desc_format",
[
pytest.param(SeqDescFormat.Seqlens, id="Seqlens"),
],
)
@pytest.mark.skipif(not _deterministic, reason="Test determinism only")
class TestFusedAttnWithDeterminism:
"""
Fused attention tester with determinism
"""
@staticmethod
@pytest.mark.parametrize(
"is_training",
[
pytest.param(True, id="TRAINING"),
],
)
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
def _test_forward(
b,
s_q,
s_kv,
h_q,
h_kv,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
):
"""
Test forward with parameterized configs
This test is not intended to run automatically during CI as it is time-consuming
It is kept for development and debugging
"""
TestFusedAttn._test_forward(
b,
s_q,
s_kv,
h_q,
h_kv,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
)
@staticmethod
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
def test_backward(
b,
s_q,
s_kv,
h_q,
h_kv,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
):
"""
Test backward with parameterized configs
"""
TestFusedAttn.test_backward(
b,
s_q,
s_kv,
h_q,
h_kv,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment