Unverified Commit 992ba01d authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Fix current scaling test_helper.py and enable test_helper.py in L0 (#1990)



Fix current scaling test_helper.py and enable test_helper.py in L0
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 4296b7d0
...@@ -25,7 +25,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" ...@@ -25,7 +25,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
: ${XML_LOG_DIR:=/logs} : ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
......
...@@ -58,7 +58,6 @@ class TestFP8Functions(unittest.TestCase): ...@@ -58,7 +58,6 @@ class TestFP8Functions(unittest.TestCase):
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo) self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
def _compare_current_scaling(self, test): def _compare_current_scaling(self, test):
self.assertEqual(QuantizeConfig.MARGIN, test.margin)
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format) self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING) self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING)
...@@ -91,7 +90,7 @@ class TestFP8Functions(unittest.TestCase): ...@@ -91,7 +90,7 @@ class TestFP8Functions(unittest.TestCase):
self._check_default_state() self._check_default_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_current_scaling(self): def test_fp8_autocast_current_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state() self._check_default_state()
...@@ -101,14 +100,14 @@ class TestFP8Functions(unittest.TestCase): ...@@ -101,14 +100,14 @@ class TestFP8Functions(unittest.TestCase):
self._check_default_state() self._check_default_state()
cs = Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3) cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=cs): with fp8_autocast(enabled=True, fp8_recipe=cs):
self.assertTrue(QuantizeConfig.is_fp8_enabled()) self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(cs) self._compare_current_scaling(cs)
self._check_default_state() self._check_default_state()
cs = Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID) cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=cs): with fp8_autocast(enabled=True, fp8_recipe=cs):
self.assertTrue(QuantizeConfig.is_fp8_enabled()) self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(cs) self._compare_current_scaling(cs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment