Unverified Commit 8ec01e5e authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[JAX] Use FP8 tolerances in FP8 tests (#501)



* Use FP8 tolerances in JAX FP8 tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Programmatically compute expected floating point error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Loosen tolerance for MNIST test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 325bf911
......@@ -280,7 +280,7 @@ class TestMNIST(unittest.TestCase):
"""Check If loss and accuracy match target"""
desired_traing_loss = 0.055
desired_traing_accuracy = 0.98
desired_test_loss = 0.035
desired_test_loss = 0.04
desired_test_accuracy = 0.098
assert actual[0] < desired_traing_loss
assert actual[1] > desired_traing_accuracy
......
......@@ -48,7 +48,7 @@ class TestFP8Dot:
scale_inv = (1 / scale).reshape(1)
y, new_amax = quantize(x, amax, scale, scale_inv, out_dtype=DType.kFloat8E4M3)
assert_allclose(new_amax, 3.0)
assert_allclose(new_amax, 3.0, rtol=0, atol=0)
no_use = jnp.zeros(1, jnp.float32)
z = dequantize(y,
......@@ -57,7 +57,7 @@ class TestFP8Dot:
scale_inv,
fp8_dtype=DType.kFloat8E4M3,
out_dtype=DType.kFloat32)
assert_allclose(z, x, rtol=5e-2, atol=5e-2)
assert_allclose(z, x, dtype=DType.kFloat8E4M3)
def test_compile_bf16(self):
key = jax.random.PRNGKey(0)
......@@ -117,10 +117,11 @@ class TestFP8Dot:
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, a, [b], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
primitive_out = fp8_dot(fp8_gemm_pkg, *_format2dtypes(None))
fwd_dtype, bwd_dtype = _format2dtypes(None)
primitive_out = fp8_dot(fp8_gemm_pkg, fwd_dtype, bwd_dtype)
ref_out = jnp.dot(a, b)
assert_allclose(primitive_out, ref_out)
assert_allclose(primitive_out, ref_out, dtype=fwd_dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
......@@ -154,7 +155,7 @@ class TestFP8Dot:
ref_out = ref_out.astype(jnp.float32)
primitive_out = primitive_out.astype(jnp.float32)
assert_allclose(primitive_out, ref_out)
assert_allclose(primitive_out, ref_out, dtype=compute_type[0])
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
def test_grad_bf16(self, m, n, k):
......@@ -162,6 +163,7 @@ class TestFP8Dot:
subkeys = jax.random.split(key, 2)
a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)
fwd_dtype, bwd_dtype = _format2dtypes(None)
def primitive_func(x, y):
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
......@@ -171,7 +173,7 @@ class TestFP8Dot:
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.mean(fp8_dot(fp8_gemm_pkg, *_format2dtypes(None)))
return jnp.mean(fp8_dot(fp8_gemm_pkg, fwd_dtype, bwd_dtype))
def ref_func(x, y):
return jnp.mean(jnp.dot(x, y))
......@@ -183,9 +185,9 @@ class TestFP8Dot:
primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b)
ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)
assert_allclose(primitive_out, ref_out)
assert_allclose(primitive_a_grad, ref_a_grad)
assert_allclose(primitive_b_grad, ref_b_grad, atol=1e-5)
assert_allclose(primitive_out, ref_out, dtype=fwd_dtype)
assert_allclose(primitive_a_grad, ref_a_grad, dtype=bwd_dtype)
assert_allclose(primitive_b_grad, ref_b_grad, dtype=bwd_dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
......@@ -227,15 +229,16 @@ class TestFP8Dot:
primitive_out, (primitive_a_grad,
primitive_b_grad) = value_n_grad_primitive_func(a, b, fp8_meta)
assert_allclose(primitive_out, ref_out)
assert_allclose(primitive_a_grad, ref_a_grad)
assert_allclose(primitive_b_grad, ref_b_grad)
assert_allclose(primitive_out, ref_out, dtype=compute_type[0])
assert_allclose(primitive_a_grad, ref_a_grad, dtype=compute_type[1])
assert_allclose(primitive_b_grad, ref_b_grad, dtype=compute_type[1])
def test_contracting_dims_bf16(self):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
a = jax.random.normal(subkeys[0], (32, 8, 16, 64), jnp.bfloat16)
b = jax.random.normal(subkeys[1], (16, 64, 128), jnp.bfloat16)
fwd_dtype, bwd_dtype = _format2dtypes(None)
def primitive_func(x, y):
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
......@@ -245,7 +248,7 @@ class TestFP8Dot:
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.sum(fp8_dot(fp8_gemm_pkg, *_format2dtypes(None), ((2, 3), (0, 1))))
return jnp.sum(fp8_dot(fp8_gemm_pkg, fwd_dtype, bwd_dtype, ((2, 3), (0, 1))))
def ref_func(x, y):
return jnp.sum(lax.dot_general(x, y, dimension_numbers=(((2, 3), (0, 1)), ((), ()))))
......@@ -255,9 +258,9 @@ class TestFP8Dot:
primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b)
ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)
assert_allclose(primitive_out, ref_out)
assert_allclose(primitive_a_grad, ref_a_grad)
assert_allclose(primitive_b_grad, ref_b_grad)
assert_allclose(primitive_out, ref_out, dtype=fwd_dtype)
assert_allclose(primitive_a_grad, ref_a_grad, dtype=bwd_dtype)
assert_allclose(primitive_b_grad, ref_b_grad, dtype=bwd_dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024),
......@@ -370,19 +373,19 @@ class TestFP8Dot:
primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
primitive_k2_grad) = value_n_grad_primitive_func(a, s, k1, k2, fp8_meta)
assert_allclose(primitive_out, ref_out, rtol=1e-2)
assert_allclose(primitive_out, ref_out, dtype=compute_type[0])
assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
jnp.asarray(ref_a_grad, np.float32),
rtol=1e-2)
dtype=compute_type[1])
assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
jnp.asarray(ref_k1_grad, np.float32),
rtol=1e-2)
dtype=compute_type[1])
assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32),
rtol=1e-2)
dtype=compute_type[1])
assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32),
rtol=1e-2)
dtype=compute_type[1])
@pytest.fixture(name="random_inputs")
......
......@@ -3,8 +3,9 @@
# See LICENSE for license information.
import functools
import math
import operator
from typing import Any, Callable, Tuple, Sequence, Union, Iterable, Optional
from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional
import jax
import jax.numpy as jnp
......@@ -15,6 +16,8 @@ from jax import lax, vmap
from jax import nn as jax_nn
from jax import random as jax_random
from transformer_engine.jax.fp8 import DType as TEDType
PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
......@@ -23,7 +26,6 @@ PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Preci
lax.Precision]]
Initializer = Callable[[PRNGKey, Shape, DType], Array]
def is_devices_enough(required):
return len(jax.devices()) >= required
......@@ -1008,15 +1010,98 @@ class DecoderLayer(nn.Module):
return z
def assert_allclose(actual,
desired,
rtol=1e-05,
atol=1e-08,
equal_nan=True,
err_msg='',
verbose=True):
def assert_allclose(
actual: Array,
desired: Array,
rtol: Optional[float] = None,
atol: Optional[float] = None,
dtype: Optional[Union[DType, TEDType, np.dtype, str]] = None,
**kwargs,
) -> None:
"""Check if two tensors are close.
Args:
actual: test tensor.
desired: reference tensor.
dtype: data type or data type name (default: inferred from
`actual`).
rtol: relative tolerance (default: based on `dtype`).
atol: absolute tolerance (default: based on `dtype`).
**kwargs: keyword arguments to pass to np.testing.assert_allclose.
"""
# Infer data type if needed
if dtype is None:
if isinstance(actual, float):
dtype = "float32"
else:
dtype = actual.dtype
# Determine tolerances
tols = dict()
if rtol is None or atol is None:
tols = dtype_tols(dtype)
if rtol is not None:
tols["rtol"] = rtol
if atol is not None:
tols["atol"] = atol
# Cast tensors to fp32
if not isinstance(actual, float):
actual = actual.astype(jnp.float32)
if not isinstance(desired, float):
desired = desired.astype(jnp.float32)
np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose)
# Check if tensors are close
np.testing.assert_allclose(actual, desired, **tols, **kwargs)
def dtype_tols(
dtype: Union[DType, TEDType, np.dtype],
reference_value: float = 1.0,
) -> Dict[str, float]:
"""Expected numerical tolerance for a data type.
Args:
dtype: data type.
reference_value: reference value (default: 1).
Returns:
Dictionary with "rtol" and "atol" as keys
"""
# Convert to JAX dtype if needed
if isinstance(dtype, TEDType):
dtype = {
TEDType.kByte: jnp.uint8,
TEDType.kInt32: jnp.int32,
TEDType.kInt64: jnp.int64,
TEDType.kFloat32: jnp.float32,
TEDType.kFloat16: jnp.float16,
TEDType.kBFloat16: jnp.bfloat16,
TEDType.kFloat8E4M3: jnp.float8_e4m3fn,
TEDType.kFloat8E5M2: jnp.float8_e5m2,
}[dtype]
elif isinstance(dtype, np.dtype):
dtype = jnp.dtype(dtype)
# Expect bit-wise accuracy for integer dtypes
if not jnp.issubdtype(dtype, jnp.floating):
return dict(rtol=0, atol=0)
# Estimate floating-point error
finfo = jnp.finfo(dtype)
eps_relaxed = math.pow(finfo.eps, 2/3)
with jax.default_device(jax.devices("cpu")[0]):
if isinstance(reference_value, (float, int)):
reference_value = jnp.array(reference_value, dtype=dtype)
else:
reference_value = reference_value.astype(dtype)
spacing_high = jnp.nextafter(reference_value, finfo.max) - reference_value
spacing_low = reference_value - jnp.nextafter(reference_value, finfo.min)
ulp = max(spacing_high.item(), spacing_low.item())
return dict(
rtol=eps_relaxed,
atol=max(ulp, eps_relaxed),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment