Commit f8c2af4c authored by yuguo's avatar yuguo
Browse files

Merge commit '1d903f5e' of...

Merge commit '1d903f5e' of https://github.com/NVIDIA/TransformerEngine
parents e92773a3 1d903f5e
...@@ -49,16 +49,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, ...@@ -49,16 +49,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
return; return;
} }
Tensor input("input", { N, H }, itype); Tensor input("input", std::vector<size_t>{ N, H }, itype);
Tensor z("z", { N, H }, otype); Tensor z("z", std::vector<size_t>{ N, H }, otype);
Tensor gamma("gamma", { H }, wtype); Tensor gamma("gamma", std::vector<size_t>{ H }, wtype);
Tensor beta("beta", { H }, wtype); Tensor beta("beta", std::vector<size_t>{ H }, wtype);
Tensor mu("mu", { N }, DType::kFloat32); Tensor mu("mu", std::vector<size_t>{ N }, DType::kFloat32);
Tensor rsigma("rsigma", { N }, DType::kFloat32); Tensor rsigma("rsigma", std::vector<size_t>{ N }, DType::kFloat32);
Tensor dz("dz", { N, H }, wtype); Tensor dz("dz", std::vector<size_t>{ N, H }, wtype);
Tensor dx("dx", { N, H }, itype); Tensor dx("dx", std::vector<size_t>{ N, H }, itype);
Tensor dgamma("dgamma", { H }, wtype); Tensor dgamma("dgamma", std::vector<size_t>{ H }, wtype);
Tensor dbeta("dbeta", { H }, wtype); Tensor dbeta("dbeta", std::vector<size_t>{ H }, wtype);
Tensor workspace_fwd, workspace_bwd; Tensor workspace_fwd, workspace_bwd;
fillUniform(&input); fillUniform(&input);
......
...@@ -116,12 +116,12 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, ...@@ -116,12 +116,12 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
DType wtype = TypeInfo<WeightType>::dtype; DType wtype = TypeInfo<WeightType>::dtype;
DType otype = TypeInfo<OutputType>::dtype; DType otype = TypeInfo<OutputType>::dtype;
Tensor input("input", { N, H }, itype); Tensor input("input", std::vector<size_t>{ N, H }, itype);
Tensor z("z", { N, H }, otype, true, is_training, NVTE_MXFP8_1D_SCALING); Tensor z("z", std::vector<size_t>{ N, H }, otype, true, is_training, NVTE_MXFP8_1D_SCALING);
Tensor gamma("gamma", { H }, wtype); Tensor gamma("gamma", std::vector<size_t>{ H }, wtype);
Tensor beta("beta", { H }, wtype); Tensor beta("beta", std::vector<size_t>{ H }, wtype);
Tensor mu("mu", { N }, DType::kFloat32); Tensor mu("mu", std::vector<size_t>{ N }, DType::kFloat32);
Tensor rsigma("rsigma", { N }, DType::kFloat32); Tensor rsigma("rsigma", std::vector<size_t>{ N }, DType::kFloat32);
Tensor workspace; Tensor workspace;
...@@ -164,7 +164,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, ...@@ -164,7 +164,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
nvte_enable_zero_centered_gamma_in_weight_dtype(false); nvte_enable_zero_centered_gamma_in_weight_dtype(false);
} }
Tensor dequantized_output("dequantized_output", { N, H }, DType::kFloat32, true, true); Tensor dequantized_output("dequantized_output", std::vector<size_t>{ N, H }, DType::kFloat32, true, true);
dequantize_2x<OutputType, fp8e8m0>(z, dequantized_output, is_training); dequantize_2x<OutputType, fp8e8m0>(z, dequantized_output, is_training);
......
...@@ -58,8 +58,8 @@ void performTestQ(const size_t N) { ...@@ -58,8 +58,8 @@ void performTestQ(const size_t N) {
DType itype = TypeInfo<InputType>::dtype; DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype; DType otype = TypeInfo<OutputType>::dtype;
Tensor input("input", { N }, itype); Tensor input("input", std::vector<size_t>{ N }, itype);
Tensor output("output", { N }, otype); Tensor output("output", std::vector<size_t>{ N }, otype);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N); std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N);
...@@ -89,8 +89,8 @@ void performTestDQ(const size_t N) { ...@@ -89,8 +89,8 @@ void performTestDQ(const size_t N) {
DType itype = TypeInfo<InputType>::dtype; DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype; DType otype = TypeInfo<OutputType>::dtype;
Tensor input("input", { N }, itype); Tensor input("input", std::vector<size_t>{ N }, itype);
Tensor output("output", { N }, otype); Tensor output("output", std::vector<size_t>{ N }, otype);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N); std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N);
......
...@@ -37,8 +37,8 @@ void performTest(const size_t N, const size_t H) { ...@@ -37,8 +37,8 @@ void performTest(const size_t N, const size_t H) {
DType dtype = TypeInfo<Type>::dtype; DType dtype = TypeInfo<Type>::dtype;
Tensor input("input", { N, H }, dtype); Tensor input("input", std::vector<size_t>{ N, H }, dtype);
Tensor output("output", { H, N }, dtype); Tensor output("output", std::vector<size_t>{ H, N }, dtype);
std::unique_ptr<Type[]> ref_output = std::make_unique<Type[]>(N * H); std::unique_ptr<Type[]> ref_output = std::make_unique<Type[]>(N * H);
......
...@@ -783,8 +783,6 @@ void fillUniform(Tensor *t) { ...@@ -783,8 +783,6 @@ void fillUniform(Tensor *t) {
template<typename InputEncoding, InputsFillCase Case> template<typename InputEncoding, InputsFillCase Case>
void fillCase_special(Tensor *t) { void fillCase_special(Tensor *t) {
const size_t size = product(t->rowwise_shape()); const size_t size = product(t->rowwise_shape());
const size_t rows = t->rowwise_shape().data[0];
const size_t cols = t->rowwise_shape().data[1];
if constexpr (Case == InputsFillCase::zeros) { if constexpr (Case == InputsFillCase::zeros) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, { TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, {
...@@ -804,16 +802,13 @@ void fillCase_special(Tensor *t) { ...@@ -804,16 +802,13 @@ void fillCase_special(Tensor *t) {
std::uniform_real_distribution<> dis_sign(-1.0, 1.0); std::uniform_real_distribution<> dis_sign(-1.0, 1.0);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, { TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, {
InputType *data = t->rowwise_cpu_dptr<InputType>(); InputType *data = t->rowwise_cpu_dptr<InputType>();
for (size_t i = 0; i < rows; ++i) { for (size_t idx = 0; idx < size; ++idx) {
for (size_t j = 0; j < cols; ++j) { const bool is_negative = (dis_sign(t->gen()) < 0.0);
const size_t idx = i * cols + j; double val = dis(t->gen());
const bool is_negative = (dis_sign(t->gen()) < 0.0); if (is_negative) {
double val = dis(t->gen()); val = -val;
if (is_negative) {
val = -val;
}
data[idx] = static_cast<InputType>(val);
} }
data[idx] = static_cast<InputType>(val);
} }
}); });
} }
......
...@@ -52,6 +52,7 @@ struct BytesToType<8> { ...@@ -52,6 +52,7 @@ struct BytesToType<8> {
}; };
using byte = uint8_t; using byte = uint8_t;
using int16 = int16_t;
using int32 = int32_t; using int32 = int32_t;
using int64 = int64_t; using int64 = int64_t;
using fp32 = float; using fp32 = float;
...@@ -70,6 +71,7 @@ using fp8e8m0 = uint8_t; ...@@ -70,6 +71,7 @@ using fp8e8m0 = uint8_t;
template <typename T> template <typename T>
struct TypeInfo{ struct TypeInfo{
using types = std::tuple<byte, using types = std::tuple<byte,
int16,
int32, int32,
int64, int64,
fp32, fp32,
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
import pytest import pytest
from jax import jit, value_and_grad from jax import jit, value_and_grad
from functools import reduce from functools import reduce
...@@ -18,11 +19,16 @@ from transformer_engine.jax.layernorm import layernorm ...@@ -18,11 +19,16 @@ from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.layernorm_mlp import layernorm_mlp from transformer_engine.jax.layernorm_mlp import layernorm_mlp
from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu, _jax_quantize_dact_dbias from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu, _jax_quantize_dact_dbias
from transformer_engine.jax.cpp_extensions.normalization import _jax_layernorm, _jax_rmsnorm from transformer_engine.jax.cpp_extensions.normalization import (
_jax_layernorm,
_jax_rmsnorm,
is_norm_zero_centered_gamma_in_weight_dtype,
)
from transformer_engine.jax.cpp_extensions.quantization import ( from transformer_engine.jax.cpp_extensions.quantization import (
_jax_quantize, _jax_quantize,
_jax_quantize_dbias, _jax_quantize_dbias,
) )
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax import cpp_extensions as tex
from transformer_engine.jax.quantize import ( from transformer_engine.jax.quantize import (
DelayedScaleQuantizer, DelayedScaleQuantizer,
...@@ -33,7 +39,7 @@ from transformer_engine.jax.quantize import ( ...@@ -33,7 +39,7 @@ from transformer_engine.jax.quantize import (
) )
from transformer_engine.jax.quantize import helper from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.dense import dense
from transformer_engine.jax.layernorm_dense import layernorm_dense from transformer_engine.jax.layernorm_dense import layernorm_dense
from transformer_engine.jax.quantize import ScaledTensor1x, ScaledTensor2x from transformer_engine.jax.quantize import ScaledTensor1x, ScaledTensor2x
...@@ -54,6 +60,7 @@ supported_scaling_modes = [] ...@@ -54,6 +60,7 @@ supported_scaling_modes = []
""" Find supported scaling modes""" """ Find supported scaling modes"""
if is_fp8_supported: if is_fp8_supported:
supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING) supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING)
supported_scaling_modes.append(ScalingMode.CURRENT_TENSOR_SCALING)
if is_mxfp8_supported: if is_mxfp8_supported:
supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING) supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING)
...@@ -71,8 +78,19 @@ def is_shape_supported_by_mxfp8(input_shape): ...@@ -71,8 +78,19 @@ def is_shape_supported_by_mxfp8(input_shape):
def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor): def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x): if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
assert a.scaling_mode == b.scaling_mode
assert a.scale_inv.dtype == b.scale_inv.dtype
if a.scaling_mode.is_tensor_scaling():
# Assert in dq_dtype as some unfused codepaths have an intermediate cast
# to an input dtype which reduces precision compared to everything in fp32
assert_allclose(a.scale_inv, b.scale_inv, dtype=a.dq_dtype)
elif a.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
# Compare MXFP8 scales as uint8
assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8))
else:
raise ValueError(f"Unsupported scaling mode {a.scaling_mode}")
assert_allclose(a.data, b.data) assert_allclose(a.data, b.data)
assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8))
elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x): elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x):
assert_bitwise_scaled_tensors(a.rowwise_tensor, b.rowwise_tensor) assert_bitwise_scaled_tensors(a.rowwise_tensor, b.rowwise_tensor)
assert_bitwise_scaled_tensors(a.colwise_tensor, b.colwise_tensor) assert_bitwise_scaled_tensors(a.colwise_tensor, b.colwise_tensor)
...@@ -159,7 +177,12 @@ class TestActivation: ...@@ -159,7 +177,12 @@ class TestActivation:
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, output_type): @pytest_parametrize_wrapper(
"scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_act_grad_with_tensor_scaling_fp8(
self, random_inputs, activation_type, output_type, scaling_mode
):
x = random_inputs x = random_inputs
x = jnp.expand_dims(x, axis=-2) x = jnp.expand_dims(x, axis=-2)
x = jnp.repeat(x, len(activation_type), axis=-2) x = jnp.repeat(x, len(activation_type), axis=-2)
...@@ -170,7 +193,7 @@ class TestActivation: ...@@ -170,7 +193,7 @@ class TestActivation:
) )
quantizer = QuantizerFactory.create( quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, scaling_mode=scaling_mode,
q_dtype=output_type, q_dtype=output_type,
q_layout=QuantizeLayout.ROWWISE, q_layout=QuantizeLayout.ROWWISE,
) )
...@@ -188,8 +211,11 @@ class TestActivation: ...@@ -188,8 +211,11 @@ class TestActivation:
@pytest_parametrize_wrapper( @pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
) )
def test_act_forward_with_delayed_scaling_fp8( @pytest_parametrize_wrapper(
self, random_inputs, activation_type, output_type, q_layout "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_act_forward_with_tensor_scaling_fp8(
self, random_inputs, activation_type, output_type, q_layout, scaling_mode
): ):
x = random_inputs x = random_inputs
x = jnp.expand_dims(x, axis=-2) x = jnp.expand_dims(x, axis=-2)
...@@ -198,7 +224,7 @@ class TestActivation: ...@@ -198,7 +224,7 @@ class TestActivation:
te_quantizer, jax_quantizer = QuantizerFactory.create( te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2, n_quantizers=2,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, scaling_mode=scaling_mode,
q_dtype=output_type, q_dtype=output_type,
q_layout=q_layout, q_layout=q_layout,
) )
...@@ -335,8 +361,20 @@ class TestNorm: ...@@ -335,8 +361,20 @@ class TestNorm:
@pytest_parametrize_wrapper( @pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
) )
def test_norm_grad_with_delayed_scaling_fp8( @pytest_parametrize_wrapper(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_norm_grad_with_tensor_scaling_fp8(
self,
n,
hidden,
norm_type,
zero_centered_gamma,
epsilon,
inp_dtype,
out_dtype,
q_layout,
scaling_mode,
): ):
""" """
Test transformer_engine.jax.layernorm.layernorm Test transformer_engine.jax.layernorm.layernorm
...@@ -345,9 +383,7 @@ class TestNorm: ...@@ -345,9 +383,7 @@ class TestNorm:
pytest.skip("RMSNorm and zero_centered_gamma is not supported!") pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
quantizer = QuantizerFactory.create( quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
q_dtype=out_dtype,
q_layout=q_layout,
) )
self._test_norm_grad( self._test_norm_grad(
n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
...@@ -395,7 +431,41 @@ class TestNorm: ...@@ -395,7 +431,41 @@ class TestNorm:
) )
ref_mu = None ref_mu = None
assert_bitwise_scaled_tensors(output, ref_out) precise_comparison = True
if get_cudnn_version() < (9, 10, 0) and scaling_mode == ScalingMode.MXFP8_1D_SCALING:
# Reduce precision of test as we don't use fused norm below this version CuDNN for MXFP8 and instead
# do an unfused norm and quantize with an intermediate cast into in_dtype which can reduce precision
precise_comparison = False
elif is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode):
# Larger tolerances as our JAX implementation _jax_*norm uses the compute dtype float32
# for zero-centered gamma always
precise_comparison = False
elif scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING and inp_dtype != jnp.float32:
# Current implementation of Current Tensor Scaling performs unfused layernorm and quantization
# and writes intermediate results into the input dtype, which will slightly reduce precision
# if the input dtype is not float32
precise_comparison = False
if precise_comparison:
assert_bitwise_scaled_tensors(output, ref_out)
else:
if isinstance(ref_out, ScaledTensor1x):
assert_allclose(output.dequantize(), ref_out.dequantize(), dtype=out_dtype)
elif isinstance(ref_out, ScaledTensor2x):
assert_allclose(
output.rowwise_tensor.dequantize(),
ref_out.rowwise_tensor.dequantize(),
dtype=out_dtype,
)
assert_allclose(
output.colwise_tensor.dequantize(),
ref_out.colwise_tensor.dequantize(),
dtype=out_dtype,
)
else:
pytest.fail("Unsupported output type")
assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype) assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype)
if norm_type == "layernorm": if norm_type == "layernorm":
assert_allclose(mu, ref_mu, dtype=inp_dtype) assert_allclose(mu, ref_mu, dtype=inp_dtype)
...@@ -406,8 +476,20 @@ class TestNorm: ...@@ -406,8 +476,20 @@ class TestNorm:
@pytest_parametrize_wrapper( @pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
) )
def test_norm_forward_with_delayed_scaling_fp8( @pytest_parametrize_wrapper(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_norm_forward_with_tensor_scaling_fp8(
self,
n,
hidden,
norm_type,
zero_centered_gamma,
epsilon,
inp_dtype,
out_dtype,
q_layout,
scaling_mode,
): ):
if norm_type == "rmsnorm" and zero_centered_gamma is True: if norm_type == "rmsnorm" and zero_centered_gamma is True:
pytest.skip("RMSNorm and zero_centered_gamma is not supported!") pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
...@@ -420,7 +502,7 @@ class TestNorm: ...@@ -420,7 +502,7 @@ class TestNorm:
epsilon=epsilon, epsilon=epsilon,
inp_dtype=inp_dtype, inp_dtype=inp_dtype,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, scaling_mode=scaling_mode,
q_layout=q_layout, q_layout=q_layout,
) )
...@@ -447,17 +529,24 @@ QUANTIZE_OUTPUT_DTYPES = { ...@@ -447,17 +529,24 @@ QUANTIZE_OUTPUT_DTYPES = {
"L2": [jnp.float8_e4m3fn, jnp.float8_e5m2], "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
} }
ALL_QUANTIZE_TEST_SHAPES = [ ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [
(32, 64), ((32, 64), -1),
(2, 64, 32), ((2, 64, 32), -1),
((2, 64, 32), -2),
((32, 256, 128), -1),
((32, 256, 128), -2),
((64, 32, 32, 256), -1),
((64, 32, 32, 256), -2),
((64, 32, 32, 256), -3),
] ]
QUANTIZE_TEST_SHAPES = { QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = {
"L0": [ "L0": [
(32, 256, 128), ((32, 64), -1),
(64, 32, 32, 256), ((2, 64, 32), -1),
((2, 64, 32), -2),
], ],
"L2": ALL_QUANTIZE_TEST_SHAPES, "L2": ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES,
} }
QUANTIZATION_INPUT_DTYPE = { QUANTIZATION_INPUT_DTYPE = {
...@@ -469,9 +558,8 @@ QUANTIZATION_INPUT_DTYPE = { ...@@ -469,9 +558,8 @@ QUANTIZATION_INPUT_DTYPE = {
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("input_shape", ALL_QUANTIZE_TEST_SHAPES) @pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("flatten_axis", [-1, -2])
@pytest_parametrize_wrapper( @pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE] "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
) )
...@@ -524,12 +612,11 @@ class TestFusedQuantize: ...@@ -524,12 +612,11 @@ class TestFusedQuantize:
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("input_shape", QUANTIZE_TEST_SHAPES) @pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper( @pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
) )
@pytest_parametrize_wrapper("flatten_axis", [-1, -2])
def test_quantize_dbias( def test_quantize_dbias(
self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis
): ):
...@@ -538,6 +625,12 @@ class TestFusedQuantize: ...@@ -538,6 +625,12 @@ class TestFusedQuantize:
): ):
pytest.skip(f"Input shape {input_shape} is not supported by MXFP8") pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")
if (flatten_axis < 0 and flatten_axis + len(input_shape) <= 0) or flatten_axis <= 0:
pytest.skip(
f"Flatten axis {flatten_axis} is not supported for input shape {input_shape}. There"
" must be at least one axis on either side of the flatten_axis split."
)
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype) input = jax.random.uniform(key, input_shape, in_dtype)
...@@ -630,16 +723,19 @@ class TestFusedQuantize: ...@@ -630,16 +723,19 @@ class TestFusedQuantize:
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper("is_dbias", [True, False]) @pytest_parametrize_wrapper("is_dbias", [True, False])
@pytest_parametrize_wrapper( @pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE] "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
) )
def test_quantize_dact_dbias_delayed_scaling( @pytest_parametrize_wrapper(
self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_quantize_dact_dbias_tensor_scaling(
self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout, scaling_mode
): ):
self._test_quantize_dact_dbias( self._test_quantize_dact_dbias(
in_dtype=in_dtype, in_dtype=in_dtype,
input_shape=input_shape, input_shape=input_shape,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, scaling_mode=scaling_mode,
activation_type=activation_type, activation_type=activation_type,
is_dbias=is_dbias, is_dbias=is_dbias,
q_layout=q_layout, q_layout=q_layout,
...@@ -830,7 +926,10 @@ class TestFusedDense: ...@@ -830,7 +926,10 @@ class TestFusedDense:
Test layernorm_dense VJP Rule Test layernorm_dense VJP Rule
""" """
# No Norm FWD E5M2 in TE backend # No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: if q_dtype == jnp.float8_e5m2 and scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
pytest.skip("E5M2 is not supported in normalization with TE Backend!") pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm # zero_centered_gamma is already tested in TestNorm
...@@ -916,7 +1015,10 @@ class TestFusedDense: ...@@ -916,7 +1015,10 @@ class TestFusedDense:
Test layernorm_mlp VJP Rule Test layernorm_mlp VJP Rule
""" """
# No Norm FWD E5M2 in TE backend # No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: if q_dtype == jnp.float8_e5m2 and scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
pytest.skip("E5M2 is not supported in normalization with TE Backend!") pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm # zero_centered_gamma is already tested in TestNorm
...@@ -1052,7 +1154,7 @@ fwd_bwd_dtypes = [ ...@@ -1052,7 +1154,7 @@ fwd_bwd_dtypes = [
[jnp.float8_e5m2, jnp.float8_e4m3fn], [jnp.float8_e5m2, jnp.float8_e4m3fn],
] ]
"""
@pytest_parametrize_wrapper( @pytest_parametrize_wrapper(
"shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]] "shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]]
) )
...@@ -1267,3 +1369,4 @@ class TestGroupedDense: ...@@ -1267,3 +1369,4 @@ class TestGroupedDense:
assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype) assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype) assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype) assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype)
"""
...@@ -34,6 +34,7 @@ is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) ...@@ -34,6 +34,7 @@ is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
SUPPORTED_RECIPES = [] SUPPORTED_RECIPES = []
if is_fp8_supported: if is_fp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling")) SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling"))
if is_mxfp8_supported: if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
...@@ -76,6 +77,8 @@ class TestDistributedLayernorm: ...@@ -76,6 +77,8 @@ class TestDistributedLayernorm:
other_bytes = 0 other_bytes = 0
if fp8_recipe == recipe.MXFP8BlockScaling() and "dp" in mesh_axes: if fp8_recipe == recipe.MXFP8BlockScaling() and "dp" in mesh_axes:
other_bytes = 384 # required for small scale shapes that require padding other_bytes = 384 # required for small scale shapes that require padding
if fp8_recipe == recipe.Float8CurrentScaling():
allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction
return generate_collectives_count( return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes
) )
......
...@@ -41,6 +41,7 @@ is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) ...@@ -41,6 +41,7 @@ is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
SUPPORTED_RECIPES = [] SUPPORTED_RECIPES = []
if is_fp8_supported: if is_fp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling")) SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling"))
if is_mxfp8_supported: if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
...@@ -217,37 +218,10 @@ class TestDistributedLayernormMLP: ...@@ -217,37 +218,10 @@ class TestDistributedLayernormMLP:
m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close" m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close"
) )
else: else:
is_gated = len(activation_type) > 1
rtol = None
atol = None
if is_gated:
if dtype == jnp.bfloat16:
if i == 2:
rtol = 800
atol = 9e-2
if i == 4:
atol = 300
rtol = 1e-1
if dtype == jnp.float16:
if i == 1: # gamma
rtol = 200
atol = 1e-2
if i == 2:
rtol = 2000
atol = 7e-2
if i == 4 and fp8_recipe == recipe.MXFP8BlockScaling(): # bias_1
# Accumulating dbias across a large tensor introduces a larger difference
rtol = 200
atol = 4e-2
if i == 4 and fp8_recipe == recipe.DelayedScaling():
rtol = 2200
atol = 9e-2
assert_allclose( assert_allclose(
multi_grads[i], multi_grads[i],
single_grads[i], single_grads[i],
dtype=dtype, dtype=dtype,
rtol=rtol,
atol=atol,
err_msg=f"multi_grads[{i}] is not close", err_msg=f"multi_grads[{i}] is not close",
) )
......
...@@ -10,47 +10,22 @@ import jax.numpy as jnp ...@@ -10,47 +10,22 @@ import jax.numpy as jnp
import numpy as np import numpy as np
from utils import assert_allclose from utils import assert_allclose
from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8CurrentScaling
from transformer_engine.common.recipe import Format as FP8Format from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast, get_delayed_scaling from transformer_engine.jax import fp8_autocast, get_delayed_scaling
from transformer_engine.jax.quantize import QuantizeConfig, is_fp8_available, AmaxComputeAlgo from transformer_engine.jax.quantize import (
QuantizeConfig,
is_fp8_available,
ScalingMode,
update_collections,
)
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
class TestQuantizeConfig(unittest.TestCase): class TestHelper(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_initialize(self):
margin = 5.0
fp8_format = FP8Format.E4M3
amax_history_len = 10
QuantizeConfig.initialize(
margin=margin, fp8_format=fp8_format, amax_history_len=amax_history_len
)
self.assertEqual(
QuantizeConfig.MARGIN,
margin,
f"QuantizeConfig.MARGIN initialization failed, should be {margin}"
f" but got {QuantizeConfig.MARGIN}.",
)
self.assertEqual(
QuantizeConfig.FP8_FORMAT,
fp8_format,
f"QuantizeConfig.FP8_FORMAT initialization failed, should be {fp8_format}"
f" but got {QuantizeConfig.FP8_FORMAT}.",
)
self.assertEqual(
QuantizeConfig.AMAX_HISTORY_LEN,
amax_history_len,
f"QuantizeConfig.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
f" but got {QuantizeConfig.AMAX_HISTORY_LEN}.",
)
QuantizeConfig.finalize()
@unittest.skipIf(not is_fp8_supported, reason=reason) @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_update_collections(self): def test_update_collections(self):
...@@ -61,19 +36,19 @@ class TestQuantizeConfig(unittest.TestCase): ...@@ -61,19 +36,19 @@ class TestQuantizeConfig(unittest.TestCase):
"test1": original_val, "test1": original_val,
"test2": original_val, "test2": original_val,
} }
updated_state = QuantizeConfig.update_collections({"test1": updated_val}, original_state) updated_state = update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val) self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val) self.assertEqual(updated_state["test2"], original_val)
original_state = flax.core.frozen_dict.FrozenDict(original_state) original_state = flax.core.frozen_dict.FrozenDict(original_state)
updated_state = QuantizeConfig.update_collections({"test1": updated_val}, original_state) updated_state = update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val) self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val) self.assertEqual(updated_state["test2"], original_val)
class TestFP8Functions(unittest.TestCase): class TestFP8Functions(unittest.TestCase):
def _check_defult_state(self): def _check_default_state(self):
self.assertFalse(QuantizeConfig.is_fp8_enabled()) self.assertFalse(QuantizeConfig.is_fp8_enabled())
def _compare_delay_scaling(self, ref, test): def _compare_delay_scaling(self, ref, test):
...@@ -82,35 +57,92 @@ class TestFP8Functions(unittest.TestCase): ...@@ -82,35 +57,92 @@ class TestFP8Functions(unittest.TestCase):
self.assertTrue(ref.amax_history_len == test.amax_history_len) self.assertTrue(ref.amax_history_len == test.amax_history_len)
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo) self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
def _compare_current_scaling(self, test):
self.assertEqual(QuantizeConfig.MARGIN, test.margin)
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING)
def _compare_mxfp8_scaling(self, test):
self.assertEqual(QuantizeConfig.MARGIN, test.margin)
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING)
@unittest.skipIf(not is_fp8_supported, reason=reason) @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast(self): def test_fp8_autocast_delayed_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state() self._check_default_state()
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()): with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
self.assertFalse(QuantizeConfig.is_fp8_enabled()) self._check_default_state()
self._compare_delay_scaling(get_delayed_scaling(), DelayedScaling())
self._check_defult_state() self._check_default_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1) ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds): with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(QuantizeConfig.is_fp8_enabled()) self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds) self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_defult_state() self._check_default_state()
ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1) ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds): with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(QuantizeConfig.is_fp8_enabled()) self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds) self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_defult_state() self._check_default_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state()
with fp8_autocast(enabled=False, fp8_recipe=Float8CurrentScaling()):
self._check_default_state()
self._check_default_state()
cs = Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=cs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_default_state()
cs = Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=cs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_default_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state()
with fp8_autocast(enabled=False, fp8_recipe=MXFP8BlockScaling()):
self._check_default_state()
self._check_default_state()
bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=bs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_default_state()
bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=bs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_default_state()
@unittest.skipIf(not is_fp8_supported, reason=reason) @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_with_sharding_resource(self): def test_fp8_autocast_with_sharding_resource(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state() self._check_default_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1) ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
...@@ -130,4 +162,4 @@ class TestFP8Functions(unittest.TestCase): ...@@ -130,4 +162,4 @@ class TestFP8Functions(unittest.TestCase):
self._compare_delay_scaling(get_delayed_scaling(), ds) self._compare_delay_scaling(get_delayed_scaling(), ds)
self.assertEqual(sr, global_mesh_resource()) self.assertEqual(sr, global_mesh_resource())
self._check_defult_state() self._check_default_state()
...@@ -13,7 +13,6 @@ import jax ...@@ -13,7 +13,6 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from flax import linen as nn from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import combine_masks from flax.linen.attention import combine_masks
from jax import lax, vmap from jax import lax, vmap
from jax import nn as jax_nn from jax import nn as jax_nn
...@@ -97,16 +96,16 @@ def combine_biases(*masks: Optional[Array]): ...@@ -97,16 +96,16 @@ def combine_biases(*masks: Optional[Array]):
return mask return mask
def parameterize_by_test_level(param_dict: dict, id_prefix: str = ""): def get_parameters_for_test_level(param_dict: dict):
""" """
Takes an input dictionary of parameters keyed by test type "L0", etc. Takes an input dictionary of parameters keyed by test type "L0", etc.
Returns a list of pytest parameters to be used in a parameterized test for the current test type Returns the parameters for the test level specified in the environment variable
""" """
DEFAULT_TEST_LEVEL = "L0" DEFAULT_TEST_LEVEL = "L0"
test_level = os.environ.get("NVTE_JAX_UNITTEST_LEVEL", DEFAULT_TEST_LEVEL) test_level = os.environ.get("NVTE_JAX_UNITTEST_LEVEL", DEFAULT_TEST_LEVEL)
if test_level not in param_dict: if test_level not in param_dict:
raise ValueError("Unsupported test level") raise ValueError("Unsupported test level")
return values_to_named_params(param_dict[test_level], id_prefix) return param_dict[test_level]
def value_to_test_name_str(value): def value_to_test_name_str(value):
...@@ -139,14 +138,18 @@ def pytest_parametrize_wrapper(param_name, param_values): ...@@ -139,14 +138,18 @@ def pytest_parametrize_wrapper(param_name, param_values):
A wrapper for pytest.mark.parametrize to allow for automatic A wrapper for pytest.mark.parametrize to allow for automatic
naming of tests based on the parameter values. naming of tests based on the parameter values.
""" """
id_prefix = param_name
if isinstance(param_values, dict): if isinstance(param_values, dict):
param_values = parameterize_by_test_level(param_values, id_prefix=param_name) # If the values are split into a dictionary of test-levels, e.g. "L0", etc.,
elif "," not in param_name: # unwrap the selected level before proceeding.
param_values = values_to_named_params(param_values, id_prefix=id_prefix) param_values = get_parameters_for_test_level(param_values)
if "," not in param_name:
# Multi-parameterize annotations are not supported in this wrapper
# and are just a passthrough to default pytest.mark.parametrize.
# E.g. @pytest_parametrize_wrapper("a,b", ((a_value1, b_value1), (a_value2, b_value2)))
# will be passed through to pytest.mark.parametrize as-is without pytest.param ids.
param_values = values_to_named_params(param_values, id_prefix=param_name)
# Currently comma separated parameters in one parametrize call aren't supported for automatic naming
# and will just be passed through with default pytest names
def decorator(func): def decorator(func):
return pytest.mark.parametrize(param_name, param_values)(func) return pytest.mark.parametrize(param_name, param_values)(func)
...@@ -312,16 +315,22 @@ class DenseGeneral(nn.Module): ...@@ -312,16 +315,22 @@ class DenseGeneral(nn.Module):
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features)) kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features))
kernel = nn_partitioning.param_with_axes( kernel = self.param(
"kernel", self.kernel_init, kernel_param_shape, self.dtype, axes=self.kernel_axes "kernel",
nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
kernel_param_shape,
self.dtype,
) )
kernel = jnp.asarray(kernel, input_dtype) kernel = jnp.asarray(kernel, input_dtype)
kernel = jnp.reshape(kernel, kernel_shape) kernel = jnp.reshape(kernel, kernel_shape)
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes( bias = self.param(
"bias", self.bias_init, self.features, self.dtype, axes=self.bias_axes "bias",
nn.with_logical_partitioning(self.bias_init, self.bias_axes),
self.features,
self.dtype,
) )
bias = bias.astype(input_dtype) bias = bias.astype(input_dtype)
else: else:
...@@ -418,9 +427,9 @@ class MlpBlock(nn.Module): ...@@ -418,9 +427,9 @@ class MlpBlock(nn.Module):
) # Broadcast along length. ) # Broadcast along length.
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "mlp")) x = nn.with_logical_constraint(x, ("length", "batch", "mlp"))
else: else:
x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "mlp")) x = nn.with_logical_constraint(x, ("batch", "length", "mlp"))
output = DenseGeneral( output = DenseGeneral(
inputs.shape[-1], inputs.shape[-1],
dtype=self.dtype, dtype=self.dtype,
...@@ -684,21 +693,13 @@ class MultiHeadAttention(nn.Module): ...@@ -684,21 +693,13 @@ class MultiHeadAttention(nn.Module):
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim)) value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
query = nn_partitioning.with_sharding_constraint( query = nn.with_logical_constraint(query, ("length", "batch", "heads", "kv"))
query, ("length", "batch", "heads", "kv") key = nn.with_logical_constraint(key, ("length", "batch", "heads", "kv"))
) value = nn.with_logical_constraint(value, ("length", "batch", "heads", "kv"))
key = nn_partitioning.with_sharding_constraint(key, ("length", "batch", "heads", "kv"))
value = nn_partitioning.with_sharding_constraint(
value, ("length", "batch", "heads", "kv")
)
else: else:
query = nn_partitioning.with_sharding_constraint( query = nn.with_logical_constraint(query, ("batch", "length", "heads", "kv"))
query, ("batch", "length", "heads", "kv") key = nn.with_logical_constraint(key, ("batch", "length", "heads", "kv"))
) value = nn.with_logical_constraint(value, ("batch", "length", "heads", "kv"))
key = nn_partitioning.with_sharding_constraint(key, ("batch", "length", "heads", "kv"))
value = nn_partitioning.with_sharding_constraint(
value, ("batch", "length", "heads", "kv")
)
if decode: if decode:
# Detect if we're initializing by absence of existing cache data. # Detect if we're initializing by absence of existing cache data.
...@@ -805,9 +806,9 @@ class MultiHeadAttention(nn.Module): ...@@ -805,9 +806,9 @@ class MultiHeadAttention(nn.Module):
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "joined_kv")) x = nn.with_logical_constraint(x, ("length", "batch", "joined_kv"))
else: else:
x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "joined_kv")) x = nn.with_logical_constraint(x, ("batch", "length", "joined_kv"))
# Back to the original inputs dimensions. # Back to the original inputs dimensions.
...@@ -853,8 +854,11 @@ class LayerNorm(nn.Module): ...@@ -853,8 +854,11 @@ class LayerNorm(nn.Module):
input_dtype = x.dtype input_dtype = x.dtype
features = x.shape[-1] features = x.shape[-1]
scale = nn_partitioning.param_with_axes( scale = self.param(
"scale", self.scale_init, (features,), self.dtype, axes=("embed",) "scale",
nn.with_logical_partitioning(self.scale_init, ("embed",)),
(features,),
self.dtype,
) )
x_ = x.astype(jnp.float32) x_ = x.astype(jnp.float32)
if self.layernorm_type == "layernorm": if self.layernorm_type == "layernorm":
...@@ -862,8 +866,11 @@ class LayerNorm(nn.Module): ...@@ -862,8 +866,11 @@ class LayerNorm(nn.Module):
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
y = (x_ - mean) * lax.rsqrt(var + self.epsilon) y = (x_ - mean) * lax.rsqrt(var + self.epsilon)
bias = nn_partitioning.param_with_axes( bias = self.param(
"ln_bias", self.bias_init, (features,), self.dtype, axes=("embed",) "ln_bias",
nn.with_logical_partitioning(self.bias_init, ("embed",)),
(features,),
self.dtype,
) )
bias = jnp.asarray(bias, input_dtype) bias = jnp.asarray(bias, input_dtype)
...@@ -972,12 +979,11 @@ class RelativePositionBiases(nn.Module): ...@@ -972,12 +979,11 @@ class RelativePositionBiases(nn.Module):
num_buckets=self.num_buckets, num_buckets=self.num_buckets,
max_distance=self.max_distance, max_distance=self.max_distance,
) )
relative_attention_bias = nn_partitioning.param_with_axes( relative_attention_bias = self.param(
"rel_embedding", "rel_embedding",
self.embedding_init, nn.with_logical_partitioning(self.embedding_init, ("heads", "relpos_buckets")),
(self.num_heads, self.num_buckets), (self.num_heads, self.num_buckets),
jnp.float32, jnp.float32,
axes=("heads", "relpos_buckets"),
) )
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
...@@ -1555,14 +1561,16 @@ def sync_params_values(dst, src, transformations, sep="/"): ...@@ -1555,14 +1561,16 @@ def sync_params_values(dst, src, transformations, sep="/"):
""" """
src_values = {} src_values = {}
for key, value in jax.tree_util.tree_leaves_with_path(src): for key, value in jax.tree_util.tree_leaves_with_path(src):
normalized_key = sep.join(x.key for x in key) # Only select DictKey(key="...") entries, skip GetAttr(name="...") entries at the end of the tree path
normalized_key = sep.join(x.key for x in key if hasattr(x, "key"))
src_values[normalized_key] = value src_values[normalized_key] = value
flatten_dst, dst_tree_def = jax.tree_util.tree_flatten_with_path(dst) flatten_dst, dst_tree_def = jax.tree_util.tree_flatten_with_path(dst)
synced_dst_values = [] synced_dst_values = []
for key, value in flatten_dst: for key, value in flatten_dst:
normalized_key = sep.join(x.key for x in key) # Only select DictKey(key="...") entries, skip GetAttr(name="...") entries at the end of the tree path
normalized_key = sep.join(x.key for x in key if hasattr(x, "key"))
if normalized_key in transformations: if normalized_key in transformations:
corresponding_src_key = transformations[normalized_key] corresponding_src_key = transformations[normalized_key]
else: else:
......
...@@ -16,6 +16,7 @@ import torch.distributed as dist ...@@ -16,6 +16,7 @@ import torch.distributed as dist
from transformer_engine.common.recipe import ( from transformer_engine.common.recipe import (
DelayedScaling, DelayedScaling,
Float8CurrentScaling, Float8CurrentScaling,
Float8BlockScaling,
Format, Format,
Recipe, Recipe,
) )
...@@ -26,6 +27,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import ( ...@@ -26,6 +27,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
) )
from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor
def _get_raw_data(quantized_tensor): def _get_raw_data(quantized_tensor):
...@@ -34,6 +36,14 @@ def _get_raw_data(quantized_tensor): ...@@ -34,6 +36,14 @@ def _get_raw_data(quantized_tensor):
assert hasattr(quantized_tensor, "_data"), "Float8Tensor does not have _data attribute" assert hasattr(quantized_tensor, "_data"), "Float8Tensor does not have _data attribute"
assert quantized_tensor._data.dtype == torch.uint8, "Float8Tensor _data must be uint8" assert quantized_tensor._data.dtype == torch.uint8, "Float8Tensor _data must be uint8"
return quantized_tensor._data return quantized_tensor._data
elif isinstance(quantized_tensor, Float8BlockwiseQTensor):
assert hasattr(
quantized_tensor, "_rowwise_data"
), "Float8BlockwiseQTensor does not have _rowwise_data attribute"
assert (
quantized_tensor._rowwise_data.dtype == torch.uint8
), "Float8BlockwiseQTensor _rowwise_data must be uint8"
return quantized_tensor._rowwise_data
else: else:
raise ValueError(f"Unsupported quantized tensor type: {type(quantized_tensor)}") raise ValueError(f"Unsupported quantized tensor type: {type(quantized_tensor)}")
...@@ -435,15 +445,15 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group): ...@@ -435,15 +445,15 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
preserve_high_precision_init_val=True, preserve_high_precision_init_val=True,
): ):
model_fp8 = nn.Sequential( model_fp8 = nn.Sequential(
te.Linear(128, 256, **linear_kwargs), te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256, 256 * 3, **linear_kwargs), te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs), te.Linear(256 * 3, 128, **linear_kwargs),
) )
# Create model with BF16 weights # Create model with BF16 weights
model = nn.Sequential( model = nn.Sequential(
te.Linear(128, 256, **linear_kwargs), te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256, 256 * 3, **linear_kwargs), te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs), te.Linear(256 * 3, 128, **linear_kwargs),
) )
...@@ -539,12 +549,13 @@ def _test_zero_1(dp_group): ...@@ -539,12 +549,13 @@ def _test_zero_1(dp_group):
def quantization_recipe(quantization) -> Recipe: def quantization_recipe(quantization) -> Recipe:
"""Quantization recipe setup""" """Quantization recipe setup"""
fp8_format = Format.HYBRID
if quantization == "fp8": if quantization == "fp8":
return DelayedScaling( return DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max"
)
elif quantization == "fp8_cs": elif quantization == "fp8_cs":
return Float8CurrentScaling() return Float8CurrentScaling(fp8_format=fp8_format)
elif quantization == "fp8_block":
return Float8BlockScaling(fp8_format=fp8_format)
else: else:
raise ValueError(f"Unsupported quantization: {quantization}") raise ValueError(f"Unsupported quantization: {quantization}")
...@@ -568,15 +579,15 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group): ...@@ -568,15 +579,15 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
preserve_high_precision_init_val=True, preserve_high_precision_init_val=True,
): ):
model_fp8 = nn.Sequential( model_fp8 = nn.Sequential(
te.Linear(128, 256, **linear_kwargs), te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256, 256 * 3, **linear_kwargs), te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs), te.Linear(256 * 3, 128, **linear_kwargs),
) )
# Create model with BF16 weights # Create model with BF16 weights
model = nn.Sequential( model = nn.Sequential(
te.Linear(128, 256, **linear_kwargs), te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256, 256 * 3, **linear_kwargs), te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs), te.Linear(256 * 3, 128, **linear_kwargs),
) )
...@@ -593,7 +604,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group): ...@@ -593,7 +604,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
optimizer_fp8 = MiniZero_1([w for w in model_fp8.parameters()], 10.0, dp_group) optimizer_fp8 = MiniZero_1([w for w in model_fp8.parameters()], 10.0, dp_group)
optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group) optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group)
for _ in range(100): for i in range(100):
for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
w_fp8.main_grad.zero_() w_fp8.main_grad.zero_()
w.main_grad.zero_() w.main_grad.zero_()
...@@ -654,7 +665,9 @@ def main(argv=None, namespace=None): ...@@ -654,7 +665,9 @@ def main(argv=None, namespace=None):
dist.init_process_group(**dist_init_kwargs) dist.init_process_group(**dist_init_kwargs)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--quantization", type=str, default=None, choices=["fp8", "fp8_cs"]) parser.add_argument(
"--quantization", type=str, default=None, choices=["fp8", "fp8_cs", "fp8_block"]
)
args = parser.parse_args(argv, namespace) args = parser.parse_args(argv, namespace)
dp_group = dist.new_group(backend="nccl") dp_group = dist.new_group(backend="nccl")
......
...@@ -21,7 +21,11 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION ...@@ -21,7 +21,11 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_cublas_workspace_size_bytes from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.module.base import (
fill_userbuffers_buffer_for_all_gather,
get_cublas_workspace_size_bytes,
)
warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=FutureWarning)
...@@ -57,7 +61,11 @@ def _parse_args(argv=None, namespace=None): ...@@ -57,7 +61,11 @@ def _parse_args(argv=None, namespace=None):
) )
parser.add_argument("--seed", type=int, default=42, help="RNG seed.") parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
parser.add_argument( parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." "--quantization",
type=str.lower,
default="none",
choices=["none", "fp8", "mxfp8"],
help="Quantization recipe",
) )
parser.add_argument( parser.add_argument(
"--fp8-output", action="store_true", default=False, help="Get FP8 output from GEMM." "--fp8-output", action="store_true", default=False, help="Get FP8 output from GEMM."
...@@ -155,9 +163,9 @@ def _parse_args(argv=None, namespace=None): ...@@ -155,9 +163,9 @@ def _parse_args(argv=None, namespace=None):
if opts.atomic: if opts.atomic:
warnings.warn("Atomic GEMM is not supported with bulk overlap.") warnings.warn("Atomic GEMM is not supported with bulk overlap.")
opts.atomic = False opts.atomic = False
if opts.fp8: if opts.quantization != "none":
warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.") warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.")
opts.fp8 = False opts.quantization = "none"
elif opts.comm_type == tex.CommOverlapType.AG: elif opts.comm_type == tex.CommOverlapType.AG:
if opts.atomic: if opts.atomic:
setattr(opts, "atomic_rs_p2p", opts.p2p) setattr(opts, "atomic_rs_p2p", opts.p2p)
...@@ -165,8 +173,11 @@ def _parse_args(argv=None, namespace=None): ...@@ -165,8 +173,11 @@ def _parse_args(argv=None, namespace=None):
if opts.atomic: if opts.atomic:
if not te.fp8.check_fp8_support(): if not te.fp8.check_fp8_support():
assert not opts.fp8, "Atomic GEMM is only supported in FP8." assert opts.quantization == "none", "Atomic GEMM is only supported in FP8."
opts.fp8 = True opts.quantization = "fp8"
if opts.fp8_output:
assert ops.quantization == "fp8", "FP8 output is only supported with FP8 compute."
return opts return opts
...@@ -303,7 +314,11 @@ def _main(opts): ...@@ -303,7 +314,11 @@ def _main(opts):
inp_shape = (opts.seq_length, opts.batch_size, hidden_size) inp_shape = (opts.seq_length, opts.batch_size, hidden_size)
outer_size = reduce(operator.mul, inp_shape[:-1], 1) outer_size = reduce(operator.mul, inp_shape[:-1], 1)
buffer_dtype = torch.bfloat16 buffer_dtype = torch.bfloat16
if opts.fp8 and not opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.AG: if (
opts.quantization != "none"
and not opts.bulk_overlap
and opts.comm_type == tex.CommOverlapType.AG
):
buffer_dtype = torch.uint8 buffer_dtype = torch.uint8
ub_obj = ( ub_obj = (
tex.CommOverlapP2P( tex.CommOverlapP2P(
...@@ -450,6 +465,8 @@ def _main(opts): ...@@ -450,6 +465,8 @@ def _main(opts):
inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable
ref2_g = torch.matmul(inp2_g, ker2_g) ref2_g = torch.matmul(inp2_g, ker2_g)
# Initialize quantizers
with_quantized_compute = opts.quantization != "none"
inp_quantizer = None inp_quantizer = None
ker_quantizer = None ker_quantizer = None
out_quantizer = None out_quantizer = None
...@@ -457,7 +474,7 @@ def _main(opts): ...@@ -457,7 +474,7 @@ def _main(opts):
inp2_quantizer = None inp2_quantizer = None
ker2_quantizer = None ker2_quantizer = None
out2_quantizer = None out2_quantizer = None
if opts.fp8: if opts.quantization == "fp8":
# Structure to maintain amax and scale/scale_inv information for the kernel and input # Structure to maintain amax and scale/scale_inv information for the kernel and input
num_gemms = 6 if ub_obj2 is not None else 3 num_gemms = 6 if ub_obj2 is not None else 3
fp8_dtype = tex.DType.kFloat8E4M3 fp8_dtype = tex.DType.kFloat8E4M3
...@@ -502,11 +519,23 @@ def _main(opts): ...@@ -502,11 +519,23 @@ def _main(opts):
out2_quantizer = Float8Quantizer( out2_quantizer = Float8Quantizer(
fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype
) )
elif opts.quantization == "mxfp8":
fp8_dtype = tex.DType.kFloat8E4M3
inp_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
ker_quantizer = MXFP8Quantizer(fp8_dtype)
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_inp_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
elif ub_obj2 is not None:
inp2_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
ker2_quantizer = MXFP8Quantizer(fp8_dtype)
# Quantize tensors
if with_quantized_compute:
# Cast input to Float8Tensor # Quantize input tensor
inp_fp8 = inp_quantizer(inp) inp_fp8 = inp_quantizer(inp)
# Cast kernel to Float8Tensor # Quantize kernel tensor
kernel_t_fp8 = ker_quantizer(kernel_t) kernel_t_fp8 = ker_quantizer(kernel_t)
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_inp_fp8 = bulk_inp_quantizer(bulk_inp) bulk_inp_fp8 = bulk_inp_quantizer(bulk_inp)
...@@ -543,31 +572,40 @@ def _main(opts): ...@@ -543,31 +572,40 @@ def _main(opts):
) )
# Set up comm/compute buffers # Set up comm/compute buffers
ag_out = None
rs_out = None rs_out = None
rs_out2 = None rs_out2 = None
if opts.comm_type == tex.CommOverlapType.AG: if opts.comm_type == tex.CommOverlapType.AG:
if opts.bulk_overlap: if opts.bulk_overlap:
ub_obj.copy_into_buffer(bulk_inp, bulk_inp_quantizer, True) ag_out, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj,
bulk_inp,
bulk_inp_quantizer,
tp_group,
)
gemm_inp = inp gemm_inp = inp
else: else:
ub_obj.copy_into_buffer(inp_fp8 if opts.fp8 else inp, inp_quantizer, True) ag_out, _ = fill_userbuffers_buffer_for_all_gather(
gemm_inp = ub_obj.get_buffer(inp_quantizer, False, inp_g.size()) ub_obj,
inp_fp8 if with_quantized_compute else inp,
inp_quantizer,
tp_group,
)
gemm_inp = ag_out
if ub_obj2 is not None: if ub_obj2 is not None:
if opts.fp8 and opts.fp8_output:
ub_obj2.set_buffer_params(out_quantizer)
rs_out2 = torch.empty( rs_out2 = torch.empty(
(outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda" (outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda"
) )
else: else:
if opts.bulk_overlap: if opts.bulk_overlap:
ub_obj.copy_into_buffer( if opts.quantization == "none":
bulk_inp_fp8 if opts.fp8 else bulk_inp, bulk_inp_quantizer, False ub_obj.copy_into_buffer(bulk_inp, local_chunk=False)
) if opts.quantization == "fp8":
if opts.fp8: ub_obj.copy_into_buffer(bulk_inp_fp8._data, local_chunk=False)
ub_obj.set_buffer_params(bulk_inp_quantizer) elif opts.quantization == "mxfp8":
elif opts.fp8 and opts.fp8_output: ub_obj.copy_into_buffer(bulk_inp_fp8._rowwise_data, local_chunk=False)
ub_obj.set_buffer_params(out_quantizer)
gemm_inp = inp_fp8 if opts.fp8 else inp gemm_inp = inp_fp8 if with_quantized_compute else inp
rs_out = torch.empty( rs_out = torch.empty(
(outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda" (outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda"
) )
...@@ -626,7 +664,7 @@ def _main(opts): ...@@ -626,7 +664,7 @@ def _main(opts):
if opts.use_cuda_graphs: if opts.use_cuda_graphs:
# Trace the CUDA graph first # Trace the CUDA graph first
g = torch.cuda.CUDAGraph() g = torch.cuda.CUDAGraph()
if opts.fp8: if with_quantized_compute:
if ub_obj is None: if ub_obj is None:
with torch.cuda.graph(g): with torch.cuda.graph(g):
all_outputs = _fp8_gemm() all_outputs = _fp8_gemm()
...@@ -646,7 +684,7 @@ def _main(opts): ...@@ -646,7 +684,7 @@ def _main(opts):
else: else:
for i in range(total_iters): for i in range(total_iters):
if opts.fp8: if with_quantized_compute:
start_events[i].record() start_events[i].record()
all_outputs = _fp8_gemm() all_outputs = _fp8_gemm()
end_events[i].record() end_events[i].record()
...@@ -691,10 +729,22 @@ def _main(opts): ...@@ -691,10 +729,22 @@ def _main(opts):
output_info = "" output_info = ""
if opts.comm_type == tex.CommOverlapType.AG: if opts.comm_type == tex.CommOverlapType.AG:
# Bulk overlap AG output is already gathered # Bulk overlap AG output is already gathered
test_out = ub_obj.get_buffer(bulk_inp_quantizer, False) test_out = ag_out
if bulk_inp_quantizer is None:
test_out = ub_obj.get_buffer(False)
else:
test_out = Float8Tensor(
shape=test_out.shape,
dtype=torch.bfloat16,
data=ub_obj.get_buffer(False),
fp8_scale=bulk_inp_quantizer.scale,
fp8_dtype=bulk_inp_quantizer.dtype,
quantizer=bulk_inp_quantizer,
)
else: else:
# Bulk overlap RS output needs to be gathered # Bulk overlap RS output needs to be gathered
out_local = ub_obj.get_buffer(bulk_inp_quantizer, True) out_local = ub_obj.get_buffer(True)
output_info += f"rs_output: {list(out_local.shape)} | " output_info += f"rs_output: {list(out_local.shape)} | "
test_out = te.distributed.gather_along_first_dim(out_local, tp_group)[0] test_out = te.distributed.gather_along_first_dim(out_local, tp_group)[0]
...@@ -765,8 +815,8 @@ def _main(opts): ...@@ -765,8 +815,8 @@ def _main(opts):
m = torch.argmax(diff) m = torch.argmax(diff)
abs_err = diff[m].item() abs_err = diff[m].item()
rel_err = abs_err / max(abs(ref_out.flatten()[m].item()), 1e-5) rel_err = abs_err / max(abs(ref_out.flatten()[m].item()), 1e-5)
rtol = 0.125 if opts.fp8 else 0.02 rtol = 0.02 if opts.quantization == "none" else 0.125
atol = 0.0625 if opts.fp8 else 0.001 atol = 0.001 if opts.quantization == "none" else 0.0625
if rel_err > rtol and abs_err > atol: if rel_err > rtol and abs_err > atol:
numerics_failed = True numerics_failed = True
numerics_info = ( numerics_info = (
......
...@@ -17,7 +17,12 @@ import torch ...@@ -17,7 +17,12 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling, Float8CurrentScaling from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
Format,
MXFP8BlockScaling,
)
warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=FutureWarning)
...@@ -163,7 +168,7 @@ def _parse_args(argv=None, namespace=None): ...@@ -163,7 +168,7 @@ def _parse_args(argv=None, namespace=None):
"--quantization", "--quantization",
type=str.lower, type=str.lower,
default="none", default="none",
choices=["none", "fp8_delayed_scaling", "fp8_current_scaling"], choices=["none", "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
help="Quantization recipe", help="Quantization recipe",
) )
parser.add_argument( parser.add_argument(
...@@ -414,6 +419,8 @@ def _train(opts): ...@@ -414,6 +419,8 @@ def _train(opts):
) )
elif opts.quantization == "fp8_current_scaling": elif opts.quantization == "fp8_current_scaling":
fp8_recipe = Float8CurrentScaling(fp8_format=fp8_format) fp8_recipe = Float8CurrentScaling(fp8_format=fp8_format)
elif opts.quantization == "mxfp8":
fp8_recipe = MXFP8BlockScaling()
# Prepare random input tensors # Prepare random input tensors
test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True) test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True)
......
...@@ -174,7 +174,7 @@ def _get_tolerances(dtype): ...@@ -174,7 +174,7 @@ def _get_tolerances(dtype):
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
return {"rtol": 1.6e-2, "atol": 1e-5} return {"rtol": 1.6e-2, "atol": 1e-5}
if dtype == torch.float32: if dtype == torch.float32:
return {"rtol": 1.3e-6, "atol": 1e-5} return {"rtol": 1.3e-6, "atol": 4e-5}
raise ValueError(f"Unsupported dtype ({dtype})") raise ValueError(f"Unsupported dtype ({dtype})")
......
...@@ -15,6 +15,9 @@ if torch.cuda.device_count() < 2: ...@@ -15,6 +15,9 @@ if torch.cuda.device_count() < 2:
pytest.skip("cast_master_weights_to_fp8 test needs at least 2 GPUs.") pytest.skip("cast_master_weights_to_fp8 test needs at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
TEST_ROOT = Path(__file__).parent.resolve() TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(2, torch.cuda.device_count()) NUM_PROCS: int = min(2, torch.cuda.device_count())
...@@ -28,8 +31,10 @@ def _run_test(quantization): ...@@ -28,8 +31,10 @@ def _run_test(quantization):
assert result.returncode == 0 assert result.returncode == 0
@pytest.mark.parametrize("quantization", ["fp8", "fp8_cs"]) @pytest.mark.parametrize("quantization", ["fp8", "fp8_cs", "fp8_block"])
def test_cast_master_weights_to_fp8(quantization): def test_cast_master_weights_to_fp8(quantization):
if not fp8_available: if quantization in ("fp8", "fp8_cs") and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if quantization == "fp8_block" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
_run_test(quantization) _run_test(quantization)
...@@ -21,6 +21,7 @@ if torch.cuda.device_count() < 2: ...@@ -21,6 +21,7 @@ if torch.cuda.device_count() < 2:
pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.") pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
RNG_SEED: int = 42 RNG_SEED: int = 42
SEQ_LENGTH: int = 1024 SEQ_LENGTH: int = 1024
...@@ -56,7 +57,7 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" ...@@ -56,7 +57,7 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
torch._dynamo.reset() torch._dynamo.reset()
def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8): def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, quantization):
test_path = TEST_ROOT / "run_gemm_with_overlap.py" test_path = TEST_ROOT / "run_gemm_with_overlap.py"
test_cmd = LAUNCH_CMD + [ test_cmd = LAUNCH_CMD + [
str(test_path), str(test_path),
...@@ -72,10 +73,11 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8): ...@@ -72,10 +73,11 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8):
if bulk: if bulk:
test_cmd.append("--bulk-overlap") test_cmd.append("--bulk-overlap")
else: else:
if fp8: if quantization == "fp8" and not fp8_available:
if not fp8_available: pytest.skip(reason_for_no_fp8)
pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available:
test_cmd.append("--fp8") pytest.skip(reason_for_no_mxfp8)
test_cmd.append(f"--quantization={quantization}")
if p2p: if p2p:
test_cmd.append("--p2p") test_cmd.append("--p2p")
if atomic: if atomic:
...@@ -114,8 +116,10 @@ def _run_layer_with_overlap( ...@@ -114,8 +116,10 @@ def _run_layer_with_overlap(
test_cmd.append("--overlap-rs-dgrad") test_cmd.append("--overlap-rs-dgrad")
if fp8: if fp8:
if not fp8_available: if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
test_cmd.append("--fp8") test_cmd.append("--fp8")
test_cmd.append(f"--quantization={quantization}") test_cmd.append(f"--quantization={quantization}")
...@@ -137,51 +141,34 @@ def _run_layer_with_overlap( ...@@ -137,51 +141,34 @@ def _run_layer_with_overlap(
raise AssertionError(result.stderr.decode()) raise AssertionError(result.stderr.decode())
@pytest.mark.parametrize( @pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
"fp8", def test_split_all_gather_overlaps(quantization):
(False, True),
ids=[" BF16 - RING-EXCHANGE ", " FP8 - RING-EXCHANGE "],
)
def test_split_all_gather_overlaps(fp8):
""" """
Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm. te.cpp_extensions.fp8_gemm.
""" """
_run_gemm_with_overlap("AG", False, True, False, fp8) _run_gemm_with_overlap("AG", False, True, False, quantization)
@pytest.mark.parametrize( @pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
"fp8,p2p", @pytest.mark.parametrize("p2p", (False, True))
[ def test_split_reduce_scatter_overlaps(quantization, p2p):
(False, False),
(False, True),
(True, False),
(True, True),
],
ids=[
" BF16 - PIPELINE ",
" BF16 - RING-EXCHANGE ",
" FP8 - PIPELINE ",
" FP8 - RING-EXCHANGE ",
],
)
def test_split_reduce_scatter_overlaps(fp8, p2p):
""" """
Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm. te.cpp_extensions.fp8_gemm.
""" """
_run_gemm_with_overlap("RS", False, p2p, False, fp8) _run_gemm_with_overlap("RS", False, p2p, False, quantization)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"comm_type, fp8, connections", "comm_type, quantization, connections",
[ [
("AG", False, 1), ("AG", "none", 1),
("RS", False, 1), ("RS", "none", 1),
("RS", True, 1), ("RS", "fp8", 1),
("AG", False, 8), ("AG", "none", 8),
("RS", False, 8), ("RS", "none", 8),
("RS", True, 8), ("RS", "fp8", 8),
], ],
ids=[ ids=[
"ALL-GATHER - BF16 - 1 connections", "ALL-GATHER - BF16 - 1 connections",
...@@ -192,7 +179,7 @@ def test_split_reduce_scatter_overlaps(fp8, p2p): ...@@ -192,7 +179,7 @@ def test_split_reduce_scatter_overlaps(fp8, p2p):
"REDUCE-SCATTER - FP8 - 8 connections", "REDUCE-SCATTER - FP8 - 8 connections",
], ],
) )
def test_bulk_overlaps(comm_type, fp8, connections): def test_bulk_overlaps(comm_type, quantization, connections):
""" """
Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm.
""" """
...@@ -203,10 +190,10 @@ def test_bulk_overlaps(comm_type, fp8, connections): ...@@ -203,10 +190,10 @@ def test_bulk_overlaps(comm_type, fp8, connections):
" 9.0 (HOPPER ARCH)." " 9.0 (HOPPER ARCH)."
) )
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
_run_gemm_with_overlap(comm_type, True, False, False, fp8) _run_gemm_with_overlap(comm_type, True, False, False, quantization)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
else: else:
_run_gemm_with_overlap(comm_type, True, False, False, fp8) _run_gemm_with_overlap(comm_type, True, False, False, quantization)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -258,15 +245,7 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d ...@@ -258,15 +245,7 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
@pytest.mark.parametrize( @pytest.mark.parametrize(
"quantization", "quantization",
["fp8_delayed_scaling", "fp8_current_scaling"], ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
ids=[" DELAYED SCALING ", " CURRENT SCALING "],
)
@pytest.mark.parametrize(
"fp8",
(True,),
ids=[
" FP8 ",
],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad", "layer_type,linear_parallel_mode,overlap_rs_dgrad",
...@@ -286,15 +265,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d ...@@ -286,15 +265,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
) )
), ),
ids=[ ids=[
f" {te.Linear.__name__} - ROW-PARALLEL ", f"{te.Linear.__name__}-row_tensor_parallel",
f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", f"{te.Linear.__name__}-col_tensor_parallel-BULK DGRAD/WGRAD",
f" {te.Linear.__name__} - COL-PARLALEL - DGRAD+RS ", f"{te.Linear.__name__}-col_tensor_parallel-DGRAD+RS",
f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ", f"{te.LayerNormLinear.__name__}-row_tensor_parallel",
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", f"{te.LayerNormLinear.__name__}-col_tensor_parallel-BULK DGRAD/WGRAD",
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ", f"{te.LayerNormLinear.__name__}-col_tensor_parallel-DGRAD+RS",
] ]
+ [ + [
" " + " - ".join(test_name_parts) + " " "-".join(test_name_parts)
for test_name_parts in zip( for test_name_parts in zip(
[layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)], [layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"] * len(TE_LAYERS[2:]), ["BULK DGRAD/WGRAD", "DGRAD+RS"] * len(TE_LAYERS[2:]),
...@@ -302,12 +281,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d ...@@ -302,12 +281,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
], ],
) )
def test_layers_with_overlap_fp8( def test_layers_with_overlap_fp8(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization layer_type,
linear_parallel_mode,
overlap_rs_dgrad,
quantization,
): ):
""" """
Test Transformer Engine layers with comm+GEMM overlap. Test Transformer Engine layers with comm+GEMM overlap.
""" """
_run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization) _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -354,22 +336,11 @@ def test_multi_layer_with_overlap_bf16( ...@@ -354,22 +336,11 @@ def test_multi_layer_with_overlap_bf16(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"quantization", "quantization",
["fp8_delayed_scaling", "fp8_current_scaling"], ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
ids=[" DELAYED SCALING ", " CURRENT SCALING "],
)
@pytest.mark.parametrize(
"fp8",
(True,),
ids=[
" FP8 ",
],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_layers", "num_layers",
(2,), (2,),
ids=[
" 2 layers ",
],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad", "layer_type,linear_parallel_mode,overlap_rs_dgrad",
...@@ -381,7 +352,7 @@ def test_multi_layer_with_overlap_bf16( ...@@ -381,7 +352,7 @@ def test_multi_layer_with_overlap_bf16(
) )
), ),
ids=[ ids=[
" " + " - ".join(test_name_parts) + " " "-".join(test_name_parts)
for test_name_parts in zip( for test_name_parts in zip(
[te.TransformerLayer.__name__ for _ in range(2)], [te.TransformerLayer.__name__ for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"], ["BULK DGRAD/WGRAD", "DGRAD+RS"],
...@@ -389,11 +360,11 @@ def test_multi_layer_with_overlap_bf16( ...@@ -389,11 +360,11 @@ def test_multi_layer_with_overlap_bf16(
], ],
) )
def test_multi_layer_with_overlap_fp8( def test_multi_layer_with_overlap_fp8(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers layer_type, linear_parallel_mode, overlap_rs_dgrad, quantization, num_layers
): ):
""" """
Test Transformer Engine layers with comm+GEMM overlap. Test Transformer Engine layers with comm+GEMM overlap.
""" """
_run_layer_with_overlap( _run_layer_with_overlap(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization, num_layers
) )
...@@ -19,7 +19,6 @@ import torch ...@@ -19,7 +19,6 @@ import torch
import transformer_engine import transformer_engine
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor from transformer_engine.pytorch.ops._common import is_float8_tensor
...@@ -27,6 +26,8 @@ from transformer_engine.pytorch.ops.fused import ( ...@@ -27,6 +26,8 @@ from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear, UserbuffersBackwardLinear,
UserbuffersForwardLinear, UserbuffersForwardLinear,
) )
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions # Import utility functions
...@@ -36,6 +37,13 @@ from utils import dtype_tols, str_to_dtype ...@@ -36,6 +37,13 @@ from utils import dtype_tols, str_to_dtype
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None]
if fp8_available:
quantization_list.append("fp8")
if mxfp8_available:
quantization_list.append("mxfp8")
# Check if there are multiple GPUs # Check if there are multiple GPUs
if torch.cuda.device_count() < 2: if torch.cuda.device_count() < 2:
...@@ -51,7 +59,7 @@ class ModelConfig: ...@@ -51,7 +59,7 @@ class ModelConfig:
num_heads: int num_heads: int
head_dim: int head_dim: int
dtype: torch.dtype dtype: torch.dtype
fp8: bool quantization: Optional[str]
@property @property
def hidden_size(self): def hidden_size(self):
...@@ -129,12 +137,16 @@ def make_reference_and_test_tensors( ...@@ -129,12 +137,16 @@ def make_reference_and_test_tensors(
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Make copy of tensor # Make copy of tensor
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8: if test_is_fp8:
test = Float8Tensor.to_float8(ref) quantizer = Float8Quantizer(
else: scale=torch.ones(1, dtype=torch.float32, device=test_device),
test = ref.to(device=test_device, dtype=test_dtype) amax=torch.zeros(1, dtype=torch.float32, device=test_device),
if test.data_ptr() == ref.data_ptr(): fp8_dtype=tex.DType.kFloat8E4M3,
test = test.clone() )
test = quantizer(test)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
# Make sure reference and test tensors represent exact same values # Make sure reference and test tensors represent exact same values
ref.copy_(test) ref.copy_(test)
...@@ -145,6 +157,21 @@ def make_reference_and_test_tensors( ...@@ -145,6 +157,21 @@ def make_reference_and_test_tensors(
return ref, test return ref, test
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
def _test_linear( def _test_linear(
*, *,
model_config: ModelConfig, model_config: ModelConfig,
...@@ -155,7 +182,8 @@ def _test_linear( ...@@ -155,7 +182,8 @@ def _test_linear(
weight_requires_grad: bool = True, weight_requires_grad: bool = True,
) -> None: ) -> None:
dtype = model_config.dtype dtype = model_config.dtype
fp8_compute = model_config.fp8 quantization = model_config.quantization
quantized_compute = quantization is not None
# Distributed process group # Distributed process group
process_group = world_group() process_group = world_group()
...@@ -175,14 +203,19 @@ def _test_linear( ...@@ -175,14 +203,19 @@ def _test_linear(
in_shape, in_shape,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8_compute, test_is_fp8=quantized_compute,
) )
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors( w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features), (out_features, in_features),
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8_compute, test_is_fp8=quantized_compute,
) )
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
b_ref, b_test = None, None b_ref, b_test = None, None
if bias: if bias:
if tensor_parallel_mode == "row": if tensor_parallel_mode == "row":
...@@ -198,9 +231,11 @@ def _test_linear( ...@@ -198,9 +231,11 @@ def _test_linear(
out_shape, out_shape,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=fp8_compute, test_is_fp8=quantized_compute,
requires_grad=False, requires_grad=False,
) )
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation # Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref) y_ref = torch.nn.functional.linear(x_ref, w_ref)
...@@ -265,21 +300,15 @@ def _test_linear( ...@@ -265,21 +300,15 @@ def _test_linear(
x_test.requires_grad_() x_test.requires_grad_()
# Implementation with fusible operation # Implementation with fusible operation
with te.fp8_model_init(enabled=fp8_compute): recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_compute, recipe=recipe):
ops = [] ops = []
linear_op = None linear_op = None
bias_op = None bias_op = None
if tensor_parallel_mode == "column": if tensor_parallel_mode == "column":
userbuffers_options = {} userbuffers_options = {}
if not weight_requires_grad: if not weight_requires_grad:
if fp8_compute: userbuffers_options["comm_name"] = "fc1"
userbuffers_options["comm_name"] = "fc1"
else:
# There is a correctness bug with overlapping
# dgrad reduce-scatter with dgrad GEMM. Fall back
# to overlapping dgrad reduce-scatter with wgrad
# GEMM, even though wgrad isn't needed.
userbuffers_options["comm_name"] = "qkv"
else: else:
userbuffers_options["comm_name"] = "qkv" userbuffers_options["comm_name"] = "qkv"
linear_op = te_ops.BasicLinear( linear_op = te_ops.BasicLinear(
...@@ -322,7 +351,7 @@ def _test_linear( ...@@ -322,7 +351,7 @@ def _test_linear(
bias_op.bias.copy_(b_test) bias_op.bias.copy_(b_test)
del w_test del w_test
del b_test del b_test
with te.fp8_autocast(enabled=fp8_compute): with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = model(x_test) y_test = model(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -338,7 +367,7 @@ def _test_linear( ...@@ -338,7 +367,7 @@ def _test_linear(
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
if dtype == torch.float32: if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM tols = dtype_tols(torch.float16) # TF32 GEMM
if fp8_compute: if quantized_compute:
tols = dtype_tols( tols = dtype_tols(
model[0].weight._fp8_dtype model[0].weight._fp8_dtype
if is_float8_tensor(model[0].weight) if is_float8_tensor(model[0].weight)
...@@ -370,7 +399,7 @@ def run_parallel_tests(model_config: ModelConfig) -> None: ...@@ -370,7 +399,7 @@ def run_parallel_tests(model_config: ModelConfig) -> None:
for test_config in itertools.product( for test_config in itertools.product(
(False, True), # bias (False, True), # bias
("column", "row"), # tensor_parallel_mode ("column", "row"), # tensor_parallel_mode
(False, True), # weight_requires_grad (True, False), # weight_requires_grad
): ):
if rank == 0: if rank == 0:
print(f"Running _test_linear with {test_config=}") print(f"Running _test_linear with {test_config=}")
...@@ -390,19 +419,15 @@ if torch.cuda.device_count() > 1: ...@@ -390,19 +419,15 @@ if torch.cuda.device_count() > 1:
@pytest.mark.parametrize("world_size", _world_sizes) @pytest.mark.parametrize("world_size", _world_sizes)
@pytest.mark.parametrize("fp8", (False, True)) @pytest.mark.parametrize("quantization", quantization_list)
def test_fuser_ops_with_userbuffers( def test_fuser_ops_with_userbuffers(
*, *,
world_size: int, world_size: int,
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
fp8: bool, quantization: Optional[str],
) -> None: ) -> None:
"""Launch parallel job and run tests""" """Launch parallel job and run tests"""
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
# Parallel job launcher # Parallel job launcher
command = [] command = []
if tex.ubuf_built_with_mpi(): if tex.ubuf_built_with_mpi():
...@@ -424,8 +449,8 @@ def test_fuser_ops_with_userbuffers( ...@@ -424,8 +449,8 @@ def test_fuser_ops_with_userbuffers(
str(dtype), str(dtype),
) )
) )
if fp8: if quantization is not None:
command.append("--fp8") command.extend(("--quantization", quantization))
# Environment # Environment
env = dict(os.environ) env = dict(os.environ)
...@@ -445,12 +470,12 @@ def main() -> None: ...@@ -445,12 +470,12 @@ def main() -> None:
# Parse command-line arguments # Parse command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--parallel", action="store_true", help="Run parallel tests") parser.add_argument("--parallel", action="store_true", help="Run parallel tests")
parser.add_argument("--sequence-length", type=int, default=32) parser.add_argument("--sequence-length", type=int, default=256)
parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--num-heads", type=int, default=16) parser.add_argument("--num-heads", type=int, default=16)
parser.add_argument("--head-dim", type=int, default=32) parser.add_argument("--head-dim", type=int, default=256)
parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--fp8", action="store_true") parser.add_argument("--quantization", type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
# Run parallel tests if needed # Run parallel tests if needed
...@@ -463,14 +488,17 @@ def main() -> None: ...@@ -463,14 +488,17 @@ def main() -> None:
num_heads=args.num_heads, num_heads=args.num_heads,
head_dim=args.head_dim, head_dim=args.head_dim,
dtype=str_to_dtype(args.dtype), dtype=str_to_dtype(args.dtype),
fp8=args.fp8, quantization=args.quantization,
) )
# Initialize Userbuffers # Initialize Userbuffers
group = world_group() # Initialize NCCL group = world_group() # Initialize NCCL
bootstrap_backend = "mpi" if launcher() == "ompi" else "nccl" bootstrap_backend = "mpi" if launcher() == "ompi" else "nccl"
userbuffer_configs = { userbuffer_configs = {
"fc1_dgrad": {"method": "pipeline"}, # Overlap dgrad RS with dgrad GEMM "fc1_dgrad": {
"method": "ring_exchange",
"fp8_buf": False,
}, # Overlap dgrad RS with dgrad GEMM
} }
te.module.base.initialize_ub( te.module.base.initialize_ub(
[ [
...@@ -478,7 +506,7 @@ def main() -> None: ...@@ -478,7 +506,7 @@ def main() -> None:
model_config.num_heads * model_config.head_dim, model_config.num_heads * model_config.head_dim,
], ],
torch.distributed.get_world_size(group), torch.distributed.get_world_size(group),
use_fp8=model_config.fp8, use_fp8=model_config.quantization is not None,
dtype=model_config.dtype, dtype=model_config.dtype,
bootstrap_backend=bootstrap_backend, bootstrap_backend=bootstrap_backend,
ub_cfgs=userbuffer_configs, ub_cfgs=userbuffer_configs,
......
...@@ -2,12 +2,16 @@ ...@@ -2,12 +2,16 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import os, sys, logging import os
import sys
import logging
from contextlib import nullcontext from contextlib import nullcontext
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import get_cu_seqlens_on_cp_rank from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
get_cu_seqlens_on_cp_rank,
)
import transformer_engine_torch as tex import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_autocast
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
import functools
import logging import logging
import math import math
import os import os
from importlib.metadata import version
from typing import Any, Dict, List, Tuple, Union, Optional from typing import Any, Dict, List, Tuple, Union, Optional
from contextlib import contextmanager from contextlib import contextmanager
...@@ -16,26 +13,22 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION ...@@ -16,26 +13,22 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init
from transformer_engine.pytorch.attention import ( from transformer_engine.pytorch.attention.dot_product_attention import (
DotProductAttention, DotProductAttention,
MultiheadAttention,
_attention_backends, _attention_backends,
) )
from transformer_engine.pytorch.dot_product_attention.utils import ( from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
FlashAttentionUtils, FlashAttentionUtils,
get_attention_backend, get_attention_backend,
check_set_window_size, check_set_window_size,
AttentionParams, AttentionParams,
) )
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding from transformer_engine.pytorch.attention import RotaryPositionEmbedding
from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import ( from transformer_engine.pytorch.cpp_extensions.fused_attn import (
AttnBiasType,
AttnMaskType,
FusedAttnBackend, FusedAttnBackend,
QKVLayout,
fused_attn_bwd, fused_attn_bwd,
fused_attn_fwd, fused_attn_fwd,
) )
...@@ -50,9 +43,7 @@ from transformer_engine.pytorch.utils import ( ...@@ -50,9 +43,7 @@ from transformer_engine.pytorch.utils import (
) )
from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.utils import get_cudnn_version
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import NVTE_Fused_Attn_Backend
from transformer_engine.pytorch.tensor.quantized_tensor import ( from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor,
Quantizer, Quantizer,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
...@@ -1659,8 +1650,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP ...@@ -1659,8 +1650,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
) )
if is_training: if is_training:
out.backward(out_grad) out.backward(out_grad)
param_names = [] param_names = []
param_names.append("hidden_states.grad") param_names.append("hidden_states.grad")
...@@ -1910,8 +1901,8 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training): ...@@ -1910,8 +1901,8 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
checkpoint_core_attention=False, checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
) )
if is_training: if is_training:
out.backward(out_grad) out.backward(out_grad)
if is_training: if is_training:
return out, (inp[0].grad, inp[1].grad, inp[2].grad) return out, (inp[0].grad, inp[1].grad, inp[2].grad)
...@@ -2024,7 +2015,7 @@ def _run_custom_mha_fp8(dtype, config, backend): ...@@ -2024,7 +2015,7 @@ def _run_custom_mha_fp8(dtype, config, backend):
mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda") mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
out = mha(inp, cu_seqlens, config.max_seqlen_q) out = mha(inp, cu_seqlens, config.max_seqlen_q)
out.backward(out_grad) out.backward(out_grad)
out = torch.load("out.pt") out = torch.load("out.pt")
dqkv = torch.load("dqkv.pt") dqkv = torch.load("dqkv.pt")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment