Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
......@@ -79,6 +79,8 @@ using fp8e8m0 = uint8_t;
using int8 = int8_t;
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
using fp4e2m1x2 = __nv_fp4x2_e2m1;
using fp4e2m1x4 = __nv_fp4x4_e2m1;
#endif
template <typename T>
......@@ -240,7 +242,9 @@ class Tensor {
float scale() const {
if(scale_cpu_data_) {
NVTE_CHECK(tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING, "Invalid scaling_mode!");
NVTE_CHECK((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)
|| (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING),
"Invalid scaling_mode!");
to_cpu();
return *scale_cpu_data_;
} else {
......@@ -254,6 +258,8 @@ class Tensor {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat8E4M3, "Invalid type!");
} else {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!");
}
......@@ -267,6 +273,8 @@ class Tensor {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kFloat8E4M3, "Invalid type!");
} else {
NVTE_CHECK(TypeInfo<T>::dtype == DType::kByte, "Invalid type!");
}
......@@ -321,10 +329,10 @@ constexpr uint32_t FP32_EXPONENT_BIAS = 127;
constexpr uint32_t FP32_MANTISSA_BITS = 23;
// [128,4] rowwise and [4,128] colwise alignment requirement
constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_rowwise = 128;
constexpr size_t scale_tensor_alignment_X_colwise = 128;
constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_colwise = 4;
constexpr size_t scale_tensor_alignment_X_colwise = 128;
inline size_t divide_round_up(const size_t N, const size_t M) {
return (N - 1 + M) / M;
......@@ -473,13 +481,15 @@ void compareResults(const std::string &name, const float test, const float ref,
double atol = 1e-5, double rtol = 1e-8);
void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref,
size_t N, float mismatch_rate_tol = 0.);
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
template <typename T>
void compare_scaling_factors(const std::string &name, const T *test, const T *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
size_t& mismatches_num,
const size_t scale_diff_abs_tolerance = 0,
const double abs_tolerable_mismatches_limit = 0,
const double rel_tolerable_mismatches_limit = 0);
std::array<size_t, 4> get_scale_tensor_dims(const size_t rows, const size_t cols,
const size_t block_size_rows, const size_t block_size_cols);
......@@ -501,6 +511,7 @@ const std::string& caseName(InputsFillCase type);
extern std::vector<DType> all_fp_types;
bool isFp8Type(DType type);
bool isFp4Type(DType type);
int32_t getDeviceComputeCapability();
constexpr int32_t hopperComputeCapability = 90;
......@@ -578,7 +589,7 @@ constexpr int32_t blackwellComputeCapability = 100;
SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \
default: \
printf("dtype: %d\n", static_cast<int>(dtype)); \
NVTE_ERROR("Invalid type MARKED TEST."); \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \
......@@ -597,7 +608,7 @@ constexpr int32_t blackwellComputeCapability = 100;
} \
break; \
default: \
NVTE_ERROR("Invalid type MARKED TEST 2."); \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \
......@@ -605,7 +616,7 @@ constexpr int32_t blackwellComputeCapability = 100;
using namespace transformer_engine; \
SWITCH_FP4_HANDLE(type, __VA_ARGS__) \
default: \
NVTE_ERROR("Invalid type MARKED TEST 3."); \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \
......@@ -630,5 +641,5 @@ constexpr int32_t blackwellComputeCapability = 100;
} \
break; \
default: \
NVTE_ERROR("Invalid type MARKED TEST 4."); \
NVTE_ERROR("Invalid type."); \
}
......@@ -69,6 +69,34 @@ bool IsMulticastSupported(int device_id) {
return supported;
}
int GetDeviceComputeCapability(int device_id) {
int major{};
int minor{};
CHECK_CU(cuDeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device_id));
CHECK_CU(cuDeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device_id));
return major * 10 + minor;
}
template <typename T>
bool IsDTypeSupported(int /* device_id */) {
return true;
}
template <>
bool IsDTypeSupported<test::fp8e5m2>(int device_id) {
return GetDeviceComputeCapability(device_id) >= 89;
}
template <>
bool IsDTypeSupported<test::fp8e4m3>(int device_id) {
return GetDeviceComputeCapability(device_id) >= 89;
}
template <typename... Ts>
bool AllDTypesSupported(int device_id) {
return (IsDTypeSupported<Ts>(device_id) && ...);
}
template <typename T>
std::vector<T> CopyMatrix(const std::vector<T>& data, size_t mstart, size_t nstart, size_t msize,
size_t nsize, size_t ld) {
......@@ -161,6 +189,9 @@ class CommGemmFixure : public ::testing::TestWithParam<Params> {
template <typename AType, typename BType, typename DType, typename BiasType>
void Run(bool transa, bool transb, size_t m, size_t n, size_t k, float tol) {
if (!AllDTypesSupported<AType, BType, DType, BiasType>(rank_))
GTEST_SKIP() << "FP8 is not supported on device " << rank_;
cudaStream_t stream{};
NVTE_CHECK_CUDA(cudaStreamCreate(&stream));
......
......@@ -17,14 +17,6 @@ from utils import assert_allclose, is_devices_enough
def generate_configs():
configs = []
if is_devices_enough(2):
configs.append(
pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1")
)
configs.append(
pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2")
)
if is_devices_enough(4):
configs.append(
pytest.param(
......@@ -32,10 +24,17 @@ def generate_configs():
(2, 2),
("dp", "tpsp"),
MeshResource(dp_resource="dp", tpsp_resource="tpsp"),
id=f"n4_dp2_tp2",
id="n4_dp2_tp2",
)
)
if is_devices_enough(2):
configs.append(
pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1")
)
configs.append(
pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2"),
)
return configs
......
......@@ -40,11 +40,13 @@ from transformer_engine.jax.quantize import (
QuantizerFactory,
QuantizeLayout,
noop_quantizer_set,
should_use_rht,
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense, grouped_dense
from transformer_engine.jax.layernorm_dense import layernorm_dense
from transformer_engine.common import recipe
GEMM_CASES = [
(256, 256, 512),
......@@ -56,16 +58,23 @@ GEMM_CASES = [
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
LN_CASES = [(256, 128), (128, 256)]
DTYPES = [jnp.bfloat16, jnp.float32]
is_fp8_supported, fp8_unsupported_reason = helper.is_fp8_available()
is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
supported_scaling_modes = []
# TODO(Phuong): remove unneccessary pytest skips
is_fp8_supported, fp8_unsupported_reason = helper.is_scaling_mode_supported(
ScalingMode.DELAYED_TENSOR_SCALING
)
is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_scaling_mode_supported(
ScalingMode.MXFP8_1D_SCALING
)
is_fp4_supported, fp4_unsupported_reason = helper.is_scaling_mode_supported(
ScalingMode.NVFP4_1D_SCALING
)
""" Find supported scaling modes"""
if is_fp8_supported:
supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING)
supported_scaling_modes.append(ScalingMode.CURRENT_TENSOR_SCALING)
if is_mxfp8_supported:
supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING)
supported_scaling_modes = helper.get_supported_scaling_modes()
non_fp4_supported_scaling_modes = [s for s in supported_scaling_modes if not s.is_nvfp4_scaling]
supported_recipes = helper.get_supported_quantization_recipes()
supported_recipes = [pytest.param(r, id=r.__class__.__name__) for r in supported_recipes]
def is_shape_supported_by_mxfp8(input_shape):
......@@ -83,12 +92,13 @@ def assert_bitwise_scaled_tensors(
a: ScaledTensor, b: ScaledTensor, precise_comparison: bool = True
):
if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
if not precise_comparison:
if not precise_comparison and not a.scaling_mode.is_nvfp4_scaling:
assert_allclose(a.dequantize(), b.dequantize(), dtype=a.data.dtype)
return
assert a.scaling_mode == b.scaling_mode
assert a.scale_inv.dtype == b.scale_inv.dtype
assert a.data_layout == b.data_layout
if a.scaling_mode.is_tensor_scaling():
# Assert in dq_dtype as some unfused codepaths have an intermediate cast
# to an input dtype which reduces precision compared to everything in fp32
......@@ -96,6 +106,16 @@ def assert_bitwise_scaled_tensors(
elif a.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
# Compare MXFP8 scales as uint8
assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8))
elif a.scaling_mode.is_nvfp4_scaling:
assert_allclose(a.amax, b.amax)
assert_allclose(a.scale_inv, b.scale_inv)
if not precise_comparison:
mismatch = a.data != b.data
mismatch_fraction = jnp.mean(mismatch.astype(jnp.float32))
assert (
mismatch_fraction < 0.05
), f"Mismatch fraction {mismatch_fraction} is too high"
return
else:
raise ValueError(f"Unsupported scaling mode {a.scaling_mode}")
assert_allclose(a.data, b.data)
......@@ -170,6 +190,7 @@ ALL_ACTIVATION_TYPES = [
("quick_gelu", "linear"),
("squared_relu",),
("squared_relu", "linear"),
("clamped_silu", "clamped_linear"),
]
ACTIVATION_TYPES = {
......@@ -182,17 +203,21 @@ ACTIVATION_TYPES = {
class TestActivation:
def ref_act(self, x, activation_type):
return _jax_act_lu(x, activation_type).data
def ref_act(self, x, activation_type, act_params):
return _jax_act_lu(x, activation_type, act_params=act_params).data
def value_n_grad_ref_func(self, x, activation_type):
def value_n_grad_ref_func(self, x, activation_type, act_params):
jitted_reference = jit(
value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type)), (0,))
value_and_grad(
lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,)
)
)
return jitted_reference(x)
def primitive_func(self, inputs, activation_type, quantizer):
out = activation(inputs, activation_type=activation_type, quantizer=quantizer)
def primitive_func(self, inputs, activation_type, quantizer, act_params):
out = activation(
inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params
)
return jnp.mean(out)
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
......@@ -209,12 +234,20 @@ class TestActivation:
x = jnp.repeat(x, len(activation_type), axis=-2)
value_n_grad_primitive_func = jit(
value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3)
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type)
act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)
assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
......@@ -234,7 +267,8 @@ class TestActivation:
self.activation_type = activation_type
value_n_grad_primitive_func = jit(
value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
value_and_grad(self.primitive_func, (0,)),
static_argnums=(1, 3),
)
quantizer = QuantizerFactory.create(
......@@ -242,9 +276,21 @@ class TestActivation:
q_dtype=output_type,
q_layout=QuantizeLayout.ROWWISE,
)
act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(
x, activation_type, quantizer, act_params
)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)
assert_allclose(prim_out, ref_out, dtype=output_type)
assert_allclose(prim_grad, ref_grad, dtype=output_type)
......@@ -273,10 +319,18 @@ class TestActivation:
q_dtype=output_type,
q_layout=q_layout,
)
te_output = tex.act_lu(x, activation_type, te_quantizer)
jax_output = _jax_act_lu(x, activation_type, jax_quantizer)
act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
te_output = tex.act_lu(x, activation_type, te_quantizer, act_params)
jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params)
assert_bitwise_scaled_tensors(te_output, jax_output)
@pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
......@@ -296,10 +350,18 @@ class TestActivation:
quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
)
output = tex.act_lu(x, activation_type, quantizer)
ref_out = self.ref_act(x, activation_type)
act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
output = tex.act_lu(x, activation_type, quantizer, act_params)
ref_out = self.ref_act(x, activation_type, act_params)
assert_dequantized_scaled_tensor(output, ref_out)
......@@ -561,10 +623,24 @@ class TestNorm:
)
QUANTIZE_OUTPUT_DTYPES = {
QUANTIZE_OUTPUT_FP8_DTYPES = {
"L0": [jnp.float8_e4m3fn],
"L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
QUANTIZE_OUTPUT_DTYPES = {
test_level: QUANTIZE_OUTPUT_FP8_DTYPES[test_level] + [jnp.float4_e2m1fn]
for test_level in QUANTIZE_OUTPUT_FP8_DTYPES
}
QUANTIZE_QDTYPE_AND_SCALING_MODES = {
test_level: [
(q_dtype, scaling_mode)
for q_dtype, scaling_mode in zip(
QUANTIZE_OUTPUT_FP8_DTYPES[test_level], supported_scaling_modes
)
if q_dtype in scaling_mode.get_compatible_q_dtypes()
]
for test_level in QUANTIZE_OUTPUT_FP8_DTYPES
}
ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [
((32, 64), -1),
......@@ -573,8 +649,7 @@ ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [
((32, 256, 128), -1),
((32, 256, 128), -2),
((64, 32, 32, 256), -1),
((64, 32, 32, 256), -2),
((64, 32, 32, 256), -3),
((8192, 2, 4096), -2),
]
QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = {
......@@ -594,18 +669,38 @@ QUANTIZATION_INPUT_DTYPE = {
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2, jnp.float4_e2m1fn])
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
"q_layout",
[
QuantizeLayout.ROWWISE,
QuantizeLayout.COLWISE,
QuantizeLayout.ROWWISE_COLWISE,
],
)
class TestQuantize:
"""
Purely quantization related tests that will always test on a wider set of types and shapes
"""
def _skip_for_fp4(self, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
"""Temporary hack to skip unsupported FP4 cases until we implement them"""
if q_dtype not in scaling_mode.get_compatible_q_dtypes():
pytest.skip(f"Quantize dtype {q_dtype} is not supported by {scaling_mode}")
return
# HACK: FIXME TODO(jberchtold)
row = reduce(operator.mul, input_shape[flatten_axis:], 1)
col = reduce(operator.mul, input_shape[:flatten_axis], 1)
will_use_rht = should_use_rht(scaling_mode, q_layout=q_layout)
if will_use_rht and (row % 64 != 0 or col % 128 != 0):
pytest.skip("Unfused RHT is not supported currently, skipping")
def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis)
key = jax.random.PRNGKey(0)
# Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling)
......@@ -615,6 +710,68 @@ class TestQuantize:
q_layout=q_layout,
)
if scaling_mode.is_nvfp4_scaling:
if in_dtype != jnp.bfloat16:
pytest.skip("NVFP4 scaling only supported with bfloat16 input dtype currently")
return
q_func = _jax_quantize
# For NVFP4 scaling, the maximum possible error for a single value can be high between the dequantized and original tensors. To ensure quantization and dequantization is operating correctly without requiring a very high tolerance for all values, we instead test that quantizing the dequantized tensor is bitwise identical to the original quantized tensor.
x = jax.random.uniform(key, input_shape, in_dtype) * 10
q1 = q_func(x, quantizer=quantizer, flatten_axis=flatten_axis)
dq_rowwise = None
dq_colwise = None
if isinstance(q1, ScaledTensor1x):
dq = q1.dequantize()
if q1.is_colwise:
dq_colwise = dq
else:
dq_rowwise = dq
elif isinstance(q1, ScaledTensor2x):
dq_rowwise = q1.rowwise_tensor.dequantize()
dq_colwise = q1.colwise_tensor.dequantize()
else:
raise ValueError(f"Unsupported output type {type(q1)}")
# We only compare Q-DQ for the same quantization layout. If we for example QDQ rowwise, then re-quantize colwise, the error will be larger and may not be bitwise identical to the original colwise quantization.
if dq_rowwise is not None:
assert (
dq_rowwise.shape == x.shape
), f"dq_rowwise shape {dq_rowwise.shape} != x shape {x.shape}"
q2_rowwise = q_func(dq_rowwise, quantizer=quantizer, flatten_axis=flatten_axis)
q2_rowwise = (
q2_rowwise
if isinstance(q2_rowwise, ScaledTensor1x)
else q2_rowwise.rowwise_tensor
)
q1_rowwise = q1 if isinstance(q1, ScaledTensor1x) else q1.rowwise_tensor
assert_bitwise_scaled_tensors(q1_rowwise, q2_rowwise)
if dq_colwise is not None:
# Since this is for NVFP4, we are assuming colwise has T layout and we do a transpose here to get back to original shape
flatten_axis = flatten_axis + len(input_shape) if flatten_axis < 0 else flatten_axis
colwise_flatten_axis = len(input_shape) - flatten_axis
dq_colwise = jnp.transpose(
dq_colwise,
(*range(colwise_flatten_axis, dq_colwise.ndim), *range(colwise_flatten_axis)),
)
assert (
dq_colwise.shape == x.shape
), f"dq_colwise shape {dq_colwise.shape} != x shape {x.shape}"
q2_colwise = q_func(dq_colwise, quantizer=quantizer, flatten_axis=flatten_axis)
q2_colwise = (
q2_colwise
if isinstance(q2_colwise, ScaledTensor1x)
else q2_colwise.colwise_tensor
)
q1_colwise = q1 if isinstance(q1, ScaledTensor1x) else q1.colwise_tensor
assert_bitwise_scaled_tensors(q1_colwise, q2_colwise)
assert (
dq_rowwise is not None or dq_colwise is not None
), "At least one of rowwise or colwise dq must be not None"
return
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations):
x = jax.random.uniform(key, input_shape, in_dtype)
......@@ -622,9 +779,33 @@ class TestQuantize:
scaled_tensor = quantizer.quantize(x, flatten_axis=flatten_axis)
assert_dequantized_scaled_tensor(scaled_tensor, x)
def _should_use_precise_comparison(
self, in_dtype, scaling_mode, q_layout, input_shape, flatten_axis
):
# TODO(jberchtold): Remove this hack once we have a better solution to ensure bitwise identical results between TE and JAX RHT+quant implementations. Currently for certain shapes the quantized fp4 data differs by a small amount on <0.5% of the values.
RHT_SLIGHT_MISMATCH_SHAPES = [
((32, 256, 128), -1),
((64, 32, 32, 256), -1),
((8192, 2, 4096), -2),
]
if (
should_use_rht(scaling_mode, q_layout=q_layout)
and (input_shape, flatten_axis) in RHT_SLIGHT_MISMATCH_SHAPES
):
# TE fused RHT+quant and JAX RHT+quant have slight implementation differences which can lead to small numerical differences on certain shapes
return False
if scaling_mode.is_nvfp4_scaling and in_dtype != jnp.bfloat16:
# With NVFP4 scaling, TE kernels internally use bfloat16 so using a different input dtype can lead to small numerical differences compared to the JAX implementation
return False
return True
def test_quantize_bitwise(
self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
......@@ -635,15 +816,202 @@ class TestQuantize:
jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
try:
te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
assert_bitwise_scaled_tensors(te_output, jax_output)
except AssertionError as e:
if should_use_rht(scaling_mode, q_layout=q_layout) and in_dtype != jnp.bfloat16:
error_message = e.args[0]
if "RHT requires input to be bfloat16" in error_message:
# Successfully caught the expected error, early return from the test
return
raise e
assert_bitwise_scaled_tensors(
te_output,
jax_output,
precise_comparison=self._should_use_precise_comparison(
in_dtype, scaling_mode, q_layout, input_shape, flatten_axis
),
)
def test_quantize_bitwise_jitted(
self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout
)
jax_impl_func_jit = jax.jit(_jax_quantize, static_argnums=(2, 3))
te_impl_func_jit = jax.jit(tex.quantize, static_argnums=(2,))
jax_output = jax_impl_func_jit(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
try:
te_output = te_impl_func_jit(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
except AssertionError as e:
if should_use_rht(scaling_mode, q_layout=q_layout) and in_dtype != jnp.bfloat16:
error_message = e.args[0]
if "RHT requires input to be bfloat16" in error_message:
# Successfully caught the expected error, early return from the test
return
raise e
assert_bitwise_scaled_tensors(
te_output,
jax_output,
precise_comparison=self._should_use_precise_comparison(
in_dtype, scaling_mode, q_layout, input_shape, flatten_axis
),
)
@pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16])
@pytest_parametrize_wrapper("q_dtype", [jnp.float4_e2m1fn])
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
@pytest_parametrize_wrapper(
"scaling_mode", [s for s in supported_scaling_modes if s.is_nvfp4_scaling]
)
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
)
class TestStochasticRounding:
def _dequantize(self, scaled_tensor) -> list[jnp.ndarray]:
"""Dequantizes a ScaledTensor back to it's original jnp.ndarray form. This always returns an array of jnp.ndarrays, for ScaledTensor2x there will be two tensors, for ScaledTensor1x there will be one tensor."""
if isinstance(scaled_tensor, ScaledTensor1x):
dq = scaled_tensor.dequantize()
if scaled_tensor.data_layout == "T":
dq = jnp.transpose(
dq,
(
*range(scaled_tensor.flatten_axis, dq.ndim),
*range(scaled_tensor.flatten_axis),
),
)
return [dq]
elif isinstance(scaled_tensor, ScaledTensor2x):
[rowwise_dq] = self._dequantize(scaled_tensor.rowwise_tensor)
[colwise_dq] = self._dequantize(scaled_tensor.colwise_tensor)
return [rowwise_dq, colwise_dq]
raise ValueError(
"Unsupported ScaledTensor type, expected ScaledTensor but received"
f" {type(scaled_tensor)}"
)
def _sample_sr_qdq(
self, num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
) -> list[jnp.ndarray]:
"""Samples num_samples quantize-dequantize operations with stochastic rounding enabled and returns the dequantized tensors."""
dq_tensors = []
key = jax.random.PRNGKey(0)
for i in range(num_samples):
iter_key = jax.random.fold_in(key, i)
sr_rng_state = jax.random.randint(
iter_key, (4,), minval=0, maxval=2**30 - 1, dtype=jnp.uint32
)
quantizer = QuantizerFactory.create(
q_dtype=q_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
stochastic_rounding_rng_state=sr_rng_state,
)
q_output = q_func(inputs, quantizer=quantizer, flatten_axis=flatten_axis)
iter_dq = self._dequantize(q_output)
dq_tensors.extend(iter_dq)
avg_sr_tensor = jnp.mean(jnp.stack(dq_tensors), axis=0)
assert avg_sr_tensor.shape == inputs.shape, (
f"Dequantized tensor shape {avg_sr_tensor.shape} does not match input shape"
f" {inputs.shape}"
)
sr_mae = jnp.mean(jnp.abs(avg_sr_tensor - inputs))
dq_var = jnp.var(jnp.stack(dq_tensors))
assert (
dq_var > 0
), "Variance of dequantized tensors is zero, stochastic rounding may not be working"
return dq_tensors
def _round_nearest(
self, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
) -> jnp.ndarray:
"""Quantizes and dequantizes the input tensor with round nearest quantization."""
quantizer = QuantizerFactory.create(
q_dtype=q_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
stochastic_rounding_rng_state=None,
)
q_output = q_func(inputs, quantizer=quantizer, flatten_axis=flatten_axis)
rn_dq = self._dequantize(q_output)[0]
return rn_dq
def _test_sr(
self, num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
) -> float:
"""Tests that the mean absolute error (MAE) of stochastic rounding is smaller than round nearest quantization over multiple samples."""
dq_tensors = self._sample_sr_qdq(
num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
)
avg_sr_tensor = jnp.mean(jnp.stack(dq_tensors).astype(jnp.float32), axis=0)
assert avg_sr_tensor.shape == inputs.shape, (
f"Dequantized tensor shape {avg_sr_tensor.shape} does not match input shape"
f" {inputs.shape}"
)
round_nearest_tensor = self._round_nearest(
q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
)
sr_mae = jnp.mean(jnp.abs(avg_sr_tensor - inputs))
rn_mae = jnp.mean(jnp.abs(round_nearest_tensor - inputs))
assert sr_mae < rn_mae, (
f"Mean absolute error of stochastic rounding ({sr_mae}) is not smaller than"
f" round nearest ({rn_mae})"
)
return sr_mae
def test_sr_nvfp4(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
"""Tests that the mean absolute error of stochastic rounding is smaller than round nearest quantization over multiple samples for both TE and JAX implementations. Asserts that the MAE of both implementations is close to each other."""
# HACK: FIXME TODO(jberchtold)
row = reduce(operator.mul, input_shape[flatten_axis:], 1)
col = reduce(operator.mul, input_shape[:flatten_axis], 1)
will_use_rht = should_use_rht(scaling_mode, q_layout=q_layout)
if will_use_rht and (row % 64 != 0 or col % 128 != 0):
pytest.skip("Unfused RHT is not supported currently, skipping")
key = jax.random.PRNGKey(0)
inputs = jax.random.uniform(key, input_shape, in_dtype)
NUM_SAMPLES = 10
te_mean_error = self._test_sr(
NUM_SAMPLES, tex.quantize, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
)
jax_mean_error = self._test_sr(
NUM_SAMPLES, _jax_quantize, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
)
assert_allclose(te_mean_error, jax_mean_error, rtol=0.2, atol=1e-4)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes)
@pytest_parametrize_wrapper("flatten_axis", [-1])
@pytest_parametrize_wrapper("with_group_sizes", [True, False])
@pytest_parametrize_wrapper(
......@@ -682,7 +1050,6 @@ class TestGroupedQuantize:
q_layout=q_layout,
n_groups=n_groups,
)
scaled_tensor = tex.grouped_quantize(
x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer
)
......@@ -694,9 +1061,8 @@ class TestGroupedQuantize:
class TestFusedQuantize:
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper("out_dtype,scaling_mode", QUANTIZE_QDTYPE_AND_SCALING_MODES)
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
......@@ -734,6 +1100,7 @@ class TestFusedQuantize:
def _test_quantize_dact_dbias(
self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout
):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1)
......@@ -780,9 +1147,15 @@ class TestFusedQuantize:
assert_allclose(te_output.data, jax_output.data)
if is_dbias:
# TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
precise_comparison = not (
in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling()
# TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
(in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling())
# Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently.
or (
activation_type in {("squared_relu",), ("clamped_silu", "clamped_linear")}
and in_dtype == jnp.bfloat16
and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
)
)
assert_allclose(
te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype
......@@ -811,7 +1184,7 @@ class TestFusedQuantize:
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_FP8_DTYPES)
@pytest_parametrize_wrapper("is_dbias", [True, False])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
......@@ -837,7 +1210,7 @@ class TestFusedQuantize:
@pytest_parametrize_wrapper(
"input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)]
)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_FP8_DTYPES)
@pytest_parametrize_wrapper("is_dbias", [True, False])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
......@@ -870,6 +1243,11 @@ valid_fp8_gemm_operand_types = [
(jnp.float8_e4m3fn, jnp.float8_e5m2),
]
supported_nvfp4_scaling_mode_pairs = [
(ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_1D_SCALING),
(ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING),
]
class TestDense:
def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
......@@ -911,7 +1289,7 @@ class TestDense:
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes)
@pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm):
......@@ -945,6 +1323,40 @@ class TestDense:
assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
# TODO(Phuong): add bitwise test
@pytest.mark.skipif(not is_fp4_supported, reason=fp4_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("scaling_mode_pair", supported_nvfp4_scaling_mode_pairs)
@pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
@pytest_parametrize_wrapper("with_jax_gemm", [True, False])
def test_gemm_nvfp4(self, m, n, k, scaling_mode_pair, data_layout, with_jax_gemm):
x_uses_rht = scaling_mode_pair[0] == ScalingMode.NVFP4_1D_SCALING and data_layout[0] == "T"
w_uses_rht = scaling_mode_pair[1] == ScalingMode.NVFP4_1D_SCALING and data_layout[1] == "N"
if x_uses_rht != w_uses_rht:
# TODO(jberchtold): Ideally avoid a skip here and rewrite test setup to ensure both or neither use RHT
pytest.skip("RHT must be used for both or neither operand, skipping")
lhs_scaling_mode, rhs_scaling_mode = scaling_mode_pair
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
lhs_quantizer = QuantizerFactory.create(
scaling_mode=lhs_scaling_mode,
q_dtype=jnp.float4_e2m1fn,
)
rhs_quantizer = QuantizerFactory.create(
scaling_mode=rhs_scaling_mode,
q_dtype=jnp.float4_e2m1fn,
)
with use_jax_gemm(enabled=with_jax_gemm):
primitive_out = tex.gemm(
x,
w,
contracting_dims=contracting_dims,
lhs_quantizer=lhs_quantizer,
rhs_quantizer=rhs_quantizer,
)
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=jnp.float4_e2m1fn)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
def test_dense_grad_bf16(self, m, n, k):
data_layout = "NN"
......@@ -970,11 +1382,10 @@ class TestDense:
assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("m,n,k", [(64, 128, 128)])
@pytest_parametrize_wrapper("recipe", supported_recipes)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm):
def test_dense_grad_fp8_and_fp4(self, m, n, k, recipe, with_jax_gemm):
data_layout = "NN"
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
......@@ -995,14 +1406,9 @@ class TestDense:
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True,
)
quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe)
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
n_iterations = 3 if recipe.delayed() else 1
with use_jax_gemm(enabled=with_jax_gemm):
for _ in range(n_iterations):
primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = (
......@@ -1013,10 +1419,10 @@ class TestDense:
x, w, bias, data_layout
)
assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.float8_e5m2)
assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp.float8_e5m2)
assert_allclose(primitive_out, ref_out, dtype=quantizer_set.x.q_dtype)
assert_allclose(primitive_x_grad, ref_x_grad, dtype=quantizer_set.dgrad.q_dtype)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=quantizer_set.dgrad.q_dtype)
assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=quantizer_set.dgrad.q_dtype)
@pytest.fixture(name="random_inputs")
......@@ -1038,11 +1444,11 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
class TestFusedDense:
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
@pytest.mark.parametrize("m,n,k", [(64, 128, 128)])
@pytest_parametrize_wrapper("recipe", supported_recipes)
@pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_gemm):
def test_layernorm_dense_grad(self, m, n, k, recipe, norm_type, with_jax_gemm):
"""
Test layernorm_dense VJP Rule
"""
......@@ -1059,12 +1465,7 @@ class TestFusedDense:
gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True,
)
quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe)
if norm_type == "layernorm":
beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
......@@ -1099,7 +1500,7 @@ class TestFusedDense:
x, w, gamma, beta
)
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
n_iterations = 3 if recipe.delayed() else 1
with use_jax_gemm(enabled=with_jax_gemm):
for _ in range(n_iterations):
prim_out, (
......@@ -1109,22 +1510,22 @@ class TestFusedDense:
prim_beta_grad,
) = value_n_grad_prim_func(x, w, gamma, beta)
assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_w_grad, ref_w_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_out, ref_out, dtype=quantizer_set.x.q_dtype)
assert_allclose(prim_x_grad, ref_x_grad, dtype=quantizer_set.dgrad.q_dtype)
assert_allclose(prim_w_grad, ref_w_grad, dtype=quantizer_set.dgrad.q_dtype)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=quantizer_set.dgrad.q_dtype)
if beta is not None:
assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_beta_grad, ref_beta_grad, dtype=quantizer_set.dgrad.q_dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("m,n,k", [(64, 128, 128)])
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("recipe", supported_recipes)
@pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad(
self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm
self, m, n, k, activation_type, recipe, norm_type, use_bias, with_jax_gemm
):
"""
Test layernorm_mlp VJP Rule
......@@ -1152,10 +1553,7 @@ class TestFusedDense:
quantizer_sets = QuantizerFactory.create_set(
n_quantizer_sets=2,
scaling_mode=scaling_mode,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True,
fp8_recipe=recipe,
)
if norm_type == "layernorm":
......@@ -1202,7 +1600,7 @@ class TestFusedDense:
value_n_grad_prim_func = value_and_grad(prim_func, range(6))
value_n_grad_ref_func = value_and_grad(ref_func, range(6))
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
n_iterations = 3 if recipe.delayed() else 1
with use_jax_gemm(enabled=with_jax_gemm):
for _ in range(n_iterations):
prim_out, (
......@@ -1223,18 +1621,16 @@ class TestFusedDense:
ref_bias_2_grad,
) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)
assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=jnp.float8_e5m2)
fwd_dtype = quantizer_sets[0].x.q_dtype
bwd_dtype = quantizer_sets[0].dgrad.q_dtype
assert_allclose(prim_out, ref_out, dtype=fwd_dtype)
assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=bwd_dtype)
assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=bwd_dtype)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=bwd_dtype)
assert_allclose(prim_x_grad, ref_x_grad, dtype=bwd_dtype)
if use_bias:
assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=jnp.float8_e5m2)
if use_bias:
assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=bwd_dtype)
assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=bwd_dtype)
# E5M2 * E5M2 is not supported
......@@ -1317,21 +1713,29 @@ class TestGroupedDense:
lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
dtype, input_shape, layout
)
num_gemms = input_shape[0]
_ = jax.jit(tex.grouped_gemm_copy_group_sizes, static_argnames=("num_gemms",))(
group_sizes,
num_gemms=num_gemms,
)
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
# jitting grouped_gemm
prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
prim_out = jax.jit(
tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes")
)(
lhs,
rhs,
group_sizes,
contracting_dims,
use_async_d2h_group_sizes=True,
)
self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes)
@pytest_parametrize_wrapper("layout", ["NN"])
def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout):
fwd_dtype, bwd_dtype = fwd_bwd_dtype
......@@ -1412,7 +1816,7 @@ class TestGroupedDense:
"fwd_bwd_dtype",
[(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)],
)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes)
def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
fwd_dtype, bwd_dtype = fwd_bwd_dtype
dtype = jnp.bfloat16
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import unittest
import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from functools import partial
from distributed_test_base import generate_configs
from utils import assert_allclose, pytest_parametrize_wrapper
import transformer_engine.jax.cpp_extensions as tex
from transformer_engine.jax import autocast
from transformer_engine.jax.dense import dense
DTYPES = [jnp.bfloat16]
GEMM_INPUT_SHAPES = [[256, 128, 256]] # [batch, seq_len, hidden_in]
WEIGHT_SHAPES = [[256, 256]] # [hidden_in, hidden_out]
def _generate_inputs(input_shape, weight_shape, dtype):
"""Generate test inputs for GEMM operations"""
_, _, hidden_in = input_shape
hidden_in_w, hidden_out = weight_shape
assert hidden_in == hidden_in_w, f"Dimension mismatch: {hidden_in} != {hidden_in_w}"
bias_shape = (hidden_out,)
# Generate random inputs
x = random.normal(random.PRNGKey(1124), input_shape, dtype=dtype)
weight = random.normal(random.PRNGKey(2248), weight_shape, dtype=dtype) / jnp.sqrt(hidden_in_w)
bias = random.normal(random.PRNGKey(3372), bias_shape, dtype=dtype) / jnp.sqrt(hidden_out)
return x, weight, bias
def _get_sharding_for_gemm(mesh, mesh_resource, partition_layout="rowwise"):
"""Get sharding patterns for GEMM inputs and outputs"""
dp_axis = mesh_resource.dp_resource
tp_axis = mesh_resource.tpsp_resource
if partition_layout == "colwise":
x_spec = PartitionSpec(dp_axis, None, None)
weight_spec = PartitionSpec(None, tp_axis)
bias_spec = PartitionSpec(tp_axis)
output_spec = PartitionSpec(dp_axis, None, tp_axis)
elif partition_layout == "rowwise":
x_spec = PartitionSpec(dp_axis, None, tp_axis)
weight_spec = PartitionSpec(tp_axis, None)
bias_spec = PartitionSpec(None)
output_spec = PartitionSpec(dp_axis, None, None)
else:
raise ValueError(f"Invalid partition: {partition_layout}")
x_sharding = NamedSharding(mesh, x_spec)
weight_sharding = NamedSharding(mesh, weight_spec)
bias_sharding = NamedSharding(mesh, bias_spec)
output_sharding = NamedSharding(mesh, output_spec)
return x_sharding, weight_sharding, bias_sharding, output_sharding
@partial(jax.jit, static_argnames=("contracting_dims", "output_sharding"))
def _jitted_gemm(x, weight, bias, contracting_dims, output_sharding):
output = tex.gemm(
x,
weight,
bias=bias,
contracting_dims=contracting_dims,
fuse_bias=True,
)
if output_sharding is not None:
output = jax.lax.with_sharding_constraint(output, output_sharding)
return output
# TODO(Phuong):
# 1. Add supported recipes after FP4 is added
# 2. Add communication type/byte checks
class TestDistributedDense:
"""Test distributed GEMM without collective operations vs JAX dot"""
@pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource",
generate_configs(),
)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("input_shape", GEMM_INPUT_SHAPES)
@pytest_parametrize_wrapper("weight_shape", WEIGHT_SHAPES)
@pytest_parametrize_wrapper("partition", ["rowwise", "colwise"])
def test_distributed_gemm(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
dtype,
input_shape,
weight_shape,
partition,
):
"""Test TE GEMM against JAX dot with bf16 dtype"""
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
# Generate inputs
x, weight, bias = _generate_inputs(input_shape, weight_shape, dtype)
# Get sharding patterns
x_sharding, weight_sharding, bias_sharding, output_sharding = _get_sharding_for_gemm(
mesh, mesh_resource, partition_layout=partition
)
# Shard inputs
x_sharded = jax.device_put(x, x_sharding)
weight_sharded = jax.device_put(weight, weight_sharding)
bias_sharded = jax.device_put(bias, bias_sharding)
contracting_dims = ((2,), (0,)) # Contract on hidden_in dimension
with mesh, autocast(enabled=False, mesh_resource=mesh_resource):
# TE GEMM result
te_result = _jitted_gemm(
x_sharded,
weight_sharded,
bias_sharded,
contracting_dims=contracting_dims,
output_sharding=output_sharding,
)
# JAX dot reference result
jax_result = (
jax.lax.dot_general(
x_sharded, weight_sharded, dimension_numbers=(contracting_dims, ((), ()))
)
+ bias_sharded
)
assert te_result.sharding == jax_result.sharding
# Ensure computation is complete
jax.block_until_ready(te_result)
jax.block_until_ready(jax_result)
# Gather results for comparison
gathered_te = jax.lax.with_sharding_constraint(
te_result, NamedSharding(mesh, PartitionSpec(None))
)
gathered_jax = jax.lax.with_sharding_constraint(
jax_result, NamedSharding(mesh, PartitionSpec(None))
)
# Compare results
assert_allclose(gathered_te, gathered_jax, dtype=dtype)
def _te_sum_dense(self, x, weight, bias, contracting_dims):
"""TE GEMM function for gradient testing"""
return jnp.sum(dense(x, weight, bias=bias, contracting_dims=contracting_dims))
def _jax_sum_dense(self, x, weight, bias, contracting_dims):
"""JAX dot function for gradient testing"""
result = (
jax.lax.dot_general(x, weight, dimension_numbers=(contracting_dims, ((), ()))) + bias
)
return jnp.sum(result)
@pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource",
generate_configs(),
)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("input_shape", GEMM_INPUT_SHAPES)
@pytest_parametrize_wrapper("weight_shape", WEIGHT_SHAPES)
@pytest_parametrize_wrapper("partition", ["rowwise", "colwise"])
def test_te_distributed_dense_grad(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
dtype,
input_shape,
weight_shape,
partition,
):
"""Test TE GEMM gradients against JAX dot gradients"""
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
# Generate inputs
x, weight, bias = _generate_inputs(input_shape, weight_shape, dtype)
# Get sharding patterns
x_sharding, weight_sharding, bias_sharding, output_sharding = _get_sharding_for_gemm(
mesh, mesh_resource, partition_layout=partition
)
x_sharded = jax.device_put(x, x_sharding)
weight_sharded = jax.device_put(weight, weight_sharding)
bias_sharded = jax.device_put(bias, bias_sharding)
contracting_dims = ((2,), (0,))
with mesh, autocast(enabled=False, mesh_resource=mesh_resource):
# Test gradients w.r.t. all inputs
te_grad_func = jax.jit(
jax.value_and_grad(self._te_sum_dense, argnums=(0, 1, 2)),
static_argnames=("contracting_dims",),
)
jax_grad_func = jax.jit(
jax.value_and_grad(self._jax_sum_dense, argnums=(0, 1, 2)),
static_argnames=("contracting_dims",),
)
te_val, te_grads = te_grad_func(
x_sharded, weight_sharded, bias_sharded, contracting_dims
)
jax_val, jax_grads = jax_grad_func(
x_sharded, weight_sharded, bias_sharded, contracting_dims
)
# Compare forward pass
assert_allclose(te_val, jax_val, dtype=dtype)
# Compare gradients
for i, (te_grad, jax_grad) in enumerate(zip(te_grads, jax_grads)):
te_grad_spec = tuple(i for i in te_grad.sharding.spec if i is not None)
jax_grad_spec = tuple(i for i in jax_grad.sharding.spec if i is not None)
assert te_grad_spec == jax_grad_spec, f"Gradient sharding mismatch at te_grads[{i}]"
gathered_te_grad = jax.lax.with_sharding_constraint(
te_grad, NamedSharding(mesh, PartitionSpec(None))
)
gathered_jax_grad = jax.lax.with_sharding_constraint(
jax_grad, NamedSharding(mesh, PartitionSpec(None))
)
assert_allclose(
gathered_te_grad,
gathered_jax_grad,
dtype=dtype,
err_msg=f"Gradient mismatch for argument {i}",
)
if __name__ == "__main__":
unittest.main()
......@@ -9,7 +9,7 @@ import numpy as np
from utils import pytest_parametrize_wrapper, is_devices_enough
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax import autocast
def generate_mesh_configs():
......@@ -26,10 +26,10 @@ def generate_mesh_configs():
class TestMeshResource(unittest.TestCase):
def test_fp8_autocast_with_mesh_resource(self):
def test_autocast_with_mesh_resource(self):
for mesh_config in generate_mesh_configs():
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource):
with mesh, autocast(enabled=False, mesh_resource=mesh_resource):
self.assertEqual(mesh_resource, global_mesh_resource())
......@@ -15,7 +15,7 @@ from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
from utils import pytest_parametrize_wrapper
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax import autocast
from transformer_engine.common import recipe
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.quantize import QuantizerFactory, ScalingMode, is_fp8_available
......@@ -66,20 +66,19 @@ class TestDistributedLayernorm:
self, mesh_resource, ln_type, shape, dtype, mesh_axes, fp8_recipe
):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
# TODO(Phuong) is_dp_enabled = dp mesh axis size > 1
is_dp_enabled = mesh_resource.dp_resource is not None
is_tpsp_enabled = mesh_resource.tpsp_resource is not None
assert ln_type in ["layernorm", "rmsnorm"]
all_reduce_loss_bytes = 4 # 1 * FP32
# for loss, dgamma and dbeta
# TODO(Jeremy): debug this check because layernorm should always have 2x weights regardless of dp
weight_count = 2 if (ln_type == "layernorm" and "dp" in mesh_axes) else 1
allreduce_total_bytes = (
all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
)
other_bytes = 0
if fp8_recipe == recipe.Float8CurrentScaling():
allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction
# loss, 1 FP32
allreduce_total_bytes = 4 if is_dp_enabled else 0
# dgamma and dbeta
weight_count = 2 if ln_type == "layernorm" else 1
allreduce_total_bytes += weight_count * shape[-1] * jax_dtype.itemsize
return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes
allreduce=allreduce_total_bytes * int(is_dp_enabled or is_tpsp_enabled),
allgather=0,
other=0,
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
......@@ -134,7 +133,7 @@ class TestDistributedLayernorm:
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec))
......@@ -210,7 +209,7 @@ class TestDistributedLayernorm:
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import re
from typing import Callable, Sequence, Union, Optional
import pytest
......@@ -17,8 +18,12 @@ from utils import (
)
from transformer_engine.common import recipe
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.quantize import (
is_fp8_available,
ScalingMode,
get_quantize_config_with_recipe,
)
from transformer_engine.jax import autocast
from transformer_engine.jax.flax import LayerNormMLP
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
from transformer_engine.jax.sharding import (
......@@ -33,22 +38,23 @@ from transformer_engine.jax.sharding import (
W_JOINED_AXES,
)
from transformer_engine.jax.sharding import MeshResource
from transformer_engine.jax.quantize import QuantizerFactory
from transformer_engine.jax.quantize import (
QuantizerFactory,
get_supported_quantization_recipes,
is_scaling_mode_supported,
)
from transformer_engine.jax.cpp_extensions.misc import get_min_device_compute_capability
is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
SUPPORTED_RECIPES = []
if is_fp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling"))
if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
SUPPORTED_RECIPES = get_supported_quantization_recipes()
SUPPORTED_RECIPES = [pytest.param(r, id=r.__class__.__name__) for r in SUPPORTED_RECIPES]
DTYPES = [jnp.bfloat16, jnp.float16]
INPUT_SHAPE = [[4, 64, 128]] # [batch, seqlen, hidden_in]
INPUT_SHAPE = [[4, 128, 256]] # [batch, seqlen, hidden_in]
LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
......@@ -59,19 +65,47 @@ LN_SCALE_AXES = (W_NO_SHARD_AXES,)
LN_BIAS_AXES = (W_NO_SHARD_AXES,)
BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES)
BIAS_2_AXES = (W_NO_SHARD_AXES,)
INTERMEDIATE = 64
INTERMEDIATE = 256
# Only test with FSDP and TPSP as DP is not used
def generate_fsdp_and_tpsp_configs():
configs = []
if is_devices_enough(4):
configs.append(
pytest.param(
[
4,
(2, 2),
("fsdp", "tpsp"),
MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"),
],
id="fsdp2_tpsp2",
)
)
if is_devices_enough(2):
configs.append(
[2, (1, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")]
pytest.param(
[
2,
(1, 2),
("fsdp", "tpsp"),
MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"),
],
id="fsdp1_tpsp2",
)
)
if is_devices_enough(4):
configs.append(
[4, (2, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")]
pytest.param(
[
2,
(2, 1),
("fsdp", "tpsp"),
MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"),
],
id="fsdp2_tpsp1",
),
)
return configs
......@@ -113,6 +147,7 @@ class TestDistributedLayernormMLP:
layernorm_type: str = "rmsnorm",
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
multi_gpus: bool = False,
quantization_recipe: recipe.Recipe = None,
) -> jnp.ndarray:
if multi_gpus:
......@@ -126,7 +161,9 @@ class TestDistributedLayernormMLP:
dot_1_input_axes = dot_2_input_axes = None
kernel_1_axes = kernel_2_axes = None
quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2)
quantizer_sets = QuantizerFactory.create_set(
n_quantizer_sets=2, fp8_recipe=quantization_recipe
)
# out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
return jnp.mean(
......@@ -154,7 +191,7 @@ class TestDistributedLayernormMLP:
use_bias,
input_shape,
dtype,
fp8_recipe,
quantization_recipe,
use_shardy,
with_jax_gemm,
):
......@@ -173,8 +210,10 @@ class TestDistributedLayernormMLP:
)
# Single GPU
with fp8_autocast(
enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()
with autocast(
enabled=quantization_recipe is not None,
recipe=quantization_recipe,
mesh_resource=MeshResource(),
):
single_jitter = jax.jit(
value_and_grad_func,
......@@ -185,8 +224,10 @@ class TestDistributedLayernormMLP:
# Multi GPUs
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(
enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
with mesh, autocast(
enabled=quantization_recipe is not None,
recipe=quantization_recipe,
mesh_resource=mesh_resource,
):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp"))
......@@ -226,11 +267,14 @@ class TestDistributedLayernormMLP:
multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)
fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn
bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2
fwd_test_type = bwd_test_type = dtype
if quantization_recipe is not None:
quantize_config = get_quantize_config_with_recipe(quantization_recipe)
fwd_test_type = quantize_config.FWD_DTYPE
bwd_test_type = quantize_config.BWD_DTYPE
if fwd_test_type == jnp.float16 and use_bias:
assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type, atol=0.04, rtol=1.5)
assert_allclose(multi_fwd, single_fwd, atol=0.04, rtol=1.5)
else:
assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type)
......@@ -253,13 +297,12 @@ class TestDistributedLayernormMLP:
err_msg=f"multi_grads[{i}] is not close",
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("quantization_recipe", [None] + SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad(
self,
......@@ -268,27 +311,28 @@ class TestDistributedLayernormMLP:
use_bias,
input_shape,
dtype,
fp8_recipe,
quantization_recipe,
with_jax_gemm,
):
if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
self._test_layernorm_mlp_grad(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
quantization_recipe,
use_shardy=False,
with_jax_gemm=with_jax_gemm,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("quantization_recipe", [None] + SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad_shardy(
self,
......@@ -297,18 +341,18 @@ class TestDistributedLayernormMLP:
use_bias,
input_shape,
dtype,
fp8_recipe,
quantization_recipe,
with_jax_gemm,
):
if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.")
if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
self._test_layernorm_mlp_grad(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe=fp8_recipe,
quantization_recipe=quantization_recipe,
use_shardy=True,
with_jax_gemm=with_jax_gemm,
)
......@@ -321,7 +365,7 @@ class TestDistributedLayernormMLP:
input_shape,
dtype,
use_fp8,
fp8_recipe,
quantization_recipe,
use_shardy,
with_jax_gemm,
):
......@@ -330,14 +374,16 @@ class TestDistributedLayernormMLP:
layernorm_type = "rmsnorm"
rng = jax.random.PRNGKey(0)
subkeys = jax.random.split(rng, 2)
subkeys = jax.random.split(rng, 3)
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
init_rngs = {"params": subkeys[1]}
init_rngs = {"params": subkeys[1], "sr_rng": subkeys[2]}
with use_jax_gemm(enabled=with_jax_gemm):
# Single GPUs
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
with autocast(
enabled=use_fp8, recipe=quantization_recipe, mesh_resource=MeshResource()
):
ln_mlp_single = LayerNormMLP(
layernorm_type=layernorm_type,
intermediate_dim=INTERMEDIATE,
......@@ -346,15 +392,15 @@ class TestDistributedLayernormMLP:
)
params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
mlp_out_single, ln_out_single = ln_mlp_single.apply(
params_single, x, deterministic=True
params_single, x, deterministic=True, rngs={"sr_rng": subkeys[2]}
)
# Multi GPUs
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(
enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
with mesh, autocast(
enabled=use_fp8, recipe=quantization_recipe, mesh_resource=mesh_resource
):
ln_mlp_sharded = LayerNormMLP(
layernorm_type=layernorm_type,
......@@ -374,19 +420,20 @@ class TestDistributedLayernormMLP:
)
params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
params_sharded, x, deterministic=True
params_sharded, x, deterministic=True, rngs={"sr_rng": subkeys[2]}
)
# Make sure params values are the same
assert_tree_like_allclose(params_sharded["params"], params_single["params"])
assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
# TODO(Phuong): check if these tols updates are still needed
atol = None
rtol = None
l40_tolerance_update = (
get_min_device_compute_capability() == 89
and fp8_recipe == recipe.DelayedScaling()
and use_fp8
and quantization_recipe.delayed()
and dtype == jnp.float16
and activation_type == ("gelu",)
)
......@@ -404,9 +451,10 @@ class TestDistributedLayernormMLP:
# within tolerance to the float32 ground truth.
jax_triton_gemm_precision_tolerance_update = (
with_jax_gemm
and isinstance(fp8_recipe, recipe.Float8CurrentScaling)
and dtype == jnp.bfloat16
and activation_type == ("gelu", "linear")
and quantization_recipe is not None
and (quantization_recipe.delayed() or quantization_recipe.float8_current_scaling())
and dtype in (jnp.bfloat16, jnp.float16)
and activation_type == ("gelu", "linear"),
)
if jax_triton_gemm_precision_tolerance_update:
atol = 0.08
......@@ -430,22 +478,30 @@ class TestDistributedLayernormMLP:
input_shape,
dtype,
use_fp8=False,
fp8_recipe=None,
quantization_recipe=None,
use_shardy=False,
with_jax_gemm=with_jax_gemm,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_fp8(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm
self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
quantization_recipe,
with_jax_gemm,
):
if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
self._test_layernorm_mlp(
mesh_config,
activation_type,
......@@ -453,7 +509,7 @@ class TestDistributedLayernormMLP:
input_shape,
dtype,
use_fp8=True,
fp8_recipe=fp8_recipe,
quantization_recipe=quantization_recipe,
use_shardy=False,
with_jax_gemm=with_jax_gemm,
)
......@@ -474,24 +530,30 @@ class TestDistributedLayernormMLP:
input_shape,
dtype,
use_fp8=False,
fp8_recipe=None,
quantization_recipe=None,
use_shardy=True,
with_jax_gemm=with_jax_gemm,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_fp8_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm
self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
quantization_recipe,
with_jax_gemm,
):
if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.")
if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
self._test_layernorm_mlp(
mesh_config,
activation_type,
......@@ -499,7 +561,7 @@ class TestDistributedLayernormMLP:
input_shape,
dtype,
use_fp8=True,
fp8_recipe=fp8_recipe,
quantization_recipe=quantization_recipe,
use_shardy=True,
with_jax_gemm=with_jax_gemm,
)
......@@ -15,7 +15,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax import autocast
from transformer_engine.jax.softmax import SoftmaxType, softmax
DTYPES = [jnp.float16, jnp.bfloat16]
......@@ -102,7 +102,7 @@ class TestDistributedSoftmax:
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
with mesh, autocast(mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec))
......
......@@ -22,7 +22,7 @@ from jax import value_and_grad, jit
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from jax.typing import ArrayLike, DTypeLike
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax import autocast
from transformer_engine.jax.sharding import MeshResource
from transformer_engine.jax.attention import (
AttnBiasType,
......@@ -32,6 +32,7 @@ from transformer_engine.jax.attention import (
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
fused_attn,
run_length_fill,
make_swa_mask,
SequenceDescriptor,
CPStrategy,
......@@ -172,15 +173,34 @@ def make_mask(
jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
)
# causal mask
if attn_mask_type.is_causal():
if attn_mask_type.is_bottom_right():
run_length_out_q = run_length_fill(segment_ids_q)
run_length_out_kv = run_length_fill(segment_ids_kv)
bottom_right_causal_mask = make_attention_mask(
run_length_out_q - segment_pos_q,
run_length_out_kv - segment_pos_kv,
jnp.less_equal,
)
inv_mask = combine_masks(bottom_right_causal_mask, inv_mask)
elif attn_mask_type.is_causal():
inv_causal_mask = make_attention_mask(
segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
)
inv_mask = combine_masks(inv_causal_mask, inv_mask)
# sliding window mask
inv_swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, jnp.bool_)
inv_swa_mask = (
make_swa_mask(
segment_pos_q,
segment_pos_kv,
window_size,
dtype=jnp.bool,
segment_ids_q=segment_ids_q,
segment_ids_kv=segment_ids_kv,
)
if attn_mask_type.is_bottom_right()
else make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool_)
)
inv_mask = combine_masks(inv_mask, inv_swa_mask)
mask = jnp.logical_not(inv_mask)
return mask
......@@ -338,6 +358,16 @@ class FusedAttnRunner:
if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
pytest.skip("THD format requires padding masks.")
if self.attn_mask_type.is_bottom_right():
if self.max_seqlen_q > self.max_seqlen_kv:
pytest.skip(
f"BRCM requires cross attn type pattern, i.e.max_seqlen_kv >= max_seqlen_q"
)
if self.attn_bias_type is not AttnBiasType.NO_BIAS:
pytest.skip(f"cuDNN does not support pre or post scale bias for BRCM")
if self.dropout_prob != 0.0:
pytest.skip(f"cuDNN does not support non-zero dropoouts for BRCM")
if self.qkv_layout.is_qkvpacked():
if self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")
......@@ -526,7 +556,11 @@ class FusedAttnRunner:
self.pad_kv = self.pad_q
else:
# Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
min_segment_len = None if self.window_size is None else self.seqlens_q
min_segment_len = None
if (
self.window_size is not None or self.attn_mask_type.is_bottom_right()
): # SWA or BRCM requires kv_len >= q_len
min_segment_len = self.seqlens_q
self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
self.batch_size,
self.max_seqlen_kv,
......@@ -737,7 +771,7 @@ class FusedAttnRunner:
],
)
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
with self.mesh, autocast(mesh_resource=self.mesh_resource):
primitive_out = customcall_fused_dpa_jit(*customcall_args)
primitive_out = self.cp_inverse_reorder_fn(primitive_out)
......@@ -754,7 +788,7 @@ class FusedAttnRunner:
assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
if self.coll_count_ref is not None:
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
with self.mesh, autocast(mesh_resource=self.mesh_resource):
target_hlo = (
customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text()
)
......@@ -854,7 +888,7 @@ class FusedAttnRunner:
)
)
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
with self.mesh, autocast(mesh_resource=self.mesh_resource):
primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)
reference_out, reference_dgrad = jitted_reference(*args)
......@@ -925,7 +959,7 @@ class FusedAttnRunner:
)
if self.coll_count_ref is not None:
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
with self.mesh, autocast(mesh_resource=self.mesh_resource):
target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text()
assert_equal_collectives(target_hlo, self.coll_count_ref)
......@@ -937,6 +971,9 @@ class FusedAttnRunner:
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
pytest.param(
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT"
),
],
)
@pytest.mark.parametrize(
......@@ -958,14 +995,14 @@ class FusedAttnRunner:
),
pytest.param(
2,
2048,
512,
1024,
12,
12,
64,
64,
jnp.bfloat16,
id="2-2048-1024-12-12-64-64-BF16-CROSS",
id="2-512-1024-12-12-64-64-BF16-CROSS",
),
pytest.param(
2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA"
......
......@@ -10,20 +10,27 @@ import jax.numpy as jnp
import numpy as np
from utils import assert_allclose
from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8CurrentScaling
from transformer_engine.common.recipe import (
DelayedScaling,
MXFP8BlockScaling,
Float8CurrentScaling,
NVFP4BlockScaling,
)
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
from transformer_engine.jax import autocast
from transformer_engine.jax.quantize import (
get_quantize_config,
is_fp8_available,
is_scaling_mode_supported,
ScalingMode,
update_collections,
TensorSource,
)
from transformer_engine.jax.quantize.helper import _format2dtypes
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
class TestHelper(unittest.TestCase):
......@@ -52,14 +59,16 @@ class TestFP8Functions(unittest.TestCase):
def _check_default_state(self):
self.assertFalse(get_quantize_config().is_fp8_enabled())
def _compare_delay_scaling(self, ref, test):
self.assertTrue(ref.margin == test.margin)
self.assertTrue(ref.fp8_format == test.fp8_format)
self.assertTrue(ref.amax_history_len == test.amax_history_len)
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
def _compare_delay_scaling(self, test):
self.assertEqual(get_quantize_config().MARGIN, test.margin)
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
self.assertEqual(get_quantize_config().AMAX_HISTORY_LEN, test.amax_history_len)
self.assertEqual(get_quantize_config().AMAX_COMPUTE_ALGO.value, test.amax_compute_algo)
def _compare_current_scaling(self, test):
self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format)
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
for tensor_source in TensorSource:
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source),
......@@ -67,82 +76,100 @@ class TestFP8Functions(unittest.TestCase):
)
def _compare_mxfp8_scaling(self, test):
self.assertEqual(get_quantize_config().MARGIN, test.margin)
self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format)
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
for tensor_source in TensorSource:
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING
)
def _compare_nvfp4_scaling(self, test):
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp4_format)[0])
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp4_format)[1])
for tensor_source in TensorSource:
target_scaling_mode = (
ScalingMode.NVFP4_2D_SCALING
if tensor_source == TensorSource.KERNEL
else ScalingMode.NVFP4_1D_SCALING
)
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode
)
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_delayed_scaling(self):
def test_autocast_delayed_scaling(self):
self._check_default_state()
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling(), mesh_resource=MeshResource()):
with autocast(enabled=False, recipe=DelayedScaling(), mesh_resource=MeshResource()):
self._check_default_state()
self._check_default_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
with autocast(enabled=True, recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self._compare_delay_scaling(ds)
self._check_default_state()
ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
with autocast(enabled=True, recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self._compare_delay_scaling(ds)
self._check_default_state()
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_current_scaling(self):
def test_autocast_current_scaling(self):
self._check_default_state()
with fp8_autocast(
enabled=False, fp8_recipe=Float8CurrentScaling(), mesh_resource=MeshResource()
):
with autocast(enabled=False, recipe=Float8CurrentScaling(), mesh_resource=MeshResource()):
self._check_default_state()
self._check_default_state()
cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
with autocast(enabled=True, recipe=cs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_default_state()
cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
with autocast(enabled=True, recipe=cs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_default_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_block_scaling(self):
def test_autocast_mxfp8_block_scaling(self):
self._check_default_state()
with fp8_autocast(
enabled=False, fp8_recipe=MXFP8BlockScaling(), mesh_resource=MeshResource()
):
with autocast(enabled=False, recipe=MXFP8BlockScaling(), mesh_resource=MeshResource()):
self._check_default_state()
self._check_default_state()
bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
bs = MXFP8BlockScaling()
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_default_state()
bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
@unittest.skipIf(not is_nvfp4_supported, reason=nvfp4_reason)
def test_autocast_nvfp4_block_scaling(self):
self._check_default_state()
with autocast(enabled=False, recipe=NVFP4BlockScaling(), mesh_resource=MeshResource()):
self._check_default_state()
self._check_default_state()
bs = NVFP4BlockScaling()
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._compare_nvfp4_scaling(bs)
self._check_default_state()
......@@ -28,7 +28,7 @@ from transformer_engine.jax.quantize import (
is_fp8_available,
update_collections,
TensorSource,
fp8_autocast,
autocast,
)
from transformer_engine.jax.sharding import MeshResource
......@@ -507,14 +507,14 @@ class BaseTester:
"""Test normal datatype forward"""
# Ensure FP8 disabled.
# Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=False, mesh_resource=MeshResource()):
with autocast(enabled=False, mesh_resource=MeshResource()):
self.runner(attrs).test_forward(data_shape, dtype)
def test_backward(self, data_shape, dtype, attrs):
"""Test normal datatype backward"""
# Ensure FP8 disabled.
# Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=False, mesh_resource=MeshResource()):
with autocast(enabled=False, mesh_resource=MeshResource()):
self.runner(attrs).test_backward(data_shape, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
......@@ -522,7 +522,7 @@ class BaseTester:
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test forward with fp8 enabled"""
# Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
with autocast(enabled=True, recipe=fp8_recipe, mesh_resource=MeshResource()):
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
......@@ -530,7 +530,7 @@ class BaseTester:
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test backward with fp8 enabled"""
# Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
with autocast(enabled=True, recipe=fp8_recipe, mesh_resource=MeshResource()):
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
......
......@@ -1544,6 +1544,12 @@ def dtype_tols(
rtol = eps_relaxed
if atol is None:
atol = max(ulp, eps_relaxed)
# Manually set tols for nvfp4
if dtype == jnp.float4_e2m1fn:
atol = 0.05
rtol = 0.1
return {"rtol": rtol, "atol": atol}
......
......@@ -8,97 +8,30 @@ import logging
from contextlib import nullcontext
import torch
import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
get_cu_seqlens_on_cp_rank,
)
from transformer_engine.pytorch.attention.dot_product_attention.utils import combine_and_quantize
import transformer_engine_torch as tex
from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.pytorch import (
autocast,
DotProductAttention,
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling
from utils import ModelConfig, compare_and_assert
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
def run_dpa_with_cp(
dtype="bf16",
model=None,
qkv_format="bshd",
kernel_backend="FlashAttention",
cp_comm_type="p2p",
fp8_mha=False,
def generate_input_shapes(
qkv_format: str,
config: ModelConfig,
world_size: int,
kernel_backend: str,
):
"""Test DotProductAttention module with context parallelism"""
# args are passed as strings
fp8_mha = fp8_mha == "True"
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
config = model_configs_flash_attn[model]
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type} is an unsupported attention mask type!"
if qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
else:
device_count = torch.cuda.device_count()
device = rank % device_count
torch.cuda.set_device(device)
print(f"[INFO] world_size:{world_size}, rank:{rank}")
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
# create flash attn comm group for CP
cp_comm_ranks = range(world_size)
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if cp_comm_type == "a2a+p2p":
assert (
world_size % 2 == 0
), "Assuming CP size for A2A is 2, and CP size for P2P is (world_size // 2)!"
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
cp_comm_sub_groups = []
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group)
if dtype == "fp8":
fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha)
# instantiate core attn module
core_attn = DotProductAttention(
config.num_heads,
(config.head_dim_qk, config.head_dim_v),
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
)
core_attn = core_attn.cuda()
# create flash attn inputs
if qkv_format == "bshd":
q_input_shape = (
config.batch_size,
......@@ -191,35 +124,192 @@ def run_dpa_with_cp(
cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
assert False, f"{qkv_format=} is not supported!"
return (
q_input_shape,
k_input_shape,
v_input_shape,
attn_output_shape,
cu_seqlens_q,
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
)
q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda()
v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda()
dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda()
def get_tols(config, dtype):
if dtype == "bf16":
if config.num_heads == config.num_gqa_groups:
atol = 2.5e-2
rtol = 2.5e-2
else:
atol = 3.5e-2
rtol = 3.5e-2
rmse_tol = 0.01
elif dtype == "fp16":
atol = 5e-3
rtol = 5e-3
rmse_tol = 0.01
elif dtype == "fp8":
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.15
else:
assert False, f"{dtype=} is not supported!"
return atol, rtol, rmse_tol
def run_dpa_with_cp(
dtype="bf16",
model=None,
qkv_format="bshd",
kernel_backend="FlashAttention",
cp_comm_type="p2p",
fp8_bwd="True",
fp8_dpa="False",
fp8_mha="False",
scaling_mode="delayed",
f16_O="False",
log_level=logging.WARNING,
):
"""Test DotProductAttention module with context parallelism"""
logging.root.setLevel(log_level)
# set up environment variables and config
fp8_bwd = fp8_bwd == "True" and dtype == "fp8"
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_bwd else "0"
fp8_dpa = fp8_dpa == "True" and dtype == "fp8"
fp8_mha = fp8_mha == "True" and dtype == "fp8"
f16_O = dtype == "fp8" and scaling_mode == "current" and f16_O == "True"
os.environ["NVTE_DPA_FP8CS_O_in_F16"] = "1" if f16_O else "0"
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
config = model_configs_flash_attn[model]
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type=} is not supported!"
if qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
# set up distributed group
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
else:
device_count = torch.cuda.device_count()
device = rank % device_count
torch.cuda.set_device(device)
logging.info(f"[Rank {rank}] Setup: world_size {world_size}")
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
# set up communication group for CP
cp_comm_ranks = range(world_size)
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if cp_comm_type == "a2a+p2p":
assert world_size % 2 == 0, (
"{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size"
" = 2."
)
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
cp_comm_sub_groups = []
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group)
if dtype == "fp8":
if scaling_mode == "delayed":
fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha)
if scaling_mode == "current":
fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha)
# instantiate attention module
core_attn = DotProductAttention(
config.num_heads,
(config.head_dim_qk, config.head_dim_v),
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
softmax_type=config.softmax_type,
).cuda()
if config.softmax_type != "vanilla":
core_attn.softmax_offset.requires_grad = True
# generate attention inputs
(
q_input_shape,
k_input_shape,
v_input_shape,
attn_output_shape,
cu_seqlens_q,
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
) = generate_input_shapes(qkv_format, config, world_size, kernel_backend)
q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
dout_orig = torch.clamp(
torch.randn(attn_output_shape, dtype=dtypes[dtype]), min=-1, max=1
).cuda()
if scaling_mode == "delayed":
qkv_quantizer = Float8Quantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
scale=torch.tensor([1], dtype=torch.float32).cuda(),
amax=torch.tensor([0], dtype=torch.float32).cuda(),
)
dout_quantizer = Float8Quantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
scale=torch.tensor([1], dtype=torch.float32).cuda(),
amax=torch.tensor([0], dtype=torch.float32).cuda(),
)
if scaling_mode == "current":
qkv_quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device="cuda",
)
dout_quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
device="cuda",
)
qkv_layout = "_".join([qkv_format] * 3)
q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]]
if fp8_mha:
q, k, v = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer)
for x in [q, k, v]:
x.requires_grad = True
# create flash attention bias
if config.attn_bias_type not in ["no_bias", "alibi"]:
attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv)
bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda()
else:
bias = None
# run core_attn without CP
for x in [q, k, v]:
x.requires_grad = True
############ run without CP ############
logging.info(f"[Rank {rank}] Run without context parallelism")
if dtype == "fp8":
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group)
else:
fp8_context = nullcontext()
with fp8_context:
# q, k, v, out in FP8; dout in F16
out = core_attn(
q,
k,
......@@ -230,16 +320,25 @@ def run_dpa_with_cp(
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
fp8_output=fp8_mha,
)
if fp8_mha:
if fp8_bwd and fp8_mha:
dout_fp8 = dout_quantizer(dout)
out.backward(dout_fp8)
else:
out.backward(dout)
dq, dk, dv = q.grad, k.grad, v.grad
d_softmax_offset = None
if config.softmax_type != "vanilla":
d_softmax_offset = core_attn.softmax_offset.grad
############ run with CP ############
logging.info(f"[Rank {rank}] Run with context parallelism")
# run core_attn wit CP
# set up inputs
q_, k_, v_, dout_, *rest = [
x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])
x.clone().detach()
for x in [q_orig, k_orig, v_orig, dout_orig] + ([] if bias is None else [bias])
]
bias_ = rest[0] if len(rest) else None
if qkv_format == "bshd" or qkv_format == "sbhd":
......@@ -269,6 +368,14 @@ def run_dpa_with_cp(
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q_, k_, v_, dout_ = [x.contiguous() for x in [q_, k_, v_, dout_]]
if scaling_mode == "delayed":
qkv_quantizer.scale.fill_(1.0)
qkv_quantizer.amax.fill_(0.0)
dout_quantizer.scale.fill_(1.0)
dout_quantizer.amax.fill_(0.0)
if fp8_mha:
q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer)
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
if bias_ is not None:
bias_ = bias_.view(
......@@ -276,20 +383,25 @@ def run_dpa_with_cp(
)
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
# set up environment
core_attn.set_context_parallel_group(
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
cp_comm_ranks,
torch.cuda.Stream(),
cp_comm_type,
)
if config.softmax_type != "vanilla":
core_attn.softmax_offset.grad.zero_()
if dtype == "fp8":
core_attn.reset_fp8_meta_tensors()
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
core_attn.fp8_initialized = False
core_attn.fp8_meta_tensors_initialized = False
fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group)
else:
fp8_context = nullcontext()
# run attention
with fp8_context:
# q, k, v, out in FP8; dout in F16
out_ = core_attn(
q_,
k_,
......@@ -300,24 +412,32 @@ def run_dpa_with_cp(
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
fp8_output=fp8_mha,
)
if fp8_mha:
if fp8_bwd and fp8_mha:
dout_fp8_ = dout_quantizer(dout_)
out_.backward(dout_fp8_)
else:
out_.backward(dout_)
dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad
d_softmax_offset_ = None
if config.softmax_type != "vanilla":
d_softmax_offset_ = core_attn.softmax_offset.grad.clone()
# get outputs
tensors = [out, dq, dk, dv, out_, dq_, dk_, dv_]
if fp8_mha:
assert isinstance(out, Float8Tensor)
assert isinstance(out_, Float8Tensor)
out = out.dequantize()
out_ = out_.dequantize()
for x in [out_, q_.grad, k_.grad, v_.grad]:
assert torch.all(~torch.isnan(x))
assert torch.all(~torch.isinf(x))
# compare results with and without CP
tensors_to_deq = [out, out_] if not fp8_bwd else tensors
for i, tensor in enumerate(tensors_to_deq):
tensors_to_deq[i] = tensor.dequantize()
if not fp8_bwd:
tensors[0], tensors[4] = tensors_to_deq
for tensor in tensors:
assert torch.all(~torch.isnan(tensor))
assert torch.all(~torch.isinf(tensor))
out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors
############ compare results between CP and no-CP ############
if qkv_format == "bshd" or qkv_format == "sbhd":
dq, dk, dv, out = [
x.view(
......@@ -326,17 +446,17 @@ def run_dpa_with_cp(
x.shape[seq_dim] // (2 * world_size),
*x.shape[(seq_dim + 1) :],
)
for x in [q.grad, k.grad, v.grad, out]
for x in [dq, dk, dv, out]
]
dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]]
dq_, dk_, dv_, out_ = [
x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :])
for x in [q_.grad, k_.grad, v_.grad, out_]
for x in [dq_, dk_, dv_, out_]
]
elif qkv_format == "thd":
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]]
dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_]
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]]
dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_]
cu_seqlens_q_padded = cu_seqlens_q_padded // world_size
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
......@@ -373,56 +493,70 @@ def run_dpa_with_cp(
).item()
== 0
)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
if dtype == "bf16":
if config.num_heads == config.num_gqa_groups:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=3.5e-2, rtol=3.5e-2)
elif dtype == "fp16":
tols = dict(atol=5e-3, rtol=5e-3)
elif dtype == "fp8":
tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
else:
assert False, f"{dtype} is an unsupported dtype!"
def _rmse(a, b):
return torch.sqrt((a - b).square().mean()).item()
def _error(a, b):
if dtype != "fp8":
torch.testing.assert_close(a, b, **tols)
else:
try:
torch.testing.assert_close(a, b, **tols)
except Exception as e:
logging.debug(e)
rmse = _rmse(a, b)
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert (
rmse < rmse_tol * rmse_range
), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
atol, rtol, rmse_tol = get_tols(config, dtype)
tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_]
tensors_no_cp = [out, dq, dk, dv, d_softmax_offset]
names = ["out", "dq", "dk", "dv", "d_softmax_offset"]
names_cp = [x + "_cp" for x in names]
names_no_cp = [x + "_no_cp" for x in names]
is_fp8 = dtype == "fp8"
for i, t in enumerate(tensors_no_cp):
if t is not None:
if "softmax_offset" not in names[i]:
if qkv_format == "bshd":
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
_error(a[:, 0], b[:, 0])
_error(a[:, 1], b[:, 1])
compare_and_assert(
t[:, 0],
tensors_cp[i][:, 0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[:, 1],
tensors_cp[i][:, 1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
elif qkv_format == "sbhd":
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
_error(a[0], b[0])
_error(a[1], b[1])
compare_and_assert(
t[0],
tensors_cp[i][0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[1],
tensors_cp[i][1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
elif qkv_format == "thd":
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
_error(a, b)
compare_and_assert(
t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
compare_and_assert(
t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
)
logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches")
# destroy distribution group
dist.destroy_process_group()
......
......@@ -2,7 +2,6 @@
#
# See LICENSE for license information.
import logging
import math
import os
import sys
import pathlib
......@@ -11,13 +10,22 @@ from typing import Any, Dict, Tuple, Union
import pytest
import torch
from transformer_engine.pytorch.quantization import FP8GlobalStateManager, get_fp8_te_dtype
from transformer_engine.common import recipe
from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init
from transformer_engine.pytorch.attention.dot_product_attention import (
from transformer_engine.pytorch import (
TransformerLayer,
autocast,
quantized_model_init,
DotProductAttention,
MultiheadAttention,
get_device_compute_capability,
Quantizer,
is_fp8_available,
is_bf16_available,
)
from transformer_engine.pytorch.attention.dot_product_attention import (
_attention_backends,
)
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
FlashAttentionUtils,
check_set_window_size,
......@@ -30,18 +38,14 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd,
)
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
)
from transformer_engine.pytorch.utils import get_cudnn_version
import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.quantized_tensor import (
Quantizer,
prepare_for_saving,
restore_from_saved,
)
......@@ -50,27 +54,35 @@ _current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import (
reset_rng_states,
compare_and_assert,
ModelConfig,
dtype_tols,
get_available_attention_backends,
)
# Only run FP8 tests on H100
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
# Check if hardware supports FP8
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
# Reset RNG seed and states
seed = 1234
# Reset RNG states
reset_rng_states()
# Reset FP8 global state manager
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
fp8.FP8GlobalStateManager.reset()
FP8GlobalStateManager.reset()
# Define F16 data types to test
param_types = [torch.float16]
if is_bf16_available():
param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]
model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"base_1_0": ModelConfig(8, 128, 16, 64),
"base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256),
"base_2_0": ModelConfig(2, 2048, 24, 128),
......@@ -86,12 +98,6 @@ model_configs_base = {
}
param_types = [torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
......@@ -125,12 +131,12 @@ def test_dot_product_attention(
config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
# Get backends
is_training = True
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
)
......@@ -141,7 +147,6 @@ def test_dot_product_attention(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
)
......@@ -227,6 +232,7 @@ def test_dot_product_attention(
is_training,
)
# Compare results
logging.info(f"[test_dot_product_attention]: is_training = {is_training}")
if unfused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
......@@ -259,6 +265,85 @@ def test_dpa_checkpoint(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
model_configs_softmax = {
# test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
"softmax_1_1": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="off-by-one"),
"softmax_1_2": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="learnable"),
"softmax_2_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal"),
"softmax_2_1": ModelConfig(
2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one"
),
"softmax_2_2": ModelConfig(
2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
),
"softmax_3_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding"),
"softmax_3_1": ModelConfig(
2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="off-by-one"
),
"softmax_3_2": ModelConfig(
2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="learnable"
),
"softmax_4_0": ModelConfig(
2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="causal"
),
"softmax_4_1": ModelConfig(
2,
2048,
64,
64,
num_gqa_groups=8,
window_size=(128, 0),
attn_mask_type="causal",
softmax_type="off-by-one",
),
"softmax_4_2": ModelConfig(
2,
2048,
64,
64,
num_gqa_groups=8,
window_size=(128, 0),
attn_mask_type="causal",
softmax_type="learnable",
),
"softmax_5_0": ModelConfig(
2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="padding_causal"
),
"softmax_5_1": ModelConfig(
2,
2048,
64,
64,
num_gqa_groups=8,
window_size=(128, 0),
attn_mask_type="padding_causal",
softmax_type="off-by-one",
),
"softmax_5_2": ModelConfig(
2,
2048,
64,
64,
num_gqa_groups=8,
window_size=(128, 0),
attn_mask_type="padding_causal",
softmax_type="learnable",
),
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("model_configs", [model_configs_softmax])
@pytest.mark.parametrize("model", model_configs_softmax.keys())
def test_dpa_softmax(dtype, model_configs, model):
"""Test DotProductAttention module with different softmax types"""
test_dot_product_attention(
dtype, model_configs, model, True, True, "bshd_bshd_bshd", False, False
)
model_configs_mla = {
#TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
......@@ -290,7 +375,7 @@ def test_dpa_mla(dtype, model_configs, model):
model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
"mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"),
"mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
......@@ -345,18 +430,16 @@ def test_dpa_mask(dtype, model_configs, model):
model_configs_bias = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"),
"bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"),
"bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"),
"bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"),
"bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"), # skipped
"bias_1_5": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"
), # skipped
"bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"),
"bias_1_5": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"),
"bias_2_0": ModelConfig(
4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
), # skipped
),
"bias_2_1": ModelConfig(
2,
128,
......@@ -365,10 +448,10 @@ model_configs_bias = {
max_seqlen_kv=256,
attn_mask_type="padding",
attn_bias_type="post_scale_bias",
), # skipped
),
"bias_2_2": ModelConfig(
4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias"
), # skipped
),
"bias_2_3": ModelConfig(
2,
2048,
......@@ -377,13 +460,11 @@ model_configs_bias = {
max_seqlen_kv=4096,
attn_mask_type="padding",
attn_bias_type="post_scale_bias",
), # skipped
"bias_2_4": ModelConfig(
4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"
), # skipped
),
"bias_2_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"),
"bias_2_5": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi"
), # skipped
),
"bias_3_0": ModelConfig(
4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
......@@ -401,14 +482,14 @@ model_configs_bias = {
max_seqlen_kv=4096,
attn_mask_type="causal",
attn_bias_type="post_scale_bias",
), # skipped
),
"bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"),
"bias_3_5": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi"
), # skipped
),
"bias_4_0": ModelConfig(
4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
), # skipped
),
"bias_4_1": ModelConfig(
2,
128,
......@@ -417,10 +498,10 @@ model_configs_bias = {
max_seqlen_kv=256,
attn_mask_type="padding_causal",
attn_bias_type="post_scale_bias",
), # skipped
),
"bias_4_2": ModelConfig(
4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
), # skipped
),
"bias_4_3": ModelConfig(
2,
2048,
......@@ -429,10 +510,10 @@ model_configs_bias = {
max_seqlen_kv=4096,
attn_mask_type="padding_causal",
attn_bias_type="post_scale_bias",
), # skipped
),
"bias_4_4": ModelConfig(
4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi"
), # skipped
),
"bias_4_5": ModelConfig(
2,
2048,
......@@ -441,7 +522,7 @@ model_configs_bias = {
max_seqlen_kv=4096,
attn_mask_type="padding_causal",
attn_bias_type="alibi",
), # skipped
),
}
......@@ -455,7 +536,7 @@ def test_dpa_bias(dtype, model_configs, model):
model_configs_bias_shapes = {
# test: b, h, hg, d, sq, skv, p,
# test: ModelConfig(b, sq, hq, dqk)
"bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"),
"bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"),
"bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"),
......@@ -493,7 +574,7 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"swa_1_1": ModelConfig(2, 2048, 16, 64),
"swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4),
"swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096),
......@@ -533,7 +614,7 @@ def test_dpa_sliding_window(dtype, model_configs, model):
model_configs_alibi_slopes = {
# test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type
# test: ModelConfig(b, sq, hq, dqk)
"alibi_1_0": ModelConfig(
2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla"
),
......@@ -587,7 +668,7 @@ qkv_layouts = [
model_configs_layout = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"layout_0_0": ModelConfig(2, 128, 16, 64),
"layout_0_1": ModelConfig(
2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
......@@ -635,7 +716,7 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
model_configs_layout_thd = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
"layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
"layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
......@@ -727,7 +808,6 @@ def _run_dot_product_attention(
is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass"""
# Set RNG and environment varables
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
......@@ -990,9 +1070,12 @@ def _run_dot_product_attention(
tp_group=None,
layer_number=1,
attention_type=config.attn_type,
softmax_type=config.softmax_type,
).to(dtype=dtype, device="cuda")
if not is_training:
block = block.eval()
if is_training and config.softmax_type != "vanilla":
block.softmax_offset.requires_grad = True
# Run a forward and backward pass
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
......@@ -1027,12 +1110,14 @@ def _run_dot_product_attention(
)
if is_training:
out.backward(d_out)
d_softmax_offset = None
if is_training and config.softmax_type != "vanilla":
d_softmax_offset = block.softmax_offset.grad
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if is_training:
return out, (q.grad, k.grad, v.grad)
return out, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
return out, (None, None, None)
return out, (None, None, None, d_softmax_offset)
if backend == "FusedAttention":
if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
......@@ -1061,18 +1146,18 @@ def _run_dot_product_attention(
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
if is_training:
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig)
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset)
else:
return out_orig, (None, None, None)
return out_orig, (None, None, None, d_softmax_offset)
else:
if is_training:
return out, (q.grad, k.grad, v.grad)
return out, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
return out, (None, None, None)
return out, (None, None, None, d_softmax_offset)
model_configs_te_layer = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"),
"te_1_1": ModelConfig(
4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
......@@ -1437,6 +1522,7 @@ def _run_transformer_layer(
model_configs_fp8_extra_state = {
# test: ModelConfig(b, sq, hq, dqk)
"large": ModelConfig(2, 128, 4, 128, num_layers=1),
}
......@@ -1446,7 +1532,8 @@ model_configs_fp8_extra_state = {
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
def test_dpa_fp8_extra_state(model, dtype):
"""Test DotProductAttention module in FP8 with checkpointing"""
config = model_configs_fp8_extra_state[model]
# Test backend availability
is_training = True
......@@ -1460,9 +1547,9 @@ def test_sanity_attention_extra_state(model, dtype):
if not fused_attn_supported and not flash_attn_supported:
pytest.skip("No attention backend available.")
outputs = _run_attention_extra_state(dtype, config, checkpoint=False)
outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True)
outputs_checkpoint_v1_6 = _run_attention_extra_state(
outputs = _run_dpa_fp8_extra_state(dtype, config, checkpoint=False)
outputs_checkpoint = _run_dpa_fp8_extra_state(dtype, config, checkpoint=True)
outputs_checkpoint_v1_6 = _run_dpa_fp8_extra_state(
dtype, config, mimic_v1_6=True, checkpoint=True
)
......@@ -1484,7 +1571,8 @@ def test_sanity_attention_extra_state(model, dtype):
)
def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
def _run_dpa_fp8_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
"""Run DotProductAttention module in FP8 with checkpointing"""
steps = 10
path = "checkpoint.pt"
fp8_enabled = True
......@@ -1510,7 +1598,7 @@ def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
with quantized_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
......@@ -1527,7 +1615,7 @@ def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False
block = get_model(dtype, config)
for i in range(steps // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
with autocast(enabled=fp8_enabled, recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()
......@@ -1562,7 +1650,7 @@ def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False
assert not param_grads, "Oops!"
for i in range((steps + 1) // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
with autocast(enabled=fp8_enabled, recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()
......@@ -1581,7 +1669,7 @@ def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False
model_configs_fp8_vs_f16 = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"fp8_9": ModelConfig(2, 2048, 16, 128),
"fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),
"fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4),
......@@ -1601,33 +1689,6 @@ qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"]
qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
def _rmse(a, b):
return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum())
def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item()))
logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item()))
try:
if a.dtype != b.dtype:
a = a.to(b.dtype)
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
except Exception as e:
logging.debug(e)
rmse = _rmse(a, b)
logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse))
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert rmse < rmse_tol * rmse_range, (
name_a
+ " vs "
+ name_b
+ " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
)
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
......@@ -1638,22 +1699,44 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
@pytest.mark.parametrize("RoPE", [True, False])
@pytest.mark.parametrize("is_training", [True, False])
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training):
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
def test_mha_fp8_vs_f16(
dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training, scaling_mode
):
"""Test MultiHeadAttention module in FP8"""
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model]
# Test backend availability
if scaling_mode == "delayed":
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=True,
fp8_mha=True,
)
elif scaling_mode == "current":
fp8_recipe = recipe.Float8CurrentScaling(
fp8_format=recipe.Format.HYBRID,
fp8_dpa=True,
fp8_mha=True,
)
fp8_meta = {}
fp8_meta["recipe"] = fp8_recipe
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=torch.float8_e4m3fn,
qkv_layout=qkv_format.replace("hd", "h3d"),
fp8=True,
fp8_meta=fp8_meta,
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
if flash_attn_supported + fused_attn_supported < 1:
pytest.skip("No FP8 attention backend available.")
if not fp8_dpa_bwd:
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
......@@ -1671,7 +1754,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
)
os.environ["NVTE_FLASH_ATTN"] = "0"
......@@ -1679,20 +1762,21 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
)
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm, RoPE, is_training
dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
)
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output"))
if flash_attn_supported:
_error(
logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
"flash_attn_fwd_fp8",
......@@ -1700,8 +1784,11 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
atol,
rtol,
rmse_tol,
True,
)
_error(
logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert(
fused_attn_fwd_fp8,
fused_attn_fwd_f16,
"fused_attn_fwd_fp8",
......@@ -1709,12 +1796,13 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
atol,
rtol,
rmse_tol,
True,
)
if is_training:
for i in range(len(param_names[:1])):
logging.debug("========== {:^25s} ==========".format(param_names[i]))
_error(
compare_and_assert(
fused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i],
f"fused_attn_bwd_fp8[{i}]",
......@@ -1722,10 +1810,14 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
atol,
rtol,
rmse_tol,
True,
)
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training):
def _run_mha_fp8_vs_f16(
dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
):
"""Run MultiHeadAttention module in FP8"""
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
......@@ -1734,16 +1826,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=fp8_mha,
fp8_mha=fp8_mha,
)
with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe):
with quantized_model_init(enabled=fp8_mha, recipe=fp8_recipe):
rotary_pos_emb = None
if RoPE:
PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
......@@ -1815,7 +1898,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
out_grad = tensor.view(*tensor.shape[:-2], -1)
with fp8_autocast(enabled=fp8_mha, fp8_recipe=fp8_recipe):
with autocast(enabled=fp8_mha, recipe=fp8_recipe):
out = mha(
hidden_states,
attn_mask_type=config.attn_mask_type,
......@@ -1851,7 +1934,9 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
@pytest.mark.parametrize("is_training", [True, False])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode):
"""Test DotProductAttention module in FP8"""
config = model_configs_fp8_vs_f16[model]
# TODO(cyang): think of another way to verify dropout results
......@@ -1866,16 +1951,33 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1"
# Test backend availability
if scaling_mode == "delayed":
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=True,
)
elif scaling_mode == "current":
fp8_recipe = recipe.Float8CurrentScaling(
fp8_format=recipe.Format.HYBRID,
fp8_dpa=True,
)
fp8_meta = {}
fp8_meta["recipe"] = fp8_recipe
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=torch.float8_e4m3fn,
qkv_layout=qkv_layout,
fp8=True,
fp8_meta=fp8_meta,
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# Skip if only unfused backend is supported
if flash_attn_supported + fused_attn_supported < 1:
pytest.skip("No FP8 attention backend available.")
if not fp8_dpa_bwd:
......@@ -1895,33 +1997,45 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)")
flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout, is_training
dtype, config, True, qkv_layout, is_training, fp8_recipe
)
if unfused_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)")
unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout, is_training, fp8_recipe
)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout, is_training
dtype, config, True, qkv_layout, is_training, fp8_recipe
)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
if config.dropout_p == 0.0:
# test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False")
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)")
fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
dtype, config, False, qkv_layout, is_training
dtype, config, False, qkv_layout, is_training, fp8_recipe
)
atol = 5e-1
rtol = 5e-2
rmse_tol = 0.11
bwd_names = ["dq", "dk", "dv"]
logging.debug("========== {:^25s} ==========".format("forward output"))
if flash_attn_supported:
_error(
logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
"flash_attn_fwd_fp8",
......@@ -1929,6 +2043,33 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
atol,
rtol,
rmse_tol,
True,
)
if unfused_attn_supported:
logging.debug("========== {:^25s} ==========".format("unfused fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert(
unfused_attn_fwd_fp8,
fused_attn_fwd_f16,
"unfused_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
True,
)
if is_training:
for i, _ in enumerate(fused_attn_bwd_f16):
logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
compare_and_assert(
unfused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i],
f"unfused_attn_bwd_fp8[{i}]",
f"fused_attn_bwd_f16[{i}]",
atol,
rtol,
rmse_tol,
True,
)
if config.dropout_p != 0.0:
# test cuDNN FP8 dropout
......@@ -1936,7 +2077,9 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
fused_attn_fwd_fp8 == 1
), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s."
else:
_error(
logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert(
fused_attn_fwd_fp8,
fused_attn_fwd_f16,
"fused_attn_fwd_fp8",
......@@ -1944,11 +2087,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
atol,
rtol,
rmse_tol,
True,
)
if is_training:
for i, _ in enumerate(fused_attn_bwd_f16):
logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
_error(
compare_and_assert(
fused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i],
f"fused_attn_bwd_fp8[{i}]",
......@@ -1956,11 +2100,13 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
atol,
rtol,
rmse_tol,
True,
)
os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "0"
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training, fp8_recipe):
"""Run DotProductAttention module in FP8"""
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
......@@ -1969,16 +2115,8 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=fp8_dpa,
)
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
with fp8_model_init(enabled=fp8_dpa):
with quantized_model_init(enabled=fp8_dpa):
dpa = DotProductAttention(
config.num_heads,
config.head_dim_qk,
......@@ -2070,7 +2208,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")
with fp8_autocast(enabled=fp8_dpa, fp8_recipe=fp8_recipe):
with autocast(enabled=fp8_dpa, recipe=fp8_recipe):
out = dpa(
inp[0],
inp[1],
......@@ -2083,6 +2221,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type,
fp8_output=fp8_dpa,
)
if is_training:
out.backward(out_grad)
......@@ -2093,7 +2232,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
model_configs_fp8 = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"fp8_1": ModelConfig(1, 512, 1, 64),
"fp8_2": ModelConfig(4, 512, 16, 64),
"fp8_3": ModelConfig(1, 2048, 1, 128),
......@@ -2148,7 +2287,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.13
_error(
compare_and_assert(
fused_attn_fwd_fp8,
unfused_attn_fwd_f16,
"fused_attn_fwd_fp8",
......@@ -2156,8 +2295,9 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
atol,
rtol,
rmse_tol,
True,
)
_error(
compare_and_assert(
fused_attn_bwd_fp8,
unfused_attn_bwd_f16,
"fused_attn_bwd_fp8",
......@@ -2165,6 +2305,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
atol,
rtol,
rmse_tol,
True,
)
......@@ -2208,7 +2349,7 @@ def _run_custom_mha_fp8(dtype, config, backend):
)
mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
with autocast(enabled=True, recipe=fp8_recipe):
out = mha(inp, cu_seqlens, config.max_seqlen_q)
out.backward(out_grad)
......@@ -2406,7 +2547,7 @@ class _custom_mha_fp8(torch.autograd.Function):
)
proj_dgrad = ctx.dO_quantizer(grad_output)
fp8_dtype_backward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_s,
......
......@@ -6,27 +6,36 @@ import os
import subprocess
import sys
import pathlib
import logging
import pytest
import torch
from transformer_engine.pytorch.utils import (
from transformer_engine.pytorch import (
get_device_compute_capability,
get_cudnn_version,
)
from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
)
from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig, get_available_attention_backends
pytest_logging_level = logging.getLevelName(logging.root.level)
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
from torch.utils.cpp_extension import IS_HIP_EXTENSION
test_essential = True
model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
......@@ -61,18 +70,31 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
return args
dtypes = ["bf16", "fp16"]
qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_2_1", "cp_3_2", "cp_3_3"]
model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs}
dtypes = ["bf16"]
qkv_formats = ["sbhd", "thd"]
@pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.")
@pytest.mark.skipif(not IS_HIP_EXTENSION and get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
@pytest.mark.parametrize("qkv_format", qkv_formats)
@pytest.mark.parametrize("cp_comm_type", cp_comm_types)
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2
if num_gpus > torch.cuda.device_count():
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")
config = model_configs_flash_attn[model]
config.context_parallel = True
config.cp_comm_type = cp_comm_type
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
......@@ -90,6 +112,15 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
)
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently only support KV P2P!")
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16}
available_backends, *_ = get_available_attention_backends(
config,
qkv_dtype=dtypes[dtype],
qkv_layout="_".join([qkv_format] * 3),
)
flash_attn_supported, *_ = available_backends
if not flash_attn_supported:
pytest.skip("No attention backend available.")
subprocess.run(
get_bash_arguments(
......@@ -99,13 +130,14 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
qkv_format=qkv_format,
kernel_backend="FlashAttention",
cp_comm_type=cp_comm_type,
log_level=pytest_logging_level,
),
check=True,
)
model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig(
......@@ -136,17 +168,42 @@ model_configs_fused_attn = {
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64
), # MLA
"cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA
"cp_4_0": ModelConfig(
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="vanilla"
), # GQA
"cp_4_1": ModelConfig(
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one"
), # GQA
"cp_4_2": ModelConfig(
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
), # GQA
}
dtypes = ["bf16", "fp16", "fp8"]
qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
dtypes = ["bf16", "fp8"]
qkv_formats = ["sbhd", "thd"]
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(IS_HIP_EXTENSION or get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"])
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
@pytest.mark.parametrize("fp8_mha", [False, True])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha):
@pytest.mark.parametrize("qkv_format", qkv_formats)
@pytest.mark.parametrize("cp_comm_type", cp_comm_types)
@pytest.mark.parametrize("fp8_bwd", [True, False])
@pytest.mark.parametrize("fp8_mha", [True, False])
@pytest.mark.parametrize("fp8_dpa", [True, False])
@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current"])
@pytest.mark.parametrize("f16_O", [True, False])
def test_cp_with_fused_attention(
dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O
):
num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2
if num_gpus > torch.cuda.device_count():
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")
......@@ -157,8 +214,15 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")
if dtype == "fp8" and get_device_compute_capability() < (9, 0):
pytest.skip("FP8 attention is only supported on sm90+!")
if dtype == "fp8" and not fp8_dpa and fp8_mha:
pytest.skip("Duplicate tests to fp8_dpa=True and fp8_mha=True!")
if dtype != "fp8" and fp8_bwd:
pytest.skip("Only fp8 works with fp8_bwd=True!")
config = model_configs_fused_attn[model]
config.context_parallel = True
config.cp_comm_type = cp_comm_type
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
......@@ -186,19 +250,57 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)
if dtype != "fp8" and fp8_mha:
pytest.skip("Only fp8 works with fp8_mha=True!")
if dtype != "fp8" and (fp8_mha or fp8_dpa):
pytest.skip("Only fp8 works with fp8_dpa=True or fp8_mha=True!")
if dtype == "fp8" and not (fp8_mha or fp8_dpa):
pytest.skip("fp8 only works with fp8_dpa=True or fp8_mha=True!")
if dtype != "fp8" and scaling_mode is not None:
pytest.skip("Only fp8 works with scaling_mode != None!")
if dtype == "fp8" and scaling_mode is None:
pytest.skip("fp8 only works with scaling_mode != None!")
if (
dtype == "fp8"
and scaling_mode == "current"
and cp_comm_type not in ["p2p", "a2a+p2p", "a2a"]
):
pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!")
if f16_O and (dtype != "fp8" or scaling_mode != "current"):
pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!")
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently only support KV P2P!")
if dtype == "fp8" and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently does not support FP8 attention!")
if dtype == "fp8" and config.softmax_type != "vanilla":
pytest.skip("CP implementation does not support non-vanilla softmax types in FP8!")
if config.softmax_type != "vanilla" and cp_comm_type != "a2a":
pytest.skip(
"CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!"
)
if config.softmax_type != "vanilla" and qkv_format == "thd":
pytest.skip(
"CP implementation does not support qkv_format=thd for non-vanilla softmax types!"
)
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
fp8_meta = {}
fp8_meta["recipe"] = None
fp8_meta["local_recipes"] = []
fp8 = dtype == "fp8" and (fp8_dpa or fp8_mha)
if fp8 and scaling_mode == "delayed":
fp8_meta["recipe"] = DelayedScaling(fp8_dpa=True)
fp8_meta["local_recipes"] = [DelayedScaling(fp8_dpa=True)]
if fp8 and scaling_mode == "current":
fp8_meta["recipe"] = DelayedScaling(fp8_dpa=True)
fp8_meta["local_recipes"] = [
Float8CurrentScaling(fp8_dpa=True),
DelayedScaling(fp8_dpa=True),
]
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtypes[dtype],
qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn,
qkv_layout="_".join([qkv_format] * 3),
window_size=config.window_size,
context_parallel=True,
fp8=fp8,
fp8_meta=fp8_meta,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
......@@ -212,7 +314,12 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
qkv_format=qkv_format,
kernel_backend="FusedAttention",
cp_comm_type=cp_comm_type,
fp8_bwd=fp8_bwd,
fp8_dpa=fp8_dpa,
fp8_mha=fp8_mha,
scaling_mode=scaling_mode,
f16_O=f16_O,
log_level=pytest_logging_level,
),
check=True,
)
......@@ -5,7 +5,6 @@
"""Unit tests for context parallel utils."""
import torch
import unittest
from typing import Tuple
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
get_batch_on_this_cp_rank,
pad_thd_sequences_for_cp,
......
......@@ -14,20 +14,22 @@ import pytest
import torch
from torch.distributions import Exponential
from transformer_engine.pytorch import make_graphed_callables
from transformer_engine.common import recipe
from transformer_engine.pytorch import fp8_autocast, fp8_model_init
from transformer_engine.pytorch.transformer import (
from transformer_engine.pytorch import (
make_graphed_callables,
autocast,
quantized_model_init,
TransformerLayer,
DotProductAttention,
InferenceParams,
is_bf16_available,
)
from transformer_engine.pytorch.attention import DotProductAttention, InferenceParams
from transformer_engine.common import recipe
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
FlashAttentionUtils as fa_utils,
)
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
)
_current_file = pathlib.Path(__file__).resolve()
......@@ -42,7 +44,7 @@ from utils import (
reset_rng_states()
param_types = [torch.float16]
if is_bf16_compatible():
if is_bf16_available():
param_types.append(torch.bfloat16)
model_configs_infer = {
......@@ -238,7 +240,7 @@ def get_model(
if module == "TransformerLayer":
hidden_size = config.head_dim_qk * config.num_heads
with fp8_model_init(enabled=is_fp8, recipe=fp8_recipe):
with quantized_model_init(enabled=is_fp8, recipe=fp8_recipe):
model = [
TransformerLayer(
hidden_size=hidden_size,
......@@ -261,7 +263,7 @@ def get_model(
for layer_number in range(1, num_layers + 1)
]
if module == "DotProductAttention":
with fp8_model_init(enabled=is_fp8, recipe=fp8_recipe):
with quantized_model_init(enabled=is_fp8, recipe=fp8_recipe):
model = [
DotProductAttention(
kv_channels=config.head_dim_qk,
......@@ -469,7 +471,6 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=False,
is_training=False,
fp8=is_fp8,
......@@ -560,9 +561,9 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
model[i],
sample_args,
num_warmup_iters=10,
fp8_enabled=is_fp8,
enabled=is_fp8,
sample_kwargs=sample_kwargs,
fp8_recipe=fp8_recipe,
recipe=fp8_recipe,
)
for i in range(num_layers)
]
......@@ -655,7 +656,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
if inference_params.is_paged:
inference_params.cache_manager.print_cache()
incremental_output = incremental_inputs
with fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe):
with autocast(enabled=is_fp8, recipe=fp8_recipe):
for m in model:
incremental_output = m(
*incremental_output,
......
......@@ -16,7 +16,7 @@ import transformer_engine
import transformer_engine_torch as tex
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug import set_weight_tensor_tp_group_reduce
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch import is_fp8_available
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from test_numerics import (
......@@ -46,7 +46,8 @@ FEATURE_DIRS = None
all_boolean = [True, False]
TEST_NR = 0
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
fp8_available = is_fp8_available()
def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None):
if tp_size is None:
......@@ -127,7 +128,7 @@ class AllGather(torch.autograd.Function):
def _run_forward_backward(x, model, parallel_mode=None, group=None):
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
with transformer_engine.pytorch.autocast(enabled=True, recipe=FP8_RECIPE):
y = model(x)
y.requires_grad_(True)
......@@ -422,13 +423,13 @@ def test_log_expert_parallel(**kwargs):
) # data parallel
model = _init_model(weight, parallel_mode=None, name="linear1")
model1 = _init_model(weight, parallel_mode=None, name="linear2")
with transformer_engine.pytorch.fp8_autocast(enabled=fp8_available, fp8_recipe=FP8_RECIPE):
with transformer_engine.pytorch.autocast(enabled=fp8_available, recipe=FP8_RECIPE):
y1 = model(x)
y2 = model1(x)
y = y1 + y2
y.sum().backward()
debug_api.step()
with transformer_engine.pytorch.fp8_autocast(enabled=fp8_available, fp8_recipe=FP8_RECIPE):
with transformer_engine.pytorch.autocast(enabled=fp8_available, recipe=FP8_RECIPE):
y = model(x)
if WORLD_RANK != 0:
y = y + model1(x)
......@@ -541,7 +542,7 @@ def test_per_tensor_scaling(
LOSS_MULTIPLIER = 100
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
with transformer_engine.pytorch.autocast(enabled=True, recipe=FP8_RECIPE):
y = model(x)
model.zero_grad()
if parallel_mode == "column":
......
......@@ -3,7 +3,7 @@
# See LICENSE for license information.
import torch
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.pytorch import Float8Tensor, Float8Quantizer
import nvdlfw_inspect.api as debug_api
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment