Commit f8c2af4c authored by yuguo's avatar yuguo
Browse files

Merge commit '1d903f5e' of...

Merge commit '1d903f5e' of https://github.com/NVIDIA/TransformerEngine
parents e92773a3 1d903f5e
......@@ -49,16 +49,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
return;
}
Tensor input("input", { N, H }, itype);
Tensor z("z", { N, H }, otype);
Tensor gamma("gamma", { H }, wtype);
Tensor beta("beta", { H }, wtype);
Tensor mu("mu", { N }, DType::kFloat32);
Tensor rsigma("rsigma", { N }, DType::kFloat32);
Tensor dz("dz", { N, H }, wtype);
Tensor dx("dx", { N, H }, itype);
Tensor dgamma("dgamma", { H }, wtype);
Tensor dbeta("dbeta", { H }, wtype);
Tensor input("input", std::vector<size_t>{ N, H }, itype);
Tensor z("z", std::vector<size_t>{ N, H }, otype);
Tensor gamma("gamma", std::vector<size_t>{ H }, wtype);
Tensor beta("beta", std::vector<size_t>{ H }, wtype);
Tensor mu("mu", std::vector<size_t>{ N }, DType::kFloat32);
Tensor rsigma("rsigma", std::vector<size_t>{ N }, DType::kFloat32);
Tensor dz("dz", std::vector<size_t>{ N, H }, wtype);
Tensor dx("dx", std::vector<size_t>{ N, H }, itype);
Tensor dgamma("dgamma", std::vector<size_t>{ H }, wtype);
Tensor dbeta("dbeta", std::vector<size_t>{ H }, wtype);
Tensor workspace_fwd, workspace_bwd;
fillUniform(&input);
......
......@@ -116,12 +116,12 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
DType wtype = TypeInfo<WeightType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
Tensor input("input", { N, H }, itype);
Tensor z("z", { N, H }, otype, true, is_training, NVTE_MXFP8_1D_SCALING);
Tensor gamma("gamma", { H }, wtype);
Tensor beta("beta", { H }, wtype);
Tensor mu("mu", { N }, DType::kFloat32);
Tensor rsigma("rsigma", { N }, DType::kFloat32);
Tensor input("input", std::vector<size_t>{ N, H }, itype);
Tensor z("z", std::vector<size_t>{ N, H }, otype, true, is_training, NVTE_MXFP8_1D_SCALING);
Tensor gamma("gamma", std::vector<size_t>{ H }, wtype);
Tensor beta("beta", std::vector<size_t>{ H }, wtype);
Tensor mu("mu", std::vector<size_t>{ N }, DType::kFloat32);
Tensor rsigma("rsigma", std::vector<size_t>{ N }, DType::kFloat32);
Tensor workspace;
......@@ -164,7 +164,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
nvte_enable_zero_centered_gamma_in_weight_dtype(false);
}
Tensor dequantized_output("dequantized_output", { N, H }, DType::kFloat32, true, true);
Tensor dequantized_output("dequantized_output", std::vector<size_t>{ N, H }, DType::kFloat32, true, true);
dequantize_2x<OutputType, fp8e8m0>(z, dequantized_output, is_training);
......
......@@ -58,8 +58,8 @@ void performTestQ(const size_t N) {
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
Tensor input("input", { N }, itype);
Tensor output("output", { N }, otype);
Tensor input("input", std::vector<size_t>{ N }, itype);
Tensor output("output", std::vector<size_t>{ N }, otype);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N);
......@@ -89,8 +89,8 @@ void performTestDQ(const size_t N) {
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;
Tensor input("input", { N }, itype);
Tensor output("output", { N }, otype);
Tensor input("input", std::vector<size_t>{ N }, itype);
Tensor output("output", std::vector<size_t>{ N }, otype);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N);
......
......@@ -37,8 +37,8 @@ void performTest(const size_t N, const size_t H) {
DType dtype = TypeInfo<Type>::dtype;
Tensor input("input", { N, H }, dtype);
Tensor output("output", { H, N }, dtype);
Tensor input("input", std::vector<size_t>{ N, H }, dtype);
Tensor output("output", std::vector<size_t>{ H, N }, dtype);
std::unique_ptr<Type[]> ref_output = std::make_unique<Type[]>(N * H);
......
......@@ -783,8 +783,6 @@ void fillUniform(Tensor *t) {
template<typename InputEncoding, InputsFillCase Case>
void fillCase_special(Tensor *t) {
const size_t size = product(t->rowwise_shape());
const size_t rows = t->rowwise_shape().data[0];
const size_t cols = t->rowwise_shape().data[1];
if constexpr (Case == InputsFillCase::zeros) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, {
......@@ -804,9 +802,7 @@ void fillCase_special(Tensor *t) {
std::uniform_real_distribution<> dis_sign(-1.0, 1.0);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, {
InputType *data = t->rowwise_cpu_dptr<InputType>();
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
const size_t idx = i * cols + j;
for (size_t idx = 0; idx < size; ++idx) {
const bool is_negative = (dis_sign(t->gen()) < 0.0);
double val = dis(t->gen());
if (is_negative) {
......@@ -814,7 +810,6 @@ void fillCase_special(Tensor *t) {
}
data[idx] = static_cast<InputType>(val);
}
}
});
}
t->set_scale_inv(1.0);
......
......@@ -52,6 +52,7 @@ struct BytesToType<8> {
};
using byte = uint8_t;
using int16 = int16_t;
using int32 = int32_t;
using int64 = int64_t;
using fp32 = float;
......@@ -70,6 +71,7 @@ using fp8e8m0 = uint8_t;
template <typename T>
struct TypeInfo{
using types = std::tuple<byte,
int16,
int32,
int64,
fp32,
......
......@@ -4,6 +4,7 @@
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import jit, value_and_grad
from functools import reduce
......@@ -18,11 +19,16 @@ from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu, _jax_quantize_dact_dbias
from transformer_engine.jax.cpp_extensions.normalization import _jax_layernorm, _jax_rmsnorm
from transformer_engine.jax.cpp_extensions.normalization import (
_jax_layernorm,
_jax_rmsnorm,
is_norm_zero_centered_gamma_in_weight_dtype,
)
from transformer_engine.jax.cpp_extensions.quantization import (
_jax_quantize,
_jax_quantize_dbias,
)
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
from transformer_engine.jax import cpp_extensions as tex
from transformer_engine.jax.quantize import (
DelayedScaleQuantizer,
......@@ -33,7 +39,7 @@ from transformer_engine.jax.quantize import (
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense, grouped_dense
from transformer_engine.jax.dense import dense
from transformer_engine.jax.layernorm_dense import layernorm_dense
from transformer_engine.jax.quantize import ScaledTensor1x, ScaledTensor2x
......@@ -54,6 +60,7 @@ supported_scaling_modes = []
""" Find supported scaling modes"""
if is_fp8_supported:
supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING)
supported_scaling_modes.append(ScalingMode.CURRENT_TENSOR_SCALING)
if is_mxfp8_supported:
supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING)
......@@ -71,8 +78,19 @@ def is_shape_supported_by_mxfp8(input_shape):
def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
assert_allclose(a.data, b.data)
assert a.scaling_mode == b.scaling_mode
assert a.scale_inv.dtype == b.scale_inv.dtype
if a.scaling_mode.is_tensor_scaling():
# Assert in dq_dtype as some unfused codepaths have an intermediate cast
# to an input dtype which reduces precision compared to everything in fp32
assert_allclose(a.scale_inv, b.scale_inv, dtype=a.dq_dtype)
elif a.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
# Compare MXFP8 scales as uint8
assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8))
else:
raise ValueError(f"Unsupported scaling mode {a.scaling_mode}")
assert_allclose(a.data, b.data)
elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x):
assert_bitwise_scaled_tensors(a.rowwise_tensor, b.rowwise_tensor)
assert_bitwise_scaled_tensors(a.colwise_tensor, b.colwise_tensor)
......@@ -159,7 +177,12 @@ class TestActivation:
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, output_type):
@pytest_parametrize_wrapper(
"scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_act_grad_with_tensor_scaling_fp8(
self, random_inputs, activation_type, output_type, scaling_mode
):
x = random_inputs
x = jnp.expand_dims(x, axis=-2)
x = jnp.repeat(x, len(activation_type), axis=-2)
......@@ -170,7 +193,7 @@ class TestActivation:
)
quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
scaling_mode=scaling_mode,
q_dtype=output_type,
q_layout=QuantizeLayout.ROWWISE,
)
......@@ -188,8 +211,11 @@ class TestActivation:
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_act_forward_with_delayed_scaling_fp8(
self, random_inputs, activation_type, output_type, q_layout
@pytest_parametrize_wrapper(
"scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_act_forward_with_tensor_scaling_fp8(
self, random_inputs, activation_type, output_type, q_layout, scaling_mode
):
x = random_inputs
x = jnp.expand_dims(x, axis=-2)
......@@ -198,7 +224,7 @@ class TestActivation:
te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
scaling_mode=scaling_mode,
q_dtype=output_type,
q_layout=q_layout,
)
......@@ -335,8 +361,20 @@ class TestNorm:
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_norm_grad_with_delayed_scaling_fp8(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout
@pytest_parametrize_wrapper(
"scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_norm_grad_with_tensor_scaling_fp8(
self,
n,
hidden,
norm_type,
zero_centered_gamma,
epsilon,
inp_dtype,
out_dtype,
q_layout,
scaling_mode,
):
"""
Test transformer_engine.jax.layernorm.layernorm
......@@ -345,9 +383,7 @@ class TestNorm:
pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
q_dtype=out_dtype,
q_layout=q_layout,
scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
)
self._test_norm_grad(
n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
......@@ -395,7 +431,41 @@ class TestNorm:
)
ref_mu = None
precise_comparison = True
if get_cudnn_version() < (9, 10, 0) and scaling_mode == ScalingMode.MXFP8_1D_SCALING:
# Reduce precision of test as we don't use fused norm below this version CuDNN for MXFP8 and instead
# do an unfused norm and quantize with an intermediate cast into in_dtype which can reduce precision
precise_comparison = False
elif is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode):
# Larger tolerances as our JAX implementation _jax_*norm uses the compute dtype float32
# for zero-centered gamma always
precise_comparison = False
elif scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING and inp_dtype != jnp.float32:
# Current implementation of Current Tensor Scaling performs unfused layernorm and quantization
# and writes intermediate results into the input dtype, which will slightly reduce precision
# if the input dtype is not float32
precise_comparison = False
if precise_comparison:
assert_bitwise_scaled_tensors(output, ref_out)
else:
if isinstance(ref_out, ScaledTensor1x):
assert_allclose(output.dequantize(), ref_out.dequantize(), dtype=out_dtype)
elif isinstance(ref_out, ScaledTensor2x):
assert_allclose(
output.rowwise_tensor.dequantize(),
ref_out.rowwise_tensor.dequantize(),
dtype=out_dtype,
)
assert_allclose(
output.colwise_tensor.dequantize(),
ref_out.colwise_tensor.dequantize(),
dtype=out_dtype,
)
else:
pytest.fail("Unsupported output type")
assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype)
if norm_type == "layernorm":
assert_allclose(mu, ref_mu, dtype=inp_dtype)
......@@ -406,8 +476,20 @@ class TestNorm:
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_norm_forward_with_delayed_scaling_fp8(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout
@pytest_parametrize_wrapper(
"scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_norm_forward_with_tensor_scaling_fp8(
self,
n,
hidden,
norm_type,
zero_centered_gamma,
epsilon,
inp_dtype,
out_dtype,
q_layout,
scaling_mode,
):
if norm_type == "rmsnorm" and zero_centered_gamma is True:
pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
......@@ -420,7 +502,7 @@ class TestNorm:
epsilon=epsilon,
inp_dtype=inp_dtype,
out_dtype=out_dtype,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
scaling_mode=scaling_mode,
q_layout=q_layout,
)
......@@ -447,17 +529,24 @@ QUANTIZE_OUTPUT_DTYPES = {
"L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
ALL_QUANTIZE_TEST_SHAPES = [
(32, 64),
(2, 64, 32),
ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [
((32, 64), -1),
((2, 64, 32), -1),
((2, 64, 32), -2),
((32, 256, 128), -1),
((32, 256, 128), -2),
((64, 32, 32, 256), -1),
((64, 32, 32, 256), -2),
((64, 32, 32, 256), -3),
]
QUANTIZE_TEST_SHAPES = {
QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = {
"L0": [
(32, 256, 128),
(64, 32, 32, 256),
((32, 64), -1),
((2, 64, 32), -1),
((2, 64, 32), -2),
],
"L2": ALL_QUANTIZE_TEST_SHAPES,
"L2": ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES,
}
QUANTIZATION_INPUT_DTYPE = {
......@@ -469,9 +558,8 @@ QUANTIZATION_INPUT_DTYPE = {
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("input_shape", ALL_QUANTIZE_TEST_SHAPES)
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("flatten_axis", [-1, -2])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
)
......@@ -524,12 +612,11 @@ class TestFusedQuantize:
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("input_shape", QUANTIZE_TEST_SHAPES)
@pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
@pytest_parametrize_wrapper("flatten_axis", [-1, -2])
def test_quantize_dbias(
self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis
):
......@@ -538,6 +625,12 @@ class TestFusedQuantize:
):
pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")
if (flatten_axis < 0 and flatten_axis + len(input_shape) <= 0) or flatten_axis <= 0:
pytest.skip(
f"Flatten axis {flatten_axis} is not supported for input shape {input_shape}. There"
" must be at least one axis on either side of the flatten_axis split."
)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
......@@ -630,16 +723,19 @@ class TestFusedQuantize:
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper("is_dbias", [True, False])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_quantize_dact_dbias_delayed_scaling(
self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout
@pytest_parametrize_wrapper(
"scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_quantize_dact_dbias_tensor_scaling(
self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout, scaling_mode
):
self._test_quantize_dact_dbias(
in_dtype=in_dtype,
input_shape=input_shape,
out_dtype=out_dtype,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
scaling_mode=scaling_mode,
activation_type=activation_type,
is_dbias=is_dbias,
q_layout=q_layout,
......@@ -830,7 +926,10 @@ class TestFusedDense:
Test layernorm_dense VJP Rule
"""
# No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if q_dtype == jnp.float8_e5m2 and scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm
......@@ -916,7 +1015,10 @@ class TestFusedDense:
Test layernorm_mlp VJP Rule
"""
# No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if q_dtype == jnp.float8_e5m2 and scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm
......@@ -1052,7 +1154,7 @@ fwd_bwd_dtypes = [
[jnp.float8_e5m2, jnp.float8_e4m3fn],
]
"""
@pytest_parametrize_wrapper(
"shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]]
)
......@@ -1267,3 +1369,4 @@ class TestGroupedDense:
assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype)
"""
......@@ -34,6 +34,7 @@ is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
SUPPORTED_RECIPES = []
if is_fp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling"))
if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
......@@ -76,6 +77,8 @@ class TestDistributedLayernorm:
other_bytes = 0
if fp8_recipe == recipe.MXFP8BlockScaling() and "dp" in mesh_axes:
other_bytes = 384 # required for small scale shapes that require padding
if fp8_recipe == recipe.Float8CurrentScaling():
allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction
return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes
)
......
......@@ -41,6 +41,7 @@ is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
SUPPORTED_RECIPES = []
if is_fp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling"))
if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
......@@ -217,37 +218,10 @@ class TestDistributedLayernormMLP:
m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close"
)
else:
is_gated = len(activation_type) > 1
rtol = None
atol = None
if is_gated:
if dtype == jnp.bfloat16:
if i == 2:
rtol = 800
atol = 9e-2
if i == 4:
atol = 300
rtol = 1e-1
if dtype == jnp.float16:
if i == 1: # gamma
rtol = 200
atol = 1e-2
if i == 2:
rtol = 2000
atol = 7e-2
if i == 4 and fp8_recipe == recipe.MXFP8BlockScaling(): # bias_1
# Accumulating dbias across a large tensor introduces a larger difference
rtol = 200
atol = 4e-2
if i == 4 and fp8_recipe == recipe.DelayedScaling():
rtol = 2200
atol = 9e-2
assert_allclose(
multi_grads[i],
single_grads[i],
dtype=dtype,
rtol=rtol,
atol=atol,
err_msg=f"multi_grads[{i}] is not close",
)
......
......@@ -10,47 +10,22 @@ import jax.numpy as jnp
import numpy as np
from utils import assert_allclose
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8CurrentScaling
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
from transformer_engine.jax.quantize import QuantizeConfig, is_fp8_available, AmaxComputeAlgo
from transformer_engine.jax.quantize import (
QuantizeConfig,
is_fp8_available,
ScalingMode,
update_collections,
)
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
class TestQuantizeConfig(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_initialize(self):
margin = 5.0
fp8_format = FP8Format.E4M3
amax_history_len = 10
QuantizeConfig.initialize(
margin=margin, fp8_format=fp8_format, amax_history_len=amax_history_len
)
self.assertEqual(
QuantizeConfig.MARGIN,
margin,
f"QuantizeConfig.MARGIN initialization failed, should be {margin}"
f" but got {QuantizeConfig.MARGIN}.",
)
self.assertEqual(
QuantizeConfig.FP8_FORMAT,
fp8_format,
f"QuantizeConfig.FP8_FORMAT initialization failed, should be {fp8_format}"
f" but got {QuantizeConfig.FP8_FORMAT}.",
)
self.assertEqual(
QuantizeConfig.AMAX_HISTORY_LEN,
amax_history_len,
f"QuantizeConfig.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
f" but got {QuantizeConfig.AMAX_HISTORY_LEN}.",
)
QuantizeConfig.finalize()
class TestHelper(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_update_collections(self):
......@@ -61,19 +36,19 @@ class TestQuantizeConfig(unittest.TestCase):
"test1": original_val,
"test2": original_val,
}
updated_state = QuantizeConfig.update_collections({"test1": updated_val}, original_state)
updated_state = update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val)
original_state = flax.core.frozen_dict.FrozenDict(original_state)
updated_state = QuantizeConfig.update_collections({"test1": updated_val}, original_state)
updated_state = update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val)
class TestFP8Functions(unittest.TestCase):
def _check_defult_state(self):
def _check_default_state(self):
self.assertFalse(QuantizeConfig.is_fp8_enabled())
def _compare_delay_scaling(self, ref, test):
......@@ -82,35 +57,92 @@ class TestFP8Functions(unittest.TestCase):
self.assertTrue(ref.amax_history_len == test.amax_history_len)
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
def _compare_current_scaling(self, test):
self.assertEqual(QuantizeConfig.MARGIN, test.margin)
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING)
def _compare_mxfp8_scaling(self, test):
self.assertEqual(QuantizeConfig.MARGIN, test.margin)
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING)
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast(self):
def test_fp8_autocast_delayed_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
self._check_default_state()
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
self.assertFalse(QuantizeConfig.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), DelayedScaling())
self._check_default_state()
self._check_defult_state()
self._check_default_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_defult_state()
self._check_default_state()
ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_defult_state()
self._check_default_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state()
with fp8_autocast(enabled=False, fp8_recipe=Float8CurrentScaling()):
self._check_default_state()
self._check_default_state()
cs = Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=cs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_default_state()
cs = Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=cs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_default_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state()
with fp8_autocast(enabled=False, fp8_recipe=MXFP8BlockScaling()):
self._check_default_state()
self._check_default_state()
bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=bs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_default_state()
bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=bs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_default_state()
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_with_sharding_resource(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
self._check_default_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
......@@ -130,4 +162,4 @@ class TestFP8Functions(unittest.TestCase):
self._compare_delay_scaling(get_delayed_scaling(), ds)
self.assertEqual(sr, global_mesh_resource())
self._check_defult_state()
self._check_default_state()
......@@ -13,7 +13,6 @@ import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import combine_masks
from jax import lax, vmap
from jax import nn as jax_nn
......@@ -97,16 +96,16 @@ def combine_biases(*masks: Optional[Array]):
return mask
def parameterize_by_test_level(param_dict: dict, id_prefix: str = ""):
def get_parameters_for_test_level(param_dict: dict):
"""
Takes an input dictionary of parameters keyed by test type "L0", etc.
Returns a list of pytest parameters to be used in a parameterized test for the current test type
Returns the parameters for the test level specified in the environment variable
"""
DEFAULT_TEST_LEVEL = "L0"
test_level = os.environ.get("NVTE_JAX_UNITTEST_LEVEL", DEFAULT_TEST_LEVEL)
if test_level not in param_dict:
raise ValueError("Unsupported test level")
return values_to_named_params(param_dict[test_level], id_prefix)
return param_dict[test_level]
def value_to_test_name_str(value):
......@@ -139,14 +138,18 @@ def pytest_parametrize_wrapper(param_name, param_values):
A wrapper for pytest.mark.parametrize to allow for automatic
naming of tests based on the parameter values.
"""
id_prefix = param_name
if isinstance(param_values, dict):
param_values = parameterize_by_test_level(param_values, id_prefix=param_name)
elif "," not in param_name:
param_values = values_to_named_params(param_values, id_prefix=id_prefix)
# If the values are split into a dictionary of test-levels, e.g. "L0", etc.,
# unwrap the selected level before proceeding.
param_values = get_parameters_for_test_level(param_values)
if "," not in param_name:
# Multi-parameterize annotations are not supported in this wrapper
# and are just a passthrough to default pytest.mark.parametrize.
# E.g. @pytest_parametrize_wrapper("a,b", ((a_value1, b_value1), (a_value2, b_value2)))
# will be passed through to pytest.mark.parametrize as-is without pytest.param ids.
param_values = values_to_named_params(param_values, id_prefix=param_name)
# Currently comma separated parameters in one parametrize call aren't supported for automatic naming
# and will just be passed through with default pytest names
def decorator(func):
return pytest.mark.parametrize(param_name, param_values)(func)
......@@ -312,16 +315,22 @@ class DenseGeneral(nn.Module):
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features))
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_param_shape, self.dtype, axes=self.kernel_axes
kernel = self.param(
"kernel",
nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
kernel_param_shape,
self.dtype,
)
kernel = jnp.asarray(kernel, input_dtype)
kernel = jnp.reshape(kernel, kernel_shape)
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, self.features, self.dtype, axes=self.bias_axes
bias = self.param(
"bias",
nn.with_logical_partitioning(self.bias_init, self.bias_axes),
self.features,
self.dtype,
)
bias = bias.astype(input_dtype)
else:
......@@ -418,9 +427,9 @@ class MlpBlock(nn.Module):
) # Broadcast along length.
if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "mlp"))
x = nn.with_logical_constraint(x, ("length", "batch", "mlp"))
else:
x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "mlp"))
x = nn.with_logical_constraint(x, ("batch", "length", "mlp"))
output = DenseGeneral(
inputs.shape[-1],
dtype=self.dtype,
......@@ -684,21 +693,13 @@ class MultiHeadAttention(nn.Module):
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
if self.transpose_batch_sequence:
query = nn_partitioning.with_sharding_constraint(
query, ("length", "batch", "heads", "kv")
)
key = nn_partitioning.with_sharding_constraint(key, ("length", "batch", "heads", "kv"))
value = nn_partitioning.with_sharding_constraint(
value, ("length", "batch", "heads", "kv")
)
query = nn.with_logical_constraint(query, ("length", "batch", "heads", "kv"))
key = nn.with_logical_constraint(key, ("length", "batch", "heads", "kv"))
value = nn.with_logical_constraint(value, ("length", "batch", "heads", "kv"))
else:
query = nn_partitioning.with_sharding_constraint(
query, ("batch", "length", "heads", "kv")
)
key = nn_partitioning.with_sharding_constraint(key, ("batch", "length", "heads", "kv"))
value = nn_partitioning.with_sharding_constraint(
value, ("batch", "length", "heads", "kv")
)
query = nn.with_logical_constraint(query, ("batch", "length", "heads", "kv"))
key = nn.with_logical_constraint(key, ("batch", "length", "heads", "kv"))
value = nn.with_logical_constraint(value, ("batch", "length", "heads", "kv"))
if decode:
# Detect if we're initializing by absence of existing cache data.
......@@ -805,9 +806,9 @@ class MultiHeadAttention(nn.Module):
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "joined_kv"))
x = nn.with_logical_constraint(x, ("length", "batch", "joined_kv"))
else:
x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "joined_kv"))
x = nn.with_logical_constraint(x, ("batch", "length", "joined_kv"))
# Back to the original inputs dimensions.
......@@ -853,8 +854,11 @@ class LayerNorm(nn.Module):
input_dtype = x.dtype
features = x.shape[-1]
scale = nn_partitioning.param_with_axes(
"scale", self.scale_init, (features,), self.dtype, axes=("embed",)
scale = self.param(
"scale",
nn.with_logical_partitioning(self.scale_init, ("embed",)),
(features,),
self.dtype,
)
x_ = x.astype(jnp.float32)
if self.layernorm_type == "layernorm":
......@@ -862,8 +866,11 @@ class LayerNorm(nn.Module):
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
y = (x_ - mean) * lax.rsqrt(var + self.epsilon)
bias = nn_partitioning.param_with_axes(
"ln_bias", self.bias_init, (features,), self.dtype, axes=("embed",)
bias = self.param(
"ln_bias",
nn.with_logical_partitioning(self.bias_init, ("embed",)),
(features,),
self.dtype,
)
bias = jnp.asarray(bias, input_dtype)
......@@ -972,12 +979,11 @@ class RelativePositionBiases(nn.Module):
num_buckets=self.num_buckets,
max_distance=self.max_distance,
)
relative_attention_bias = nn_partitioning.param_with_axes(
relative_attention_bias = self.param(
"rel_embedding",
self.embedding_init,
nn.with_logical_partitioning(self.embedding_init, ("heads", "relpos_buckets")),
(self.num_heads, self.num_buckets),
jnp.float32,
axes=("heads", "relpos_buckets"),
)
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
......@@ -1555,14 +1561,16 @@ def sync_params_values(dst, src, transformations, sep="/"):
"""
src_values = {}
for key, value in jax.tree_util.tree_leaves_with_path(src):
normalized_key = sep.join(x.key for x in key)
# Only select DictKey(key="...") entries, skip GetAttr(name="...") entries at the end of the tree path
normalized_key = sep.join(x.key for x in key if hasattr(x, "key"))
src_values[normalized_key] = value
flatten_dst, dst_tree_def = jax.tree_util.tree_flatten_with_path(dst)
synced_dst_values = []
for key, value in flatten_dst:
normalized_key = sep.join(x.key for x in key)
# Only select DictKey(key="...") entries, skip GetAttr(name="...") entries at the end of the tree path
normalized_key = sep.join(x.key for x in key if hasattr(x, "key"))
if normalized_key in transformations:
corresponding_src_key = transformations[normalized_key]
else:
......
......@@ -16,6 +16,7 @@ import torch.distributed as dist
from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
Float8BlockScaling,
Format,
Recipe,
)
......@@ -26,6 +27,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor
def _get_raw_data(quantized_tensor):
......@@ -34,6 +36,14 @@ def _get_raw_data(quantized_tensor):
assert hasattr(quantized_tensor, "_data"), "Float8Tensor does not have _data attribute"
assert quantized_tensor._data.dtype == torch.uint8, "Float8Tensor _data must be uint8"
return quantized_tensor._data
elif isinstance(quantized_tensor, Float8BlockwiseQTensor):
assert hasattr(
quantized_tensor, "_rowwise_data"
), "Float8BlockwiseQTensor does not have _rowwise_data attribute"
assert (
quantized_tensor._rowwise_data.dtype == torch.uint8
), "Float8BlockwiseQTensor _rowwise_data must be uint8"
return quantized_tensor._rowwise_data
else:
raise ValueError(f"Unsupported quantized tensor type: {type(quantized_tensor)}")
......@@ -435,15 +445,15 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
preserve_high_precision_init_val=True,
):
model_fp8 = nn.Sequential(
te.Linear(128, 256, **linear_kwargs),
te.Linear(256, 256 * 3, **linear_kwargs),
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Create model with BF16 weights
model = nn.Sequential(
te.Linear(128, 256, **linear_kwargs),
te.Linear(256, 256 * 3, **linear_kwargs),
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
......@@ -539,12 +549,13 @@ def _test_zero_1(dp_group):
def quantization_recipe(quantization) -> Recipe:
"""Quantization recipe setup"""
fp8_format = Format.HYBRID
if quantization == "fp8":
return DelayedScaling(
fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max"
)
return DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
elif quantization == "fp8_cs":
return Float8CurrentScaling()
return Float8CurrentScaling(fp8_format=fp8_format)
elif quantization == "fp8_block":
return Float8BlockScaling(fp8_format=fp8_format)
else:
raise ValueError(f"Unsupported quantization: {quantization}")
......@@ -568,15 +579,15 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
preserve_high_precision_init_val=True,
):
model_fp8 = nn.Sequential(
te.Linear(128, 256, **linear_kwargs),
te.Linear(256, 256 * 3, **linear_kwargs),
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Create model with BF16 weights
model = nn.Sequential(
te.Linear(128, 256, **linear_kwargs),
te.Linear(256, 256 * 3, **linear_kwargs),
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
......@@ -593,7 +604,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
optimizer_fp8 = MiniZero_1([w for w in model_fp8.parameters()], 10.0, dp_group)
optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group)
for _ in range(100):
for i in range(100):
for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
w_fp8.main_grad.zero_()
w.main_grad.zero_()
......@@ -654,7 +665,9 @@ def main(argv=None, namespace=None):
dist.init_process_group(**dist_init_kwargs)
parser = argparse.ArgumentParser()
parser.add_argument("--quantization", type=str, default=None, choices=["fp8", "fp8_cs"])
parser.add_argument(
"--quantization", type=str, default=None, choices=["fp8", "fp8_cs", "fp8_block"]
)
args = parser.parse_args(argv, namespace)
dp_group = dist.new_group(backend="nccl")
......
......@@ -21,7 +21,11 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_cublas_workspace_size_bytes
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.module.base import (
fill_userbuffers_buffer_for_all_gather,
get_cublas_workspace_size_bytes,
)
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
......@@ -57,7 +61,11 @@ def _parse_args(argv=None, namespace=None):
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
"--quantization",
type=str.lower,
default="none",
choices=["none", "fp8", "mxfp8"],
help="Quantization recipe",
)
parser.add_argument(
"--fp8-output", action="store_true", default=False, help="Get FP8 output from GEMM."
......@@ -155,9 +163,9 @@ def _parse_args(argv=None, namespace=None):
if opts.atomic:
warnings.warn("Atomic GEMM is not supported with bulk overlap.")
opts.atomic = False
if opts.fp8:
if opts.quantization != "none":
warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.")
opts.fp8 = False
opts.quantization = "none"
elif opts.comm_type == tex.CommOverlapType.AG:
if opts.atomic:
setattr(opts, "atomic_rs_p2p", opts.p2p)
......@@ -165,8 +173,11 @@ def _parse_args(argv=None, namespace=None):
if opts.atomic:
if not te.fp8.check_fp8_support():
assert not opts.fp8, "Atomic GEMM is only supported in FP8."
opts.fp8 = True
assert opts.quantization == "none", "Atomic GEMM is only supported in FP8."
opts.quantization = "fp8"
if opts.fp8_output:
assert ops.quantization == "fp8", "FP8 output is only supported with FP8 compute."
return opts
......@@ -303,7 +314,11 @@ def _main(opts):
inp_shape = (opts.seq_length, opts.batch_size, hidden_size)
outer_size = reduce(operator.mul, inp_shape[:-1], 1)
buffer_dtype = torch.bfloat16
if opts.fp8 and not opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.AG:
if (
opts.quantization != "none"
and not opts.bulk_overlap
and opts.comm_type == tex.CommOverlapType.AG
):
buffer_dtype = torch.uint8
ub_obj = (
tex.CommOverlapP2P(
......@@ -450,6 +465,8 @@ def _main(opts):
inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable
ref2_g = torch.matmul(inp2_g, ker2_g)
# Initialize quantizers
with_quantized_compute = opts.quantization != "none"
inp_quantizer = None
ker_quantizer = None
out_quantizer = None
......@@ -457,7 +474,7 @@ def _main(opts):
inp2_quantizer = None
ker2_quantizer = None
out2_quantizer = None
if opts.fp8:
if opts.quantization == "fp8":
# Structure to maintain amax and scale/scale_inv information for the kernel and input
num_gemms = 6 if ub_obj2 is not None else 3
fp8_dtype = tex.DType.kFloat8E4M3
......@@ -502,11 +519,23 @@ def _main(opts):
out2_quantizer = Float8Quantizer(
fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype
)
elif opts.quantization == "mxfp8":
fp8_dtype = tex.DType.kFloat8E4M3
inp_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
ker_quantizer = MXFP8Quantizer(fp8_dtype)
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_inp_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
elif ub_obj2 is not None:
inp2_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
ker2_quantizer = MXFP8Quantizer(fp8_dtype)
# Cast input to Float8Tensor
# Quantize tensors
if with_quantized_compute:
# Quantize input tensor
inp_fp8 = inp_quantizer(inp)
# Cast kernel to Float8Tensor
# Quantize kernel tensor
kernel_t_fp8 = ker_quantizer(kernel_t)
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_inp_fp8 = bulk_inp_quantizer(bulk_inp)
......@@ -543,31 +572,40 @@ def _main(opts):
)
# Set up comm/compute buffers
ag_out = None
rs_out = None
rs_out2 = None
if opts.comm_type == tex.CommOverlapType.AG:
if opts.bulk_overlap:
ub_obj.copy_into_buffer(bulk_inp, bulk_inp_quantizer, True)
ag_out, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj,
bulk_inp,
bulk_inp_quantizer,
tp_group,
)
gemm_inp = inp
else:
ub_obj.copy_into_buffer(inp_fp8 if opts.fp8 else inp, inp_quantizer, True)
gemm_inp = ub_obj.get_buffer(inp_quantizer, False, inp_g.size())
ag_out, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj,
inp_fp8 if with_quantized_compute else inp,
inp_quantizer,
tp_group,
)
gemm_inp = ag_out
if ub_obj2 is not None:
if opts.fp8 and opts.fp8_output:
ub_obj2.set_buffer_params(out_quantizer)
rs_out2 = torch.empty(
(outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda"
)
else:
if opts.bulk_overlap:
ub_obj.copy_into_buffer(
bulk_inp_fp8 if opts.fp8 else bulk_inp, bulk_inp_quantizer, False
)
if opts.fp8:
ub_obj.set_buffer_params(bulk_inp_quantizer)
elif opts.fp8 and opts.fp8_output:
ub_obj.set_buffer_params(out_quantizer)
gemm_inp = inp_fp8 if opts.fp8 else inp
if opts.quantization == "none":
ub_obj.copy_into_buffer(bulk_inp, local_chunk=False)
if opts.quantization == "fp8":
ub_obj.copy_into_buffer(bulk_inp_fp8._data, local_chunk=False)
elif opts.quantization == "mxfp8":
ub_obj.copy_into_buffer(bulk_inp_fp8._rowwise_data, local_chunk=False)
gemm_inp = inp_fp8 if with_quantized_compute else inp
rs_out = torch.empty(
(outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda"
)
......@@ -626,7 +664,7 @@ def _main(opts):
if opts.use_cuda_graphs:
# Trace the CUDA graph first
g = torch.cuda.CUDAGraph()
if opts.fp8:
if with_quantized_compute:
if ub_obj is None:
with torch.cuda.graph(g):
all_outputs = _fp8_gemm()
......@@ -646,7 +684,7 @@ def _main(opts):
else:
for i in range(total_iters):
if opts.fp8:
if with_quantized_compute:
start_events[i].record()
all_outputs = _fp8_gemm()
end_events[i].record()
......@@ -691,10 +729,22 @@ def _main(opts):
output_info = ""
if opts.comm_type == tex.CommOverlapType.AG:
# Bulk overlap AG output is already gathered
test_out = ub_obj.get_buffer(bulk_inp_quantizer, False)
test_out = ag_out
if bulk_inp_quantizer is None:
test_out = ub_obj.get_buffer(False)
else:
test_out = Float8Tensor(
shape=test_out.shape,
dtype=torch.bfloat16,
data=ub_obj.get_buffer(False),
fp8_scale=bulk_inp_quantizer.scale,
fp8_dtype=bulk_inp_quantizer.dtype,
quantizer=bulk_inp_quantizer,
)
else:
# Bulk overlap RS output needs to be gathered
out_local = ub_obj.get_buffer(bulk_inp_quantizer, True)
out_local = ub_obj.get_buffer(True)
output_info += f"rs_output: {list(out_local.shape)} | "
test_out = te.distributed.gather_along_first_dim(out_local, tp_group)[0]
......@@ -765,8 +815,8 @@ def _main(opts):
m = torch.argmax(diff)
abs_err = diff[m].item()
rel_err = abs_err / max(abs(ref_out.flatten()[m].item()), 1e-5)
rtol = 0.125 if opts.fp8 else 0.02
atol = 0.0625 if opts.fp8 else 0.001
rtol = 0.02 if opts.quantization == "none" else 0.125
atol = 0.001 if opts.quantization == "none" else 0.0625
if rel_err > rtol and abs_err > atol:
numerics_failed = True
numerics_info = (
......
......@@ -17,7 +17,12 @@ import torch
import torch.distributed as dist
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling, Float8CurrentScaling
from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
Format,
MXFP8BlockScaling,
)
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
......@@ -163,7 +168,7 @@ def _parse_args(argv=None, namespace=None):
"--quantization",
type=str.lower,
default="none",
choices=["none", "fp8_delayed_scaling", "fp8_current_scaling"],
choices=["none", "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
help="Quantization recipe",
)
parser.add_argument(
......@@ -414,6 +419,8 @@ def _train(opts):
)
elif opts.quantization == "fp8_current_scaling":
fp8_recipe = Float8CurrentScaling(fp8_format=fp8_format)
elif opts.quantization == "mxfp8":
fp8_recipe = MXFP8BlockScaling()
# Prepare random input tensors
test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True)
......
......@@ -174,7 +174,7 @@ def _get_tolerances(dtype):
if dtype == torch.bfloat16:
return {"rtol": 1.6e-2, "atol": 1e-5}
if dtype == torch.float32:
return {"rtol": 1.3e-6, "atol": 1e-5}
return {"rtol": 1.3e-6, "atol": 4e-5}
raise ValueError(f"Unsupported dtype ({dtype})")
......
......@@ -15,6 +15,9 @@ if torch.cuda.device_count() < 2:
pytest.skip("cast_master_weights_to_fp8 test needs at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(2, torch.cuda.device_count())
......@@ -28,8 +31,10 @@ def _run_test(quantization):
assert result.returncode == 0
@pytest.mark.parametrize("quantization", ["fp8", "fp8_cs"])
@pytest.mark.parametrize("quantization", ["fp8", "fp8_cs", "fp8_block"])
def test_cast_master_weights_to_fp8(quantization):
if not fp8_available:
if quantization in ("fp8", "fp8_cs") and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "fp8_block" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
_run_test(quantization)
......@@ -21,6 +21,7 @@ if torch.cuda.device_count() < 2:
pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
RNG_SEED: int = 42
SEQ_LENGTH: int = 1024
......@@ -56,7 +57,7 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
torch._dynamo.reset()
def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8):
def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, quantization):
test_path = TEST_ROOT / "run_gemm_with_overlap.py"
test_cmd = LAUNCH_CMD + [
str(test_path),
......@@ -72,10 +73,11 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8):
if bulk:
test_cmd.append("--bulk-overlap")
else:
if fp8:
if not fp8_available:
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
test_cmd.append("--fp8")
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
test_cmd.append(f"--quantization={quantization}")
if p2p:
test_cmd.append("--p2p")
if atomic:
......@@ -114,8 +116,10 @@ def _run_layer_with_overlap(
test_cmd.append("--overlap-rs-dgrad")
if fp8:
if not fp8_available:
if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
test_cmd.append("--fp8")
test_cmd.append(f"--quantization={quantization}")
......@@ -137,51 +141,34 @@ def _run_layer_with_overlap(
raise AssertionError(result.stderr.decode())
@pytest.mark.parametrize(
"fp8",
(False, True),
ids=[" BF16 - RING-EXCHANGE ", " FP8 - RING-EXCHANGE "],
)
def test_split_all_gather_overlaps(fp8):
@pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
def test_split_all_gather_overlaps(quantization):
"""
Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap("AG", False, True, False, fp8)
_run_gemm_with_overlap("AG", False, True, False, quantization)
@pytest.mark.parametrize(
"fp8,p2p",
[
(False, False),
(False, True),
(True, False),
(True, True),
],
ids=[
" BF16 - PIPELINE ",
" BF16 - RING-EXCHANGE ",
" FP8 - PIPELINE ",
" FP8 - RING-EXCHANGE ",
],
)
def test_split_reduce_scatter_overlaps(fp8, p2p):
@pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
@pytest.mark.parametrize("p2p", (False, True))
def test_split_reduce_scatter_overlaps(quantization, p2p):
"""
Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap("RS", False, p2p, False, fp8)
_run_gemm_with_overlap("RS", False, p2p, False, quantization)
@pytest.mark.parametrize(
"comm_type, fp8, connections",
"comm_type, quantization, connections",
[
("AG", False, 1),
("RS", False, 1),
("RS", True, 1),
("AG", False, 8),
("RS", False, 8),
("RS", True, 8),
("AG", "none", 1),
("RS", "none", 1),
("RS", "fp8", 1),
("AG", "none", 8),
("RS", "none", 8),
("RS", "fp8", 8),
],
ids=[
"ALL-GATHER - BF16 - 1 connections",
......@@ -192,7 +179,7 @@ def test_split_reduce_scatter_overlaps(fp8, p2p):
"REDUCE-SCATTER - FP8 - 8 connections",
],
)
def test_bulk_overlaps(comm_type, fp8, connections):
def test_bulk_overlaps(comm_type, quantization, connections):
"""
Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm.
"""
......@@ -203,10 +190,10 @@ def test_bulk_overlaps(comm_type, fp8, connections):
" 9.0 (HOPPER ARCH)."
)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
_run_gemm_with_overlap(comm_type, True, False, False, fp8)
_run_gemm_with_overlap(comm_type, True, False, False, quantization)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
else:
_run_gemm_with_overlap(comm_type, True, False, False, fp8)
_run_gemm_with_overlap(comm_type, True, False, False, quantization)
@pytest.mark.parametrize(
......@@ -258,15 +245,7 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
@pytest.mark.parametrize(
"quantization",
["fp8_delayed_scaling", "fp8_current_scaling"],
ids=[" DELAYED SCALING ", " CURRENT SCALING "],
)
@pytest.mark.parametrize(
"fp8",
(True,),
ids=[
" FP8 ",
],
["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
)
@pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad",
......@@ -286,15 +265,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
)
),
ids=[
f" {te.Linear.__name__} - ROW-PARALLEL ",
f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ",
f" {te.Linear.__name__} - COL-PARLALEL - DGRAD+RS ",
f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ",
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ",
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ",
f"{te.Linear.__name__}-row_tensor_parallel",
f"{te.Linear.__name__}-col_tensor_parallel-BULK DGRAD/WGRAD",
f"{te.Linear.__name__}-col_tensor_parallel-DGRAD+RS",
f"{te.LayerNormLinear.__name__}-row_tensor_parallel",
f"{te.LayerNormLinear.__name__}-col_tensor_parallel-BULK DGRAD/WGRAD",
f"{te.LayerNormLinear.__name__}-col_tensor_parallel-DGRAD+RS",
]
+ [
" " + " - ".join(test_name_parts) + " "
"-".join(test_name_parts)
for test_name_parts in zip(
[layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"] * len(TE_LAYERS[2:]),
......@@ -302,12 +281,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
],
)
def test_layers_with_overlap_fp8(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization
layer_type,
linear_parallel_mode,
overlap_rs_dgrad,
quantization,
):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization)
_run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization)
@pytest.mark.parametrize(
......@@ -354,22 +336,11 @@ def test_multi_layer_with_overlap_bf16(
@pytest.mark.parametrize(
"quantization",
["fp8_delayed_scaling", "fp8_current_scaling"],
ids=[" DELAYED SCALING ", " CURRENT SCALING "],
)
@pytest.mark.parametrize(
"fp8",
(True,),
ids=[
" FP8 ",
],
["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
)
@pytest.mark.parametrize(
"num_layers",
(2,),
ids=[
" 2 layers ",
],
)
@pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad",
......@@ -381,7 +352,7 @@ def test_multi_layer_with_overlap_bf16(
)
),
ids=[
" " + " - ".join(test_name_parts) + " "
"-".join(test_name_parts)
for test_name_parts in zip(
[te.TransformerLayer.__name__ for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"],
......@@ -389,11 +360,11 @@ def test_multi_layer_with_overlap_bf16(
],
)
def test_multi_layer_with_overlap_fp8(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers
layer_type, linear_parallel_mode, overlap_rs_dgrad, quantization, num_layers
):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers
layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization, num_layers
)
......@@ -19,7 +19,6 @@ import torch
import transformer_engine
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
......@@ -27,6 +26,8 @@ from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear,
UserbuffersForwardLinear,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions
......@@ -36,6 +37,13 @@ from utils import dtype_tols, str_to_dtype
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None]
if fp8_available:
quantization_list.append("fp8")
if mxfp8_available:
quantization_list.append("mxfp8")
# Check if there are multiple GPUs
if torch.cuda.device_count() < 2:
......@@ -51,7 +59,7 @@ class ModelConfig:
num_heads: int
head_dim: int
dtype: torch.dtype
fp8: bool
quantization: Optional[str]
@property
def hidden_size(self):
......@@ -129,11 +137,15 @@ def make_reference_and_test_tensors(
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Make copy of tensor
if test_is_fp8:
test = Float8Tensor.to_float8(ref)
else:
test = ref.to(device=test_device, dtype=test_dtype)
if test.data_ptr() == ref.data_ptr():
if test_is_fp8:
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device),
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
test = quantizer(test)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
# Make sure reference and test tensors represent exact same values
......@@ -145,6 +157,21 @@ def make_reference_and_test_tensors(
return ref, test
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
def _test_linear(
*,
model_config: ModelConfig,
......@@ -155,7 +182,8 @@ def _test_linear(
weight_requires_grad: bool = True,
) -> None:
dtype = model_config.dtype
fp8_compute = model_config.fp8
quantization = model_config.quantization
quantized_compute = quantization is not None
# Distributed process group
process_group = world_group()
......@@ -175,14 +203,19 @@ def _test_linear(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_compute,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_compute,
test_is_fp8=quantized_compute,
)
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
b_ref, b_test = None, None
if bias:
if tensor_parallel_mode == "row":
......@@ -198,9 +231,11 @@ def _test_linear(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_compute,
test_is_fp8=quantized_compute,
requires_grad=False,
)
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
......@@ -265,21 +300,15 @@ def _test_linear(
x_test.requires_grad_()
# Implementation with fusible operation
with te.fp8_model_init(enabled=fp8_compute):
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_compute, recipe=recipe):
ops = []
linear_op = None
bias_op = None
if tensor_parallel_mode == "column":
userbuffers_options = {}
if not weight_requires_grad:
if fp8_compute:
userbuffers_options["comm_name"] = "fc1"
else:
# There is a correctness bug with overlapping
# dgrad reduce-scatter with dgrad GEMM. Fall back
# to overlapping dgrad reduce-scatter with wgrad
# GEMM, even though wgrad isn't needed.
userbuffers_options["comm_name"] = "qkv"
else:
userbuffers_options["comm_name"] = "qkv"
linear_op = te_ops.BasicLinear(
......@@ -322,7 +351,7 @@ def _test_linear(
bias_op.bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=fp8_compute):
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
......@@ -338,7 +367,7 @@ def _test_linear(
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if fp8_compute:
if quantized_compute:
tols = dtype_tols(
model[0].weight._fp8_dtype
if is_float8_tensor(model[0].weight)
......@@ -370,7 +399,7 @@ def run_parallel_tests(model_config: ModelConfig) -> None:
for test_config in itertools.product(
(False, True), # bias
("column", "row"), # tensor_parallel_mode
(False, True), # weight_requires_grad
(True, False), # weight_requires_grad
):
if rank == 0:
print(f"Running _test_linear with {test_config=}")
......@@ -390,19 +419,15 @@ if torch.cuda.device_count() > 1:
@pytest.mark.parametrize("world_size", _world_sizes)
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("quantization", quantization_list)
def test_fuser_ops_with_userbuffers(
*,
world_size: int,
dtype: torch.dtype = torch.bfloat16,
fp8: bool,
quantization: Optional[str],
) -> None:
"""Launch parallel job and run tests"""
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
# Parallel job launcher
command = []
if tex.ubuf_built_with_mpi():
......@@ -424,8 +449,8 @@ def test_fuser_ops_with_userbuffers(
str(dtype),
)
)
if fp8:
command.append("--fp8")
if quantization is not None:
command.extend(("--quantization", quantization))
# Environment
env = dict(os.environ)
......@@ -445,12 +470,12 @@ def main() -> None:
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--parallel", action="store_true", help="Run parallel tests")
parser.add_argument("--sequence-length", type=int, default=32)
parser.add_argument("--sequence-length", type=int, default=256)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--num-heads", type=int, default=16)
parser.add_argument("--head-dim", type=int, default=32)
parser.add_argument("--head-dim", type=int, default=256)
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--fp8", action="store_true")
parser.add_argument("--quantization", type=str, default=None)
args = parser.parse_args()
# Run parallel tests if needed
......@@ -463,14 +488,17 @@ def main() -> None:
num_heads=args.num_heads,
head_dim=args.head_dim,
dtype=str_to_dtype(args.dtype),
fp8=args.fp8,
quantization=args.quantization,
)
# Initialize Userbuffers
group = world_group() # Initialize NCCL
bootstrap_backend = "mpi" if launcher() == "ompi" else "nccl"
userbuffer_configs = {
"fc1_dgrad": {"method": "pipeline"}, # Overlap dgrad RS with dgrad GEMM
"fc1_dgrad": {
"method": "ring_exchange",
"fp8_buf": False,
}, # Overlap dgrad RS with dgrad GEMM
}
te.module.base.initialize_ub(
[
......@@ -478,7 +506,7 @@ def main() -> None:
model_config.num_heads * model_config.head_dim,
],
torch.distributed.get_world_size(group),
use_fp8=model_config.fp8,
use_fp8=model_config.quantization is not None,
dtype=model_config.dtype,
bootstrap_backend=bootstrap_backend,
ub_cfgs=userbuffer_configs,
......
......@@ -2,12 +2,16 @@
#
# See LICENSE for license information.
import os, sys, logging
import os
import sys
import logging
from contextlib import nullcontext
import torch
import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import get_cu_seqlens_on_cp_rank
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
get_cu_seqlens_on_cp_rank,
)
import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch.fp8 import fp8_autocast
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import functools
import logging
import math
import os
from importlib.metadata import version
from typing import Any, Dict, List, Tuple, Union, Optional
from contextlib import contextmanager
......@@ -16,26 +13,22 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.common import recipe
from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init
from transformer_engine.pytorch.attention import (
from transformer_engine.pytorch.attention.dot_product_attention import (
DotProductAttention,
MultiheadAttention,
_attention_backends,
)
from transformer_engine.pytorch.dot_product_attention.utils import (
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
FlashAttentionUtils,
get_attention_backend,
check_set_window_size,
AttentionParams,
)
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
AttnBiasType,
AttnMaskType,
FusedAttnBackend,
QKVLayout,
fused_attn_bwd,
fused_attn_fwd,
)
......@@ -50,9 +43,7 @@ from transformer_engine.pytorch.utils import (
)
from transformer_engine.pytorch.utils import get_cudnn_version
import transformer_engine_torch as tex
from transformer_engine_torch import NVTE_Fused_Attn_Backend
from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor,
Quantizer,
prepare_for_saving,
restore_from_saved,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment