Unverified Commit 8ec01e5e authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[JAX] Use FP8 tolerances in FP8 tests (#501)



* Use FP8 tolerances in JAX FP8 tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Programmatically compute expected floating point error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Loosen tolerance for MNIST test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 325bf911
...@@ -280,7 +280,7 @@ class TestMNIST(unittest.TestCase): ...@@ -280,7 +280,7 @@ class TestMNIST(unittest.TestCase):
"""Check If loss and accuracy match target""" """Check If loss and accuracy match target"""
desired_traing_loss = 0.055 desired_traing_loss = 0.055
desired_traing_accuracy = 0.98 desired_traing_accuracy = 0.98
desired_test_loss = 0.035 desired_test_loss = 0.04
desired_test_accuracy = 0.098 desired_test_accuracy = 0.098
assert actual[0] < desired_traing_loss assert actual[0] < desired_traing_loss
assert actual[1] > desired_traing_accuracy assert actual[1] > desired_traing_accuracy
......
...@@ -48,7 +48,7 @@ class TestFP8Dot: ...@@ -48,7 +48,7 @@ class TestFP8Dot:
scale_inv = (1 / scale).reshape(1) scale_inv = (1 / scale).reshape(1)
y, new_amax = quantize(x, amax, scale, scale_inv, out_dtype=DType.kFloat8E4M3) y, new_amax = quantize(x, amax, scale, scale_inv, out_dtype=DType.kFloat8E4M3)
assert_allclose(new_amax, 3.0) assert_allclose(new_amax, 3.0, rtol=0, atol=0)
no_use = jnp.zeros(1, jnp.float32) no_use = jnp.zeros(1, jnp.float32)
z = dequantize(y, z = dequantize(y,
...@@ -57,7 +57,7 @@ class TestFP8Dot: ...@@ -57,7 +57,7 @@ class TestFP8Dot:
scale_inv, scale_inv,
fp8_dtype=DType.kFloat8E4M3, fp8_dtype=DType.kFloat8E4M3,
out_dtype=DType.kFloat32) out_dtype=DType.kFloat32)
assert_allclose(z, x, rtol=5e-2, atol=5e-2) assert_allclose(z, x, dtype=DType.kFloat8E4M3)
def test_compile_bf16(self): def test_compile_bf16(self):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
...@@ -117,10 +117,11 @@ class TestFP8Dot: ...@@ -117,10 +117,11 @@ class TestFP8Dot:
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, a, [b], fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_gemm_pkg = FP8GemmPackage(1, a, [b], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv) fp8_metas_scale_inv)
primitive_out = fp8_dot(fp8_gemm_pkg, *_format2dtypes(None)) fwd_dtype, bwd_dtype = _format2dtypes(None)
primitive_out = fp8_dot(fp8_gemm_pkg, fwd_dtype, bwd_dtype)
ref_out = jnp.dot(a, b) ref_out = jnp.dot(a, b)
assert_allclose(primitive_out, ref_out) assert_allclose(primitive_out, ref_out, dtype=fwd_dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES) @pytest.mark.parametrize('m,n,k', GEMM_CASES)
...@@ -154,7 +155,7 @@ class TestFP8Dot: ...@@ -154,7 +155,7 @@ class TestFP8Dot:
ref_out = ref_out.astype(jnp.float32) ref_out = ref_out.astype(jnp.float32)
primitive_out = primitive_out.astype(jnp.float32) primitive_out = primitive_out.astype(jnp.float32)
assert_allclose(primitive_out, ref_out) assert_allclose(primitive_out, ref_out, dtype=compute_type[0])
@pytest.mark.parametrize('m,n,k', GEMM_CASES) @pytest.mark.parametrize('m,n,k', GEMM_CASES)
def test_grad_bf16(self, m, n, k): def test_grad_bf16(self, m, n, k):
...@@ -162,6 +163,7 @@ class TestFP8Dot: ...@@ -162,6 +163,7 @@ class TestFP8Dot:
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16) a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16) b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)
fwd_dtype, bwd_dtype = _format2dtypes(None)
def primitive_func(x, y): def primitive_func(x, y):
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
...@@ -171,7 +173,7 @@ class TestFP8Dot: ...@@ -171,7 +173,7 @@ class TestFP8Dot:
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv) fp8_metas_scale_inv)
return jnp.mean(fp8_dot(fp8_gemm_pkg, *_format2dtypes(None))) return jnp.mean(fp8_dot(fp8_gemm_pkg, fwd_dtype, bwd_dtype))
def ref_func(x, y): def ref_func(x, y):
return jnp.mean(jnp.dot(x, y)) return jnp.mean(jnp.dot(x, y))
...@@ -183,9 +185,9 @@ class TestFP8Dot: ...@@ -183,9 +185,9 @@ class TestFP8Dot:
primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b) primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b)
ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b) ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)
assert_allclose(primitive_out, ref_out) assert_allclose(primitive_out, ref_out, dtype=fwd_dtype)
assert_allclose(primitive_a_grad, ref_a_grad) assert_allclose(primitive_a_grad, ref_a_grad, dtype=bwd_dtype)
assert_allclose(primitive_b_grad, ref_b_grad, atol=1e-5) assert_allclose(primitive_b_grad, ref_b_grad, dtype=bwd_dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES) @pytest.mark.parametrize('m,n,k', GEMM_CASES)
...@@ -227,15 +229,16 @@ class TestFP8Dot: ...@@ -227,15 +229,16 @@ class TestFP8Dot:
primitive_out, (primitive_a_grad, primitive_out, (primitive_a_grad,
primitive_b_grad) = value_n_grad_primitive_func(a, b, fp8_meta) primitive_b_grad) = value_n_grad_primitive_func(a, b, fp8_meta)
assert_allclose(primitive_out, ref_out) assert_allclose(primitive_out, ref_out, dtype=compute_type[0])
assert_allclose(primitive_a_grad, ref_a_grad) assert_allclose(primitive_a_grad, ref_a_grad, dtype=compute_type[1])
assert_allclose(primitive_b_grad, ref_b_grad) assert_allclose(primitive_b_grad, ref_b_grad, dtype=compute_type[1])
def test_contracting_dims_bf16(self): def test_contracting_dims_bf16(self):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
a = jax.random.normal(subkeys[0], (32, 8, 16, 64), jnp.bfloat16) a = jax.random.normal(subkeys[0], (32, 8, 16, 64), jnp.bfloat16)
b = jax.random.normal(subkeys[1], (16, 64, 128), jnp.bfloat16) b = jax.random.normal(subkeys[1], (16, 64, 128), jnp.bfloat16)
fwd_dtype, bwd_dtype = _format2dtypes(None)
def primitive_func(x, y): def primitive_func(x, y):
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
...@@ -245,7 +248,7 @@ class TestFP8Dot: ...@@ -245,7 +248,7 @@ class TestFP8Dot:
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv) fp8_metas_scale_inv)
return jnp.sum(fp8_dot(fp8_gemm_pkg, *_format2dtypes(None), ((2, 3), (0, 1)))) return jnp.sum(fp8_dot(fp8_gemm_pkg, fwd_dtype, bwd_dtype, ((2, 3), (0, 1))))
def ref_func(x, y): def ref_func(x, y):
return jnp.sum(lax.dot_general(x, y, dimension_numbers=(((2, 3), (0, 1)), ((), ())))) return jnp.sum(lax.dot_general(x, y, dimension_numbers=(((2, 3), (0, 1)), ((), ()))))
...@@ -255,9 +258,9 @@ class TestFP8Dot: ...@@ -255,9 +258,9 @@ class TestFP8Dot:
primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b) primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b)
ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b) ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)
assert_allclose(primitive_out, ref_out) assert_allclose(primitive_out, ref_out, dtype=fwd_dtype)
assert_allclose(primitive_a_grad, ref_a_grad) assert_allclose(primitive_a_grad, ref_a_grad, dtype=bwd_dtype)
assert_allclose(primitive_b_grad, ref_b_grad) assert_allclose(primitive_b_grad, ref_b_grad, dtype=bwd_dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024), @pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024),
...@@ -370,19 +373,19 @@ class TestFP8Dot: ...@@ -370,19 +373,19 @@ class TestFP8Dot:
primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad, primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
primitive_k2_grad) = value_n_grad_primitive_func(a, s, k1, k2, fp8_meta) primitive_k2_grad) = value_n_grad_primitive_func(a, s, k1, k2, fp8_meta)
assert_allclose(primitive_out, ref_out, rtol=1e-2) assert_allclose(primitive_out, ref_out, dtype=compute_type[0])
assert_allclose(jnp.asarray(primitive_a_grad, np.float32), assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
jnp.asarray(ref_a_grad, np.float32), jnp.asarray(ref_a_grad, np.float32),
rtol=1e-2) dtype=compute_type[1])
assert_allclose(jnp.asarray(primitive_k1_grad, np.float32), assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
jnp.asarray(ref_k1_grad, np.float32), jnp.asarray(ref_k1_grad, np.float32),
rtol=1e-2) dtype=compute_type[1])
assert_allclose(jnp.asarray(primitive_k2_grad, np.float32), assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32), jnp.asarray(ref_k2_grad, np.float32),
rtol=1e-2) dtype=compute_type[1])
assert_allclose(jnp.asarray(primitive_s_grad, np.float32), assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32), jnp.asarray(ref_s_grad, np.float32),
rtol=1e-2) dtype=compute_type[1])
@pytest.fixture(name="random_inputs") @pytest.fixture(name="random_inputs")
......
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
# See LICENSE for license information. # See LICENSE for license information.
import functools import functools
import math
import operator import operator
from typing import Any, Callable, Tuple, Sequence, Union, Iterable, Optional from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -15,6 +16,8 @@ from jax import lax, vmap ...@@ -15,6 +16,8 @@ from jax import lax, vmap
from jax import nn as jax_nn from jax import nn as jax_nn
from jax import random as jax_random from jax import random as jax_random
from transformer_engine.jax.fp8 import DType as TEDType
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
DType = jnp.dtype DType = jnp.dtype
...@@ -23,7 +26,6 @@ PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Preci ...@@ -23,7 +26,6 @@ PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Preci
lax.Precision]] lax.Precision]]
Initializer = Callable[[PRNGKey, Shape, DType], Array] Initializer = Callable[[PRNGKey, Shape, DType], Array]
def is_devices_enough(required): def is_devices_enough(required):
return len(jax.devices()) >= required return len(jax.devices()) >= required
...@@ -1008,15 +1010,98 @@ class DecoderLayer(nn.Module): ...@@ -1008,15 +1010,98 @@ class DecoderLayer(nn.Module):
return z return z
def assert_allclose(actual, def assert_allclose(
desired, actual: Array,
rtol=1e-05, desired: Array,
atol=1e-08, rtol: Optional[float] = None,
equal_nan=True, atol: Optional[float] = None,
err_msg='', dtype: Optional[Union[DType, TEDType, np.dtype, str]] = None,
verbose=True): **kwargs,
) -> None:
"""Check if two tensors are close.
Args:
actual: test tensor.
desired: reference tensor.
dtype: data type or data type name (default: inferred from
`actual`).
rtol: relative tolerance (default: based on `dtype`).
atol: absolute tolerance (default: based on `dtype`).
**kwargs: keyword arguments to pass to np.testing.assert_allclose.
"""
# Infer data type if needed
if dtype is None:
if isinstance(actual, float):
dtype = "float32"
else:
dtype = actual.dtype
# Determine tolerances
tols = dict()
if rtol is None or atol is None:
tols = dtype_tols(dtype)
if rtol is not None:
tols["rtol"] = rtol
if atol is not None:
tols["atol"] = atol
# Cast tensors to fp32
if not isinstance(actual, float): if not isinstance(actual, float):
actual = actual.astype(jnp.float32) actual = actual.astype(jnp.float32)
if not isinstance(desired, float): if not isinstance(desired, float):
desired = desired.astype(jnp.float32) desired = desired.astype(jnp.float32)
np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose)
# Check if tensors are close
np.testing.assert_allclose(actual, desired, **tols, **kwargs)
def dtype_tols(
dtype: Union[DType, TEDType, np.dtype],
reference_value: float = 1.0,
) -> Dict[str, float]:
"""Expected numerical tolerance for a data type.
Args:
dtype: data type.
reference_value: reference value (default: 1).
Returns:
Dictionary with "rtol" and "atol" as keys
"""
# Convert to JAX dtype if needed
if isinstance(dtype, TEDType):
dtype = {
TEDType.kByte: jnp.uint8,
TEDType.kInt32: jnp.int32,
TEDType.kInt64: jnp.int64,
TEDType.kFloat32: jnp.float32,
TEDType.kFloat16: jnp.float16,
TEDType.kBFloat16: jnp.bfloat16,
TEDType.kFloat8E4M3: jnp.float8_e4m3fn,
TEDType.kFloat8E5M2: jnp.float8_e5m2,
}[dtype]
elif isinstance(dtype, np.dtype):
dtype = jnp.dtype(dtype)
# Expect bit-wise accuracy for integer dtypes
if not jnp.issubdtype(dtype, jnp.floating):
return dict(rtol=0, atol=0)
# Estimate floating-point error
finfo = jnp.finfo(dtype)
eps_relaxed = math.pow(finfo.eps, 2/3)
with jax.default_device(jax.devices("cpu")[0]):
if isinstance(reference_value, (float, int)):
reference_value = jnp.array(reference_value, dtype=dtype)
else:
reference_value = reference_value.astype(dtype)
spacing_high = jnp.nextafter(reference_value, finfo.max) - reference_value
spacing_low = reference_value - jnp.nextafter(reference_value, finfo.min)
ulp = max(spacing_high.item(), spacing_low.item())
return dict(
rtol=eps_relaxed,
atol=max(ulp, eps_relaxed),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment