Commit 44740c6c authored by yuguo's avatar yuguo
Browse files

Merge commit '7a9a0825' of...

Merge commit '7a9a0825' of https://github.com/NVIDIA/TransformerEngine
parents 8113d9e0 7a9a0825
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <string>
#include <vector>
#include <cstdio>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/padding.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename InputType, typename OutputType>
void compute_unpadding_ref(const std::vector<std::vector<InputType>>& input_list,
std::vector<std::vector<OutputType>>& output_list,
const std::vector<size_t>& height_list,
const std::vector<size_t>& width_list,
const std::vector<int>& padded_height_list) {
using compute_t = float;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
const auto& input = input_list[tensor_id];
auto& output = output_list[tensor_id];
const size_t height = height_list[tensor_id];
const size_t width = width_list[tensor_id];
const size_t padded_height = padded_height_list[tensor_id];
// Only copy the valid (unpadded) portion
for (size_t i = 0; i < height; ++i) {
for (size_t j = 0; j < width; ++j) {
const compute_t x = static_cast<compute_t>(input[i * width + j]);
const OutputType y = static_cast<OutputType>(x);
output[i * width + j] = y;
}
}
}
}
template <typename InputType, typename OutputType>
void performUnpaddingTest() {
using namespace test;
const DType itype = TypeInfo<InputType>::dtype;
const DType otype = TypeInfo<OutputType>::dtype;
const std::vector<std::pair<size_t, size_t>> tensor_dims = {{1,1},
{1,768},
{768,1},
{768,768},
{43,43},
{43,256},
{256,43},
{256,256}};
const size_t num_tensors = tensor_dims.size();
constexpr int align = 16;
// Buffers for Transformer Engine implementation
std::vector<Tensor> padded_input_list, unpadded_output_list;
// Buffers for reference implementation
std::vector<std::vector<InputType>> ref_padded_input_list;
std::vector<std::vector<OutputType>> ref_unpadded_output_list;
std::vector<size_t> ref_height_list(num_tensors), ref_width_list(num_tensors);
std::vector<int> ref_padded_height_list(num_tensors);
// Initialize buffers
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
const size_t original_height = tensor_dims[tensor_id].first;
const size_t width = tensor_dims[tensor_id].second;
const size_t padded_height = (original_height + align - 1) / align * align;
// Input is padded tensor (padded_height x width)
padded_input_list.emplace_back(
Tensor("padded_input_" + std::to_string(tensor_id),
std::vector<size_t>{padded_height, width}, itype));
// Output is unpadded tensor (original_height x width)
unpadded_output_list.emplace_back(
Tensor("unpadded_output_" + std::to_string(tensor_id),
std::vector<size_t>{original_height, width}, otype));
auto& padded_input = padded_input_list.back();
auto& unpadded_output = unpadded_output_list.back();
// Fill padded input with random data (including padding area)
fillUniform(&padded_input);
setRandomScale(&unpadded_output);
// Initialize reference buffers
ref_padded_input_list.emplace_back(padded_height * width);
ref_unpadded_output_list.emplace_back(original_height * width);
// Copy data to reference buffers
std::copy(padded_input.rowwise_cpu_dptr<InputType>(),
padded_input.rowwise_cpu_dptr<InputType>() + padded_height * width,
ref_padded_input_list.back().begin());
ref_height_list[tensor_id] = original_height;
ref_width_list[tensor_id] = width;
ref_padded_height_list[tensor_id] = padded_height;
}
// Transformer Engine implementation
auto make_nvte_vector = [](std::vector<Tensor>& tensor_list)
-> std::vector<NVTETensor> {
std::vector<NVTETensor> nvte_tensor_list;
for (auto& tensor : tensor_list) {
nvte_tensor_list.emplace_back(tensor.data());
}
return nvte_tensor_list;
};
// Convert height_list to int for the API
std::vector<int> original_height_list_int(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
original_height_list_int[i] = static_cast<int>(ref_height_list[i]);
}
// Call unpadding API
nvte_multi_unpadding(num_tensors,
make_nvte_vector(padded_input_list).data(),
make_nvte_vector(unpadded_output_list).data(),
original_height_list_int.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
// Reference implementation
compute_unpadding_ref<InputType, OutputType>(ref_padded_input_list,
ref_unpadded_output_list,
ref_height_list,
ref_width_list,
ref_padded_height_list);
// Check correctness
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
auto [atol, rtol] = getTolerances(otype);
compareResults("unpadded_output",
unpadded_output_list[tensor_id],
ref_unpadded_output_list[tensor_id].data(),
true,
atol, rtol);
}
}
} // namespace
class MultiUnpaddingTestSuite
: public ::testing::TestWithParam<transformer_engine::DType> {};
TEST_P(MultiUnpaddingTestSuite, TestMultiUnpadding) {
using namespace transformer_engine;
using namespace test;
const DType input_type = GetParam();
const DType output_type = input_type;
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performUnpaddingTest<InputType, OutputType>();
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
MultiUnpaddingTestSuite,
::testing::ValuesIn(test::all_fp_types),
[](const testing::TestParamInfo<MultiUnpaddingTestSuite::ParamType>& info) {
std::string name = test::typeName(info.param);
return name;
});
......@@ -34,8 +34,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
return;
}
if (getDeviceComputeCapability() < blackwellComputeCapability && use_cudnn) {
GTEST_SKIP() << "cuDNN normalizations not supported on pre-Blackwell GPUs yet!";
if (getDeviceComputeCapability() < hopperComputeCapability && use_cudnn) {
GTEST_SKIP() << "cuDNN normalizations not supported on pre-Hopper GPUs yet!";
}
using WeightType = InputType;
......
......@@ -38,7 +38,7 @@ TEST(UtilTest, ToStringLike) { // to_string_like
// Non-zero integer types
EXPECT_EQ(to_string_like(static_cast<char>(1)), "1");
EXPECT_EQ(to_string_like(static_cast<char>(-1)), "-1");
EXPECT_EQ(to_string_like(static_cast<signed char>(-1)), "-1");
EXPECT_EQ(to_string_like(static_cast<unsigned char>(2)), "2");
EXPECT_EQ(to_string_like(static_cast<short>(3)), "3");
EXPECT_EQ(to_string_like(static_cast<short>(-5)), "-5");
......
......@@ -13,6 +13,7 @@ import operator
from utils import (
assert_allclose,
pytest_parametrize_wrapper,
use_jax_gemm,
)
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
......@@ -30,7 +31,6 @@ from transformer_engine.jax.cpp_extensions.quantization import (
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
from transformer_engine.jax import cpp_extensions as tex
from transformer_engine.jax.quantize import (
DelayedScaleQuantizer,
ScaledTensor,
ScaledTensor1x,
ScaledTensor2x,
......@@ -109,8 +109,8 @@ def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
else:
assert_allclose(a.dequantize(), b, dtype=a.data.dtype)
elif isinstance(a, ScaledTensor2x):
assert_dequantized_scaled_tensor(a.get_rowwise_tensor(), b)
assert_dequantized_scaled_tensor(a.get_colwise_tensor(), b)
assert_dequantized_scaled_tensor(a.rowwise_tensor, b)
assert_dequantized_scaled_tensor(a.colwise_tensor, b)
else:
pytest.fail("a must be a ScaledTensor object")
......@@ -139,10 +139,10 @@ def assert_dequantized_grouped_scaled_tensor(
dq_a_i = dq_a_i.reshape(b_i.shape)
assert_allclose(dq_a_i, b_i, dtype=a.data.dtype)
elif isinstance(a, ScaledTensor2x):
assert isinstance(a.get_rowwise_tensor(), GroupedScaledTensor1x)
assert isinstance(a.get_colwise_tensor(), GroupedScaledTensor1x)
assert_dequantized_grouped_scaled_tensor(a.get_rowwise_tensor(), b)
assert_dequantized_grouped_scaled_tensor(a.get_colwise_tensor(), b)
assert isinstance(a.rowwise_tensor, GroupedScaledTensor1x)
assert isinstance(a.colwise_tensor, GroupedScaledTensor1x)
assert_dequantized_grouped_scaled_tensor(a.rowwise_tensor, b)
assert_dequantized_grouped_scaled_tensor(a.colwise_tensor, b)
else:
pytest.fail("a must be a GroupedScaledTensor object")
......@@ -851,6 +851,22 @@ class TestFusedQuantize:
)
valid_fp8_gemm_operand_types = [
(jnp.float8_e4m3fn, jnp.float8_e4m3fn),
(jnp.float8_e5m2, jnp.float8_e4m3fn),
(jnp.float8_e4m3fn, jnp.float8_e5m2),
]
def _use_jax_fp8_gemm(enabled=False):
import os
if enabled:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$"
elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ:
os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE")
class TestDense:
def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
if data_layout[0] == "T":
......@@ -883,27 +899,47 @@ class TestDense:
def test_gemm_bf16(self, m, n, k, data_layout):
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
primitive_out = tex.gemm(x, w, contracting_dims)
primitive_out = tex.gemm(x, w, contracting_dims=contracting_dims)
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, data_layout):
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm):
if (
not with_jax_gemm
and scaling_mode.is_1d_block_scaling()
and jnp.float8_e5m2 in (x_qtype, w_qtype)
):
pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.")
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False
scaling_mode=scaling_mode,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2,
is_2x2x=False,
)
with use_jax_gemm(enabled=with_jax_gemm):
primitive_out = tex.gemm(
x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set
x,
w,
contracting_dims=contracting_dims,
lhs_quantizer=(
quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad
),
rhs_quantizer=(
quantizer_set.kernel if w_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad
),
)
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=q_dtype)
assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
def test_dense_grad_bf16(self, m, n, k):
......@@ -932,9 +968,9 @@ class TestDense:
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode):
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm):
data_layout = "NN"
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
......@@ -956,10 +992,14 @@ class TestDense:
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True
scaling_mode=scaling_mode,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True,
)
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
with use_jax_gemm(enabled=with_jax_gemm):
for _ in range(n_iterations):
primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = (
value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set)
......@@ -969,10 +1009,10 @@ class TestDense:
x, w, bias, data_layout
)
assert_allclose(primitive_out, ref_out, dtype=q_dtype)
assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=q_dtype)
assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=q_dtype)
assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.float8_e5m2)
assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp.float8_e5m2)
@pytest.fixture(name="random_inputs")
......@@ -996,20 +1036,13 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
class TestFusedDense:
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
@pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type):
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_gemm):
"""
Test layernorm_dense VJP Rule
"""
# No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm
zero_centered_gamma = False
eps = 1e-6
......@@ -1025,8 +1058,8 @@ class TestFusedDense:
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode,
fwd_dtype=q_dtype,
bwd_dtype=q_dtype,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True,
)
......@@ -1064,6 +1097,7 @@ class TestFusedDense:
)
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
with use_jax_gemm(enabled=with_jax_gemm):
for _ in range(n_iterations):
prim_out, (
prim_x_grad,
......@@ -1072,33 +1106,26 @@ class TestFusedDense:
prim_beta_grad,
) = value_n_grad_prim_func(x, w, gamma, beta)
assert_allclose(prim_out, ref_out, dtype=q_dtype)
assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype)
assert_allclose(prim_w_grad, ref_w_grad, dtype=q_dtype)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype)
assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_w_grad, ref_w_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2)
if beta is not None:
assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype)
assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp.float8_e5m2)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
@pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad(
self, m, n, k, activation_type, q_dtype, scaling_mode, norm_type, use_bias
self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm
):
"""
Test layernorm_mlp VJP Rule
"""
# No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm
zero_centered_gamma = False
eps = 1e-6
......@@ -1123,8 +1150,8 @@ class TestFusedDense:
quantizer_sets = QuantizerFactory.create_set(
n_quantizer_sets=2,
scaling_mode=scaling_mode,
fwd_dtype=q_dtype,
bwd_dtype=q_dtype,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True,
)
......@@ -1153,14 +1180,13 @@ class TestFusedDense:
ln_out = _ref_jax_norm_impl(
x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
)
# TODO: replace gemm with jnp.dot
linear_1_out = tex.gemm(ln_out, kernel_1, ((1,), (0,)))
linear_1_out = jax.lax.dot_general(ln_out, kernel_1, (((1,), (0,)), ((), ())))
if use_bias:
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape)
x = _jax_act_lu(linear_1_out, activation_type)
linear_2_out = tex.gemm(x, kernel_2, ((1,), (0,)))
linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ())))
if use_bias:
bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape
linear_2_out += jnp.reshape(bias_2, bias_2_shape)
......@@ -1174,6 +1200,7 @@ class TestFusedDense:
value_n_grad_ref_func = value_and_grad(ref_func, range(6))
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
with use_jax_gemm(enabled=with_jax_gemm):
for _ in range(n_iterations):
prim_out, (
prim_x_grad,
......@@ -1193,18 +1220,18 @@ class TestFusedDense:
ref_bias_2_grad,
) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)
assert_allclose(prim_out, ref_out, dtype=q_dtype)
assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=q_dtype)
assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=jnp.float8_e5m2)
if use_bias:
assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=q_dtype)
assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=q_dtype)
assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=jnp.float8_e5m2)
if use_bias:
assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=q_dtype)
assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype)
assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
# E5M2 * E5M2 is not supported
......@@ -1238,7 +1265,9 @@ class TestGroupedDense:
ref_out = []
dim_num = (contracting_dims, ((), ()))
for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias):
out_i = jax.lax.dot_general(lhs_i, rhs_i, dim_num) + jnp.expand_dims(bias_i, axis=0)
out_i = jax.lax.dot_general(
lhs_i, rhs_i, dim_num, precision=jax.lax.Precision.HIGHEST
) + jnp.expand_dims(bias_i, axis=0)
ref_out.append(jnp.squeeze(out_i))
return ref_out
......@@ -1250,6 +1279,9 @@ class TestGroupedDense:
group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
group_sizes = jnp.diff(group_sizes)
# Make one empty input lhs to test empty GEMM handling
group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1])
group_sizes = group_sizes.at[1].set(0)
assert group_sizes.sum() == m
# *32 to make sure that input shape works for MXFP8
......@@ -1301,9 +1333,6 @@ class TestGroupedDense:
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("layout", ["NN"])
def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout):
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
pytest.skip("MXFP8 is not supported in grouped_gemm yet")
fwd_dtype, bwd_dtype = fwd_bwd_dtype
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode,
......@@ -1343,9 +1372,10 @@ class TestGroupedDense:
def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims):
out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
# and prevent them from being clamp to zero in FP8. / sqrt(x.size) is used to
# normalize the output and prevent the gradient from being too large for FP8.
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
return jnp.sum(jnp.asarray(out_sum_list)) / jnp.sqrt(x.size)
def _primitive_sum_grouped_dense(
self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set
......@@ -1353,7 +1383,7 @@ class TestGroupedDense:
out = grouped_dense(
x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set
)
return jnp.sum(jnp.asarray(out))
return jnp.sum(jnp.asarray(out)) / jnp.sqrt(x.size)
@pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
def test_grouped_dense_grad_fp16(self, dtype, input_shape):
......@@ -1388,9 +1418,6 @@ class TestGroupedDense:
)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
pytest.skip("MXFP8 is not supported in grouped_dense yet")
fwd_dtype, bwd_dtype = fwd_bwd_dtype
dtype = jnp.bfloat16
x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
......
......@@ -75,8 +75,6 @@ class TestDistributedLayernorm:
all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
)
other_bytes = 0
if fp8_recipe == recipe.MXFP8BlockScaling() and "dp" in mesh_axes:
other_bytes = 384 # required for small scale shapes that require padding
if fp8_recipe == recipe.Float8CurrentScaling():
allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction
return generate_collectives_count(
......
......@@ -13,6 +13,7 @@ from utils import (
assert_tree_like_allclose,
is_devices_enough,
pytest_parametrize_wrapper,
use_jax_gemm,
)
from transformer_engine.common import recipe
......@@ -33,6 +34,7 @@ from transformer_engine.jax.sharding import (
)
from transformer_engine.jax.sharding import MeshResource
from transformer_engine.jax.quantize import QuantizerFactory
from transformer_engine.jax.cpp_extensions.misc import get_min_device_compute_capability
is_fp8_supported, reason = is_fp8_available()
......@@ -146,7 +148,15 @@ class TestDistributedLayernormMLP:
)
def _test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, use_shardy
self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
use_shardy,
with_jax_gemm,
):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
......@@ -156,6 +166,8 @@ class TestDistributedLayernormMLP:
input_shape, activation_type, use_bias, dtype
)
static_inputs = [layernorm_type, activation_type]
with use_jax_gemm(enabled=with_jax_gemm):
value_and_grad_func = jax.value_and_grad(
self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
)
......@@ -171,7 +183,9 @@ class TestDistributedLayernormMLP:
# Multi GPUs
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
with mesh, fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding)
......@@ -203,25 +217,32 @@ class TestDistributedLayernormMLP:
value_and_grad_func,
in_shardings=in_shardings,
out_shardings=out_shardings,
static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1),
static_argnums=range(
len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1
),
) # +1 for multi_gpus
multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)
assert_allclose(multi_fwd, single_fwd, dtype=dtype)
fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn
bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2
assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type)
for i in range(len(inputs)):
if multi_grads[i] is not None:
if isinstance(multi_grads[i], list):
assert isinstance(single_grads[i], list)
for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
assert_allclose(
m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close"
m_grad,
s_grad,
dtype=bwd_test_type,
err_msg=f"multi_grads[{i}] is not close",
)
else:
assert_allclose(
multi_grads[i],
single_grads[i],
dtype=dtype,
dtype=bwd_test_type,
err_msg=f"multi_grads[{i}] is not close",
)
......@@ -232,8 +253,16 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
with_jax_gemm,
):
self._test_layernorm_mlp_grad(
mesh_config,
......@@ -243,6 +272,7 @@ class TestDistributedLayernormMLP:
dtype,
fp8_recipe,
use_shardy=False,
with_jax_gemm=with_jax_gemm,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
......@@ -251,19 +281,29 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype
self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
with_jax_gemm,
):
# We don't test block scaling with Shardy because at the time of writing,
# it is not supported in JAX's scaled_matmul_stablehlo.
if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.")
self._test_layernorm_mlp_grad(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe=recipe.DelayedScaling(),
fp8_recipe=fp8_recipe,
use_shardy=True,
with_jax_gemm=with_jax_gemm,
)
def _test_layernorm_mlp(
......@@ -276,6 +316,7 @@ class TestDistributedLayernormMLP:
use_fp8,
fp8_recipe,
use_shardy,
with_jax_gemm,
):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
batch, seqlen, hidden_in = input_shape
......@@ -287,6 +328,7 @@ class TestDistributedLayernormMLP:
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
init_rngs = {"params": subkeys[1]}
with use_jax_gemm(enabled=with_jax_gemm):
# Single GPUs
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
ln_mlp_single = LayerNormMLP(
......@@ -333,16 +375,48 @@ class TestDistributedLayernormMLP:
# Make sure params values are the same
assert_tree_like_allclose(params_sharded["params"], params_single["params"])
assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype)
atol = None
rtol = None
l40_tolerance_update = (
get_min_device_compute_capability() == 89
and fp8_recipe == recipe.DelayedScaling()
and use_fp8
and dtype == jnp.float16
and activation_type == ("gelu",)
)
if l40_tolerance_update:
atol = 0.04
rtol = 11
# JAX's FP8 GEMM, jax.lax.dot_general, now uses the
# Triton backend by default. The error of
# the Triton FP8 gemm has been verified to be less than or equal
# to the error of the cuDNN FP8 gemm w.r.t a float32 ground truth.
# However, Triton can auto-tune a different kernel for the single GPU
# and multi-GPU run in this test, meaning the diff between single GPU
# and multi-GPU can be larger in some cases, even though both are
# within tolerance to the float32 ground truth.
jax_triton_gemm_precision_tolerance_update = (
with_jax_gemm
and isinstance(fp8_recipe, recipe.Float8CurrentScaling)
and dtype == jnp.bfloat16
and activation_type == ("gelu", "linear")
)
if jax_triton_gemm_precision_tolerance_update:
atol = 0.08
rtol = 15
assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype, atol=atol, rtol=rtol)
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("use_shardy", [False, True])
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer(
self, mesh_config, activation_type, use_bias, input_shape, dtype, use_shardy
self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
):
self._test_layernorm_mlp(
mesh_config,
......@@ -352,7 +426,8 @@ class TestDistributedLayernormMLP:
dtype,
use_fp8=False,
fp8_recipe=None,
use_shardy=use_shardy,
use_shardy=False,
with_jax_gemm=with_jax_gemm,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
......@@ -362,8 +437,9 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_fp8(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm
):
self._test_layernorm_mlp(
mesh_config,
......@@ -374,4 +450,51 @@ class TestDistributedLayernormMLP:
use_fp8=True,
fp8_recipe=fp8_recipe,
use_shardy=False,
with_jax_gemm=with_jax_gemm,
)
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
):
self._test_layernorm_mlp(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=False,
fp8_recipe=None,
use_shardy=True,
with_jax_gemm=with_jax_gemm,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_fp8_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm
):
if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.")
self._test_layernorm_mlp(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=True,
fp8_recipe=fp8_recipe,
use_shardy=True,
with_jax_gemm=with_jax_gemm,
)
......@@ -92,7 +92,7 @@ class TestFP8Functions(unittest.TestCase):
self._check_default_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_scaling(self):
def test_fp8_autocast_current_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state()
......@@ -116,7 +116,7 @@ class TestFP8Functions(unittest.TestCase):
self._check_default_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_scaling(self):
def test_fp8_autocast_mxfp8_block_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state()
......
......@@ -3,11 +3,12 @@
# See LICENSE for license information.
"""Utility for the TE layer tests"""
import os
import functools
import math
import operator
from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional
import os
from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional, NewType
from contextlib import contextmanager
import jax
import jax.numpy as jnp
......@@ -20,7 +21,6 @@ from jax import random as jax_random
import pytest
from transformer_engine.jax.attention import (
AttnMaskType,
canonicalize_attn_mask_type,
make_swa_mask,
)
......@@ -28,8 +28,8 @@ from transformer_engine.jax.quantize.helper import DType as TEDType
PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = Any
DType = NewType("DType", jnp.dtype)
Array = NewType("Array", jnp.ndarray)
PrecisionLike = Union[
None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
......@@ -1519,7 +1519,7 @@ def dtype_tols(
TEDType.kFloat8E5M2: jnp.float8_e5m2,
}[dtype]
elif isinstance(dtype, np.dtype):
dtype = jnp.dtype(dtype)
dtype = DType(dtype)
# Expect bit-wise accuracy for integer dtypes
if not jnp.issubdtype(dtype, jnp.floating):
......@@ -1600,3 +1600,20 @@ def print_debug_tensor_stats(prefix, tensor, hist=False):
fmt = fmt + "\n {}\n {}"
jax.debug.print(fmt, *args)
@contextmanager
def use_jax_gemm(enabled=False):
orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS_RE", None)
try:
if enabled:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$"
yield
finally:
if enabled:
if orig_custom_calls_filter is None:
os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE")
else:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = orig_custom_calls_filter
......@@ -16,7 +16,7 @@ import transformer_engine
import transformer_engine_torch as tex
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug import set_weight_tensor_tp_group_reduce
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from test_numerics import (
_emulate_linear,
......@@ -45,6 +45,8 @@ FEATURE_DIRS = None
all_boolean = [True, False]
TEST_NR = 0
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None):
if tp_size is None:
......@@ -221,7 +223,7 @@ def run_debug_test(func):
return wrapper
CONFIG_LOG_TEST_DISTRIBUTED = """log_distributed:
CONFIG_LOG_TEST_DISTRIBUTED_FP8 = """log_distributed:
layers:
layer_types: [linear]
enabled:
......@@ -241,11 +243,27 @@ CONFIG_LOG_TEST_DISTRIBUTED = """log_distributed:
end_step: 1
"""
CONFIG_LOG_TEST_DISTRIBUTED_NO_FP8 = """log_distributed:
layers:
layer_types: [linear]
enabled:
True
transformer_engine:
LogTensorStats:
enabled: True
tensors: [activation, gradient, weight, output, wgrad, dgrad]
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
start_step : 0
end_step: 1
"""
def _prepare_config_test_log_distributed(config_file):
if WORLD_RANK != 0:
return
config_file.write(CONFIG_LOG_TEST_DISTRIBUTED)
config_file.write(
CONFIG_LOG_TEST_DISTRIBUTED_FP8 if fp8_available else CONFIG_LOG_TEST_DISTRIBUTED_NO_FP8
)
config_file.flush()
......@@ -361,13 +379,13 @@ def test_log_expert_parallel(**kwargs):
) # data parallel
model = _init_model(weight, parallel_mode=None, name="linear1")
model1 = _init_model(weight, parallel_mode=None, name="linear2")
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
with transformer_engine.pytorch.fp8_autocast(enabled=fp8_available, fp8_recipe=FP8_RECIPE):
y1 = model(x)
y2 = model1(x)
y = y1 + y2
y.sum().backward()
debug_api.step()
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
with transformer_engine.pytorch.fp8_autocast(enabled=fp8_available, fp8_recipe=FP8_RECIPE):
y = model(x)
if WORLD_RANK != 0:
y = y + model1(x)
......@@ -620,6 +638,7 @@ if __name__ == "__main__":
for gather_weight in [True, False]:
test_log_distributed(parallel_mode, gather_weight)
if fp8_available:
for parallel_mode in ["row", "column"]:
test_disable_fp8_layer(parallel_mode)
......
......@@ -5,7 +5,6 @@
import os
import subprocess
from pathlib import Path
import pytest
import torch
......@@ -21,7 +20,6 @@ import torch
"""
if torch.cuda.device_count() < 2:
pytest.skip("Distributed training needs at least 2 GPUs.")
......@@ -34,6 +32,6 @@ def test_debug_distributed(feature_dirs):
test_path = TEST_ROOT / "run_distributed.py"
test_cmd = LAUNCH_CMD + [str(test_path), f"--feature_dirs={feature_dirs[0]}"]
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
result = subprocess.run(test_cmd, env=os.environ, check=False, text=True)
if result.returncode != 0:
raise AssertionError(result.stderr.decode())
raise AssertionError(f"torchrun exited with {result.returncode}")
......@@ -27,6 +27,9 @@ from transformer_engine.pytorch.module.base import (
_2X_ACC_FPROP,
_2X_ACC_WGRAD,
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
all_boolean = [True, False]
FP8_FORMAT = Format.HYBRID
......@@ -246,8 +249,8 @@ def _init_model(weight):
return model
def _run_forward_backward(x, model, loss_scale=1.0, is_first_microbatch=None):
with tepytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
def _run_forward_backward(x, model, loss_scale=1.0, is_first_microbatch=None, fp8=True):
with tepytorch.fp8_autocast(enabled=fp8, fp8_recipe=FP8_RECIPE):
y = model(x, is_first_microbatch=is_first_microbatch)
(y.sum() * loss_scale).backward()
debug_api.step()
......@@ -262,6 +265,18 @@ def _get_tensors():
return x, weight
LOGGING_CONFIG = """logging_config:
enabled: True
layers:
layer_types: [linear]
transformer_engine:
LogTensorStats:
enabled: True
tensors: [activation, gradient, weight, output, wgrad, dgrad]
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
"""
DISABLE_FP8_CONFIG = Template(
"""disable_fp8_config:
enabled: True
......@@ -275,10 +290,30 @@ DISABLE_FP8_CONFIG = Template(
)
@create_config_file
def run_logging_zero_numel_tensor(feature_dirs, **kwargs):
kwargs["config_file"].write(LOGGING_CONFIG)
kwargs["config_file"].flush()
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs)
x, weight = _get_tensors()
x1 = x[:0, :]
model = _init_model(weight)
_ = _run_forward_backward(x1, model, fp8=False)
_ = _run_forward_backward(x, model, fp8=False)
def test_logging_zero_numel_tensor(feature_dirs):
run_logging_zero_numel_tensor(feature_dirs)
@pytest.mark.parametrize("fprop_fp8", all_boolean)
@pytest.mark.parametrize("dgrad_fp8", all_boolean)
@pytest.mark.parametrize("wgrad_fp8", all_boolean)
def test_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
run_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8)
......@@ -318,6 +353,8 @@ def run_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8, **kwarg
def test_disable_fp8_layer(feature_dirs):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
run_disable_fp8_layer(feature_dirs)
......@@ -363,6 +400,8 @@ subset_combinations = random.sample(all_combinations, 20)
def test_per_tensor_scaling(
feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad
):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not any([fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad]):
pytest.skip("Skipping test because all parameters are False")
run_per_tensor_scaling(
......@@ -535,6 +574,8 @@ def run_per_tensor_scaling(
def test_microbatching_per_tensor_scaling(
feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad
):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not any([fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad]):
pytest.skip("Skipping test because all parameters are False")
......@@ -624,6 +665,8 @@ subset_combinations = random.sample(all_combinations, 10)
def test_fake_quant_fp8(
feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad
):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
run_fake_quant_fp8(
feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad
)
......
......@@ -2,27 +2,17 @@
#
# See LICENSE for license information.
import functools
import itertools
import os
import random
import tempfile
from string import Template
import pytest
import torch
import nvdlfw_inspect.api as debug_api
import transformer_engine.debug
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import _default_sf_compute
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from test_numerics import create_config_file
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
B, S, H, D = 64, 64, 64, 64
model_keys = ["linear", "layernorm_linear", "layernorm_mlp", "mha_attention", "transformer_layer"]
......@@ -104,4 +94,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir):
@pytest.mark.parametrize("fp8", [False, True])
@pytest.mark.parametrize("config_key", configs.keys())
def test_sanity_debug(model_key, fp8, config_key, feature_dirs):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
_run_test(model_key, fp8, configs[config_key], feature_dirs)
......@@ -48,11 +48,6 @@ if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False):
)
# Disable TF32
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
# Quantization recipe setup
def quantization_recipe() -> Recipe:
if QUANTIZATION == "fp8":
......@@ -167,7 +162,7 @@ def _gather(tensor, dim=0):
def _constant(tensor):
return nn.init.constant_(tensor, 0.5)
return nn.init.constant_(tensor, 0.05)
def dist_print(msg, src=None, end="\n", error=False):
......@@ -190,7 +185,8 @@ def _get_tolerances(dtype):
if dtype == torch.bfloat16:
return {"rtol": 1.6e-2, "atol": 1e-5}
if dtype == torch.float32:
return {"rtol": 1.2e-4, "atol": 1e-4}
# TF32 has same mantissa bits as FP16
return {"rtol": 1e-3, "atol": 1e-5}
raise ValueError(f"Unsupported dtype ({dtype})")
......@@ -521,8 +517,11 @@ def test_linear():
{"return_bias": True},
{"params_dtype": torch.float16},
{"delay_wgrad_compute": True},
{"save_original_input": True},
]
for kwargs in kwargs_list:
if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8":
continue
for parallel_mode in ["column", "row"]:
for sequence_parallel in [False, True]:
_test_linear(parallel_mode, sequence_parallel, **kwargs)
......
......@@ -28,7 +28,6 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex
from torch.utils.cpp_extension import IS_HIP_EXTENSION
......
......@@ -21,7 +21,6 @@ import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear,
UserbuffersForwardLinear,
......@@ -32,6 +31,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions
......@@ -370,7 +370,7 @@ def _test_linear(
if quantized_compute:
tols = dtype_tols(
model[0].weight._fp8_dtype
if is_float8_tensor(model[0].weight)
if isinstance(model[0].weight, Float8Tensor)
else tex.DType.kFloat8E4M3
)
......
......@@ -89,7 +89,7 @@ def run_dpa_with_cp(
# instantiate core attn module
core_attn = DotProductAttention(
config.num_heads,
config.head_dim_qk,
(config.head_dim_qk, config.head_dim_v),
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
......@@ -106,16 +106,22 @@ def run_dpa_with_cp(
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
k_input_shape = (
config.batch_size,
config.max_seqlen_kv,
config.num_gqa_groups,
config.head_dim_qk,
)
v_input_shape = (
config.batch_size,
config.max_seqlen_kv,
config.num_gqa_groups,
config.head_dim_v,
)
attn_output_shape = (
config.batch_size,
config.max_seqlen_q,
config.num_heads * config.head_dim_qk,
config.num_heads * config.head_dim_v,
)
cu_seqlens_q = None
cu_seqlens_kv = None
......@@ -128,16 +134,22 @@ def run_dpa_with_cp(
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
k_input_shape = (
config.max_seqlen_kv,
config.batch_size,
config.num_gqa_groups,
config.head_dim_qk,
)
v_input_shape = (
config.max_seqlen_kv,
config.batch_size,
config.num_gqa_groups,
config.head_dim_v,
)
attn_output_shape = (
config.max_seqlen_q,
config.batch_size,
config.num_heads * config.head_dim_qk,
config.num_heads * config.head_dim_v,
)
cu_seqlens_q = None
cu_seqlens_kv = None
......@@ -149,14 +161,19 @@ def run_dpa_with_cp(
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
k_input_shape = (
config.batch_size * config.max_seqlen_q,
config.num_gqa_groups,
config.head_dim_qk,
)
v_input_shape = (
config.batch_size * config.max_seqlen_q,
config.num_gqa_groups,
config.head_dim_v,
)
attn_output_shape = (
config.batch_size * config.max_seqlen_q,
config.num_heads * config.head_dim_qk,
config.num_heads * config.head_dim_v,
)
seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
......@@ -177,8 +194,8 @@ def run_dpa_with_cp(
assert False, f"{qkv_format} is an unsupported qkv_format!"
q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
k = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda()
v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda()
dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda()
dout_quantizer = Float8Quantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
......
......@@ -174,6 +174,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("Only fp8 works with fp8_mha=True!")
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently only support KV P2P!")
if dtype == "fp8" and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently does not support FP8 attention!")
subprocess.run(
get_bash_arguments(
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from __future__ import annotations
import argparse
import functools
import os
import pathlib
import pytest
import torch
import transformer_engine.pytorch as te
from utils import make_recipe
# Check supported quantization schemes
fp8_available, reason_for_no_fp8 = te.fp8.FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = te.fp8.FP8GlobalStateManager.is_mxfp8_available()
# Test cases for loading checkpoint files
_TestLoadCheckpoint_name_list: tuple[str, ...] = (
"linear",
"layernorm_linear",
"layernorm_mlp",
"layernorm",
"rmsnorm",
"transformer_layer",
"ops_linear",
"linear.fp8",
"ops_linear.fp8",
"linear.mxfp8",
"ops_linear.mxfp8",
)
class TestLoadCheckpoint:
"""Tests for loading checkpoint files
Tests assume that checkpoint files have already been created. In
order to regenerate checkpoint files, e.g. after a breaking change
in the checkpoint format, run this file directly as a Python
script: `python3 test_checkpoint.py --save-checkpoint all`.
"""
@staticmethod
def _make_module(name: str) -> torch.nn.Module:
"""Construct a module"""
if name == "linear":
return te.Linear(1, 1)
if name == "layernorm_linear":
return te.LayerNormLinear(1, 1)
if name == "layernorm_mlp":
return te.LayerNormMLP(1, 1)
if name == "layernorm":
return te.LayerNorm(1)
if name == "rmsnorm":
return te.RMSNorm(1)
if name == "transformer_layer":
return te.TransformerLayer(1, 1, 1)
if name == "ops_linear":
return te.ops.Linear(1, 1)
if name == "linear.fp8":
with te.fp8_model_init(recipe=make_recipe("fp8")):
return te.Linear(16, 16)
if name == "ops_linear.fp8":
with te.fp8_model_init(recipe=make_recipe("fp8")):
return te.ops.Linear(16, 16)
if name == "linear.mxfp8":
with te.fp8_model_init(recipe=make_recipe("mxfp8")):
return te.Linear(32, 32)
if name == "ops_linear.mxfp8":
with te.fp8_model_init(recipe=make_recipe("mxfp8")):
return te.ops.Linear(32, 32)
raise ValueError(f"Unrecognized module name ({name})")
@staticmethod
@functools.lru_cache(maxsize=None)
def _checkpoint_dir() -> pathlib.Path:
"""Path to directory with checkpoint files"""
# Check environment variable
path = os.getenv("NVTE_TEST_CHECKPOINT_ARTIFACT_PATH")
if path:
return pathlib.Path(path).resolve()
# Fallback to path in root dir
root_dir = pathlib.Path(__file__).resolve().parent.parent.parent
return root_dir / "artifacts" / "tests" / "pytorch" / "test_checkpoint"
@staticmethod
def _save_checkpoint(name: str, checkpoint_dir: Optional[pathlib.Path] = None) -> None:
"""Save a module's checkpoint file"""
# Path to save checkpoint
if checkpoint_dir is None:
checkpoint_dir = TestLoadCheckpoint._checkpoint_dir()
checkpoint_dir.mkdir(exist_ok=True)
checkpoint_file = checkpoint_dir / f"{name}.pt"
# Create module and save checkpoint
module = TestLoadCheckpoint._make_module(name)
torch.save(module.state_dict(), checkpoint_file)
print(f"Saved checkpoint for {name} at {checkpoint_file}")
@pytest.mark.parametrize("name", _TestLoadCheckpoint_name_list)
def test_module(self, name: str) -> None:
"""Test for loading a module's checkpoint file"""
# Skip if quantization is not supported
quantization = None
if "." in name:
quantization = name.split(".")[1]
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
# Construct module
module = self._make_module(name)
# Load checkpoint from file
checkpoint_file = self._checkpoint_dir() / f"{name}.pt"
if not checkpoint_file.is_file():
raise FileNotFoundError(f"Could not find checkpoint file at {checkpoint_file}")
state_dict = torch.load(checkpoint_file, weights_only=False)
# Update module from checkpoint
module.load_state_dict(state_dict, strict=True)
def main() -> None:
"""Main function
Typically used to generate checkpoint files.
"""
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument(
"--save-checkpoint",
type=str,
default=None,
help="Save checkpoint file for a module",
)
parser.add_argument(
"--checkpoint-dir",
type=str,
default=None,
help="Directory to save checkpoint file in",
)
args = parser.parse_args()
# Save checkpoint files if needed
if args.save_checkpoint is not None:
checkpoint_dir = args.checkpoint_dir
if checkpoint_dir is not None:
checkpoint_dir = pathlib.Path(checkpoint_dir).resolve()
if args.save_checkpoint == "all":
for name in _TestLoadCheckpoint_name_list:
TestLoadCheckpoint._save_checkpoint(name, checkpoint_dir=checkpoint_dir)
else:
TestLoadCheckpoint._save_checkpoint(
args.save_checkpoint,
checkpoint_dir=checkpoint_dir,
)
if __name__ == "__main__":
main()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
import math
from typing import Optional, Dict
from transformer_engine.pytorch.router import (
fused_topk_with_score_function,
fused_compute_score_for_moe_aux_loss,
fused_moe_aux_loss,
)
import pytest
from copy import deepcopy
seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# Pytorch-based group topk
def group_limited_topk(
scores: torch.Tensor,
topk: int,
num_tokens: int,
num_experts: int,
num_groups: int,
group_topk: int,
):
group_scores = (
scores.view(num_tokens, num_groups, -1).topk(topk // group_topk, dim=-1)[0].sum(dim=-1)
)
group_idx = torch.topk(group_scores, k=group_topk, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
# Mask the experts based on selection groups
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_tokens, num_groups, num_experts // num_groups)
.reshape(num_tokens, -1)
)
masked_scores = scores.masked_fill(~score_mask.bool(), float("-inf"))
probs, top_indices = torch.topk(masked_scores, k=topk, dim=-1)
return probs, top_indices
# Pytorch-based topk softmax/sigmoid
def topk_softmax_sigmoid_pytorch(
logits: torch.Tensor,
topk: int,
use_pre_softmax: bool = False,
num_groups: Optional[int] = None,
group_topk: Optional[int] = None,
scaling_factor: Optional[float] = None,
score_function: str = "softmax",
expert_bias: Optional[torch.Tensor] = None,
):
num_tokens, num_experts = logits.shape
def compute_topk(scores, topk, num_groups=None, group_topk=None):
if group_topk:
return group_limited_topk(
scores=scores,
topk=topk,
num_tokens=num_tokens,
num_experts=num_experts,
num_groups=num_groups,
group_topk=group_topk,
)
else:
return torch.topk(scores, k=topk, dim=1)
if score_function == "softmax":
if use_pre_softmax:
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
probs, top_indices = compute_topk(scores, topk, num_groups, group_topk)
else:
scores, top_indices = compute_topk(logits, topk, num_groups, group_topk)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits.float()).type_as(logits)
if expert_bias is not None:
scores_for_routing = scores + expert_bias
_, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk)
scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
else:
scores, top_indices = compute_topk(scores, topk, num_groups, group_topk)
probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
else:
raise ValueError(f"Invalid score_function: {score_function}")
if scaling_factor:
probs = probs * scaling_factor
topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
topk_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
return topk_masked_gates, topk_map
# Pytorch-based compute routing scores for aux loss
def compute_scores_for_aux_loss_pytorch(
logits: torch.Tensor, topk: int, score_function: str
) -> torch.Tensor:
if score_function == "softmax":
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits)
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
else:
raise ValueError(f"Invalid score_function: {score_function}")
_, top_indices = torch.topk(scores, k=topk, dim=1)
routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
return routing_map, scores
# Pytorch-based aux loss
def aux_loss_pytorch(
probs: torch.Tensor,
tokens_per_expert: torch.Tensor,
total_num_tokens: int,
topk: int,
num_experts: int,
moe_aux_loss_coeff: float,
):
aggregated_probs_per_expert = probs.sum(dim=0)
aux_loss = torch.sum(aggregated_probs_per_expert * tokens_per_expert) * (
num_experts * moe_aux_loss_coeff / (topk * total_num_tokens * total_num_tokens)
)
return aux_loss
def run_comparison(
dtype,
num_tokens,
num_experts,
topk,
use_pre_softmax,
num_groups,
group_topk,
scaling_factor,
score_function,
enable_bias,
):
# Set some parameters
if score_function == "sigmoid":
# Construct the special logits to avoid inf in the sigmoid function
offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4
logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2
logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
else:
logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4
logits = logits.view(num_tokens, num_experts)
logits.requires_grad = True
if enable_bias and score_function == "sigmoid":
expert_bias = torch.arange(num_experts, device="cuda") * 0.1
expert_bias = torch.flip(expert_bias, dims=[0])
expert_bias.requires_grad = True
else:
expert_bias = None
# Clone the input tensor
logits_clone = deepcopy(logits)
logits_clone.requires_grad = True
if expert_bias is not None:
expert_bias_clone = deepcopy(expert_bias)
expert_bias_clone.requires_grad = True
else:
expert_bias_clone = None
# Run the original implementation
# We do not support the capacity factor case
probs, routing_map = topk_softmax_sigmoid_pytorch(
logits=logits,
topk=topk,
use_pre_softmax=use_pre_softmax,
num_groups=num_groups,
group_topk=group_topk,
scaling_factor=scaling_factor,
score_function=score_function,
expert_bias=expert_bias,
)
# Run the fused implementation
probs_fused, routing_map_fused = fused_topk_with_score_function(
logits=logits_clone,
topk=topk,
use_pre_softmax=use_pre_softmax,
num_groups=num_groups,
group_topk=group_topk,
scaling_factor=scaling_factor,
score_function=score_function,
expert_bias=expert_bias_clone,
)
torch.testing.assert_close(probs, probs_fused)
torch.testing.assert_close(routing_map, routing_map_fused)
# Fake the loss
loss = torch.sum(probs)
loss_fused = torch.sum(probs_fused)
# Backward the loss
loss.backward()
loss_fused.backward()
# Check the gradient
torch.testing.assert_close(logits.grad, logits_clone.grad)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 8992])
@pytest.mark.parametrize("num_experts", [128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("group_topk", [None, 4])
@pytest.mark.parametrize("scaling_factor", [None, 1.2])
@pytest.mark.parametrize("enable_bias", [True, False])
def test_topk_sigmoid(
dtype,
num_tokens,
num_experts,
topk,
group_topk,
scaling_factor,
enable_bias,
):
num_groups = 8 if group_topk else None
run_comparison(
dtype=dtype,
num_tokens=num_tokens,
num_experts=num_experts,
topk=topk,
use_pre_softmax=False,
num_groups=num_groups,
group_topk=group_topk,
scaling_factor=scaling_factor,
score_function="sigmoid",
enable_bias=enable_bias,
)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_experts", [128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("use_pre_softmax", [True, False])
@pytest.mark.parametrize("group_topk", [None, 4])
@pytest.mark.parametrize("scaling_factor", [None, 1.2])
def test_topk_softmax(
dtype,
num_tokens,
num_experts,
topk,
use_pre_softmax,
group_topk,
scaling_factor,
):
num_groups = 8 if group_topk else None
run_comparison(
dtype=dtype,
num_tokens=num_tokens,
num_experts=num_experts,
topk=topk,
use_pre_softmax=use_pre_softmax,
num_groups=num_groups,
group_topk=group_topk,
scaling_factor=scaling_factor,
score_function="softmax",
enable_bias=False,
)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_experts", [256, 128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("score_function", ["softmax", "sigmoid"])
def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function):
if score_function == "sigmoid":
# Construct the special logits to avoid inf in the sigmoid function
offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4
logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2
logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
else:
logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4
logits = logits.view(num_tokens, num_experts)
logits.requires_grad = True
logits_clone = deepcopy(logits)
logits_clone.requires_grad = True
routing_map, scores = compute_scores_for_aux_loss_pytorch(
logits=logits,
topk=topk,
score_function=score_function,
)
routing_map_fused, scores_fused = fused_compute_score_for_moe_aux_loss(
logits=logits_clone,
topk=topk,
score_function=score_function,
)
torch.testing.assert_close(scores, scores_fused)
torch.testing.assert_close(routing_map, routing_map_fused)
loss = torch.sum(scores)
loss.backward()
loss_fused = torch.sum(scores_fused)
loss_fused.backward()
torch.testing.assert_close(logits.grad, logits_clone.grad)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_experts", [256, 128, 32])
@pytest.mark.parametrize("topk", [4])
def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk):
# Construct the special probs to avoid inf in the sigmoid function
offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4
probs = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2
probs = probs.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
probs = probs.view(num_tokens, num_experts)
probs.requires_grad = True
tokens_per_expert = torch.randint(1, 1000, (num_experts,), device="cuda", dtype=torch.int32)
coeff = 0.01
probs_clone = deepcopy(probs)
probs_clone.requires_grad = True
aux_loss = aux_loss_pytorch(
probs=probs,
tokens_per_expert=tokens_per_expert,
total_num_tokens=num_tokens,
topk=topk,
num_experts=num_experts,
moe_aux_loss_coeff=coeff,
)
aux_loss_fused = fused_moe_aux_loss(
probs=probs_clone,
tokens_per_expert=tokens_per_expert,
total_num_tokens=num_tokens,
num_experts=num_experts,
topk=topk,
coeff=coeff,
)
torch.testing.assert_close(aux_loss, aux_loss_fused)
# Backward
aux_loss.backward()
aux_loss_fused.backward()
torch.testing.assert_close(probs.grad, probs_clone.grad)
def profile_topk_softmax(
dtype,
num_tokens,
num_experts,
topk,
enable_bias,
use_pre_softmax,
):
group_topk = 4
scaling_factor = 1.2
test_topk_sigmoid(
torch.float32, num_tokens, num_experts, topk, group_topk, scaling_factor, enable_bias
)
test_topk_softmax(
torch.float32, num_tokens, num_experts, topk, use_pre_softmax, group_topk, scaling_factor
)
if __name__ == "__main__":
test_fused_scores_for_aux_loss(
dtype=torch.float32, num_tokens=2, num_experts=32, topk=8, score_function="softmax"
)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=256, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=256, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=256, topk=4)
......@@ -20,8 +20,8 @@ import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.ops.fused import (
BackwardBiasActivation,
BackwardLinearAdd,
ForwardLinearBiasActivation,
ForwardLinearBiasAdd,
......@@ -162,7 +162,7 @@ def make_reference_and_test_tensors(
return ref, test
class TestSequential:
class TestSequentialContainer:
"""Tests for sequential container"""
def test_modules(self) -> None:
......@@ -1878,6 +1878,98 @@ class TestFusedOps:
db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("activation", ("relu", "gelu"))
@pytest.mark.parametrize("out_shape", ((32, 32), (32, 1, 32), (8, 2, 2, 32)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_backward_bias_activation(
self,
*,
activation: str,
out_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
) -> None:
"""Backward dbias + dact + quantize"""
# Tensor dimensions
in_shape = list(out_shape)
hidden_size = in_shape[-1]
# Skip invalid configurations
with_quantization = quantization is not None
maybe_skip_quantization(quantization, device=device)
if quantization == "mxfp8" and (len(in_shape) < 2 or in_shape[-1] % 32 != 0):
pytest.skip("Unsupported tensor size for MXFP8")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
b_ref, b_test = make_reference_and_test_tensors(
hidden_size,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = x_ref + b_ref.reshape([1] * (len(in_shape) - 1) + [hidden_size])
if activation == "gelu":
y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh")
elif activation == "relu":
y_ref = torch.nn.functional.relu(y_ref)
else:
raise ValueError(f"Unexpected activation function ({activation})")
y_ref.backward(dy_ref)
# Implementation with fusible operations
recipe = make_recipe(quantization)
act_type = te_ops.GELU if activation == "gelu" else te_ops.ReLU
model = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=True),
te_ops.Bias(hidden_size, device=device, dtype=dtype),
act_type(),
)
with torch.no_grad():
model[1].bias.copy_(b_test)
del b_test
with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
# Check that backward operations have been fused
backward_ops = model._module_groups[0]._backward_ops
if with_quantization and quantization in ["fp8_delayed_scaling", "mxfp8"]:
assert len(backward_ops) == 2
assert isinstance(backward_ops[0][0], BackwardBiasActivation)
assert isinstance(backward_ops[1][0], te_ops.Quantize)
else:
assert len(backward_ops) == 3
assert isinstance(backward_ops[0][0], act_type)
assert isinstance(backward_ops[1][0], te_ops.Bias)
assert isinstance(backward_ops[2][0], te_ops.Quantize)
# Expected numerical error
tols = dtype_tols(dtype)
if with_quantization:
tols = dtype_tols(tex.DType.kFloat8E4M3)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
db_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_backward_linear_add(
......@@ -2093,3 +2185,109 @@ class TestCheckpointing:
torch.testing.assert_close(y_load, y_save, **tols)
for x_load, x_save in zip(xs_load, xs_save):
torch.testing.assert_close(x_load.grad, x_save.grad, **tols)
class TestSequentialModules:
"""Test for larger Sequentials with modules commonly used together"""
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("normalization", ("LayerNorm", "RMSNorm"))
@pytest.mark.parametrize("quantized_compute", (False, True))
@pytest.mark.parametrize("quantized_weight", (False, True))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_layernorm_mlp(
self,
*,
bias: bool,
normalization: str,
quantized_compute: bool,
quantized_weight: bool,
dtype: torch.dtype,
quantization: Optional[str],
device: torch.device = "cuda",
hidden_size: int = 32,
sequence_length: int = 512,
batch_size: int = 4,
ffn_hidden_size: int = 64,
layernorm_epsilon: float = 1e-5,
) -> None:
"""
LayerNorm/RMSNorm + Linear + GELU + Linear
Note that this test checks only if the module runs
as when chaining multiple modules it is hard to validate
numerical accuracy.
"""
# Make input shape
in_shape = (sequence_length, batch_size, hidden_size)
ffn_shape = in_shape[:-1] + (ffn_hidden_size,)
# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=ffn_shape, device=device)
quantization_needed = quantized_compute or quantized_weight
if quantization is None and quantization_needed:
pytest.skip("Quantization scheme is not specified")
if quantization is not None and not quantization_needed:
pytest.skip("Quantization scheme is not used")
# Random data
_, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
_, dy_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
if normalization == "LayerNorm":
norm = te_ops.LayerNorm(
hidden_size,
eps=layernorm_epsilon,
device=device,
dtype=dtype,
)
else:
norm = te_ops.RMSNorm(
hidden_size,
eps=layernorm_epsilon,
device=device,
dtype=dtype,
)
ffn1 = te_ops.Linear(
hidden_size,
ffn_hidden_size,
bias=bias,
device=device,
dtype=dtype,
)
act = te_ops.GELU()
ffn2 = te_ops.Linear(
ffn_hidden_size,
hidden_size,
bias=bias,
device=device,
dtype=dtype,
)
forward = te_ops.Sequential(norm, ffn1, act, ffn2)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment