Commit 87e3e56e authored by yuguo's avatar yuguo
Browse files

Merge commit '734bcedd' of...

Merge commit '734bcedd' of https://github.com/NVIDIA/TransformerEngine
parents 2f11bd2e 734bcedd
...@@ -8,4 +8,4 @@ set -xe ...@@ -8,4 +8,4 @@ set -xe
: ${XML_LOG_DIR:=/logs} : ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
...@@ -21,14 +21,20 @@ FAILED_CASES="" ...@@ -21,14 +21,20 @@ FAILED_CASES=""
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
# It is not installed as a requirement,
# because it is not available on PyPI.
pip uninstall -y nvdlfw-inspect
pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --log-cli-level=INFO --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --log-cli-level=INFO --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -xe
: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
...@@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" ...@@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
# Test without custom calls # Test without custom calls
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS_RE="" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" NVTE_JAX_CUSTOM_CALLS="false" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
if [ $RET -ne 0 ]; then if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES" echo "Error: some sub-tests failed: $FAILED_CASES"
......
...@@ -41,6 +41,6 @@ do ...@@ -41,6 +41,6 @@ do
fi fi
# Run tests # Run tests
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py
done done
...@@ -28,6 +28,7 @@ list(APPEND test_cuda_sources ...@@ -28,6 +28,7 @@ list(APPEND test_cuda_sources
test_multi_unpadding.cu test_multi_unpadding.cu
test_causal_softmax.cu test_causal_softmax.cu
test_swizzle.cu test_swizzle.cu
test_swap_first_dims.cu
../test_common.cu) ../test_common.cu)
if(USE_ROCM) if(USE_ROCM)
list(APPEND test_cuda_sources list(APPEND test_cuda_sources
......
...@@ -36,25 +36,54 @@ enum ActivationType { ...@@ -36,25 +36,54 @@ enum ActivationType {
SReLU SReLU
}; };
template <typename InputType, typename OutputType, float (*OP)(const float)> template <typename InputType, typename OutputType>
void scale_block(const ProcessingMethod processing_method, void compute_ref(const ProcessingMethod processing_method,
float (*OP)(const float),
const bool rowwise,
const bool colwise,
const InputType* input, const InputType* input,
const InputType* grad, const InputType* grad,
OutputType* output_c, OutputType* output_rowwise,
float* dbias, OutputType* output_colwise,
fp8e8m0* output_scales, fp8e8m0* output_scales_rowwise,
const size_t scale_idx, fp8e8m0* output_scales_colwise,
const size_t i_min, InputType* output_dbias,
const size_t i_max, const size_t rows,
const size_t j_min, const size_t cols,
const size_t j_max, const size_t scales_stride_rowwise,
const size_t cols) { const size_t scales_stride_colwise)
float amax = 0.0f; {
const size_t tile_size_Y = 32;
// Find the absolute maximum value in the block const size_t tile_size_X = 32;
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
std::vector<float> output_dbias_fp32(cols, 0);
#pragma omp parallel proc_bind(spread)
{
// Buffers to cache intermediate computations
std::vector<float> cache_buffer(tile_size_Y * tile_size_X);
std::vector<float> thread_dbias(cols, 0);
#pragma omp for schedule(static)
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
const size_t tile_Y = t / tiles_num_X;
const size_t tile_X = t % tiles_num_X;
const size_t tile_offset_Y = tile_Y * tile_size_Y;
const size_t tile_offset_X = tile_X * tile_size_X;
const size_t i_min = tile_offset_Y;
const size_t i_max = std::min(i_min + tile_size_Y, rows);
const size_t j_min = tile_offset_X;
const size_t j_max = std::min(j_min + tile_size_X, cols);
// Cache computations
for (size_t i = i_min; i < i_max; ++i) { for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) { for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j; const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
float elt = static_cast<float>(input[idx]); float elt = static_cast<float>(input[idx]);
if (processing_method == ProcessingMethod::CAST_DBIAS) { if (processing_method == ProcessingMethod::CAST_DBIAS) {
// grad is the input // grad is the input
...@@ -68,89 +97,58 @@ void scale_block(const ProcessingMethod processing_method, ...@@ -68,89 +97,58 @@ void scale_block(const ProcessingMethod processing_method,
processing_method == ProcessingMethod::CAST_DBIAS_DACT) { processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
elt *= static_cast<float>(grad[idx]); elt *= static_cast<float>(grad[idx]);
} }
dbias[j] += elt; thread_dbias[j] += elt;
if (std::isinf(elt) || std::isnan(elt)) {
// Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32
elt = static_cast<float>(static_cast<InputType>(elt));
cache_buffer[cache_idx] = elt;
if (isinf(elt) || isnan(elt)) {
continue; continue;
} }
amax = std::max(amax, std::abs(elt));
} }
} }
const fp8e8m0 biased_exponent = float_to_e8m0(amax * Quantized_Limits<OutputType>::max_reciprocal()); if (rowwise) {
for (size_t i = i_min; i < i_max; ++i) {
float block_amax = 0.0f;
for (size_t j = j_min; j < j_max; ++j) {
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
}
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const size_t scale_idx = i * scales_stride_rowwise + tile_X;
output_scales_rowwise[scale_idx] = biased_exponent;
const float scale_reciprocal = exp2f_rcp(biased_exponent); const float scale_reciprocal = exp2f_rcp(biased_exponent);
output_scales[scale_idx] = biased_exponent;
// Quantize elements in the block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) { for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j; const size_t idx = i * cols + j;
float elt = static_cast<float>(input[idx]); const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
if (processing_method == ProcessingMethod::CAST_DBIAS) { output_rowwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
// grad is the input
elt = static_cast<float>(grad[idx]);
} }
if (processing_method != ProcessingMethod::CAST_ONLY
&& processing_method != ProcessingMethod::CAST_DBIAS) {
elt = OP(elt);
} }
if (processing_method == ProcessingMethod::CAST_DACT ||
processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
elt *= static_cast<float>(grad[idx]);
} }
output_c[idx] = static_cast<OutputType>(elt * scale_reciprocal); if (colwise) {
} for (size_t j = j_min; j < j_max; ++j) {
} float block_amax = 0.0f;
}
template <typename InputType, typename OutputType, float (*OP)(const float)>
void compute_ref_x1(const ProcessingMethod processing_method,
const InputType* input,
const InputType* grad,
OutputType* output_c,
fp8e8m0* output_scales,
InputType* output_dbias,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride)
{
const size_t tile_size_Y = std::max(32lu, block_size_Y);
const size_t tile_size_X = std::max(64lu, block_size_X);
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y;
const size_t blocks_per_tile_X = tile_size_X / block_size_X;
std::vector<float> output_dbias_fp32(cols, 0);
#pragma omp parallel proc_bind(spread)
{
std::vector<float> thread_dbias(cols, 0);
#pragma omp for schedule(static)
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
const size_t tile_Y = t / tiles_num_X;
const size_t tile_X = t % tiles_num_X;
const size_t tile_offset_Y = tile_Y * tile_size_Y;
const size_t tile_offset_X = tile_X * tile_size_X;
for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) { for (size_t i = i_min; i < i_max; ++i) {
const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii; const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t block_offset_Y = ii * block_size_Y; block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
const size_t i_min = tile_offset_Y + block_offset_Y; }
if (i_min >= rows) continue;
const size_t i_max = std::min(i_min + block_size_Y, rows);
for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) { const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const size_t block_idx_X = tile_X * blocks_per_tile_X + jj; const size_t scale_idx = tile_Y * scales_stride_colwise + j;
const size_t block_offset_X = jj * block_size_X; output_scales_colwise[scale_idx] = biased_exponent;
const size_t j_min = tile_offset_X + block_offset_X; const float scale_reciprocal = exp2f_rcp(biased_exponent);
if (j_min >= cols) continue;
const size_t j_max = std::min(j_min + block_size_X, cols);
const size_t scale_idx = block_idx_Y * scales_stride + block_idx_X; for (size_t i = i_min; i < i_max; ++i) {
scale_block<InputType, OutputType, OP>( const size_t idx = i * cols + j;
processing_method, input, grad, output_c, thread_dbias.data(), const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
output_scales, scale_idx, i_min, i_max, j_min, j_max, cols); output_colwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
}
} }
} }
} }
...@@ -166,29 +164,6 @@ void compute_ref_x1(const ProcessingMethod processing_method, ...@@ -166,29 +164,6 @@ void compute_ref_x1(const ProcessingMethod processing_method,
} }
} }
template <typename InputType, typename OutputType, float (*OP)(const float)>
void compute_ref_x2(const ProcessingMethod processing_method,
const InputType* input,
const InputType* grad,
OutputType* output_rowwise,
OutputType* output_colwise,
fp8e8m0* scales_rowwise,
fp8e8m0* scales_colwise,
InputType* output_dbias,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise) {
compute_ref_x1<InputType, OutputType, OP>(
processing_method, input, grad, output_rowwise, scales_rowwise, output_dbias,
rows, cols, 1, block_size_X, scales_stride_rowwise);
compute_ref_x1<InputType, OutputType, OP>(
processing_method, input, grad, output_colwise, scales_colwise, output_dbias,
rows, cols, block_size_Y, 1, scales_stride_colwise);
}
/** /**
* Scaling along single dimension (either rows or columns) * Scaling along single dimension (either rows or columns)
* Produces one set of output data and the corresponding data of the fused operation (dbias): * Produces one set of output data and the corresponding data of the fused operation (dbias):
...@@ -197,8 +172,9 @@ void compute_ref_x2(const ProcessingMethod processing_method, ...@@ -197,8 +172,9 @@ void compute_ref_x2(const ProcessingMethod processing_method,
* 2) Scaled columns + column-wise scaling factors * 2) Scaled columns + column-wise scaling factors
*/ */
template <typename InputType, typename OutputType, float (*OP)(const float)> template <typename InputType, typename OutputType>
void performTest_x1(const ProcessingMethod processing_method, void performTest_x1(const ProcessingMethod processing_method,
float (*OP)(const float),
const std::vector<size_t>& shape, const std::vector<size_t>& shape,
const bool rowwise, const bool rowwise,
const bool colwise, const bool colwise,
...@@ -261,7 +237,13 @@ void performTest_x1(const ProcessingMethod processing_method, ...@@ -261,7 +237,13 @@ void performTest_x1(const ProcessingMethod processing_method,
break; break;
} }
case ProcessingMethod::CAST_DBIAS_DACT: { case ProcessingMethod::CAST_DBIAS_DACT: {
nvte_quantize_dbias_dgelu(grad.data(), auto nvte_quantize_dbias_dact = &nvte_quantize_dbias_dgelu;
if (OP == &dsilu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsilu; }
else if (OP == &drelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_drelu; }
else if (OP == &dqgelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dqgelu; }
else if (OP == &dsrelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsrelu; }
nvte_quantize_dbias_dact(grad.data(),
input.data(), input.data(),
output_c.data(), output_c.data(),
output_dbias.data(), output_dbias.data(),
...@@ -269,7 +251,7 @@ void performTest_x1(const ProcessingMethod processing_method, ...@@ -269,7 +251,7 @@ void performTest_x1(const ProcessingMethod processing_method,
0); 0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias_dgelu(grad.data(), nvte_quantize_dbias_dact(grad.data(),
input.data(), input.data(),
output_c.data(), output_c.data(),
output_dbias.data(), output_dbias.data(),
...@@ -278,11 +260,23 @@ void performTest_x1(const ProcessingMethod processing_method, ...@@ -278,11 +260,23 @@ void performTest_x1(const ProcessingMethod processing_method,
break; break;
} }
case ProcessingMethod::CAST_DACT: { case ProcessingMethod::CAST_DACT: {
nvte_dgelu(grad.data(), input.data(), output_c.data(), 0); auto nvte_dact = &nvte_dgelu;
if (OP == &dsilu) { nvte_dact = &nvte_dsilu; }
else if (OP == &drelu) { nvte_dact = &nvte_drelu; }
else if (OP == &dqgelu) { nvte_dact = &nvte_dqgelu; }
else if (OP == &dsrelu) { nvte_dact = &nvte_dsrelu; }
nvte_dact(grad.data(), input.data(), output_c.data(), 0);
break; break;
} }
case ProcessingMethod::CAST_ACT: { case ProcessingMethod::CAST_ACT: {
nvte_gelu(input.data(), output_c.data(), 0); auto nvte_act = &nvte_gelu;
if (OP == &silu) { nvte_act = &nvte_silu; }
else if (OP == &relu) { nvte_act = &nvte_relu; }
else if (OP == &qgelu) { nvte_act = &nvte_qgelu; }
else if (OP == &srelu) { nvte_act = &nvte_srelu; }
nvte_act(input.data(), output_c.data(), 0);
break; break;
} }
} }
...@@ -291,29 +285,45 @@ void performTest_x1(const ProcessingMethod processing_method, ...@@ -291,29 +285,45 @@ void performTest_x1(const ProcessingMethod processing_method,
auto err = cudaGetLastError(); auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
compute_ref_x1<InputType, OutputType, OP>(processing_method, compute_ref<InputType, OutputType>(processing_method,
OP,
rowwise,
colwise,
input.rowwise_cpu_dptr<InputType>(), input.rowwise_cpu_dptr<InputType>(),
grad.rowwise_cpu_dptr<InputType>(), grad.rowwise_cpu_dptr<InputType>(),
ref_output_c.get(), ref_output_c.get(),
ref_output_c.get(),
ref_output_scales.get(),
ref_output_scales.get(), ref_output_scales.get(),
ref_output_dbias.get(), ref_output_dbias.get(),
rows, rows,
cols, cols,
block_size_rows, scales_stride,
block_size_cols,
scales_stride); scales_stride);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol);
const uint8_t * const gpu_scales_ptr = rowwise const uint8_t * const gpu_scales_ptr = rowwise
? output_c.rowwise_cpu_scale_inv_ptr<fp8e8m0>() ? output_c.rowwise_cpu_scale_inv_ptr<fp8e8m0>()
: output_c.columnwise_cpu_scale_inv_ptr<fp8e8m0>(); : output_c.columnwise_cpu_scale_inv_ptr<fp8e8m0>();
const size_t scale_diff_abs_tolerance = 0;
const double abs_tolerable_mismatches_limit = 0.0;
const double rel_tolerable_mismatches_limit = 0.0;
size_t mismatches_scales = 0;
compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride); unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
const size_t mismatches_elts = 32 * mismatches_scales;
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol, true, mismatches_elts);
if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { if (processing_method == ProcessingMethod::CAST_DBIAS
|| processing_method == ProcessingMethod::CAST_DBIAS_DACT)
{
auto [atol_dbias, rtol_dbias] = getTolerances(itype); auto [atol_dbias, rtol_dbias] = getTolerances(itype);
if (itype == DType::kFloat32) { if (itype == DType::kFloat32) {
atol_dbias = 1e-4; atol_dbias = 1e-4;
...@@ -332,8 +342,9 @@ void performTest_x1(const ProcessingMethod processing_method, ...@@ -332,8 +342,9 @@ void performTest_x1(const ProcessingMethod processing_method,
* AND * AND
* 2) Scaled columns + column-wise scaling factors * 2) Scaled columns + column-wise scaling factors
*/ */
template <typename InputType, typename OutputType, float (*OP)(const float)> template <typename InputType, typename OutputType>
void performTest_x2(const ProcessingMethod processing_method, void performTest_x2(const ProcessingMethod processing_method,
float (*OP)(const float),
const std::vector<size_t>& shape, const std::vector<size_t>& shape,
const size_t block_size_rows, const size_t block_size_rows,
const size_t block_size_cols, const size_t block_size_cols,
...@@ -401,7 +412,13 @@ void performTest_x2(const ProcessingMethod processing_method, ...@@ -401,7 +412,13 @@ void performTest_x2(const ProcessingMethod processing_method,
break; break;
} }
case ProcessingMethod::CAST_DBIAS_DACT: { case ProcessingMethod::CAST_DBIAS_DACT: {
nvte_quantize_dbias_dgelu(grad.data(), auto nvte_quantize_dbias_dact = &nvte_quantize_dbias_dgelu;
if (OP == &dsilu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsilu; }
else if (OP == &drelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_drelu; }
else if (OP == &dqgelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dqgelu; }
else if (OP == &dsrelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsrelu; }
nvte_quantize_dbias_dact(grad.data(),
input.data(), input.data(),
output.data(), output.data(),
output_dbias.data(), output_dbias.data(),
...@@ -409,7 +426,7 @@ void performTest_x2(const ProcessingMethod processing_method, ...@@ -409,7 +426,7 @@ void performTest_x2(const ProcessingMethod processing_method,
0); 0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_quantize_dbias_dgelu(grad.data(), nvte_quantize_dbias_dact(grad.data(),
input.data(), input.data(),
output.data(), output.data(),
output_dbias.data(), output_dbias.data(),
...@@ -418,11 +435,23 @@ void performTest_x2(const ProcessingMethod processing_method, ...@@ -418,11 +435,23 @@ void performTest_x2(const ProcessingMethod processing_method,
break; break;
} }
case ProcessingMethod::CAST_DACT: { case ProcessingMethod::CAST_DACT: {
nvte_dgelu(grad.data(), input.data(), output.data(), 0); auto nvte_dact = &nvte_dgelu;
if (OP == &dsilu) { nvte_dact = &nvte_dsilu; }
else if (OP == &drelu) { nvte_dact = &nvte_drelu; }
else if (OP == &dqgelu) { nvte_dact = &nvte_dqgelu; }
else if (OP == &dsrelu) { nvte_dact = &nvte_dsrelu; }
nvte_dact(grad.data(), input.data(), output.data(), 0);
break; break;
} }
case ProcessingMethod::CAST_ACT: { case ProcessingMethod::CAST_ACT: {
nvte_gelu(input.data(), output.data(), 0); auto nvte_act = &nvte_gelu;
if (OP == &silu) { nvte_act = &nvte_silu; }
else if (OP == &relu) { nvte_act = &nvte_relu; }
else if (OP == &qgelu) { nvte_act = &nvte_qgelu; }
else if (OP == &srelu) { nvte_act = &nvte_srelu; }
nvte_act(input.data(), output.data(), 0);
break; break;
} }
} }
...@@ -431,7 +460,10 @@ void performTest_x2(const ProcessingMethod processing_method, ...@@ -431,7 +460,10 @@ void performTest_x2(const ProcessingMethod processing_method,
auto err = cudaGetLastError(); auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
compute_ref_x2<InputType, OutputType, OP>(processing_method, compute_ref<InputType, OutputType>(processing_method,
OP,
true,
true,
input.rowwise_cpu_dptr<InputType>(), input.rowwise_cpu_dptr<InputType>(),
grad.rowwise_cpu_dptr<InputType>(), grad.rowwise_cpu_dptr<InputType>(),
ref_output_c_rowwise.get(), ref_output_c_rowwise.get(),
...@@ -441,22 +473,41 @@ void performTest_x2(const ProcessingMethod processing_method, ...@@ -441,22 +473,41 @@ void performTest_x2(const ProcessingMethod processing_method,
ref_output_dbias.get(), ref_output_dbias.get(),
rows, rows,
cols, cols,
block_size_rows,
block_size_cols,
scales_stride_rowwise, scales_stride_rowwise,
scales_stride_colwise); scales_stride_colwise);
auto [atol, rtol] = getTolerances(otype); const size_t scale_diff_abs_tolerance = 0;
compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol); const double abs_tolerable_mismatches_limit = 0.0;
compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol); const double rel_tolerable_mismatches_limit = 0.0;
size_t mismatches_scales_rowwise = 0;
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(), compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise); unpadded_blocks_X_rowwise, scales_stride_rowwise,
mismatches_scales_rowwise,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
size_t mismatches_scales_colwise = 0;
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(), compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise, ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise); unpadded_blocks_X_colwise, scales_stride_colwise,
mismatches_scales_colwise,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise;
const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise;
if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { auto [atol, rtol] = getTolerances(otype);
compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol, true, mismatches_elts_rowwise);
compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol, true, mismatches_elts_colwise);
if (processing_method == ProcessingMethod::CAST_DBIAS
|| processing_method == ProcessingMethod::CAST_DBIAS_DACT)
{
auto [atol_dbias, rtol_dbias] = getTolerances(itype); auto [atol_dbias, rtol_dbias] = getTolerances(itype);
if (itype == DType::kFloat32) { if (itype == DType::kFloat32) {
atol_dbias = 1e-4; atol_dbias = 1e-4;
...@@ -475,11 +526,10 @@ std::vector<std::vector<size_t>> matrix_sizes = { ...@@ -475,11 +526,10 @@ std::vector<std::vector<size_t>> matrix_sizes = {
{128, 128}, {128, 128},
{256, 256}, {256, 256},
{993, 512}, {993, 512},
{256, 65536}, {511, 6144},
{2048, 6144}, {8192, 128},
{16384, 128}, {2048, 160},
{32768, 160}, {577, 1632},
{4096, 1632},
{1024}, {1024},
{8, 32, 1024}, {8, 32, 1024},
{16, 8, 4, 512}, {16, 8, 4, 512},
...@@ -528,26 +578,6 @@ class FusedCastMXFP8TestSuite : public ::testing::TestWithParam ...@@ -528,26 +578,6 @@ class FusedCastMXFP8TestSuite : public ::testing::TestWithParam
transformer_engine::DType, transformer_engine::DType,
InputsFillCase>> {}; InputsFillCase>> {};
#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
switch (OP_FUNC_TYPE) { \
case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \
case ActivationType::GeLU: { constexpr auto OP = &dgelu; { __VA_ARGS__ } } break; \
case ActivationType::SiLU: { constexpr auto OP = &dsilu; { __VA_ARGS__ } } break; \
case ActivationType::ReLU: { constexpr auto OP = &drelu; { __VA_ARGS__ } } break; \
case ActivationType::QGeLU: { constexpr auto OP = &dqgelu; { __VA_ARGS__ } } break; \
case ActivationType::SReLU: { constexpr auto OP = &dsrelu; { __VA_ARGS__ } } break; \
}
#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
switch (OP_FUNC_TYPE) { \
case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \
case ActivationType::GeLU: { constexpr auto OP = &gelu; { __VA_ARGS__ } } break; \
case ActivationType::SiLU: { constexpr auto OP = &silu; { __VA_ARGS__ } } break; \
case ActivationType::ReLU: { constexpr auto OP = &relu; { __VA_ARGS__ } } break; \
case ActivationType::QGeLU: { constexpr auto OP = &qgelu; { __VA_ARGS__ } } break; \
case ActivationType::SReLU: { constexpr auto OP = &srelu; { __VA_ARGS__ } } break; \
}
TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) { TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) {
// Skip tests for pre-Blackwell architectures // Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) { if (getDeviceComputeCapability() < blackwellComputeCapability) {
...@@ -581,37 +611,50 @@ TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) { ...@@ -581,37 +611,50 @@ TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) {
const bool colwise = block_size.first != 1; const bool colwise = block_size.first != 1;
if (processing_method == ProcessingMethod::CAST_ACT) { if (processing_method == ProcessingMethod::CAST_ACT) {
// Forward activations // Forward activations
ACT_FUNC_SWITCH(Act_type, OP, auto OP = &identity;
switch (Act_type) {
case ActivationType::GeLU: OP = &gelu; break;
case ActivationType::SiLU: OP = &silu; break;
case ActivationType::ReLU: OP = &relu; break;
case ActivationType::QGeLU: OP = &qgelu; break;
case ActivationType::SReLU: OP = &srelu; break;
}
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
if (block_size.first == 1 || block_size.second == 1) { if (block_size.first == 1 || block_size.second == 1) {
performTest_x1<InputType, OutputType, OP>( performTest_x1<InputType, OutputType>(
processing_method, matrix_size, processing_method, OP, matrix_size,
rowwise, colwise, fill_case); rowwise, colwise, fill_case);
} else { } else {
performTest_x2<InputType, OutputType, OP>( performTest_x2<InputType, OutputType>(
processing_method, matrix_size, processing_method, OP, matrix_size,
block_size.first, block_size.second, fill_case); block_size.first, block_size.second, fill_case);
} }
); );
); );
);
} else { } else {
DACT_FUNC_SWITCH(Act_type, OP, auto OP = &identity;
switch (Act_type) {
case ActivationType::GeLU: OP = &dgelu; break;
case ActivationType::SiLU: OP = &dsilu; break;
case ActivationType::ReLU: OP = &drelu; break;
case ActivationType::QGeLU: OP = &dqgelu; break;
case ActivationType::SReLU: OP = &dsrelu; break;
}
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
if (block_size.first == 1 || block_size.second == 1) { if (block_size.first == 1 || block_size.second == 1) {
performTest_x1<InputType, OutputType, OP>( performTest_x1<InputType, OutputType>(
processing_method, matrix_size, processing_method, OP, matrix_size,
rowwise, colwise, fill_case); rowwise, colwise, fill_case);
} else { } else {
performTest_x2<InputType, OutputType, OP>( performTest_x2<InputType, OutputType>(
processing_method, matrix_size, processing_method, OP, matrix_size,
block_size.first, block_size.second, fill_case); block_size.first, block_size.second, fill_case);
} }
); );
); );
);
} }
} }
......
...@@ -18,134 +18,157 @@ using namespace test; ...@@ -18,134 +18,157 @@ using namespace test;
namespace { namespace {
template <bool IS_DGATED, typename IType, typename OType> template <typename IType, typename OType>
void scale_block(const IType* grad, void compute_ref(const IType* grad,
const IType* input, const IType* input,
OType* output, OType* output_rowwise,
fp8e8m0* output_scales, OType* output_colwise,
const size_t scale_idx, fp8e8m0* output_scales_rowwise,
const size_t scale_idx_gate, fp8e8m0* output_scales_colwise,
float& thread_amax, float& ref_amax,
const size_t i_min, const bool IS_DGATED,
const size_t i_max, const size_t rows,
const size_t j_min, const size_t cols,
const size_t j_max, const size_t scales_stride_rowwise,
const size_t cols) { const size_t scales_stride_colwise,
const bool is_rowwise,
float block_amax = 0.0f; const bool is_colwise) {
float block_amax_gate = 0.0f; constexpr size_t tile_size_Y = 32;
constexpr size_t tile_size_X = 32;
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
float amax = 0;
#pragma omp parallel reduction(max: amax) proc_bind(spread)
{
// Buffers to cache intermediate computations
std::vector<float> cache_buffer_act(tile_size_Y * tile_size_X);
std::vector<float> cache_buffer_gate(tile_size_Y * tile_size_X);
float thread_amax = 0.0f;
#pragma omp for schedule(static)
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
const size_t tile_Y = t / tiles_num_X;
const size_t tile_X = t % tiles_num_X;
const size_t tile_offset_Y = tile_Y * tile_size_Y;
const size_t tile_offset_X = tile_X * tile_size_X;
const size_t stride = cols * 2; const size_t stride = cols * 2;
// Find the absolute maximum value in the block const size_t i_min = tile_offset_Y;
const size_t i_max = std::min(rows, tile_offset_Y + tile_size_Y);
const size_t j_min = tile_offset_X;
const size_t j_max = std::min(cols, tile_offset_X + tile_size_X);
// Compute and cache activations for the entire tile
for (size_t i = i_min; i < i_max; ++i) { for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) { for (size_t j = j_min; j < j_max; ++j) {
float silu_elt = static_cast<float>(input[i * stride + j]); float silu_elt = static_cast<float>(input[i * stride + j]);
float gate_elt = static_cast<float>(input[i * stride + cols + j]); float gate_elt = static_cast<float>(input[i * stride + cols + j]);
float gated_amax_act = 0;
float gated_amax_gate = 0;
if constexpr (IS_DGATED) { const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
if (IS_DGATED) {
const float x = silu_elt;
const float s = sigmoid(x);
const float act_x = x * s;
const float dact_x = x * s * (1 - s) + s;
const float grad_elt = static_cast<float>(grad[i * cols + j]); const float grad_elt = static_cast<float>(grad[i * cols + j]);
const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; float after_dsilu = dact_x * grad_elt * gate_elt;
const float after_dgate = silu(silu_elt) * grad_elt; float after_dgate = act_x * grad_elt;
gated_amax_act = abs(after_dsilu);
gated_amax_gate = abs(after_dgate); // Numerical truncation: after downcast to IType (BF16/FP16), upcast it back to FP32
after_dsilu = static_cast<float>(static_cast<IType>(after_dsilu));
after_dgate = static_cast<float>(static_cast<IType>(after_dgate));
cache_buffer_act[cached_idx] = after_dsilu;
cache_buffer_gate[cached_idx] = after_dgate;
thread_amax = std::max(thread_amax, std::abs(after_dsilu));
thread_amax = std::max(thread_amax, std::abs(after_dgate));
} else { } else {
const float after_silu = silu(silu_elt) * gate_elt; float after_silu = silu(silu_elt) * gate_elt;
gated_amax_act = abs(after_silu);
} // Numerical truncation: after downcast to IType (BF16/FP16), upcast it back to FP32
after_silu = static_cast<float>(static_cast<IType>(after_silu));
if (gated_amax_act > block_amax) { block_amax = gated_amax_act; } cache_buffer_act[cached_idx] = after_silu;
if (gated_amax_gate > block_amax_gate) { block_amax_gate = gated_amax_gate; } thread_amax = std::max(thread_amax, std::abs(after_silu));
} }
} }
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax *
Quantized_Limits<OType>::max_reciprocal());
const float scale_reciprocal = exp2f_rcp(biased_exponent);
output_scales[scale_idx] = biased_exponent;
float scale_reciprocal_gate = 1;
if constexpr (IS_DGATED) {
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax_gate *
Quantized_Limits<OType>::max_reciprocal());
scale_reciprocal_gate = exp2f_rcp(biased_exponent);
output_scales[scale_idx_gate] = biased_exponent;
} }
if (is_rowwise) {
// Quantize elements in the block
for (size_t i = i_min; i < i_max; ++i) { for (size_t i = i_min; i < i_max; ++i) {
float block_amax_act = 0.0f;
float block_amax_gate = 0.0f;
for (size_t j = j_min; j < j_max; ++j) { for (size_t j = j_min; j < j_max; ++j) {
float silu_elt = static_cast<float>(input[i * stride + j]); const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
float gate_elt = static_cast<float>(input[i * stride + cols + j]); block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx]));
if (IS_DGATED) {
block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx]));
}
}
const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits<OType>::max_reciprocal());
const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act);
const size_t scale_idx_act = i * scales_stride_rowwise + tile_X;
output_scales_rowwise[scale_idx_act] = biased_exponent_act;
if constexpr (IS_DGATED) { float scale_reciprocal_gate;
const float grad_elt = static_cast<float>(grad[i * cols + j]); if (IS_DGATED) {
const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits<OType>::max_reciprocal());
const float after_dgate = silu(silu_elt) * grad_elt; scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate);
output[i * stride + j] = static_cast<OType>(after_dsilu * scale_reciprocal); const size_t scale_idx_gate = scale_idx_act + (cols + 32 - 1) / 32;
output[i * stride + cols + j] = static_cast<OType>(after_dgate * output_scales_rowwise[scale_idx_gate] = biased_exponent_gate;
scale_reciprocal_gate);
} else {
const float after_silu = silu(silu_elt) * gate_elt;
output[i * cols + j] = static_cast<OType>(after_silu * scale_reciprocal);
} }
for (size_t j = j_min; j < j_max; ++j) {
const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act;
if (IS_DGATED) {
const float after_gate = cache_buffer_gate[cached_idx] * scale_reciprocal_gate;
output_rowwise[i * stride + j] = static_cast<OType>(after_act);
output_rowwise[i * stride + cols + j] = static_cast<OType>(after_gate);
} else {
output_rowwise[i * cols + j] = static_cast<OType>(after_act);
}
}
} }
} }
thread_amax = std::max(thread_amax, block_amax);
thread_amax = std::max(thread_amax, block_amax_gate);
}
template <bool IS_DGATED, typename IType, typename OType> if (is_colwise) {
void compute_ref_x1(const IType* grad, for (size_t j = j_min; j < j_max; ++j) {
const IType* input, float block_amax_act = 0.0f;
OType* output, float block_amax_gate = 0.0f;
fp8e8m0* output_scales, for (size_t i = i_min; i < i_max; ++i) {
float& ref_amax, const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t rows, block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx]));
const size_t cols, if (IS_DGATED) {
const size_t block_size_Y, block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx]));
const size_t block_size_X, }
const size_t scales_stride) { }
const size_t tile_size_Y = std::max(32lu, block_size_Y); const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits<OType>::max_reciprocal());
const size_t tile_size_X = std::max(64lu, block_size_X); const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act);
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; const size_t scale_idx_act = tile_Y * scales_stride_colwise + j;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; output_scales_colwise[scale_idx_act] = biased_exponent_act;
const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y;
const size_t blocks_per_tile_X = tile_size_X / block_size_X;
float amax = 0; float scale_reciprocal_gate;
#pragma omp parallel reduction(max: amax) proc_bind(spread) if (IS_DGATED) {
{ const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits<OType>::max_reciprocal());
float thread_amax = 0; const size_t scale_idx_gate = scale_idx_act + cols;
#pragma omp for schedule(static) scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate);
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { output_scales_colwise[scale_idx_gate] = biased_exponent_gate;
const size_t tile_Y = t / tiles_num_X; }
const size_t tile_X = t % tiles_num_X; for (size_t i = i_min; i < i_max; ++i) {
const size_t tile_offset_Y = tile_Y * tile_size_Y; const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const size_t tile_offset_X = tile_X * tile_size_X; const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act;
for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) { if (IS_DGATED) {
const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii; const float after_gate = cache_buffer_gate[cached_idx] * scale_reciprocal_gate;
const size_t block_offset_Y = ii * block_size_Y; output_colwise[i * stride + j] = static_cast<OType>(after_act);
const size_t i_min = tile_offset_Y + block_offset_Y; output_colwise[i * stride + cols + j] = static_cast<OType>(after_gate);
if (i_min >= rows) continue; } else {
const size_t i_max = std::min(i_min + block_size_Y, rows); output_colwise[i * cols + j] = static_cast<OType>(after_act);
}
for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) { }
const size_t block_idx_X = tile_X * blocks_per_tile_X + jj;
const size_t block_offset_X = jj * block_size_X;
const size_t j_min = tile_offset_X + block_offset_X;
if (j_min >= cols) continue;
const size_t j_max = std::min(j_min + block_size_X, cols);
const size_t mx_scale_idx = block_idx_Y * scales_stride + block_idx_X;
const size_t mx_scale_idx_gate = block_idx_Y * scales_stride + block_idx_X +
cols / block_size_X;
scale_block<IS_DGATED, IType, OType>(
grad, input, output, output_scales, mx_scale_idx, mx_scale_idx_gate,
thread_amax, i_min, i_max, j_min, j_max, cols);
} }
} }
} }
...@@ -156,26 +179,6 @@ void compute_ref_x1(const IType* grad, ...@@ -156,26 +179,6 @@ void compute_ref_x1(const IType* grad,
ref_amax = amax; ref_amax = amax;
} }
template <bool IS_DGATED, typename IType, typename OType>
void compute_ref_x2(const IType* grad,
const IType* input,
OType* output_rowwise,
OType* output_colwise,
fp8e8m0* scales_rowwise,
fp8e8m0* scales_colwise,
float& ref_amax,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise) {
compute_ref_x1<IS_DGATED, IType, OType>(
grad, input, output_rowwise, scales_rowwise, ref_amax, rows, cols, 1, block_size_X, scales_stride_rowwise);
compute_ref_x1<IS_DGATED, IType, OType>(
grad, input, output_colwise, scales_colwise, ref_amax, rows, cols, block_size_Y, 1, scales_stride_colwise);
}
/** /**
* Scaling along single dimension (either rows or columns) * Scaling along single dimension (either rows or columns)
* Produces one set of output data and the corresponding data of the fused operation (dbias): * Produces one set of output data and the corresponding data of the fused operation (dbias):
...@@ -183,12 +186,13 @@ void compute_ref_x2(const IType* grad, ...@@ -183,12 +186,13 @@ void compute_ref_x2(const IType* grad,
* OR * OR
* 2) Scaled columns + column-wise scaling factors * 2) Scaled columns + column-wise scaling factors
*/ */
template <bool IS_DGATED, typename IType, typename OType> template <typename IType, typename OType>
void performTest_x1(const size_t rows, void performTest_x1(const size_t rows,
const size_t cols, const size_t cols,
const size_t block_size_rows, const size_t block_size_rows,
const size_t block_size_cols, const size_t block_size_cols,
InputsFillCase fill_case) { InputsFillCase fill_case,
const bool IS_DGATED) {
using namespace test; using namespace test;
using EncodingType = fp32; using EncodingType = fp32;
DType itype = TypeInfo<IType>::dtype; DType itype = TypeInfo<IType>::dtype;
...@@ -198,12 +202,6 @@ void performTest_x1(const size_t rows, ...@@ -198,12 +202,6 @@ void performTest_x1(const size_t rows,
const bool colwise = (block_size_rows == 32) && (block_size_cols == 1); const bool colwise = (block_size_rows == 32) && (block_size_cols == 1);
NVTE_CHECK(rowwise || colwise); NVTE_CHECK(rowwise || colwise);
// std::cout << "unpadded_blocks_Y: " << unpadded_blocks_Y << std::endl;
// std::cout << "unpadded_blocks_X: " << unpadded_blocks_X << std::endl;
// std::cout << "blocks_Y: " << blocks_Y << std::endl;
// std::cout << "blocks_X: " << blocks_X << std::endl;
// std::cout << "scales_stride: " << scales_stride << std::endl;
Tensor grad("grad", std::vector<size_t>{ rows, cols }, itype); Tensor grad("grad", std::vector<size_t>{ rows, cols }, itype);
Tensor input("input", std::vector<size_t>{ rows, cols * 2 }, itype); Tensor input("input", std::vector<size_t>{ rows, cols * 2 }, itype);
...@@ -229,12 +227,12 @@ void performTest_x1(const size_t rows, ...@@ -229,12 +227,12 @@ void performTest_x1(const size_t rows,
} }
// fillCase<EncodingType>(&grad, fill_case); // fillCase<EncodingType>(&grad, fill_case);
if constexpr (IS_DGATED) { if (IS_DGATED) {
fillUniform(&grad); fillUniform(&grad);
} }
fillUniform(&input); fillUniform(&input);
if constexpr (IS_DGATED) { if (IS_DGATED) {
nvte_dswiglu(grad.data(), input.data(), output.data(), 0); nvte_dswiglu(grad.data(), input.data(), output.data(), 0);
} else { } else {
nvte_swiglu(input.data(), output.data(), 0); nvte_swiglu(input.data(), output.data(), 0);
...@@ -245,30 +243,48 @@ void performTest_x1(const size_t rows, ...@@ -245,30 +243,48 @@ void performTest_x1(const size_t rows,
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
float ref_amax = 0; float ref_amax = 0;
compute_ref_x1<IS_DGATED, IType, OType>(grad.rowwise_cpu_dptr<IType>(), compute_ref<IType, OType>(grad.rowwise_cpu_dptr<IType>(),
input.rowwise_cpu_dptr<IType>(), input.rowwise_cpu_dptr<IType>(),
ref_output.get(), ref_output.get(),
ref_output.get(),
ref_output_scales.get(),
ref_output_scales.get(), ref_output_scales.get(),
ref_amax, ref_amax,
IS_DGATED,
rows, rows,
cols, cols,
block_size_rows, scales_stride,
block_size_cols, scales_stride,
scales_stride); rowwise,
colwise);
auto [atol, rtol] = getTolerances(otype); size_t mismatches_scales = 0;
compareResults("output", output, ref_output.get(), rowwise, atol, rtol); const size_t scale_diff_abs_tolerance = 0;
const double abs_tolerable_mismatches_limit = 1.0;
const double rel_tolerable_mismatches_limit = 1.0e-4;
const uint8_t * const gpu_scales_ptr = rowwise const uint8_t * const gpu_scales_ptr = rowwise
? output.rowwise_cpu_scale_inv_ptr<fp8e8m0>() ? output.rowwise_cpu_scale_inv_ptr<fp8e8m0>()
: output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(); : output.columnwise_cpu_scale_inv_ptr<fp8e8m0>();
if (rowwise) { if (rowwise) {
compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride); unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
} else { } else {
compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride); unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
} }
const size_t mismatches_elts = 32 * mismatches_scales;
auto [atol, rtol] = getTolerances(otype);
compareResults("output", output, ref_output.get(), rowwise, atol, rtol, true, mismatches_elts);
} }
/** /**
...@@ -278,12 +294,13 @@ void performTest_x1(const size_t rows, ...@@ -278,12 +294,13 @@ void performTest_x1(const size_t rows,
* AND * AND
* 2) Scaled columns + column-wise scaling factors * 2) Scaled columns + column-wise scaling factors
*/ */
template <bool IS_DGATED, typename IType, typename OType> template <typename IType, typename OType>
void performTest_x2(const size_t rows, void performTest_x2(const size_t rows,
const size_t cols, const size_t cols,
const size_t block_size_rows, const size_t block_size_rows,
const size_t block_size_cols, const size_t block_size_cols,
InputsFillCase fill_case) { InputsFillCase fill_case,
const bool IS_DGATED) {
using namespace test; using namespace test;
using EncodingType = fp32; using EncodingType = fp32;
DType itype = TypeInfo<IType>::dtype; DType itype = TypeInfo<IType>::dtype;
...@@ -325,12 +342,12 @@ void performTest_x2(const size_t rows, ...@@ -325,12 +342,12 @@ void performTest_x2(const size_t rows,
} }
// fillCase<EncodingType>(&grad, fill_case); // fillCase<EncodingType>(&grad, fill_case);
if constexpr (IS_DGATED) { if (IS_DGATED) {
fillUniform(&grad); fillUniform(&grad);
} }
fillUniform(&input); fillUniform(&input);
if constexpr (IS_DGATED) { if (IS_DGATED) {
nvte_dswiglu(grad.data(), input.data(), output.data(), 0); nvte_dswiglu(grad.data(), input.data(), output.data(), 0);
} else { } else {
nvte_swiglu(input.data(), output.data(), 0); nvte_swiglu(input.data(), output.data(), 0);
...@@ -341,30 +358,49 @@ void performTest_x2(const size_t rows, ...@@ -341,30 +358,49 @@ void performTest_x2(const size_t rows,
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
float ref_amax = 0; float ref_amax = 0;
compute_ref_x2<IS_DGATED, IType, OType>(grad.rowwise_cpu_dptr<IType>(), compute_ref<IType, OType>(grad.rowwise_cpu_dptr<IType>(),
input.rowwise_cpu_dptr<IType>(), input.rowwise_cpu_dptr<IType>(),
ref_output_rowwise.get(), ref_output_rowwise.get(),
ref_output_colwise.get(), ref_output_colwise.get(),
ref_scales_rowwise.get(), ref_scales_rowwise.get(),
ref_scales_colwise.get(), ref_scales_colwise.get(),
ref_amax, ref_amax,
IS_DGATED,
rows, rows,
cols, cols,
block_size_rows,
block_size_cols,
scales_stride_rowwise, scales_stride_rowwise,
scales_stride_colwise); scales_stride_colwise,
true,
true);
auto [atol, rtol] = getTolerances(otype); const size_t scale_diff_abs_tolerance = 0;
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); const double abs_tolerable_mismatches_limit = 1.0;
compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol); const double rel_tolerable_mismatches_limit = 1.0e-4;
compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol);
size_t mismatches_scales_rowwise = 0;
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(), compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise); unpadded_blocks_X_rowwise, scales_stride_rowwise,
mismatches_scales_rowwise,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
size_t mismatches_scales_colwise = 0;
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(), compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise, ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise); unpadded_blocks_X_colwise, scales_stride_colwise,
mismatches_scales_colwise,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise;
const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise;
auto [atol, rtol] = getTolerances(otype);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol, true, mismatches_elts_rowwise);
compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol, true, mismatches_elts_colwise);
} }
std::vector<std::pair<size_t, size_t>> matrix_sizes = { std::vector<std::pair<size_t, size_t>> matrix_sizes = {
...@@ -375,8 +411,8 @@ std::vector<std::pair<size_t, size_t>> matrix_sizes = { ...@@ -375,8 +411,8 @@ std::vector<std::pair<size_t, size_t>> matrix_sizes = {
{256, 256}, {256, 256},
{993, 512}, {993, 512},
{768, 1024}, {768, 1024},
{65504, 128}, {8192, 128},
{16384, 1632}, {577, 1632},
}; };
std::vector<std::pair<size_t, size_t>> block_sizes = { std::vector<std::pair<size_t, size_t>> block_sizes = {
...@@ -393,9 +429,9 @@ std::vector<InputsFillCase> input_scenarios = { ...@@ -393,9 +429,9 @@ std::vector<InputsFillCase> input_scenarios = {
// InputsFillCase::maxNorm_to_inf // InputsFillCase::maxNorm_to_inf
}; };
std::vector<bool> is_dgated_op = { std::vector<bool> is_bwd_op = {
true, false,
false true
}; };
} // namespace } // namespace
...@@ -427,21 +463,11 @@ TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) { ...@@ -427,21 +463,11 @@ TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OType,
if (block_size.first == 1 || block_size.second == 1) { if (block_size.first == 1 || block_size.second == 1) {
if (IS_DGATED) { performTest_x1<IType, OType>(matrix_size.first, matrix_size.second,
performTest_x1<true, IType, OType>(matrix_size.first, matrix_size.second, block_size.first, block_size.second, fill_case, IS_DGATED);
block_size.first, block_size.second, fill_case);
} else {
performTest_x1<false, IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case);
}
} else { } else {
if (IS_DGATED) { performTest_x2<IType, OType>(matrix_size.first, matrix_size.second,
performTest_x2<true, IType, OType>(matrix_size.first, matrix_size.second, block_size.first, block_size.second, fill_case, IS_DGATED);
block_size.first, block_size.second, fill_case);
} else {
performTest_x2<false, IType, OType>(matrix_size.first, matrix_size.second,
block_size.first, block_size.second, fill_case);
}
} }
); );
); );
...@@ -456,7 +482,7 @@ INSTANTIATE_TEST_SUITE_P( ...@@ -456,7 +482,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(input_scenarios), ::testing::ValuesIn(input_scenarios),
::testing::ValuesIn(is_dgated_op)), ::testing::ValuesIn(is_bwd_op)),
[](const testing::TestParamInfo<CastMXFP8_GatedActTestSuite::ParamType>& info) { [](const testing::TestParamInfo<CastMXFP8_GatedActTestSuite::ParamType>& info) {
std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::string name = std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "X" + std::to_string(std::get<0>(info.param).second) + "X" +
...@@ -465,6 +491,6 @@ INSTANTIATE_TEST_SUITE_P( ...@@ -465,6 +491,6 @@ INSTANTIATE_TEST_SUITE_P(
test::typeName(std::get<2>(info.param)) + "X" + test::typeName(std::get<2>(info.param)) + "X" +
test::typeName(std::get<3>(info.param)) + "X" + test::typeName(std::get<3>(info.param)) + "X" +
test::caseName(std::get<4>(info.param)) + "X" + test::caseName(std::get<4>(info.param)) + "X" +
(std::get<5>(info.param) ? "DGATED" : "GATED"); (std::get<5>(info.param) ? "BWD" : "FWD");
return name; return name;
}); });
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <memory>
#include <string>
#include <tuple>
#include <vector>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/transpose.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename Type>
void compute_ref(const Type *input, Type *output,
const std::vector<size_t> &shape) {
const size_t dim0 = shape[0];
const size_t dim1 = shape[1];
size_t dim2 = 1;
for (size_t i = 2; i < shape.size(); ++i) {
dim2 *= shape[i];
}
for (size_t i = 0; i < dim0; ++i) {
for (size_t j = 0; j < dim1; ++j) {
for (size_t k = 0; k < dim2; ++k) {
const size_t in_offset = i * dim1 * dim2 + j * dim2 + k;
const size_t out_offset = j * dim0 * dim2 + i * dim2 + k;
output[out_offset] = input[in_offset];
}
}
}
}
template <typename Type>
void performTest(const std::vector<size_t> &in_shape) {
using namespace test;
DType dtype = TypeInfo<Type>::dtype;
// Tensor dimensions
std::vector<size_t> out_shape = in_shape;
out_shape[0] = in_shape[1];
out_shape[1] = in_shape[0];
size_t numel = 1;
for (const auto& dim : in_shape) {
numel *= dim;
}
// Transformer engine implementation
Tensor input("input", in_shape, dtype);
Tensor output("output", out_shape, dtype);
fillUniform(&input);
nvte_swap_first_dims(input.data(), output.data(), 0);
// Reference implementation
std::unique_ptr<Type[]> ref_output = std::make_unique<Type[]>(numel);
compute_ref<Type>(input.rowwise_cpu_dptr<Type>(), ref_output.get(), in_shape);
// Check for CUDA failure
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
// Check for exact numerics
compareResults("output", output, ref_output.get(), true, 0, 0);
}
std::vector<std::vector<size_t>> test_cases = {{4, 64, 1280},
{48, 8, 128, 16},
{229, 173}, // Primes 50, 40
{113, 71, 1, 1, 1, 29, 1, 1}}; // Primes 30, 20, 10
} // namespace
class SwapFirstDimsTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
std::vector<size_t>>> {};
TEST_P(SwapFirstDimsTestSuite, TestSwapFirstDims) {
using namespace transformer_engine;
using namespace test;
const DType type = std::get<0>(GetParam());
const auto shape = std::get<1>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
performTest<T>(shape);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
SwapFirstDimsTestSuite,
::testing::Combine(
::testing::ValuesIn(test::all_fp_types),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<SwapFirstDimsTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param));
for (const auto& dim : std::get<1>(info.param)) {
name += "X";
name += std::to_string(dim);
}
return name;
});
...@@ -523,10 +523,13 @@ std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) { ...@@ -523,10 +523,13 @@ std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) {
void compareResults_sequential(const std::string &name, const Tensor &test, void compareResults_sequential(const std::string &name, const Tensor &test,
const void *ref, const bool rowwise, const void *ref, const bool rowwise,
double atol, double rtol, bool if_on_gpus) { double atol, double rtol, bool if_on_gpus,
const size_t tolerable_mismatches_limit) {
if (if_on_gpus) test.to_cpu(); if (if_on_gpus) test.to_cpu();
const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
const size_t N = product(shape); const size_t N = product(shape);
size_t mismatches_num = 0;
int first_mismatch_idx = -1;
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>(); const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
const T *ref_data = reinterpret_cast<const T*>(ref); const T *ref_data = reinterpret_cast<const T*>(ref);
...@@ -562,27 +565,39 @@ void compareResults_sequential(const std::string &name, const Tensor &test, ...@@ -562,27 +565,39 @@ void compareResults_sequential(const std::string &name, const Tensor &test,
#endif #endif
} }
std::string direction = rowwise ? "rowwise" : "columnwise"; std::string direction = rowwise ? "rowwise" : "columnwise";
ASSERT_FALSE(assertion) << "Error in tensor " << name << " in " if (assertion) {
mismatches_num++;
if (first_mismatch_idx == -1) {
first_mismatch_idx = i;
}
}
if (mismatches_num > tolerable_mismatches_limit) {
const double first_mismatch_t = static_cast<double>(test_data[first_mismatch_idx]);
const double first_mismatch_r = static_cast<double>(ref_data[first_mismatch_idx]);
GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
<< tolerable_mismatches_limit << "." << std::endl
<< "Error in tensor " << name << " in "
<< direction << " direction." << std::endl << direction << " direction." << std::endl
<< "Mismatch at place " << to_string(unravel(i, shape)) << "First mismatch at place " << to_string(unravel(first_mismatch_idx, shape))
<< " (" << std::to_string(i) << "): " << t << " vs " << r; << " (" << std::to_string(first_mismatch_idx) << "): "
<< first_mismatch_t << " vs " << first_mismatch_r;
}
} }
); );
} }
template <typename T> template <typename T>
static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data, static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data,
const size_t N, const double atol, const double rtol) { const size_t N, const double atol, const double rtol,
size_t& mismatches) {
int first_mismatch_idx = N; int first_mismatch_idx = N;
bool is_mismatch_found = false; #pragma omp parallel reduction(min: first_mismatch_idx) reduction(+: mismatches) proc_bind(spread)
#pragma omp parallel for schedule(static) firstprivate(is_mismatch_found) \ {
reduction(min: first_mismatch_idx) proc_bind(spread) size_t thread_mismatches = 0;
#pragma omp for schedule(static)
for (size_t i = 0; i < N; ++i) { for (size_t i = 0; i < N; ++i) {
if (is_mismatch_found) { // early escape of the omp thread
continue;
}
#ifndef __HIP_PLATFORM_AMD__ #ifndef __HIP_PLATFORM_AMD__
double t = static_cast<double>(test_data[i]); double t = static_cast<double>(test_data[i]);
double r = static_cast<double>(ref_data[i]); double r = static_cast<double>(ref_data[i]);
...@@ -591,7 +606,6 @@ static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, con ...@@ -591,7 +606,6 @@ static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, con
double r = static_cast<double>(static_cast<float>(ref_data[i])); double r = static_cast<double>(static_cast<float>(ref_data[i]));
#endif #endif
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
/* For Float32 the floating point comparison is enough to error out */ /* For Float32 the floating point comparison is enough to error out */
bool assertion = mismatch && (data_type == DType::kFloat32); bool assertion = mismatch && (data_type == DType::kFloat32);
...@@ -608,32 +622,37 @@ static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, con ...@@ -608,32 +622,37 @@ static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, con
const double cast_mean_p = static_cast<double>(static_cast<float>(static_cast<T>(static_cast<float>(mean_p)))); const double cast_mean_p = static_cast<double>(static_cast<float>(static_cast<T>(static_cast<float>(mean_p))));
const double cast_mean_m = static_cast<double>(static_cast<float>(static_cast<T>(static_cast<float>(mean_m)))); const double cast_mean_m = static_cast<double>(static_cast<float>(static_cast<T>(static_cast<float>(mean_m))));
#endif #endif
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
assertion = !(cast_mean_m == std::min<double>(t,r) && cast_mean_p == std::max<double>(t,r)); assertion = !(cast_mean_m == std::min<double>(t,r) && cast_mean_p == std::max<double>(t,r));
#else #else
assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
#endif #endif
} }
if (assertion && i < first_mismatch_idx) { if (assertion) {
if (i < first_mismatch_idx) {
first_mismatch_idx = i; first_mismatch_idx = i;
is_mismatch_found = true;
} }
thread_mismatches++;
}
}
mismatches += thread_mismatches;
} }
return first_mismatch_idx; return first_mismatch_idx;
} }
void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref, void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref,
const bool rowwise, double atol, double rtol, bool if_on_gpus) { const bool rowwise, double atol, double rtol, bool if_on_gpus,
const size_t tolerable_mismatches_limit) {
if (if_on_gpus) test.to_cpu(); if (if_on_gpus) test.to_cpu();
const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
const size_t N = product(shape); const size_t N = product(shape);
size_t mismatches = 0;
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>(); const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
const T *ref_data = reinterpret_cast<const T*>(ref); const T *ref_data = reinterpret_cast<const T*>(ref);
const size_t i = getFirstMismatchIdx<T>(test.dtype(), test_data, ref_data, N, atol, rtol); const size_t i = getFirstMismatchIdx<T>(test.dtype(), test_data, ref_data, N, atol, rtol, mismatches);
if (i != N) { if ((i != N) && (mismatches > tolerable_mismatches_limit)) {
#ifndef __HIP_PLATFORM_AMD__ #ifndef __HIP_PLATFORM_AMD__
const double t = static_cast<double>(test_data[i]); const double t = static_cast<double>(test_data[i]);
const double r = static_cast<double>(ref_data[i]); const double r = static_cast<double>(ref_data[i]);
...@@ -642,7 +661,10 @@ void compareResults_parallel(const std::string &name, const Tensor &test, const ...@@ -642,7 +661,10 @@ void compareResults_parallel(const std::string &name, const Tensor &test, const
const double r = static_cast<double>(static_cast<float>(ref_data[i])); const double r = static_cast<double>(static_cast<float>(ref_data[i]));
#endif #endif
std::string direction = rowwise ? "rowwise" : "columnwise"; std::string direction = rowwise ? "rowwise" : "columnwise";
ASSERT_FALSE(true) << "Error in tensor " << name << " in "
GTEST_FAIL() << mismatches << " mismatche(s) which is more than tolerable mismatch limit of "
<< tolerable_mismatches_limit << "." << std::endl
<< "Error in tensor " << name << " in "
<< direction << " direction." << std::endl << direction << " direction." << std::endl
<< "Mismatch at place " << to_string(unravel(i, shape)) << "Mismatch at place " << to_string(unravel(i, shape))
<< " (" << std::to_string(i) << "): " << t << " vs " << r; << " (" << std::to_string(i) << "): " << t << " vs " << r;
...@@ -651,12 +673,13 @@ void compareResults_parallel(const std::string &name, const Tensor &test, const ...@@ -651,12 +673,13 @@ void compareResults_parallel(const std::string &name, const Tensor &test, const
} }
void compareResults(const std::string &name, const Tensor &test, const void *ref, void compareResults(const std::string &name, const Tensor &test, const void *ref,
const bool rowwise, double atol, double rtol, bool if_on_gpus) { const bool rowwise, double atol, double rtol, bool if_on_gpus,
const size_t tolerable_mismatches_limit) {
constexpr bool sequential = false; constexpr bool sequential = false;
if constexpr (sequential) { if constexpr (sequential) {
compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus); compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit);
} else { } else {
compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus); compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit);
} }
} }
...@@ -698,25 +721,39 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t ...@@ -698,25 +721,39 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t
} }
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride) const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num, const size_t atol,
const double abs_tolerable_mismatches_limit,
const double rel_tolerable_mismatches_limit)
{ {
const size_t N = row_blocks * col_blocks;
const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit,
std::floor(N * rel_tolerable_mismatches_limit));
mismatches_num = 0;
std::vector<int> mismatch_indices;
for (int i = 0; i < row_blocks; ++i) { for (int i = 0; i < row_blocks; ++i) {
for (int j = 0; j < col_blocks; ++j) { for (int j = 0; j < col_blocks; ++j) {
const int idx = i * stride + j; const int idx = i * stride + j;
ASSERT_FALSE(test[idx] != ref[idx]) << "Error in " << name << std::endl const int test_val = static_cast<int>(test[idx]);
<< "Mismatch: " << static_cast<int>(test[idx]) << " vs " const int ref_val = static_cast<int>(ref[idx]);
<< static_cast<int>(ref[idx]) << " at index " << idx; const int abs_delta = std::abs(test_val - ref_val);
if (abs_delta > atol) {
mismatches_num++;
mismatch_indices.push_back(idx);
}
if (mismatches_num > tolerable_mismatches_limit) {
std::cout << "Error in " << name << std::endl;
for (const int index : mismatch_indices) {
std::cout << "Mismatch at (" << index << "):"
<< static_cast<int>(test[index]) << " vs "
<< static_cast<int>(ref[index]) << std::endl;
}
GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
<< tolerable_mismatches_limit << ".";
} }
} }
}
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t N)
{
for (int i = 0; i < N; i++) {
ASSERT_FALSE(test[i] != ref[i]) << "Error in " << name << std::endl
<< "Mismatch: " << static_cast<int>(test[i]) << " vs "
<< static_cast<int>(ref[i]) << " at index " << i;
} }
} }
......
...@@ -430,7 +430,12 @@ inline fp8e8m0 float_to_e8m0(float val) { ...@@ -430,7 +430,12 @@ inline fp8e8m0 float_to_e8m0(float val) {
} }
inline float exp2f_rcp(fp8e8m0 biased_exp) { inline float exp2f_rcp(fp8e8m0 biased_exp) {
return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast<float>(biased_exp)); if (biased_exp == 0) {
return 1.0f;
}
int32_t int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127)
float fp32_val = *reinterpret_cast<float*>(&int_val);
return fp32_val;
} }
inline float identity(const float x) { return x; } inline float identity(const float x) { return x; }
...@@ -462,15 +467,18 @@ size_t last_dimension(const std::vector<size_t> &shape); ...@@ -462,15 +467,18 @@ size_t last_dimension(const std::vector<size_t> &shape);
bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2); bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2);
void compareResults(const std::string &name, const Tensor &test, const void *ref, void compareResults(const std::string &name, const Tensor &test, const void *ref,
bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true); bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true,
const size_t tolerable_mismatches_limit = 0);
void compareResults(const std::string &name, const float test, const float ref, void compareResults(const std::string &name, const float test, const float ref,
double atol = 1e-5, double rtol = 1e-8); double atol = 1e-5, double rtol = 1e-8);
void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref,
size_t N, float mismatch_rate_tol = 0.); size_t N, float mismatch_rate_tol = 0.);
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride); const size_t row_blocks, const size_t col_blocks, const size_t stride,
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, size_t& mismatches_num,
const size_t N); const size_t scale_diff_abs_tolerance = 0,
const double abs_tolerable_mismatches_limit = 0,
const double rel_tolerable_mismatches_limit = 0);
std::array<size_t, 4> get_scale_tensor_dims(const size_t rows, const size_t cols, std::array<size_t, 4> get_scale_tensor_dims(const size_t rows, const size_t cols,
const size_t block_size_rows, const size_t block_size_cols); const size_t block_size_rows, const size_t block_size_cols);
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
import os import os
import jax import jax
import pytest import pytest
from collections import defaultdict
import time
import transformer_engine.jax import transformer_engine.jax
...@@ -32,3 +34,54 @@ def enable_fused_attn_after_hopper(): ...@@ -32,3 +34,54 @@ def enable_fused_attn_after_hopper():
yield yield
if "NVTE_FUSED_ATTN" in os.environ: if "NVTE_FUSED_ATTN" in os.environ:
del os.environ["NVTE_FUSED_ATTN"] del os.environ["NVTE_FUSED_ATTN"]
class TestTimingPlugin:
"""
Plugin to measure test execution time. Enable test timing by setting NVTE_JAX_TEST_TIMING=1
in the environment.
"""
def __init__(self):
self.test_timings = defaultdict(list)
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_setup(self, item):
item._timing_start = time.time()
@pytest.hookimpl(trylast=True)
def pytest_runtest_teardown(self, item, nextitem):
if hasattr(item, "_timing_start"):
duration = time.time() - item._timing_start
# Extract base function name without parameters
test_name = item.name
if "[" in test_name:
base_name = test_name.split("[")[0]
else:
base_name = test_name
self.test_timings[base_name].append(duration)
def pytest_sessionfinish(self, session, exitstatus):
print("\n" + "=" * 80)
print("TEST RUNTIME SUMMARY (grouped by function)")
print("=" * 80)
total_overall = 0
for test_name, durations in sorted(self.test_timings.items()):
total_time = sum(durations)
count = len(durations)
avg_time = total_time / count if count > 0 else 0
total_overall += total_time
print(f"{test_name:<60} | {count:3}x | {total_time:7.2f}s | avg: {avg_time:6.2f}s")
print("=" * 80)
print(f"{'TOTAL RUNTIME':<60} | {'':>3} | {total_overall:7.2f}s |")
print("=" * 80)
def pytest_configure(config):
if os.getenv("NVTE_JAX_TEST_TIMING", "0") == "1":
config.pluginmanager.register(TestTimingPlugin(), "test_timing")
...@@ -39,8 +39,10 @@ def generate_configs(): ...@@ -39,8 +39,10 @@ def generate_configs():
return configs return configs
def generate_context_parallel_configs(): def generate_context_parallel_configs_for_attn():
configs = [] """Generate CP combinations along with TP+DP for TestDistributedContextParallelSelfAttn only"""
configsL1 = []
configsL2 = []
mr = MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp") mr = MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp")
axes = ("dp", "cp", "tp") axes = ("dp", "cp", "tp")
DP_sizes = (1, 2) DP_sizes = (1, 2)
...@@ -49,10 +51,16 @@ def generate_context_parallel_configs(): ...@@ -49,10 +51,16 @@ def generate_context_parallel_configs():
for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes): for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes):
ndev = cp * tp * dp ndev = cp * tp * dp
if is_devices_enough(ndev): if is_devices_enough(ndev):
configs.append( # Do not run cp1 case in L1 as that is already covered in TestDistributedSelfAttn and TestDistributedCrossAttn (as these do not have any cp combinations)
if cp != 1:
configsL1.append(
pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}") pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}")
) )
else:
configsL2.append(
pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}")
)
configs = {"L0": [], "L1": configsL1, "L2": configsL2}
return configs return configs
......
...@@ -78,8 +78,14 @@ def is_shape_supported_by_mxfp8(input_shape): ...@@ -78,8 +78,14 @@ def is_shape_supported_by_mxfp8(input_shape):
return False return False
def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor): def assert_bitwise_scaled_tensors(
a: ScaledTensor, b: ScaledTensor, precise_comparison: bool = True
):
if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x): if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
if not precise_comparison:
assert_allclose(a.dequantize(), b.dequantize(), dtype=a.data.dtype)
return
assert a.scaling_mode == b.scaling_mode assert a.scaling_mode == b.scaling_mode
assert a.scale_inv.dtype == b.scale_inv.dtype assert a.scale_inv.dtype == b.scale_inv.dtype
if a.scaling_mode.is_tensor_scaling(): if a.scaling_mode.is_tensor_scaling():
...@@ -94,8 +100,12 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor): ...@@ -94,8 +100,12 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
assert_allclose(a.data, b.data) assert_allclose(a.data, b.data)
elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x): elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x):
assert_bitwise_scaled_tensors(a.rowwise_tensor, b.rowwise_tensor) assert_bitwise_scaled_tensors(
assert_bitwise_scaled_tensors(a.colwise_tensor, b.colwise_tensor) a.rowwise_tensor, b.rowwise_tensor, precise_comparison=precise_comparison
)
assert_bitwise_scaled_tensors(
a.colwise_tensor, b.colwise_tensor, precise_comparison=precise_comparison
)
else: else:
pytest.fail("Unsupported input types") pytest.fail("Unsupported input types")
...@@ -481,24 +491,7 @@ class TestNorm: ...@@ -481,24 +491,7 @@ class TestNorm:
# if the input dtype is not float32 # if the input dtype is not float32
precise_comparison = False precise_comparison = False
if precise_comparison: assert_bitwise_scaled_tensors(output, ref_out, precise_comparison=precise_comparison)
assert_bitwise_scaled_tensors(output, ref_out)
else:
if isinstance(ref_out, ScaledTensor1x):
assert_allclose(output.dequantize(), ref_out.dequantize(), dtype=out_dtype)
elif isinstance(ref_out, ScaledTensor2x):
assert_allclose(
output.rowwise_tensor.dequantize(),
ref_out.rowwise_tensor.dequantize(),
dtype=out_dtype,
)
assert_allclose(
output.colwise_tensor.dequantize(),
ref_out.colwise_tensor.dequantize(),
dtype=out_dtype,
)
else:
pytest.fail("Unsupported output type")
assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype) assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype)
if norm_type == "layernorm": if norm_type == "layernorm":
...@@ -680,10 +673,6 @@ class TestGroupedQuantize: ...@@ -680,10 +673,6 @@ class TestGroupedQuantize:
n_groups=n_groups, n_groups=n_groups,
) )
# grouped_quantize does not work with cudaGraph yet, so the jitting will breaks
# To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
# disable cudaGraph, then use the following jitted function
scaled_tensor = tex.grouped_quantize( scaled_tensor = tex.grouped_quantize(
x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer
) )
...@@ -768,12 +757,24 @@ class TestFusedQuantize: ...@@ -768,12 +757,24 @@ class TestFusedQuantize:
)(dz, x) )(dz, x)
if is_casted_output: if is_casted_output:
assert_bitwise_scaled_tensors(te_output, jax_output) # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation
precise_comparison = not (
in_dtype != jnp.float32 and scaling_mode.is_1d_block_scaling()
)
assert_bitwise_scaled_tensors(
te_output, jax_output, precise_comparison=precise_comparison
)
else: else:
assert_allclose(te_output, jax_output) assert_allclose(te_output, jax_output)
if is_dbias: if is_dbias:
assert_allclose(te_dbias, jax_dbias) # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
precise_comparison = not (
in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling()
)
assert_allclose(
te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype
)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
...@@ -858,15 +859,6 @@ valid_fp8_gemm_operand_types = [ ...@@ -858,15 +859,6 @@ valid_fp8_gemm_operand_types = [
] ]
def _use_jax_fp8_gemm(enabled=False):
import os
if enabled:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$"
elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ:
os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE")
class TestDense: class TestDense:
def _ref_gemm_with_jnp_dot(self, a, b, data_layout): def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
if data_layout[0] == "T": if data_layout[0] == "T":
...@@ -1316,16 +1308,14 @@ class TestGroupedDense: ...@@ -1316,16 +1308,14 @@ class TestGroupedDense:
) )
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
# grouped_gemm does not work with cudaGraph yet, so the jitting will breaks
# To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
# disable cudaGraph, then use the following jitted function
# jitting grouped_gemm # jitting grouped_gemm
# prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
# lhs, rhs, group_sizes, contracting_dims, lhs,
# ) rhs,
group_sizes,
contracting_dims,
)
prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims)
self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
...@@ -1354,12 +1344,7 @@ class TestGroupedDense: ...@@ -1354,12 +1344,7 @@ class TestGroupedDense:
) )
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
# jitting grouped_gemm prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
# prim_out = jax.jit(tex.grouped_gemm, static_argnames=('contracting_dims',))(
# lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
# )
prim_out = tex.grouped_gemm(
lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
) )
...@@ -1395,9 +1380,9 @@ class TestGroupedDense: ...@@ -1395,9 +1380,9 @@ class TestGroupedDense:
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
# jitting the grouped_dense # jitting the grouped_dense
# value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), value_n_grad_prim_func = jit(
# static_argnums=(4,)) value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,)
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) )
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
x, kernel, bias, group_sizes, contracting_dims x, kernel, bias, group_sizes, contracting_dims
...@@ -1436,9 +1421,9 @@ class TestGroupedDense: ...@@ -1436,9 +1421,9 @@ class TestGroupedDense:
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
# jitting the grouped_dense # jitting the grouped_dense
# value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), value_n_grad_prim_func = jit(
# static_argnums=(4,)) value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,)
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) )
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
x, x,
......
...@@ -9,10 +9,11 @@ import jax.numpy as jnp ...@@ -9,10 +9,11 @@ import jax.numpy as jnp
from jax import random from jax import random
from distributed_test_base import ( from distributed_test_base import (
generate_configs, generate_configs,
generate_context_parallel_configs, generate_context_parallel_configs_for_attn,
generate_collectives_count, generate_collectives_count,
) )
from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat
from utils import pytest_parametrize_wrapper
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
is_fused_attn_kernel_available, is_fused_attn_kernel_available,
AttnBiasType, AttnBiasType,
...@@ -28,6 +29,12 @@ from transformer_engine.jax.attention import ( ...@@ -28,6 +29,12 @@ from transformer_engine.jax.attention import (
DTYPES = [jnp.bfloat16] DTYPES = [jnp.bfloat16]
DISTRIBUTED_SELF_ATTN_DATA_SHAPES = {
"L0": [()],
"L1": [(32, 1024, 16, 128)],
"L2": [(32, 512, 12, 64)],
}
class TestDistributedSelfAttn: class TestDistributedSelfAttn:
...@@ -64,7 +71,6 @@ class TestDistributedSelfAttn: ...@@ -64,7 +71,6 @@ class TestDistributedSelfAttn:
jax.config.update("jax_use_shardy_partitioner", use_shardy) jax.config.update("jax_use_shardy_partitioner", use_shardy)
dropout_prob = 0.0 dropout_prob = 0.0
is_training = True is_training = True
batch, seqlen, num_head, hidden = data_shape batch, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available( if not is_fused_attn_kernel_available(
...@@ -119,13 +125,7 @@ class TestDistributedSelfAttn: ...@@ -119,13 +125,7 @@ class TestDistributedSelfAttn:
runner.test_backward() runner.test_backward()
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize( @pytest_parametrize_wrapper("data_shape", DISTRIBUTED_SELF_ATTN_DATA_SHAPES)
"data_shape",
[
pytest.param((32, 512, 12, 64), id="32-512-12-64"),
pytest.param((32, 1024, 16, 128), id="32-1024-16-128"),
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"attn_bias_type, bias_shape", "attn_bias_type, bias_shape",
[ [
...@@ -193,6 +193,13 @@ class TestDistributedSelfAttn: ...@@ -193,6 +193,13 @@ class TestDistributedSelfAttn:
) )
DISTRIBUTED_CROSS_ATTN_DATA_SHAPES = {
"L0": [()],
"L1": [[32, 512, 16, 64]],
"L2": [[32, 128, 12, 64]],
}
class TestDistributedCrossAttn: class TestDistributedCrossAttn:
def generate_collectives_count_ref(self): def generate_collectives_count_ref(self):
...@@ -201,7 +208,7 @@ class TestDistributedCrossAttn: ...@@ -201,7 +208,7 @@ class TestDistributedCrossAttn:
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]]) @pytest_parametrize_wrapper("data_shape", DISTRIBUTED_CROSS_ATTN_DATA_SHAPES)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK] "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
) )
...@@ -390,8 +397,9 @@ class TestDistributedContextParallelSelfAttn: ...@@ -390,8 +397,9 @@ class TestDistributedContextParallelSelfAttn:
runner.test_backward() runner.test_backward()
del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"]
@pytest.mark.parametrize( @pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() "device_count,mesh_shape,mesh_axes,mesh_resource",
generate_context_parallel_configs_for_attn(),
) )
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1]) @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
...@@ -426,8 +434,9 @@ class TestDistributedContextParallelSelfAttn: ...@@ -426,8 +434,9 @@ class TestDistributedContextParallelSelfAttn:
use_shardy=True, use_shardy=True,
) )
@pytest.mark.parametrize( @pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() "device_count,mesh_shape,mesh_axes,mesh_resource",
generate_context_parallel_configs_for_attn(),
) )
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES) @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
@pytest.mark.parametrize("kv_groups", [1, 8]) @pytest.mark.parametrize("kv_groups", [1, 8])
...@@ -468,8 +477,9 @@ class TestDistributedContextParallelSelfAttn: ...@@ -468,8 +477,9 @@ class TestDistributedContextParallelSelfAttn:
use_shardy=False, use_shardy=False,
) )
@pytest.mark.parametrize( @pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() "device_count,mesh_shape,mesh_axes,mesh_resource",
generate_context_parallel_configs_for_attn(),
) )
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES) @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
@pytest.mark.parametrize("kv_groups", [1, 8]) @pytest.mark.parametrize("kv_groups", [1, 8])
...@@ -532,8 +542,9 @@ class TestDistributedContextParallelSelfAttn: ...@@ -532,8 +542,9 @@ class TestDistributedContextParallelSelfAttn:
window_size=window_size, window_size=window_size,
) )
@pytest.mark.parametrize( @pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() "device_count,mesh_shape,mesh_axes,mesh_resource",
generate_context_parallel_configs_for_attn(),
) )
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1]) @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
...@@ -570,16 +581,16 @@ class TestDistributedContextParallelSelfAttn: ...@@ -570,16 +581,16 @@ class TestDistributedContextParallelSelfAttn:
) )
REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES = {
"L0": [[]],
"L1": [[3, 32, 8, 64]],
"L2": [[4, 32, 12, 32], [1, 16, 1, 1]],
}
class TestReorderCausalLoadBalancing: class TestReorderCausalLoadBalancing:
@pytest.mark.parametrize("cp_size", [2, 4, 8]) @pytest.mark.parametrize("cp_size", [2, 4, 8])
@pytest.mark.parametrize( @pytest_parametrize_wrapper("shape", REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES)
"shape",
[
pytest.param([1, 16, 1, 1], id="1-16-1-1"),
pytest.param([4, 32, 12, 32], id="4-32-12-32"),
pytest.param([3, 32, 8, 64], id="3-32-8-64"),
],
)
@pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD]) @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"reorder_strategy", "reorder_strategy",
......
...@@ -25,6 +25,7 @@ DTYPES = [jnp.bfloat16, jnp.float32] ...@@ -25,6 +25,7 @@ DTYPES = [jnp.bfloat16, jnp.float32]
NORM_INPUT_SHAPES = { NORM_INPUT_SHAPES = {
"L0": [[64, 64]], "L0": [[64, 64]],
"L1": [[64, 64]],
"L2": [[64, 64]], "L2": [[64, 64]],
} }
......
...@@ -333,7 +333,6 @@ class TestDistributedLayernormMLP: ...@@ -333,7 +333,6 @@ class TestDistributedLayernormMLP:
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
ln_mlp_single = LayerNormMLP( ln_mlp_single = LayerNormMLP(
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
intermediate_dim=INTERMEDIATE, intermediate_dim=INTERMEDIATE,
activations=activation_type, activations=activation_type,
use_bias=use_bias, use_bias=use_bias,
...@@ -352,7 +351,6 @@ class TestDistributedLayernormMLP: ...@@ -352,7 +351,6 @@ class TestDistributedLayernormMLP:
): ):
ln_mlp_sharded = LayerNormMLP( ln_mlp_sharded = LayerNormMLP(
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE, intermediate_dim=INTERMEDIATE,
activations=activation_type, activations=activation_type,
scale_axes=LN_SCALE_AXES, scale_axes=LN_SCALE_AXES,
......
...@@ -135,7 +135,7 @@ class TestDistributedSoftmax: ...@@ -135,7 +135,7 @@ class TestDistributedSoftmax:
) )
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]]) @pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [8, 8, 1024, 1024]])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"softmax_type", "softmax_type",
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED], [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
...@@ -168,14 +168,14 @@ class TestDistributedSoftmax: ...@@ -168,14 +168,14 @@ class TestDistributedSoftmax:
dtype, dtype,
bad_sharding, bad_sharding,
broadcast_batch_mask, broadcast_batch_mask,
use_shardy=False, use_shardy=True,
) )
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED]) @pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED])
@pytest.mark.parametrize("bad_sharding", [False, True]) @pytest.mark.parametrize("bad_sharding", [False, True])
@pytest.mark.parametrize("broadcast_batch_mask", [False, True]) @pytest.mark.parametrize("broadcast_batch_mask", [False, True])
def test_softmax_shardy( def test_softmax_gspmd(
self, self,
device_count, device_count,
mesh_shape, mesh_shape,
...@@ -196,5 +196,5 @@ class TestDistributedSoftmax: ...@@ -196,5 +196,5 @@ class TestDistributedSoftmax:
dtype=DTYPES[0], dtype=DTYPES[0],
bad_sharding=bad_sharding, bad_sharding=bad_sharding,
broadcast_batch_mask=broadcast_batch_mask, broadcast_batch_mask=broadcast_batch_mask,
use_shardy=True, use_shardy=False,
) )
...@@ -372,7 +372,7 @@ class FusedAttnRunner: ...@@ -372,7 +372,7 @@ class FusedAttnRunner:
self.head_dim_v, self.head_dim_v,
(-1, -1) if self.window_size is None else self.window_size, (-1, -1) if self.window_size is None else self.window_size,
).get_fused_attn_backend() ).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
pytest.skip("Unsupported inputs combination or device compute capability.") pytest.skip("Unsupported inputs combination or device compute capability.")
if ( if (
......
...@@ -58,7 +58,6 @@ class TestFP8Functions(unittest.TestCase): ...@@ -58,7 +58,6 @@ class TestFP8Functions(unittest.TestCase):
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo) self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
def _compare_current_scaling(self, test): def _compare_current_scaling(self, test):
self.assertEqual(QuantizeConfig.MARGIN, test.margin)
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format) self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING) self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING)
...@@ -91,7 +90,7 @@ class TestFP8Functions(unittest.TestCase): ...@@ -91,7 +90,7 @@ class TestFP8Functions(unittest.TestCase):
self._check_default_state() self._check_default_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_current_scaling(self): def test_fp8_autocast_current_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state() self._check_default_state()
...@@ -101,14 +100,14 @@ class TestFP8Functions(unittest.TestCase): ...@@ -101,14 +100,14 @@ class TestFP8Functions(unittest.TestCase):
self._check_default_state() self._check_default_state()
cs = Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3) cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=cs): with fp8_autocast(enabled=True, fp8_recipe=cs):
self.assertTrue(QuantizeConfig.is_fp8_enabled()) self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(cs) self._compare_current_scaling(cs)
self._check_default_state() self._check_default_state()
cs = Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID) cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=cs): with fp8_autocast(enabled=True, fp8_recipe=cs):
self.assertTrue(QuantizeConfig.is_fp8_enabled()) self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(cs) self._compare_current_scaling(cs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment