Unverified Commit 227961e6 authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[JAX] Distinguish the reasons why fp8 / mxfp8 is not supported in unit test (#1873)



Distinguish the reasons why fp8 is not supported and mxfp8 is not supported
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
parent 40a30a5f
...@@ -57,8 +57,8 @@ GEMM_CASES = [ ...@@ -57,8 +57,8 @@ GEMM_CASES = [
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2] FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
LN_CASES = [(256, 128), (128, 256)] LN_CASES = [(256, 128), (128, 256)]
DTYPES = [jnp.bfloat16, jnp.float32] DTYPES = [jnp.bfloat16, jnp.float32]
is_fp8_supported, reason = helper.is_fp8_available() is_fp8_supported, fp8_unsupported_reason = helper.is_fp8_available()
is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
supported_scaling_modes = [] supported_scaling_modes = []
""" Find supported scaling modes""" """ Find supported scaling modes"""
...@@ -209,7 +209,7 @@ class TestActivation: ...@@ -209,7 +209,7 @@ class TestActivation:
assert_allclose(prim_out, ref_out, dtype=x.dtype) assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype) assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
...@@ -240,7 +240,7 @@ class TestActivation: ...@@ -240,7 +240,7 @@ class TestActivation:
assert_allclose(prim_out, ref_out, dtype=output_type) assert_allclose(prim_out, ref_out, dtype=output_type)
assert_allclose(prim_grad, ref_grad, dtype=output_type) assert_allclose(prim_grad, ref_grad, dtype=output_type)
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
...@@ -270,7 +270,7 @@ class TestActivation: ...@@ -270,7 +270,7 @@ class TestActivation:
assert_bitwise_scaled_tensors(te_output, jax_output) assert_bitwise_scaled_tensors(te_output, jax_output)
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)]) @pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)])
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
...@@ -391,7 +391,7 @@ class TestNorm: ...@@ -391,7 +391,7 @@ class TestNorm:
n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer=None n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer=None
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
# No Norm FWD E5M2 in TE backend # No Norm FWD E5M2 in TE backend
@pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper( @pytest_parametrize_wrapper(
...@@ -506,7 +506,7 @@ class TestNorm: ...@@ -506,7 +506,7 @@ class TestNorm:
if norm_type == "layernorm": if norm_type == "layernorm":
assert_allclose(mu, ref_mu, dtype=inp_dtype) assert_allclose(mu, ref_mu, dtype=inp_dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
# No Norm FWD E5M2 in TE backend # No Norm FWD E5M2 in TE backend
@pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper( @pytest_parametrize_wrapper(
...@@ -542,7 +542,7 @@ class TestNorm: ...@@ -542,7 +542,7 @@ class TestNorm:
q_layout=q_layout, q_layout=q_layout,
) )
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest.mark.parametrize("out_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("out_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
def test_norm_forward_with_block_scaling_fp8( def test_norm_forward_with_block_scaling_fp8(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype
...@@ -591,7 +591,7 @@ QUANTIZATION_INPUT_DTYPE = { ...@@ -591,7 +591,7 @@ QUANTIZATION_INPUT_DTYPE = {
} }
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) @pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
...@@ -638,7 +638,7 @@ class TestQuantize: ...@@ -638,7 +638,7 @@ class TestQuantize:
assert_bitwise_scaled_tensors(te_output, jax_output) assert_bitwise_scaled_tensors(te_output, jax_output)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)]) @pytest_parametrize_wrapper("input_shape", [(8, 16, 32)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn])
...@@ -692,7 +692,7 @@ class TestGroupedQuantize: ...@@ -692,7 +692,7 @@ class TestGroupedQuantize:
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
class TestFusedQuantize: class TestFusedQuantize:
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) @pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
...@@ -793,7 +793,7 @@ class TestFusedQuantize: ...@@ -793,7 +793,7 @@ class TestFusedQuantize:
q_layout=QuantizeLayout.ROWWISE, q_layout=QuantizeLayout.ROWWISE,
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
...@@ -817,7 +817,7 @@ class TestFusedQuantize: ...@@ -817,7 +817,7 @@ class TestFusedQuantize:
q_layout=q_layout, q_layout=q_layout,
) )
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper( @pytest_parametrize_wrapper(
"input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)] "input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)]
...@@ -886,7 +886,7 @@ class TestDense: ...@@ -886,7 +886,7 @@ class TestDense:
assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
...@@ -928,7 +928,7 @@ class TestDense: ...@@ -928,7 +928,7 @@ class TestDense:
assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16) assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16) assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
...@@ -992,7 +992,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan ...@@ -992,7 +992,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
class TestFusedDense: class TestFusedDense:
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
...@@ -1077,7 +1077,7 @@ class TestFusedDense: ...@@ -1077,7 +1077,7 @@ class TestFusedDense:
if beta is not None: if beta is not None:
assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype) assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
...@@ -1284,7 +1284,7 @@ class TestGroupedDense: ...@@ -1284,7 +1284,7 @@ class TestGroupedDense:
prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims) prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims)
self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("layout", ["NN"]) @pytest_parametrize_wrapper("layout", ["NN"])
...@@ -1360,7 +1360,7 @@ class TestGroupedDense: ...@@ -1360,7 +1360,7 @@ class TestGroupedDense:
assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype) assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype)
assert_allclose(prim_dbias, ref_dbias, dtype=dtype) assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fwd_bwd_dtype", "fwd_bwd_dtype",
[(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)], [(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment