"transformer_engine/pytorch/attention/inference.py" did not exist on "4f33ece48b542ac29b5a483ffafc2245cb6a7334"
Commit 44740c6c authored by yuguo's avatar yuguo
Browse files

Merge commit '7a9a0825' of...

Merge commit '7a9a0825' of https://github.com/NVIDIA/TransformerEngine
parents 8113d9e0 7a9a0825
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <string>
#include <vector>
#include <cstdio>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/padding.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename InputType, typename OutputType>
void compute_unpadding_ref(const std::vector<std::vector<InputType>>& input_list,
std::vector<std::vector<OutputType>>& output_list,
const std::vector<size_t>& height_list,
const std::vector<size_t>& width_list,
const std::vector<int>& padded_height_list) {
using compute_t = float;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
const auto& input = input_list[tensor_id];
auto& output = output_list[tensor_id];
const size_t height = height_list[tensor_id];
const size_t width = width_list[tensor_id];
const size_t padded_height = padded_height_list[tensor_id];
// Only copy the valid (unpadded) portion
for (size_t i = 0; i < height; ++i) {
for (size_t j = 0; j < width; ++j) {
const compute_t x = static_cast<compute_t>(input[i * width + j]);
const OutputType y = static_cast<OutputType>(x);
output[i * width + j] = y;
}
}
}
}
template <typename InputType, typename OutputType>
void performUnpaddingTest() {
using namespace test;
const DType itype = TypeInfo<InputType>::dtype;
const DType otype = TypeInfo<OutputType>::dtype;
const std::vector<std::pair<size_t, size_t>> tensor_dims = {{1,1},
{1,768},
{768,1},
{768,768},
{43,43},
{43,256},
{256,43},
{256,256}};
const size_t num_tensors = tensor_dims.size();
constexpr int align = 16;
// Buffers for Transformer Engine implementation
std::vector<Tensor> padded_input_list, unpadded_output_list;
// Buffers for reference implementation
std::vector<std::vector<InputType>> ref_padded_input_list;
std::vector<std::vector<OutputType>> ref_unpadded_output_list;
std::vector<size_t> ref_height_list(num_tensors), ref_width_list(num_tensors);
std::vector<int> ref_padded_height_list(num_tensors);
// Initialize buffers
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
const size_t original_height = tensor_dims[tensor_id].first;
const size_t width = tensor_dims[tensor_id].second;
const size_t padded_height = (original_height + align - 1) / align * align;
// Input is padded tensor (padded_height x width)
padded_input_list.emplace_back(
Tensor("padded_input_" + std::to_string(tensor_id),
std::vector<size_t>{padded_height, width}, itype));
// Output is unpadded tensor (original_height x width)
unpadded_output_list.emplace_back(
Tensor("unpadded_output_" + std::to_string(tensor_id),
std::vector<size_t>{original_height, width}, otype));
auto& padded_input = padded_input_list.back();
auto& unpadded_output = unpadded_output_list.back();
// Fill padded input with random data (including padding area)
fillUniform(&padded_input);
setRandomScale(&unpadded_output);
// Initialize reference buffers
ref_padded_input_list.emplace_back(padded_height * width);
ref_unpadded_output_list.emplace_back(original_height * width);
// Copy data to reference buffers
std::copy(padded_input.rowwise_cpu_dptr<InputType>(),
padded_input.rowwise_cpu_dptr<InputType>() + padded_height * width,
ref_padded_input_list.back().begin());
ref_height_list[tensor_id] = original_height;
ref_width_list[tensor_id] = width;
ref_padded_height_list[tensor_id] = padded_height;
}
// Transformer Engine implementation
auto make_nvte_vector = [](std::vector<Tensor>& tensor_list)
-> std::vector<NVTETensor> {
std::vector<NVTETensor> nvte_tensor_list;
for (auto& tensor : tensor_list) {
nvte_tensor_list.emplace_back(tensor.data());
}
return nvte_tensor_list;
};
// Convert height_list to int for the API
std::vector<int> original_height_list_int(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
original_height_list_int[i] = static_cast<int>(ref_height_list[i]);
}
// Call unpadding API
nvte_multi_unpadding(num_tensors,
make_nvte_vector(padded_input_list).data(),
make_nvte_vector(unpadded_output_list).data(),
original_height_list_int.data(),
0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
// Reference implementation
compute_unpadding_ref<InputType, OutputType>(ref_padded_input_list,
ref_unpadded_output_list,
ref_height_list,
ref_width_list,
ref_padded_height_list);
// Check correctness
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
auto [atol, rtol] = getTolerances(otype);
compareResults("unpadded_output",
unpadded_output_list[tensor_id],
ref_unpadded_output_list[tensor_id].data(),
true,
atol, rtol);
}
}
} // namespace
class MultiUnpaddingTestSuite
: public ::testing::TestWithParam<transformer_engine::DType> {};
TEST_P(MultiUnpaddingTestSuite, TestMultiUnpadding) {
using namespace transformer_engine;
using namespace test;
const DType input_type = GetParam();
const DType output_type = input_type;
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performUnpaddingTest<InputType, OutputType>();
);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
MultiUnpaddingTestSuite,
::testing::ValuesIn(test::all_fp_types),
[](const testing::TestParamInfo<MultiUnpaddingTestSuite::ParamType>& info) {
std::string name = test::typeName(info.param);
return name;
});
...@@ -34,8 +34,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, ...@@ -34,8 +34,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
return; return;
} }
if (getDeviceComputeCapability() < blackwellComputeCapability && use_cudnn) { if (getDeviceComputeCapability() < hopperComputeCapability && use_cudnn) {
GTEST_SKIP() << "cuDNN normalizations not supported on pre-Blackwell GPUs yet!"; GTEST_SKIP() << "cuDNN normalizations not supported on pre-Hopper GPUs yet!";
} }
using WeightType = InputType; using WeightType = InputType;
......
...@@ -38,7 +38,7 @@ TEST(UtilTest, ToStringLike) { // to_string_like ...@@ -38,7 +38,7 @@ TEST(UtilTest, ToStringLike) { // to_string_like
// Non-zero integer types // Non-zero integer types
EXPECT_EQ(to_string_like(static_cast<char>(1)), "1"); EXPECT_EQ(to_string_like(static_cast<char>(1)), "1");
EXPECT_EQ(to_string_like(static_cast<char>(-1)), "-1"); EXPECT_EQ(to_string_like(static_cast<signed char>(-1)), "-1");
EXPECT_EQ(to_string_like(static_cast<unsigned char>(2)), "2"); EXPECT_EQ(to_string_like(static_cast<unsigned char>(2)), "2");
EXPECT_EQ(to_string_like(static_cast<short>(3)), "3"); EXPECT_EQ(to_string_like(static_cast<short>(3)), "3");
EXPECT_EQ(to_string_like(static_cast<short>(-5)), "-5"); EXPECT_EQ(to_string_like(static_cast<short>(-5)), "-5");
......
...@@ -13,6 +13,7 @@ import operator ...@@ -13,6 +13,7 @@ import operator
from utils import ( from utils import (
assert_allclose, assert_allclose,
pytest_parametrize_wrapper, pytest_parametrize_wrapper,
use_jax_gemm,
) )
from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.layernorm_mlp import layernorm_mlp from transformer_engine.jax.layernorm_mlp import layernorm_mlp
...@@ -30,7 +31,6 @@ from transformer_engine.jax.cpp_extensions.quantization import ( ...@@ -30,7 +31,6 @@ from transformer_engine.jax.cpp_extensions.quantization import (
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax import cpp_extensions as tex
from transformer_engine.jax.quantize import ( from transformer_engine.jax.quantize import (
DelayedScaleQuantizer,
ScaledTensor, ScaledTensor,
ScaledTensor1x, ScaledTensor1x,
ScaledTensor2x, ScaledTensor2x,
...@@ -109,8 +109,8 @@ def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray): ...@@ -109,8 +109,8 @@ def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
else: else:
assert_allclose(a.dequantize(), b, dtype=a.data.dtype) assert_allclose(a.dequantize(), b, dtype=a.data.dtype)
elif isinstance(a, ScaledTensor2x): elif isinstance(a, ScaledTensor2x):
assert_dequantized_scaled_tensor(a.get_rowwise_tensor(), b) assert_dequantized_scaled_tensor(a.rowwise_tensor, b)
assert_dequantized_scaled_tensor(a.get_colwise_tensor(), b) assert_dequantized_scaled_tensor(a.colwise_tensor, b)
else: else:
pytest.fail("a must be a ScaledTensor object") pytest.fail("a must be a ScaledTensor object")
...@@ -139,10 +139,10 @@ def assert_dequantized_grouped_scaled_tensor( ...@@ -139,10 +139,10 @@ def assert_dequantized_grouped_scaled_tensor(
dq_a_i = dq_a_i.reshape(b_i.shape) dq_a_i = dq_a_i.reshape(b_i.shape)
assert_allclose(dq_a_i, b_i, dtype=a.data.dtype) assert_allclose(dq_a_i, b_i, dtype=a.data.dtype)
elif isinstance(a, ScaledTensor2x): elif isinstance(a, ScaledTensor2x):
assert isinstance(a.get_rowwise_tensor(), GroupedScaledTensor1x) assert isinstance(a.rowwise_tensor, GroupedScaledTensor1x)
assert isinstance(a.get_colwise_tensor(), GroupedScaledTensor1x) assert isinstance(a.colwise_tensor, GroupedScaledTensor1x)
assert_dequantized_grouped_scaled_tensor(a.get_rowwise_tensor(), b) assert_dequantized_grouped_scaled_tensor(a.rowwise_tensor, b)
assert_dequantized_grouped_scaled_tensor(a.get_colwise_tensor(), b) assert_dequantized_grouped_scaled_tensor(a.colwise_tensor, b)
else: else:
pytest.fail("a must be a GroupedScaledTensor object") pytest.fail("a must be a GroupedScaledTensor object")
...@@ -851,6 +851,22 @@ class TestFusedQuantize: ...@@ -851,6 +851,22 @@ class TestFusedQuantize:
) )
valid_fp8_gemm_operand_types = [
(jnp.float8_e4m3fn, jnp.float8_e4m3fn),
(jnp.float8_e5m2, jnp.float8_e4m3fn),
(jnp.float8_e4m3fn, jnp.float8_e5m2),
]
def _use_jax_fp8_gemm(enabled=False):
import os
if enabled:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$"
elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ:
os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE")
class TestDense: class TestDense:
def _ref_gemm_with_jnp_dot(self, a, b, data_layout): def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
if data_layout[0] == "T": if data_layout[0] == "T":
...@@ -883,27 +899,47 @@ class TestDense: ...@@ -883,27 +899,47 @@ class TestDense:
def test_gemm_bf16(self, m, n, k, data_layout): def test_gemm_bf16(self, m, n, k, data_layout):
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
primitive_out = tex.gemm(x, w, contracting_dims) primitive_out = tex.gemm(x, w, contracting_dims=contracting_dims)
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, data_layout): @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm):
if (
not with_jax_gemm
and scaling_mode.is_1d_block_scaling()
and jnp.float8_e5m2 in (x_qtype, w_qtype)
):
pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.")
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
quantizer_set = QuantizerFactory.create_set( quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False scaling_mode=scaling_mode,
) fwd_dtype=jnp.float8_e4m3fn,
primitive_out = tex.gemm( bwd_dtype=jnp.float8_e5m2,
x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set is_2x2x=False,
) )
with use_jax_gemm(enabled=with_jax_gemm):
primitive_out = tex.gemm(
x,
w,
contracting_dims=contracting_dims,
lhs_quantizer=(
quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad
),
rhs_quantizer=(
quantizer_set.kernel if w_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad
),
)
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=q_dtype) assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
def test_dense_grad_bf16(self, m, n, k): def test_dense_grad_bf16(self, m, n, k):
...@@ -932,9 +968,9 @@ class TestDense: ...@@ -932,9 +968,9 @@ class TestDense:
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode): @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm):
data_layout = "NN" data_layout = "NN"
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
...@@ -956,23 +992,27 @@ class TestDense: ...@@ -956,23 +992,27 @@ class TestDense:
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
quantizer_set = QuantizerFactory.create_set( quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True scaling_mode=scaling_mode,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True,
) )
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations): with use_jax_gemm(enabled=with_jax_gemm):
primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = ( for _ in range(n_iterations):
value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set) primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = (
) value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set)
)
ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func( ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func(
x, w, bias, data_layout x, w, bias, data_layout
) )
assert_allclose(primitive_out, ref_out, dtype=q_dtype) assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype) assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=q_dtype) assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.float8_e5m2)
assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=q_dtype) assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp.float8_e5m2)
@pytest.fixture(name="random_inputs") @pytest.fixture(name="random_inputs")
...@@ -996,20 +1036,13 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan ...@@ -996,20 +1036,13 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
class TestFusedDense: class TestFusedDense:
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
@pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type): @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_gemm):
""" """
Test layernorm_dense VJP Rule Test layernorm_dense VJP Rule
""" """
# No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm # zero_centered_gamma is already tested in TestNorm
zero_centered_gamma = False zero_centered_gamma = False
eps = 1e-6 eps = 1e-6
...@@ -1025,8 +1058,8 @@ class TestFusedDense: ...@@ -1025,8 +1058,8 @@ class TestFusedDense:
quantizer_set = QuantizerFactory.create_set( quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
fwd_dtype=q_dtype, fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=q_dtype, bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True, is_2x2x=True,
) )
...@@ -1064,41 +1097,35 @@ class TestFusedDense: ...@@ -1064,41 +1097,35 @@ class TestFusedDense:
) )
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations): with use_jax_gemm(enabled=with_jax_gemm):
prim_out, ( for _ in range(n_iterations):
prim_x_grad, prim_out, (
prim_w_grad, prim_x_grad,
prim_gamma_grad, prim_w_grad,
prim_beta_grad, prim_gamma_grad,
) = value_n_grad_prim_func(x, w, gamma, beta) prim_beta_grad,
) = value_n_grad_prim_func(x, w, gamma, beta)
assert_allclose(prim_out, ref_out, dtype=q_dtype)
assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(prim_w_grad, ref_w_grad, dtype=q_dtype) assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype) assert_allclose(prim_w_grad, ref_w_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2)
if beta is not None: if beta is not None:
assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype) assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp.float8_e5m2)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
@pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
@pytest.mark.parametrize("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad( def test_layernorm_mlp_grad(
self, m, n, k, activation_type, q_dtype, scaling_mode, norm_type, use_bias self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm
): ):
""" """
Test layernorm_mlp VJP Rule Test layernorm_mlp VJP Rule
""" """
# No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm # zero_centered_gamma is already tested in TestNorm
zero_centered_gamma = False zero_centered_gamma = False
eps = 1e-6 eps = 1e-6
...@@ -1123,8 +1150,8 @@ class TestFusedDense: ...@@ -1123,8 +1150,8 @@ class TestFusedDense:
quantizer_sets = QuantizerFactory.create_set( quantizer_sets = QuantizerFactory.create_set(
n_quantizer_sets=2, n_quantizer_sets=2,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
fwd_dtype=q_dtype, fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=q_dtype, bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True, is_2x2x=True,
) )
...@@ -1153,14 +1180,13 @@ class TestFusedDense: ...@@ -1153,14 +1180,13 @@ class TestFusedDense:
ln_out = _ref_jax_norm_impl( ln_out = _ref_jax_norm_impl(
x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
) )
# TODO: replace gemm with jnp.dot linear_1_out = jax.lax.dot_general(ln_out, kernel_1, (((1,), (0,)), ((), ())))
linear_1_out = tex.gemm(ln_out, kernel_1, ((1,), (0,)))
if use_bias: if use_bias:
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape) linear_1_out += jnp.reshape(bias_1, bias_1_shape)
x = _jax_act_lu(linear_1_out, activation_type) x = _jax_act_lu(linear_1_out, activation_type)
linear_2_out = tex.gemm(x, kernel_2, ((1,), (0,))) linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ())))
if use_bias: if use_bias:
bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape
linear_2_out += jnp.reshape(bias_2, bias_2_shape) linear_2_out += jnp.reshape(bias_2, bias_2_shape)
...@@ -1174,15 +1200,16 @@ class TestFusedDense: ...@@ -1174,15 +1200,16 @@ class TestFusedDense:
value_n_grad_ref_func = value_and_grad(ref_func, range(6)) value_n_grad_ref_func = value_and_grad(ref_func, range(6))
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations): with use_jax_gemm(enabled=with_jax_gemm):
prim_out, ( for _ in range(n_iterations):
prim_x_grad, prim_out, (
prim_gamma_grad, prim_x_grad,
prim_kernel_1_grad, prim_gamma_grad,
prim_kernel_2_grad, prim_kernel_1_grad,
prim_bias_1_grad, prim_kernel_2_grad,
prim_bias_2_grad, prim_bias_1_grad,
) = value_n_grad_prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) prim_bias_2_grad,
) = value_n_grad_prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)
ref_out, ( ref_out, (
ref_x_grad, ref_x_grad,
...@@ -1193,18 +1220,18 @@ class TestFusedDense: ...@@ -1193,18 +1220,18 @@ class TestFusedDense:
ref_bias_2_grad, ref_bias_2_grad,
) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)
assert_allclose(prim_out, ref_out, dtype=q_dtype) assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=q_dtype) assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=jnp.float8_e5m2)
if use_bias: if use_bias:
assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=q_dtype) assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=q_dtype) assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=jnp.float8_e5m2)
if use_bias: if use_bias:
assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=q_dtype) assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype) assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
# E5M2 * E5M2 is not supported # E5M2 * E5M2 is not supported
...@@ -1238,7 +1265,9 @@ class TestGroupedDense: ...@@ -1238,7 +1265,9 @@ class TestGroupedDense:
ref_out = [] ref_out = []
dim_num = (contracting_dims, ((), ())) dim_num = (contracting_dims, ((), ()))
for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias): for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias):
out_i = jax.lax.dot_general(lhs_i, rhs_i, dim_num) + jnp.expand_dims(bias_i, axis=0) out_i = jax.lax.dot_general(
lhs_i, rhs_i, dim_num, precision=jax.lax.Precision.HIGHEST
) + jnp.expand_dims(bias_i, axis=0)
ref_out.append(jnp.squeeze(out_i)) ref_out.append(jnp.squeeze(out_i))
return ref_out return ref_out
...@@ -1250,6 +1279,9 @@ class TestGroupedDense: ...@@ -1250,6 +1279,9 @@ class TestGroupedDense:
group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
group_sizes = jnp.diff(group_sizes) group_sizes = jnp.diff(group_sizes)
# Make one empty input lhs to test empty GEMM handling
group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1])
group_sizes = group_sizes.at[1].set(0)
assert group_sizes.sum() == m assert group_sizes.sum() == m
# *32 to make sure that input shape works for MXFP8 # *32 to make sure that input shape works for MXFP8
...@@ -1301,9 +1333,6 @@ class TestGroupedDense: ...@@ -1301,9 +1333,6 @@ class TestGroupedDense:
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("layout", ["NN"]) @pytest_parametrize_wrapper("layout", ["NN"])
def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout): def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout):
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
pytest.skip("MXFP8 is not supported in grouped_gemm yet")
fwd_dtype, bwd_dtype = fwd_bwd_dtype fwd_dtype, bwd_dtype = fwd_bwd_dtype
quantizer_set = QuantizerFactory.create_set( quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
...@@ -1343,9 +1372,10 @@ class TestGroupedDense: ...@@ -1343,9 +1372,10 @@ class TestGroupedDense:
def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims): def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims):
out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims) out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger # Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero # and prevent them from being clamp to zero in FP8. / sqrt(x.size) is used to
# normalize the output and prevent the gradient from being too large for FP8.
out_sum_list = [jnp.sum(out) for out in out_list] out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list)) return jnp.sum(jnp.asarray(out_sum_list)) / jnp.sqrt(x.size)
def _primitive_sum_grouped_dense( def _primitive_sum_grouped_dense(
self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set
...@@ -1353,7 +1383,7 @@ class TestGroupedDense: ...@@ -1353,7 +1383,7 @@ class TestGroupedDense:
out = grouped_dense( out = grouped_dense(
x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set
) )
return jnp.sum(jnp.asarray(out)) return jnp.sum(jnp.asarray(out)) / jnp.sqrt(x.size)
@pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
def test_grouped_dense_grad_fp16(self, dtype, input_shape): def test_grouped_dense_grad_fp16(self, dtype, input_shape):
...@@ -1388,9 +1418,6 @@ class TestGroupedDense: ...@@ -1388,9 +1418,6 @@ class TestGroupedDense:
) )
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
pytest.skip("MXFP8 is not supported in grouped_dense yet")
fwd_dtype, bwd_dtype = fwd_bwd_dtype fwd_dtype, bwd_dtype = fwd_bwd_dtype
dtype = jnp.bfloat16 dtype = jnp.bfloat16
x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
......
...@@ -75,8 +75,6 @@ class TestDistributedLayernorm: ...@@ -75,8 +75,6 @@ class TestDistributedLayernorm:
all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
) )
other_bytes = 0 other_bytes = 0
if fp8_recipe == recipe.MXFP8BlockScaling() and "dp" in mesh_axes:
other_bytes = 384 # required for small scale shapes that require padding
if fp8_recipe == recipe.Float8CurrentScaling(): if fp8_recipe == recipe.Float8CurrentScaling():
allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction
return generate_collectives_count( return generate_collectives_count(
......
...@@ -13,6 +13,7 @@ from utils import ( ...@@ -13,6 +13,7 @@ from utils import (
assert_tree_like_allclose, assert_tree_like_allclose,
is_devices_enough, is_devices_enough,
pytest_parametrize_wrapper, pytest_parametrize_wrapper,
use_jax_gemm,
) )
from transformer_engine.common import recipe from transformer_engine.common import recipe
...@@ -33,6 +34,7 @@ from transformer_engine.jax.sharding import ( ...@@ -33,6 +34,7 @@ from transformer_engine.jax.sharding import (
) )
from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.sharding import MeshResource
from transformer_engine.jax.quantize import QuantizerFactory from transformer_engine.jax.quantize import QuantizerFactory
from transformer_engine.jax.cpp_extensions.misc import get_min_device_compute_capability
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
...@@ -146,7 +148,15 @@ class TestDistributedLayernormMLP: ...@@ -146,7 +148,15 @@ class TestDistributedLayernormMLP:
) )
def _test_layernorm_mlp_grad( def _test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, use_shardy self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
use_shardy,
with_jax_gemm,
): ):
jax.config.update("jax_use_shardy_partitioner", use_shardy) jax.config.update("jax_use_shardy_partitioner", use_shardy)
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
...@@ -156,72 +166,83 @@ class TestDistributedLayernormMLP: ...@@ -156,72 +166,83 @@ class TestDistributedLayernormMLP:
input_shape, activation_type, use_bias, dtype input_shape, activation_type, use_bias, dtype
) )
static_inputs = [layernorm_type, activation_type] static_inputs = [layernorm_type, activation_type]
value_and_grad_func = jax.value_and_grad(
self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
)
# Single GPU with use_jax_gemm(enabled=with_jax_gemm):
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): value_and_grad_func = jax.value_and_grad(
single_jitter = jax.jit( self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
value_and_grad_func,
static_argnums=range(len(inputs), len(static_inputs) + len(inputs)),
) )
single_fwd, single_grads = single_jitter(*inputs, *static_inputs)
# Multi GPUs
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding)
k2_ = jax.device_put(k2, k2_sharding)
if use_bias:
b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
b1_ = jax.device_put(b1, b1_sharding)
else:
b1_sharding = b1_ = None
multi_inputs = [*inputs[:2], k1_, k2_, b1_, *inputs[5:]]
# Position ref for sharding pspec lists
# x, gamma, k1, k2, b1,
# b2
in_shardings = (
None,
None,
k1_sharding,
k2_sharding,
b1_sharding,
None,
)
out_shardings = (
None,
(None, None, k1_sharding, k2_sharding, b1_sharding, None),
)
multi_jitter = jax.jit(
value_and_grad_func,
in_shardings=in_shardings,
out_shardings=out_shardings,
static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1),
) # +1 for multi_gpus
multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) # Single GPU
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
assert_allclose(multi_fwd, single_fwd, dtype=dtype) single_jitter = jax.jit(
value_and_grad_func,
static_argnums=range(len(inputs), len(static_inputs) + len(inputs)),
)
single_fwd, single_grads = single_jitter(*inputs, *static_inputs)
# Multi GPUs
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding)
k2_ = jax.device_put(k2, k2_sharding)
if use_bias:
b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
b1_ = jax.device_put(b1, b1_sharding)
else:
b1_sharding = b1_ = None
multi_inputs = [*inputs[:2], k1_, k2_, b1_, *inputs[5:]]
# Position ref for sharding pspec lists
# x, gamma, k1, k2, b1,
# b2
in_shardings = (
None,
None,
k1_sharding,
k2_sharding,
b1_sharding,
None,
)
out_shardings = (
None,
(None, None, k1_sharding, k2_sharding, b1_sharding, None),
)
multi_jitter = jax.jit(
value_and_grad_func,
in_shardings=in_shardings,
out_shardings=out_shardings,
static_argnums=range(
len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1
),
) # +1 for multi_gpus
multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)
fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn
bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2
assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type)
for i in range(len(inputs)): for i in range(len(inputs)):
if multi_grads[i] is not None: if multi_grads[i] is not None:
if isinstance(multi_grads[i], list): if isinstance(multi_grads[i], list):
assert isinstance(single_grads[i], list) assert isinstance(single_grads[i], list)
for m_grad, s_grad in zip(multi_grads[i], single_grads[i]): for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
assert_allclose( assert_allclose(
m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close" m_grad,
s_grad,
dtype=bwd_test_type,
err_msg=f"multi_grads[{i}] is not close",
) )
else: else:
assert_allclose( assert_allclose(
multi_grads[i], multi_grads[i],
single_grads[i], single_grads[i],
dtype=dtype, dtype=bwd_test_type,
err_msg=f"multi_grads[{i}] is not close", err_msg=f"multi_grads[{i}] is not close",
) )
...@@ -232,8 +253,16 @@ class TestDistributedLayernormMLP: ...@@ -232,8 +253,16 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad( def test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
with_jax_gemm,
): ):
self._test_layernorm_mlp_grad( self._test_layernorm_mlp_grad(
mesh_config, mesh_config,
...@@ -243,6 +272,7 @@ class TestDistributedLayernormMLP: ...@@ -243,6 +272,7 @@ class TestDistributedLayernormMLP:
dtype, dtype,
fp8_recipe, fp8_recipe,
use_shardy=False, use_shardy=False,
with_jax_gemm=with_jax_gemm,
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
...@@ -251,19 +281,29 @@ class TestDistributedLayernormMLP: ...@@ -251,19 +281,29 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad_shardy( def test_layernorm_mlp_grad_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
with_jax_gemm,
): ):
# We don't test block scaling with Shardy because at the time of writing, if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
# it is not supported in JAX's scaled_matmul_stablehlo. pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.")
self._test_layernorm_mlp_grad( self._test_layernorm_mlp_grad(
mesh_config, mesh_config,
activation_type, activation_type,
use_bias, use_bias,
input_shape, input_shape,
dtype, dtype,
fp8_recipe=recipe.DelayedScaling(), fp8_recipe=fp8_recipe,
use_shardy=True, use_shardy=True,
with_jax_gemm=with_jax_gemm,
) )
def _test_layernorm_mlp( def _test_layernorm_mlp(
...@@ -276,6 +316,7 @@ class TestDistributedLayernormMLP: ...@@ -276,6 +316,7 @@ class TestDistributedLayernormMLP:
use_fp8, use_fp8,
fp8_recipe, fp8_recipe,
use_shardy, use_shardy,
with_jax_gemm,
): ):
jax.config.update("jax_use_shardy_partitioner", use_shardy) jax.config.update("jax_use_shardy_partitioner", use_shardy)
batch, seqlen, hidden_in = input_shape batch, seqlen, hidden_in = input_shape
...@@ -287,62 +328,95 @@ class TestDistributedLayernormMLP: ...@@ -287,62 +328,95 @@ class TestDistributedLayernormMLP:
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
init_rngs = {"params": subkeys[1]} init_rngs = {"params": subkeys[1]}
# Single GPUs with use_jax_gemm(enabled=with_jax_gemm):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): # Single GPUs
ln_mlp_single = LayerNormMLP( with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
layernorm_type=layernorm_type, ln_mlp_single = LayerNormMLP(
transpose_batch_sequence=False, # input: [batch, seqlen, hidden] layernorm_type=layernorm_type,
intermediate_dim=INTERMEDIATE, transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
activations=activation_type, intermediate_dim=INTERMEDIATE,
use_bias=use_bias, activations=activation_type,
) use_bias=use_bias,
params_single = ln_mlp_single.init(init_rngs, x, deterministic=True) )
mlp_out_single, ln_out_single = ln_mlp_single.apply( params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
params_single, x, deterministic=True mlp_out_single, ln_out_single = ln_mlp_single.apply(
) params_single, x, deterministic=True
)
# Multi GPUs
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config # Multi GPUs
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
mesh = Mesh(devices, mesh_axes) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
with mesh, fp8_autocast( mesh = Mesh(devices, mesh_axes)
enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource with mesh, fp8_autocast(
): enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
ln_mlp_sharded = LayerNormMLP( ):
layernorm_type=layernorm_type, ln_mlp_sharded = LayerNormMLP(
transpose_batch_sequence=False, layernorm_type=layernorm_type,
intermediate_dim=INTERMEDIATE, transpose_batch_sequence=False,
activations=activation_type, intermediate_dim=INTERMEDIATE,
scale_axes=LN_SCALE_AXES, activations=activation_type,
ln_bias_axes=LN_BIAS_AXES, scale_axes=LN_SCALE_AXES,
kernel_axes_1=KERNEL_1_AXES, ln_bias_axes=LN_BIAS_AXES,
kernel_axes_2=KERNEL_2_AXES, kernel_axes_1=KERNEL_1_AXES,
use_bias=use_bias, kernel_axes_2=KERNEL_2_AXES,
bias_axes_1=BIAS_1_AXES, use_bias=use_bias,
bias_axes_2=BIAS_2_AXES, bias_axes_1=BIAS_1_AXES,
layernorm_input_axes=LAYERNORM_INPUT_AXES, bias_axes_2=BIAS_2_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES, layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES, dot_1_input_axes=DOT_1_INPUT_AXES,
name="mlp", dot_2_input_axes=DOT_2_INPUT_AXES,
) name="mlp",
params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True) )
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply( params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True)
params_sharded, x, deterministic=True mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
) params_sharded, x, deterministic=True
)
# Make sure params values are the same # Make sure params values are the same
assert_tree_like_allclose(params_sharded["params"], params_single["params"]) assert_tree_like_allclose(params_sharded["params"], params_single["params"])
assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype) assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype)
atol = None
rtol = None
l40_tolerance_update = (
get_min_device_compute_capability() == 89
and fp8_recipe == recipe.DelayedScaling()
and use_fp8
and dtype == jnp.float16
and activation_type == ("gelu",)
)
if l40_tolerance_update:
atol = 0.04
rtol = 11
# JAX's FP8 GEMM, jax.lax.dot_general, now uses the
# Triton backend by default. The error of
# the Triton FP8 gemm has been verified to be less than or equal
# to the error of the cuDNN FP8 gemm w.r.t a float32 ground truth.
# However, Triton can auto-tune a different kernel for the single GPU
# and multi-GPU run in this test, meaning the diff between single GPU
# and multi-GPU can be larger in some cases, even though both are
# within tolerance to the float32 ground truth.
jax_triton_gemm_precision_tolerance_update = (
with_jax_gemm
and isinstance(fp8_recipe, recipe.Float8CurrentScaling)
and dtype == jnp.bfloat16
and activation_type == ("gelu", "linear")
)
if jax_triton_gemm_precision_tolerance_update:
atol = 0.08
rtol = 15
assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype, atol=atol, rtol=rtol)
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("use_shardy", [False, True]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer( def test_layernorm_mlp_layer(
self, mesh_config, activation_type, use_bias, input_shape, dtype, use_shardy self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
): ):
self._test_layernorm_mlp( self._test_layernorm_mlp(
mesh_config, mesh_config,
...@@ -352,7 +426,8 @@ class TestDistributedLayernormMLP: ...@@ -352,7 +426,8 @@ class TestDistributedLayernormMLP:
dtype, dtype,
use_fp8=False, use_fp8=False,
fp8_recipe=None, fp8_recipe=None,
use_shardy=use_shardy, use_shardy=False,
with_jax_gemm=with_jax_gemm,
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
...@@ -362,8 +437,9 @@ class TestDistributedLayernormMLP: ...@@ -362,8 +437,9 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_fp8( def test_layernorm_mlp_layer_fp8(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm
): ):
self._test_layernorm_mlp( self._test_layernorm_mlp(
mesh_config, mesh_config,
...@@ -374,4 +450,51 @@ class TestDistributedLayernormMLP: ...@@ -374,4 +450,51 @@ class TestDistributedLayernormMLP:
use_fp8=True, use_fp8=True,
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
use_shardy=False, use_shardy=False,
with_jax_gemm=with_jax_gemm,
)
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
):
self._test_layernorm_mlp(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=False,
fp8_recipe=None,
use_shardy=True,
with_jax_gemm=with_jax_gemm,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_fp8_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm
):
if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.")
self._test_layernorm_mlp(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=True,
fp8_recipe=fp8_recipe,
use_shardy=True,
with_jax_gemm=with_jax_gemm,
) )
...@@ -92,7 +92,7 @@ class TestFP8Functions(unittest.TestCase): ...@@ -92,7 +92,7 @@ class TestFP8Functions(unittest.TestCase):
self._check_default_state() self._check_default_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_scaling(self): def test_fp8_autocast_current_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state() self._check_default_state()
...@@ -116,7 +116,7 @@ class TestFP8Functions(unittest.TestCase): ...@@ -116,7 +116,7 @@ class TestFP8Functions(unittest.TestCase):
self._check_default_state() self._check_default_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_scaling(self): def test_fp8_autocast_mxfp8_block_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state() self._check_default_state()
......
...@@ -3,11 +3,12 @@ ...@@ -3,11 +3,12 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Utility for the TE layer tests""" """Utility for the TE layer tests"""
import os
import functools import functools
import math import math
import operator import operator
from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional, NewType
import os from contextlib import contextmanager
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -20,7 +21,6 @@ from jax import random as jax_random ...@@ -20,7 +21,6 @@ from jax import random as jax_random
import pytest import pytest
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
AttnMaskType,
canonicalize_attn_mask_type, canonicalize_attn_mask_type,
make_swa_mask, make_swa_mask,
) )
...@@ -28,8 +28,8 @@ from transformer_engine.jax.quantize.helper import DType as TEDType ...@@ -28,8 +28,8 @@ from transformer_engine.jax.quantize.helper import DType as TEDType
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
DType = jnp.dtype DType = NewType("DType", jnp.dtype)
Array = Any Array = NewType("Array", jnp.ndarray)
PrecisionLike = Union[ PrecisionLike = Union[
None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
] ]
...@@ -1519,7 +1519,7 @@ def dtype_tols( ...@@ -1519,7 +1519,7 @@ def dtype_tols(
TEDType.kFloat8E5M2: jnp.float8_e5m2, TEDType.kFloat8E5M2: jnp.float8_e5m2,
}[dtype] }[dtype]
elif isinstance(dtype, np.dtype): elif isinstance(dtype, np.dtype):
dtype = jnp.dtype(dtype) dtype = DType(dtype)
# Expect bit-wise accuracy for integer dtypes # Expect bit-wise accuracy for integer dtypes
if not jnp.issubdtype(dtype, jnp.floating): if not jnp.issubdtype(dtype, jnp.floating):
...@@ -1600,3 +1600,20 @@ def print_debug_tensor_stats(prefix, tensor, hist=False): ...@@ -1600,3 +1600,20 @@ def print_debug_tensor_stats(prefix, tensor, hist=False):
fmt = fmt + "\n {}\n {}" fmt = fmt + "\n {}\n {}"
jax.debug.print(fmt, *args) jax.debug.print(fmt, *args)
@contextmanager
def use_jax_gemm(enabled=False):
orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS_RE", None)
try:
if enabled:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$"
yield
finally:
if enabled:
if orig_custom_calls_filter is None:
os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE")
else:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = orig_custom_calls_filter
...@@ -16,7 +16,7 @@ import transformer_engine ...@@ -16,7 +16,7 @@ import transformer_engine
import transformer_engine_torch as tex import transformer_engine_torch as tex
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
from transformer_engine.debug import set_weight_tensor_tp_group_reduce from transformer_engine.debug import set_weight_tensor_tp_group_reduce
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from test_numerics import ( from test_numerics import (
_emulate_linear, _emulate_linear,
...@@ -45,6 +45,8 @@ FEATURE_DIRS = None ...@@ -45,6 +45,8 @@ FEATURE_DIRS = None
all_boolean = [True, False] all_boolean = [True, False]
TEST_NR = 0 TEST_NR = 0
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None): def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None):
if tp_size is None: if tp_size is None:
...@@ -221,7 +223,7 @@ def run_debug_test(func): ...@@ -221,7 +223,7 @@ def run_debug_test(func):
return wrapper return wrapper
CONFIG_LOG_TEST_DISTRIBUTED = """log_distributed: CONFIG_LOG_TEST_DISTRIBUTED_FP8 = """log_distributed:
layers: layers:
layer_types: [linear] layer_types: [linear]
enabled: enabled:
...@@ -241,11 +243,27 @@ CONFIG_LOG_TEST_DISTRIBUTED = """log_distributed: ...@@ -241,11 +243,27 @@ CONFIG_LOG_TEST_DISTRIBUTED = """log_distributed:
end_step: 1 end_step: 1
""" """
CONFIG_LOG_TEST_DISTRIBUTED_NO_FP8 = """log_distributed:
layers:
layer_types: [linear]
enabled:
True
transformer_engine:
LogTensorStats:
enabled: True
tensors: [activation, gradient, weight, output, wgrad, dgrad]
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
start_step : 0
end_step: 1
"""
def _prepare_config_test_log_distributed(config_file): def _prepare_config_test_log_distributed(config_file):
if WORLD_RANK != 0: if WORLD_RANK != 0:
return return
config_file.write(CONFIG_LOG_TEST_DISTRIBUTED) config_file.write(
CONFIG_LOG_TEST_DISTRIBUTED_FP8 if fp8_available else CONFIG_LOG_TEST_DISTRIBUTED_NO_FP8
)
config_file.flush() config_file.flush()
...@@ -361,13 +379,13 @@ def test_log_expert_parallel(**kwargs): ...@@ -361,13 +379,13 @@ def test_log_expert_parallel(**kwargs):
) # data parallel ) # data parallel
model = _init_model(weight, parallel_mode=None, name="linear1") model = _init_model(weight, parallel_mode=None, name="linear1")
model1 = _init_model(weight, parallel_mode=None, name="linear2") model1 = _init_model(weight, parallel_mode=None, name="linear2")
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): with transformer_engine.pytorch.fp8_autocast(enabled=fp8_available, fp8_recipe=FP8_RECIPE):
y1 = model(x) y1 = model(x)
y2 = model1(x) y2 = model1(x)
y = y1 + y2 y = y1 + y2
y.sum().backward() y.sum().backward()
debug_api.step() debug_api.step()
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): with transformer_engine.pytorch.fp8_autocast(enabled=fp8_available, fp8_recipe=FP8_RECIPE):
y = model(x) y = model(x)
if WORLD_RANK != 0: if WORLD_RANK != 0:
y = y + model1(x) y = y + model1(x)
...@@ -620,28 +638,29 @@ if __name__ == "__main__": ...@@ -620,28 +638,29 @@ if __name__ == "__main__":
for gather_weight in [True, False]: for gather_weight in [True, False]:
test_log_distributed(parallel_mode, gather_weight) test_log_distributed(parallel_mode, gather_weight)
for parallel_mode in ["row", "column"]: if fp8_available:
test_disable_fp8_layer(parallel_mode) for parallel_mode in ["row", "column"]:
test_disable_fp8_layer(parallel_mode)
# test_disable_fp8_gemms # test_disable_fp8_gemms
_run_test_with_combinations( _run_test_with_combinations(
test_disable_fp8_gemms, all_boolean, num_repeat=3, extra_args=["column", "row"] test_disable_fp8_gemms, all_boolean, num_repeat=3, extra_args=["column", "row"]
) )
# test_fake_quant_fp8 # test_fake_quant_fp8
dtype_options = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None] dtype_options = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None]
_run_test_with_combinations( _run_test_with_combinations(
test_fake_quant_fp8, test_fake_quant_fp8,
dtype_options, dtype_options,
num_repeat=6, num_repeat=6,
extra_args=["column", "row"], extra_args=["column", "row"],
sample_size=20, sample_size=20,
) )
_run_test_with_combinations( _run_test_with_combinations(
test_per_tensor_scaling, test_per_tensor_scaling,
all_boolean, all_boolean,
num_repeat=6, num_repeat=6,
extra_args=["column"], extra_args=["column"],
sample_size=20, sample_size=20,
) )
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import os import os
import subprocess import subprocess
from pathlib import Path from pathlib import Path
import pytest import pytest
import torch import torch
...@@ -21,7 +20,6 @@ import torch ...@@ -21,7 +20,6 @@ import torch
""" """
if torch.cuda.device_count() < 2: if torch.cuda.device_count() < 2:
pytest.skip("Distributed training needs at least 2 GPUs.") pytest.skip("Distributed training needs at least 2 GPUs.")
...@@ -34,6 +32,6 @@ def test_debug_distributed(feature_dirs): ...@@ -34,6 +32,6 @@ def test_debug_distributed(feature_dirs):
test_path = TEST_ROOT / "run_distributed.py" test_path = TEST_ROOT / "run_distributed.py"
test_cmd = LAUNCH_CMD + [str(test_path), f"--feature_dirs={feature_dirs[0]}"] test_cmd = LAUNCH_CMD + [str(test_path), f"--feature_dirs={feature_dirs[0]}"]
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) result = subprocess.run(test_cmd, env=os.environ, check=False, text=True)
if result.returncode != 0: if result.returncode != 0:
raise AssertionError(result.stderr.decode()) raise AssertionError(f"torchrun exited with {result.returncode}")
...@@ -27,6 +27,9 @@ from transformer_engine.pytorch.module.base import ( ...@@ -27,6 +27,9 @@ from transformer_engine.pytorch.module.base import (
_2X_ACC_FPROP, _2X_ACC_FPROP,
_2X_ACC_WGRAD, _2X_ACC_WGRAD,
) )
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
all_boolean = [True, False] all_boolean = [True, False]
FP8_FORMAT = Format.HYBRID FP8_FORMAT = Format.HYBRID
...@@ -246,8 +249,8 @@ def _init_model(weight): ...@@ -246,8 +249,8 @@ def _init_model(weight):
return model return model
def _run_forward_backward(x, model, loss_scale=1.0, is_first_microbatch=None): def _run_forward_backward(x, model, loss_scale=1.0, is_first_microbatch=None, fp8=True):
with tepytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): with tepytorch.fp8_autocast(enabled=fp8, fp8_recipe=FP8_RECIPE):
y = model(x, is_first_microbatch=is_first_microbatch) y = model(x, is_first_microbatch=is_first_microbatch)
(y.sum() * loss_scale).backward() (y.sum() * loss_scale).backward()
debug_api.step() debug_api.step()
...@@ -262,6 +265,18 @@ def _get_tensors(): ...@@ -262,6 +265,18 @@ def _get_tensors():
return x, weight return x, weight
LOGGING_CONFIG = """logging_config:
enabled: True
layers:
layer_types: [linear]
transformer_engine:
LogTensorStats:
enabled: True
tensors: [activation, gradient, weight, output, wgrad, dgrad]
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
"""
DISABLE_FP8_CONFIG = Template( DISABLE_FP8_CONFIG = Template(
"""disable_fp8_config: """disable_fp8_config:
enabled: True enabled: True
...@@ -275,10 +290,30 @@ DISABLE_FP8_CONFIG = Template( ...@@ -275,10 +290,30 @@ DISABLE_FP8_CONFIG = Template(
) )
@create_config_file
def run_logging_zero_numel_tensor(feature_dirs, **kwargs):
kwargs["config_file"].write(LOGGING_CONFIG)
kwargs["config_file"].flush()
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs)
x, weight = _get_tensors()
x1 = x[:0, :]
model = _init_model(weight)
_ = _run_forward_backward(x1, model, fp8=False)
_ = _run_forward_backward(x, model, fp8=False)
def test_logging_zero_numel_tensor(feature_dirs):
run_logging_zero_numel_tensor(feature_dirs)
@pytest.mark.parametrize("fprop_fp8", all_boolean) @pytest.mark.parametrize("fprop_fp8", all_boolean)
@pytest.mark.parametrize("dgrad_fp8", all_boolean) @pytest.mark.parametrize("dgrad_fp8", all_boolean)
@pytest.mark.parametrize("wgrad_fp8", all_boolean) @pytest.mark.parametrize("wgrad_fp8", all_boolean)
def test_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8): def test_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
run_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8) run_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8)
...@@ -318,6 +353,8 @@ def run_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8, **kwarg ...@@ -318,6 +353,8 @@ def run_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8, **kwarg
def test_disable_fp8_layer(feature_dirs): def test_disable_fp8_layer(feature_dirs):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
run_disable_fp8_layer(feature_dirs) run_disable_fp8_layer(feature_dirs)
...@@ -363,6 +400,8 @@ subset_combinations = random.sample(all_combinations, 20) ...@@ -363,6 +400,8 @@ subset_combinations = random.sample(all_combinations, 20)
def test_per_tensor_scaling( def test_per_tensor_scaling(
feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad
): ):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not any([fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad]): if not any([fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad]):
pytest.skip("Skipping test because all parameters are False") pytest.skip("Skipping test because all parameters are False")
run_per_tensor_scaling( run_per_tensor_scaling(
...@@ -535,6 +574,8 @@ def run_per_tensor_scaling( ...@@ -535,6 +574,8 @@ def run_per_tensor_scaling(
def test_microbatching_per_tensor_scaling( def test_microbatching_per_tensor_scaling(
feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad
): ):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not any([fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad]): if not any([fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad]):
pytest.skip("Skipping test because all parameters are False") pytest.skip("Skipping test because all parameters are False")
...@@ -624,6 +665,8 @@ subset_combinations = random.sample(all_combinations, 10) ...@@ -624,6 +665,8 @@ subset_combinations = random.sample(all_combinations, 10)
def test_fake_quant_fp8( def test_fake_quant_fp8(
feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad
): ):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
run_fake_quant_fp8( run_fake_quant_fp8(
feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad
) )
......
...@@ -2,27 +2,17 @@ ...@@ -2,27 +2,17 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import functools
import itertools
import os
import random
import tempfile
from string import Template
import pytest import pytest
import torch import torch
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
import transformer_engine.debug
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import _default_sf_compute
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from test_numerics import create_config_file from test_numerics import create_config_file
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
B, S, H, D = 64, 64, 64, 64 B, S, H, D = 64, 64, 64, 64
model_keys = ["linear", "layernorm_linear", "layernorm_mlp", "mha_attention", "transformer_layer"] model_keys = ["linear", "layernorm_linear", "layernorm_mlp", "mha_attention", "transformer_layer"]
...@@ -104,4 +94,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir): ...@@ -104,4 +94,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir):
@pytest.mark.parametrize("fp8", [False, True]) @pytest.mark.parametrize("fp8", [False, True])
@pytest.mark.parametrize("config_key", configs.keys()) @pytest.mark.parametrize("config_key", configs.keys())
def test_sanity_debug(model_key, fp8, config_key, feature_dirs): def test_sanity_debug(model_key, fp8, config_key, feature_dirs):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
_run_test(model_key, fp8, configs[config_key], feature_dirs) _run_test(model_key, fp8, configs[config_key], feature_dirs)
...@@ -48,11 +48,6 @@ if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False): ...@@ -48,11 +48,6 @@ if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False):
) )
# Disable TF32
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
# Quantization recipe setup # Quantization recipe setup
def quantization_recipe() -> Recipe: def quantization_recipe() -> Recipe:
if QUANTIZATION == "fp8": if QUANTIZATION == "fp8":
...@@ -167,7 +162,7 @@ def _gather(tensor, dim=0): ...@@ -167,7 +162,7 @@ def _gather(tensor, dim=0):
def _constant(tensor): def _constant(tensor):
return nn.init.constant_(tensor, 0.5) return nn.init.constant_(tensor, 0.05)
def dist_print(msg, src=None, end="\n", error=False): def dist_print(msg, src=None, end="\n", error=False):
...@@ -190,7 +185,8 @@ def _get_tolerances(dtype): ...@@ -190,7 +185,8 @@ def _get_tolerances(dtype):
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
return {"rtol": 1.6e-2, "atol": 1e-5} return {"rtol": 1.6e-2, "atol": 1e-5}
if dtype == torch.float32: if dtype == torch.float32:
return {"rtol": 1.2e-4, "atol": 1e-4} # TF32 has same mantissa bits as FP16
return {"rtol": 1e-3, "atol": 1e-5}
raise ValueError(f"Unsupported dtype ({dtype})") raise ValueError(f"Unsupported dtype ({dtype})")
...@@ -521,8 +517,11 @@ def test_linear(): ...@@ -521,8 +517,11 @@ def test_linear():
{"return_bias": True}, {"return_bias": True},
{"params_dtype": torch.float16}, {"params_dtype": torch.float16},
{"delay_wgrad_compute": True}, {"delay_wgrad_compute": True},
{"save_original_input": True},
] ]
for kwargs in kwargs_list: for kwargs in kwargs_list:
if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8":
continue
for parallel_mode in ["column", "row"]: for parallel_mode in ["column", "row"]:
for sequence_parallel in [False, True]: for sequence_parallel in [False, True]:
_test_linear(parallel_mode, sequence_parallel, **kwargs) _test_linear(parallel_mode, sequence_parallel, **kwargs)
......
...@@ -28,7 +28,6 @@ from transformer_engine.pytorch.tensor.float8_tensor import ( ...@@ -28,7 +28,6 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
) )
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex import transformer_engine_torch as tex
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
......
...@@ -21,7 +21,6 @@ import transformer_engine.pytorch as te ...@@ -21,7 +21,6 @@ import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.ops.fused import ( from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear, UserbuffersBackwardLinear,
UserbuffersForwardLinear, UserbuffersForwardLinear,
...@@ -32,6 +31,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import ( ...@@ -32,6 +31,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
) )
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions # Import utility functions
...@@ -370,7 +370,7 @@ def _test_linear( ...@@ -370,7 +370,7 @@ def _test_linear(
if quantized_compute: if quantized_compute:
tols = dtype_tols( tols = dtype_tols(
model[0].weight._fp8_dtype model[0].weight._fp8_dtype
if is_float8_tensor(model[0].weight) if isinstance(model[0].weight, Float8Tensor)
else tex.DType.kFloat8E4M3 else tex.DType.kFloat8E4M3
) )
......
...@@ -89,7 +89,7 @@ def run_dpa_with_cp( ...@@ -89,7 +89,7 @@ def run_dpa_with_cp(
# instantiate core attn module # instantiate core attn module
core_attn = DotProductAttention( core_attn = DotProductAttention(
config.num_heads, config.num_heads,
config.head_dim_qk, (config.head_dim_qk, config.head_dim_v),
num_gqa_groups=config.num_gqa_groups, num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p, attention_dropout=config.dropout_p,
qkv_format=qkv_format, qkv_format=qkv_format,
...@@ -106,16 +106,22 @@ def run_dpa_with_cp( ...@@ -106,16 +106,22 @@ def run_dpa_with_cp(
config.num_heads, config.num_heads,
config.head_dim_qk, config.head_dim_qk,
) )
kv_input_shape = ( k_input_shape = (
config.batch_size, config.batch_size,
config.max_seqlen_kv, config.max_seqlen_kv,
config.num_gqa_groups, config.num_gqa_groups,
config.head_dim_qk, config.head_dim_qk,
) )
v_input_shape = (
config.batch_size,
config.max_seqlen_kv,
config.num_gqa_groups,
config.head_dim_v,
)
attn_output_shape = ( attn_output_shape = (
config.batch_size, config.batch_size,
config.max_seqlen_q, config.max_seqlen_q,
config.num_heads * config.head_dim_qk, config.num_heads * config.head_dim_v,
) )
cu_seqlens_q = None cu_seqlens_q = None
cu_seqlens_kv = None cu_seqlens_kv = None
...@@ -128,16 +134,22 @@ def run_dpa_with_cp( ...@@ -128,16 +134,22 @@ def run_dpa_with_cp(
config.num_heads, config.num_heads,
config.head_dim_qk, config.head_dim_qk,
) )
kv_input_shape = ( k_input_shape = (
config.max_seqlen_kv, config.max_seqlen_kv,
config.batch_size, config.batch_size,
config.num_gqa_groups, config.num_gqa_groups,
config.head_dim_qk, config.head_dim_qk,
) )
v_input_shape = (
config.max_seqlen_kv,
config.batch_size,
config.num_gqa_groups,
config.head_dim_v,
)
attn_output_shape = ( attn_output_shape = (
config.max_seqlen_q, config.max_seqlen_q,
config.batch_size, config.batch_size,
config.num_heads * config.head_dim_qk, config.num_heads * config.head_dim_v,
) )
cu_seqlens_q = None cu_seqlens_q = None
cu_seqlens_kv = None cu_seqlens_kv = None
...@@ -149,14 +161,19 @@ def run_dpa_with_cp( ...@@ -149,14 +161,19 @@ def run_dpa_with_cp(
config.num_heads, config.num_heads,
config.head_dim_qk, config.head_dim_qk,
) )
kv_input_shape = ( k_input_shape = (
config.batch_size * config.max_seqlen_q, config.batch_size * config.max_seqlen_q,
config.num_gqa_groups, config.num_gqa_groups,
config.head_dim_qk, config.head_dim_qk,
) )
v_input_shape = (
config.batch_size * config.max_seqlen_q,
config.num_gqa_groups,
config.head_dim_v,
)
attn_output_shape = ( attn_output_shape = (
config.batch_size * config.max_seqlen_q, config.batch_size * config.max_seqlen_q,
config.num_heads * config.head_dim_qk, config.num_heads * config.head_dim_v,
) )
seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32) seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2) seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
...@@ -177,8 +194,8 @@ def run_dpa_with_cp( ...@@ -177,8 +194,8 @@ def run_dpa_with_cp(
assert False, f"{qkv_format} is an unsupported qkv_format!" assert False, f"{qkv_format} is an unsupported qkv_format!"
q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda() q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
k = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda() k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda()
v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda() v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda()
dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda() dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda()
dout_quantizer = Float8Quantizer( dout_quantizer = Float8Quantizer(
fp8_dtype=tex.DType.kFloat8E5M2, fp8_dtype=tex.DType.kFloat8E5M2,
......
...@@ -174,6 +174,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -174,6 +174,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("Only fp8 works with fp8_mha=True!") pytest.skip("Only fp8 works with fp8_mha=True!")
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently only support KV P2P!") pytest.skip("MLA CP currently only support KV P2P!")
if dtype == "fp8" and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently does not support FP8 attention!")
subprocess.run( subprocess.run(
get_bash_arguments( get_bash_arguments(
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from __future__ import annotations
import argparse
import functools
import os
import pathlib
import pytest
import torch
import transformer_engine.pytorch as te
from utils import make_recipe
# Check supported quantization schemes
fp8_available, reason_for_no_fp8 = te.fp8.FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = te.fp8.FP8GlobalStateManager.is_mxfp8_available()
# Test cases for loading checkpoint files
_TestLoadCheckpoint_name_list: tuple[str, ...] = (
"linear",
"layernorm_linear",
"layernorm_mlp",
"layernorm",
"rmsnorm",
"transformer_layer",
"ops_linear",
"linear.fp8",
"ops_linear.fp8",
"linear.mxfp8",
"ops_linear.mxfp8",
)
class TestLoadCheckpoint:
"""Tests for loading checkpoint files
Tests assume that checkpoint files have already been created. In
order to regenerate checkpoint files, e.g. after a breaking change
in the checkpoint format, run this file directly as a Python
script: `python3 test_checkpoint.py --save-checkpoint all`.
"""
@staticmethod
def _make_module(name: str) -> torch.nn.Module:
"""Construct a module"""
if name == "linear":
return te.Linear(1, 1)
if name == "layernorm_linear":
return te.LayerNormLinear(1, 1)
if name == "layernorm_mlp":
return te.LayerNormMLP(1, 1)
if name == "layernorm":
return te.LayerNorm(1)
if name == "rmsnorm":
return te.RMSNorm(1)
if name == "transformer_layer":
return te.TransformerLayer(1, 1, 1)
if name == "ops_linear":
return te.ops.Linear(1, 1)
if name == "linear.fp8":
with te.fp8_model_init(recipe=make_recipe("fp8")):
return te.Linear(16, 16)
if name == "ops_linear.fp8":
with te.fp8_model_init(recipe=make_recipe("fp8")):
return te.ops.Linear(16, 16)
if name == "linear.mxfp8":
with te.fp8_model_init(recipe=make_recipe("mxfp8")):
return te.Linear(32, 32)
if name == "ops_linear.mxfp8":
with te.fp8_model_init(recipe=make_recipe("mxfp8")):
return te.ops.Linear(32, 32)
raise ValueError(f"Unrecognized module name ({name})")
@staticmethod
@functools.lru_cache(maxsize=None)
def _checkpoint_dir() -> pathlib.Path:
"""Path to directory with checkpoint files"""
# Check environment variable
path = os.getenv("NVTE_TEST_CHECKPOINT_ARTIFACT_PATH")
if path:
return pathlib.Path(path).resolve()
# Fallback to path in root dir
root_dir = pathlib.Path(__file__).resolve().parent.parent.parent
return root_dir / "artifacts" / "tests" / "pytorch" / "test_checkpoint"
@staticmethod
def _save_checkpoint(name: str, checkpoint_dir: Optional[pathlib.Path] = None) -> None:
"""Save a module's checkpoint file"""
# Path to save checkpoint
if checkpoint_dir is None:
checkpoint_dir = TestLoadCheckpoint._checkpoint_dir()
checkpoint_dir.mkdir(exist_ok=True)
checkpoint_file = checkpoint_dir / f"{name}.pt"
# Create module and save checkpoint
module = TestLoadCheckpoint._make_module(name)
torch.save(module.state_dict(), checkpoint_file)
print(f"Saved checkpoint for {name} at {checkpoint_file}")
@pytest.mark.parametrize("name", _TestLoadCheckpoint_name_list)
def test_module(self, name: str) -> None:
"""Test for loading a module's checkpoint file"""
# Skip if quantization is not supported
quantization = None
if "." in name:
quantization = name.split(".")[1]
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
# Construct module
module = self._make_module(name)
# Load checkpoint from file
checkpoint_file = self._checkpoint_dir() / f"{name}.pt"
if not checkpoint_file.is_file():
raise FileNotFoundError(f"Could not find checkpoint file at {checkpoint_file}")
state_dict = torch.load(checkpoint_file, weights_only=False)
# Update module from checkpoint
module.load_state_dict(state_dict, strict=True)
def main() -> None:
"""Main function
Typically used to generate checkpoint files.
"""
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument(
"--save-checkpoint",
type=str,
default=None,
help="Save checkpoint file for a module",
)
parser.add_argument(
"--checkpoint-dir",
type=str,
default=None,
help="Directory to save checkpoint file in",
)
args = parser.parse_args()
# Save checkpoint files if needed
if args.save_checkpoint is not None:
checkpoint_dir = args.checkpoint_dir
if checkpoint_dir is not None:
checkpoint_dir = pathlib.Path(checkpoint_dir).resolve()
if args.save_checkpoint == "all":
for name in _TestLoadCheckpoint_name_list:
TestLoadCheckpoint._save_checkpoint(name, checkpoint_dir=checkpoint_dir)
else:
TestLoadCheckpoint._save_checkpoint(
args.save_checkpoint,
checkpoint_dir=checkpoint_dir,
)
if __name__ == "__main__":
main()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
import math
from typing import Optional, Dict
from transformer_engine.pytorch.router import (
fused_topk_with_score_function,
fused_compute_score_for_moe_aux_loss,
fused_moe_aux_loss,
)
import pytest
from copy import deepcopy
seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# Pytorch-based group topk
def group_limited_topk(
scores: torch.Tensor,
topk: int,
num_tokens: int,
num_experts: int,
num_groups: int,
group_topk: int,
):
group_scores = (
scores.view(num_tokens, num_groups, -1).topk(topk // group_topk, dim=-1)[0].sum(dim=-1)
)
group_idx = torch.topk(group_scores, k=group_topk, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
# Mask the experts based on selection groups
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_tokens, num_groups, num_experts // num_groups)
.reshape(num_tokens, -1)
)
masked_scores = scores.masked_fill(~score_mask.bool(), float("-inf"))
probs, top_indices = torch.topk(masked_scores, k=topk, dim=-1)
return probs, top_indices
# Pytorch-based topk softmax/sigmoid
def topk_softmax_sigmoid_pytorch(
logits: torch.Tensor,
topk: int,
use_pre_softmax: bool = False,
num_groups: Optional[int] = None,
group_topk: Optional[int] = None,
scaling_factor: Optional[float] = None,
score_function: str = "softmax",
expert_bias: Optional[torch.Tensor] = None,
):
num_tokens, num_experts = logits.shape
def compute_topk(scores, topk, num_groups=None, group_topk=None):
if group_topk:
return group_limited_topk(
scores=scores,
topk=topk,
num_tokens=num_tokens,
num_experts=num_experts,
num_groups=num_groups,
group_topk=group_topk,
)
else:
return torch.topk(scores, k=topk, dim=1)
if score_function == "softmax":
if use_pre_softmax:
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
probs, top_indices = compute_topk(scores, topk, num_groups, group_topk)
else:
scores, top_indices = compute_topk(logits, topk, num_groups, group_topk)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits.float()).type_as(logits)
if expert_bias is not None:
scores_for_routing = scores + expert_bias
_, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk)
scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
else:
scores, top_indices = compute_topk(scores, topk, num_groups, group_topk)
probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
else:
raise ValueError(f"Invalid score_function: {score_function}")
if scaling_factor:
probs = probs * scaling_factor
topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
topk_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
return topk_masked_gates, topk_map
# Pytorch-based compute routing scores for aux loss
def compute_scores_for_aux_loss_pytorch(
logits: torch.Tensor, topk: int, score_function: str
) -> torch.Tensor:
if score_function == "softmax":
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits)
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
else:
raise ValueError(f"Invalid score_function: {score_function}")
_, top_indices = torch.topk(scores, k=topk, dim=1)
routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
return routing_map, scores
# Pytorch-based aux loss
def aux_loss_pytorch(
probs: torch.Tensor,
tokens_per_expert: torch.Tensor,
total_num_tokens: int,
topk: int,
num_experts: int,
moe_aux_loss_coeff: float,
):
aggregated_probs_per_expert = probs.sum(dim=0)
aux_loss = torch.sum(aggregated_probs_per_expert * tokens_per_expert) * (
num_experts * moe_aux_loss_coeff / (topk * total_num_tokens * total_num_tokens)
)
return aux_loss
def run_comparison(
dtype,
num_tokens,
num_experts,
topk,
use_pre_softmax,
num_groups,
group_topk,
scaling_factor,
score_function,
enable_bias,
):
# Set some parameters
if score_function == "sigmoid":
# Construct the special logits to avoid inf in the sigmoid function
offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4
logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2
logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
else:
logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4
logits = logits.view(num_tokens, num_experts)
logits.requires_grad = True
if enable_bias and score_function == "sigmoid":
expert_bias = torch.arange(num_experts, device="cuda") * 0.1
expert_bias = torch.flip(expert_bias, dims=[0])
expert_bias.requires_grad = True
else:
expert_bias = None
# Clone the input tensor
logits_clone = deepcopy(logits)
logits_clone.requires_grad = True
if expert_bias is not None:
expert_bias_clone = deepcopy(expert_bias)
expert_bias_clone.requires_grad = True
else:
expert_bias_clone = None
# Run the original implementation
# We do not support the capacity factor case
probs, routing_map = topk_softmax_sigmoid_pytorch(
logits=logits,
topk=topk,
use_pre_softmax=use_pre_softmax,
num_groups=num_groups,
group_topk=group_topk,
scaling_factor=scaling_factor,
score_function=score_function,
expert_bias=expert_bias,
)
# Run the fused implementation
probs_fused, routing_map_fused = fused_topk_with_score_function(
logits=logits_clone,
topk=topk,
use_pre_softmax=use_pre_softmax,
num_groups=num_groups,
group_topk=group_topk,
scaling_factor=scaling_factor,
score_function=score_function,
expert_bias=expert_bias_clone,
)
torch.testing.assert_close(probs, probs_fused)
torch.testing.assert_close(routing_map, routing_map_fused)
# Fake the loss
loss = torch.sum(probs)
loss_fused = torch.sum(probs_fused)
# Backward the loss
loss.backward()
loss_fused.backward()
# Check the gradient
torch.testing.assert_close(logits.grad, logits_clone.grad)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 8992])
@pytest.mark.parametrize("num_experts", [128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("group_topk", [None, 4])
@pytest.mark.parametrize("scaling_factor", [None, 1.2])
@pytest.mark.parametrize("enable_bias", [True, False])
def test_topk_sigmoid(
dtype,
num_tokens,
num_experts,
topk,
group_topk,
scaling_factor,
enable_bias,
):
num_groups = 8 if group_topk else None
run_comparison(
dtype=dtype,
num_tokens=num_tokens,
num_experts=num_experts,
topk=topk,
use_pre_softmax=False,
num_groups=num_groups,
group_topk=group_topk,
scaling_factor=scaling_factor,
score_function="sigmoid",
enable_bias=enable_bias,
)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_experts", [128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("use_pre_softmax", [True, False])
@pytest.mark.parametrize("group_topk", [None, 4])
@pytest.mark.parametrize("scaling_factor", [None, 1.2])
def test_topk_softmax(
dtype,
num_tokens,
num_experts,
topk,
use_pre_softmax,
group_topk,
scaling_factor,
):
num_groups = 8 if group_topk else None
run_comparison(
dtype=dtype,
num_tokens=num_tokens,
num_experts=num_experts,
topk=topk,
use_pre_softmax=use_pre_softmax,
num_groups=num_groups,
group_topk=group_topk,
scaling_factor=scaling_factor,
score_function="softmax",
enable_bias=False,
)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_experts", [256, 128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("score_function", ["softmax", "sigmoid"])
def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function):
if score_function == "sigmoid":
# Construct the special logits to avoid inf in the sigmoid function
offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4
logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2
logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
else:
logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4
logits = logits.view(num_tokens, num_experts)
logits.requires_grad = True
logits_clone = deepcopy(logits)
logits_clone.requires_grad = True
routing_map, scores = compute_scores_for_aux_loss_pytorch(
logits=logits,
topk=topk,
score_function=score_function,
)
routing_map_fused, scores_fused = fused_compute_score_for_moe_aux_loss(
logits=logits_clone,
topk=topk,
score_function=score_function,
)
torch.testing.assert_close(scores, scores_fused)
torch.testing.assert_close(routing_map, routing_map_fused)
loss = torch.sum(scores)
loss.backward()
loss_fused = torch.sum(scores_fused)
loss_fused.backward()
torch.testing.assert_close(logits.grad, logits_clone.grad)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_experts", [256, 128, 32])
@pytest.mark.parametrize("topk", [4])
def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk):
# Construct the special probs to avoid inf in the sigmoid function
offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4
probs = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2
probs = probs.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
probs = probs.view(num_tokens, num_experts)
probs.requires_grad = True
tokens_per_expert = torch.randint(1, 1000, (num_experts,), device="cuda", dtype=torch.int32)
coeff = 0.01
probs_clone = deepcopy(probs)
probs_clone.requires_grad = True
aux_loss = aux_loss_pytorch(
probs=probs,
tokens_per_expert=tokens_per_expert,
total_num_tokens=num_tokens,
topk=topk,
num_experts=num_experts,
moe_aux_loss_coeff=coeff,
)
aux_loss_fused = fused_moe_aux_loss(
probs=probs_clone,
tokens_per_expert=tokens_per_expert,
total_num_tokens=num_tokens,
num_experts=num_experts,
topk=topk,
coeff=coeff,
)
torch.testing.assert_close(aux_loss, aux_loss_fused)
# Backward
aux_loss.backward()
aux_loss_fused.backward()
torch.testing.assert_close(probs.grad, probs_clone.grad)
def profile_topk_softmax(
dtype,
num_tokens,
num_experts,
topk,
enable_bias,
use_pre_softmax,
):
group_topk = 4
scaling_factor = 1.2
test_topk_sigmoid(
torch.float32, num_tokens, num_experts, topk, group_topk, scaling_factor, enable_bias
)
test_topk_softmax(
torch.float32, num_tokens, num_experts, topk, use_pre_softmax, group_topk, scaling_factor
)
if __name__ == "__main__":
test_fused_scores_for_aux_loss(
dtype=torch.float32, num_tokens=2, num_experts=32, topk=8, score_function="softmax"
)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=256, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=256, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=256, topk=4)
...@@ -20,8 +20,8 @@ import transformer_engine.common.recipe ...@@ -20,8 +20,8 @@ import transformer_engine.common.recipe
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.ops.fused import ( from transformer_engine.pytorch.ops.fused import (
BackwardBiasActivation,
BackwardLinearAdd, BackwardLinearAdd,
ForwardLinearBiasActivation, ForwardLinearBiasActivation,
ForwardLinearBiasAdd, ForwardLinearBiasAdd,
...@@ -162,7 +162,7 @@ def make_reference_and_test_tensors( ...@@ -162,7 +162,7 @@ def make_reference_and_test_tensors(
return ref, test return ref, test
class TestSequential: class TestSequentialContainer:
"""Tests for sequential container""" """Tests for sequential container"""
def test_modules(self) -> None: def test_modules(self) -> None:
...@@ -1878,6 +1878,98 @@ class TestFusedOps: ...@@ -1878,6 +1878,98 @@ class TestFusedOps:
db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu") db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, b_ref.grad, **tols) torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("activation", ("relu", "gelu"))
@pytest.mark.parametrize("out_shape", ((32, 32), (32, 1, 32), (8, 2, 2, 32)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_backward_bias_activation(
self,
*,
activation: str,
out_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
) -> None:
"""Backward dbias + dact + quantize"""
# Tensor dimensions
in_shape = list(out_shape)
hidden_size = in_shape[-1]
# Skip invalid configurations
with_quantization = quantization is not None
maybe_skip_quantization(quantization, device=device)
if quantization == "mxfp8" and (len(in_shape) < 2 or in_shape[-1] % 32 != 0):
pytest.skip("Unsupported tensor size for MXFP8")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
b_ref, b_test = make_reference_and_test_tensors(
hidden_size,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = x_ref + b_ref.reshape([1] * (len(in_shape) - 1) + [hidden_size])
if activation == "gelu":
y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh")
elif activation == "relu":
y_ref = torch.nn.functional.relu(y_ref)
else:
raise ValueError(f"Unexpected activation function ({activation})")
y_ref.backward(dy_ref)
# Implementation with fusible operations
recipe = make_recipe(quantization)
act_type = te_ops.GELU if activation == "gelu" else te_ops.ReLU
model = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=True),
te_ops.Bias(hidden_size, device=device, dtype=dtype),
act_type(),
)
with torch.no_grad():
model[1].bias.copy_(b_test)
del b_test
with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
# Check that backward operations have been fused
backward_ops = model._module_groups[0]._backward_ops
if with_quantization and quantization in ["fp8_delayed_scaling", "mxfp8"]:
assert len(backward_ops) == 2
assert isinstance(backward_ops[0][0], BackwardBiasActivation)
assert isinstance(backward_ops[1][0], te_ops.Quantize)
else:
assert len(backward_ops) == 3
assert isinstance(backward_ops[0][0], act_type)
assert isinstance(backward_ops[1][0], te_ops.Bias)
assert isinstance(backward_ops[2][0], te_ops.Quantize)
# Expected numerical error
tols = dtype_tols(dtype)
if with_quantization:
tols = dtype_tols(tex.DType.kFloat8E4M3)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
db_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantization", _quantization_list)
def test_backward_linear_add( def test_backward_linear_add(
...@@ -2093,3 +2185,109 @@ class TestCheckpointing: ...@@ -2093,3 +2185,109 @@ class TestCheckpointing:
torch.testing.assert_close(y_load, y_save, **tols) torch.testing.assert_close(y_load, y_save, **tols)
for x_load, x_save in zip(xs_load, xs_save): for x_load, x_save in zip(xs_load, xs_save):
torch.testing.assert_close(x_load.grad, x_save.grad, **tols) torch.testing.assert_close(x_load.grad, x_save.grad, **tols)
class TestSequentialModules:
"""Test for larger Sequentials with modules commonly used together"""
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("normalization", ("LayerNorm", "RMSNorm"))
@pytest.mark.parametrize("quantized_compute", (False, True))
@pytest.mark.parametrize("quantized_weight", (False, True))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_layernorm_mlp(
self,
*,
bias: bool,
normalization: str,
quantized_compute: bool,
quantized_weight: bool,
dtype: torch.dtype,
quantization: Optional[str],
device: torch.device = "cuda",
hidden_size: int = 32,
sequence_length: int = 512,
batch_size: int = 4,
ffn_hidden_size: int = 64,
layernorm_epsilon: float = 1e-5,
) -> None:
"""
LayerNorm/RMSNorm + Linear + GELU + Linear
Note that this test checks only if the module runs
as when chaining multiple modules it is hard to validate
numerical accuracy.
"""
# Make input shape
in_shape = (sequence_length, batch_size, hidden_size)
ffn_shape = in_shape[:-1] + (ffn_hidden_size,)
# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=ffn_shape, device=device)
quantization_needed = quantized_compute or quantized_weight
if quantization is None and quantization_needed:
pytest.skip("Quantization scheme is not specified")
if quantization is not None and not quantization_needed:
pytest.skip("Quantization scheme is not used")
# Random data
_, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
_, dy_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
if normalization == "LayerNorm":
norm = te_ops.LayerNorm(
hidden_size,
eps=layernorm_epsilon,
device=device,
dtype=dtype,
)
else:
norm = te_ops.RMSNorm(
hidden_size,
eps=layernorm_epsilon,
device=device,
dtype=dtype,
)
ffn1 = te_ops.Linear(
hidden_size,
ffn_hidden_size,
bias=bias,
device=device,
dtype=dtype,
)
act = te_ops.GELU()
ffn2 = te_ops.Linear(
ffn_hidden_size,
hidden_size,
bias=bias,
device=device,
dtype=dtype,
)
forward = te_ops.Sequential(norm, ffn1, act, ffn2)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment