Unverified Commit fe31af80 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Fix failing L2 JAX unit tests (#1735)



* Fix L2 test_custom_call_compute.py L2 tests
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix test_helper.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Address comments
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 5bee81e2
......@@ -20,6 +20,7 @@ FAILED_CASES=""
pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
......@@ -30,10 +31,9 @@ python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/py
NVTE_JAX_UNITTEST_LEVEL="L2" NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_custom_call_compute.xml $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py"
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements"
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
......
......@@ -78,12 +78,17 @@ def is_shape_supported_by_mxfp8(input_shape):
def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
assert a.scaling_mode == b.scaling_mode
assert a.scale_inv.dtype == b.scale_inv.dtype
if a.scale_inv.dtype == jnp.float8_e8m0fnu:
if a.scaling_mode.is_tensor_scaling():
# Assert in dq_dtype as some unfused codepaths have an intermediate cast
# to an input dtype which reduces precision compared to everything in fp32
assert_allclose(a.scale_inv, b.scale_inv, dtype=a.dq_dtype)
elif a.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
# Compare MXFP8 scales as uint8
assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8))
else:
assert_allclose(a.scale_inv, b.scale_inv)
raise ValueError(f"Unsupported scaling mode {a.scaling_mode}")
assert_allclose(a.data, b.data)
elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x):
......@@ -524,17 +529,24 @@ QUANTIZE_OUTPUT_DTYPES = {
"L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
ALL_QUANTIZE_TEST_SHAPES = [
(32, 64),
(2, 64, 32),
ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [
((32, 64), -1),
((2, 64, 32), -1),
((2, 64, 32), -2),
((32, 256, 128), -1),
((32, 256, 128), -2),
((64, 32, 32, 256), -1),
((64, 32, 32, 256), -2),
((64, 32, 32, 256), -3),
]
QUANTIZE_TEST_SHAPES = {
QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = {
"L0": [
(32, 256, 128),
(64, 32, 32, 256),
((32, 64), -1),
((2, 64, 32), -1),
((2, 64, 32), -2),
],
"L2": ALL_QUANTIZE_TEST_SHAPES,
"L2": ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES,
}
QUANTIZATION_INPUT_DTYPE = {
......@@ -546,9 +558,8 @@ QUANTIZATION_INPUT_DTYPE = {
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("input_shape", ALL_QUANTIZE_TEST_SHAPES)
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("flatten_axis", [-1, -2])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
)
......@@ -601,12 +612,11 @@ class TestFusedQuantize:
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("input_shape", QUANTIZE_TEST_SHAPES)
@pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
@pytest_parametrize_wrapper("flatten_axis", [-1, -2])
def test_quantize_dbias(
self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis
):
......@@ -615,6 +625,12 @@ class TestFusedQuantize:
):
pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")
if (flatten_axis < 0 and flatten_axis + len(input_shape) <= 0) or flatten_axis <= 0:
pytest.skip(
f"Flatten axis {flatten_axis} is not supported for input shape {input_shape}. There"
" must be at least one axis on either side of the flatten_axis split."
)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
......
......@@ -48,7 +48,7 @@ class TestHelper(unittest.TestCase):
class TestFP8Functions(unittest.TestCase):
def _check_defult_state(self):
def _check_default_state(self):
self.assertFalse(QuantizeConfig.is_fp8_enabled())
def _compare_delay_scaling(self, ref, test):
......@@ -70,82 +70,79 @@ class TestFP8Functions(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_delayed_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
self._check_default_state()
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
self.assertFalse(QuantizeConfig.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), DelayedScaling())
self._check_default_state()
self._check_defult_state()
self._check_default_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_defult_state()
self._check_default_state()
ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_defult_state()
self._check_default_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
self._check_default_state()
with fp8_autocast(enabled=False, fp8_recipe=Float8CurrentScaling()):
self.assertFalse(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(Float8CurrentScaling())
self._check_default_state()
self._check_defult_state()
self._check_default_state()
cs = Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=cs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_defult_state()
self._check_default_state()
cs = Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=cs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_defult_state()
self._check_default_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
self._check_default_state()
with fp8_autocast(enabled=False, fp8_recipe=MXFP8BlockScaling()):
self.assertFalse(QuantizeConfig.is_fp8_enabled())
self._compare_mxfp8_scaling(MXFP8BlockScaling())
self._check_default_state()
self._check_defult_state()
self._check_default_state()
bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=bs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_defult_state()
self._check_default_state()
bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=bs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_defult_state()
self._check_default_state()
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_with_sharding_resource(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
self._check_default_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
......@@ -165,4 +162,4 @@ class TestFP8Functions(unittest.TestCase):
self._compare_delay_scaling(get_delayed_scaling(), ds)
self.assertEqual(sr, global_mesh_resource())
self._check_defult_state()
self._check_default_state()
......@@ -97,16 +97,16 @@ def combine_biases(*masks: Optional[Array]):
return mask
def parameterize_by_test_level(param_dict: dict, id_prefix: str = ""):
def get_parameters_for_test_level(param_dict: dict):
"""
Takes an input dictionary of parameters keyed by test type "L0", etc.
Returns a list of pytest parameters to be used in a parameterized test for the current test type
Returns the parameters for the test level specified in the environment variable
"""
DEFAULT_TEST_LEVEL = "L0"
test_level = os.environ.get("NVTE_JAX_UNITTEST_LEVEL", DEFAULT_TEST_LEVEL)
if test_level not in param_dict:
raise ValueError("Unsupported test level")
return values_to_named_params(param_dict[test_level], id_prefix)
return param_dict[test_level]
def value_to_test_name_str(value):
......@@ -139,14 +139,18 @@ def pytest_parametrize_wrapper(param_name, param_values):
A wrapper for pytest.mark.parametrize to allow for automatic
naming of tests based on the parameter values.
"""
id_prefix = param_name
if isinstance(param_values, dict):
param_values = parameterize_by_test_level(param_values, id_prefix=param_name)
elif "," not in param_name:
param_values = values_to_named_params(param_values, id_prefix=id_prefix)
# If the values are split into a dictionary of test-levels, e.g. "L0", etc.,
# unwrap the selected level before proceeding.
param_values = get_parameters_for_test_level(param_values)
if "," not in param_name:
# Multi-parameterize annotations are not supported in this wrapper
# and are just a passthrough to default pytest.mark.parametrize.
# E.g. @pytest_parametrize_wrapper("a,b", ((a_value1, b_value1), (a_value2, b_value2)))
# will be passed through to pytest.mark.parametrize as-is without pytest.param ids.
param_values = values_to_named_params(param_values, id_prefix=param_name)
# Currently comma separated parameters in one parametrize call aren't supported for automatic naming
# and will just be passed through with default pytest names
def decorator(func):
return pytest.mark.parametrize(param_name, param_values)(func)
......
......@@ -230,7 +230,7 @@ class CurrentScaleQuantizer(Quantizer):
compute_dtype = jnp.float32
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
amax = jnp.max(jnp.abs(x)).reshape((1,))
amax = jnp.max(jnp.abs(x)).reshape((1,)).astype(compute_dtype)
fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32)
scale = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
scaled_x = x.astype(compute_dtype) * scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment