Unverified Commit ed75c2b0 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Tighten Encoder Test tolerances (#1955)



tighten encoder test tols
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 07afda98
...@@ -474,7 +474,7 @@ class TestEncoder(unittest.TestCase): ...@@ -474,7 +474,7 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.43 and actual[1] > 0.80 assert actual[0] < 0.39 and actual[1] > 0.83
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self): def test_te_delayed_scaling_fp8(self):
...@@ -482,7 +482,7 @@ class TestEncoder(unittest.TestCase): ...@@ -482,7 +482,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.43 and actual[1] > 0.80 assert actual[0] < 0.39 and actual[1] > 0.83
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self): def test_te_mxfp8(self):
...@@ -490,14 +490,14 @@ class TestEncoder(unittest.TestCase): ...@@ -490,14 +490,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.43 and actual[1] > 0.80 assert actual[0] < 0.39 and actual[1] > 0.83
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_with_sp(self): def test_te_bf16_with_sp(self):
"""Test Transformer Engine with BF16 + SP""" """Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True self.args.enable_sp = True
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.43 and actual[1] > 0.80 assert actual[0] < 0.39 and actual[1] > 0.83
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp(self): def test_te_delayed_scaling_fp8_with_sp(self):
...@@ -506,7 +506,7 @@ class TestEncoder(unittest.TestCase): ...@@ -506,7 +506,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.43 and actual[1] > 0.80 assert actual[0] < 0.39 and actual[1] > 0.83
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self): def test_te_mxfp8_with_sp(self):
...@@ -515,14 +515,14 @@ class TestEncoder(unittest.TestCase): ...@@ -515,14 +515,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.43 and actual[1] > 0.80 assert actual[0] < 0.39 and actual[1] > 0.83
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self): def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
self.args.enable_shardy = True self.args.enable_shardy = True
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.43 and actual[1] > 0.80 assert actual[0] < 0.39 and actual[1] > 0.83
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self): def test_te_delayed_scaling_fp8_shardy(self):
...@@ -531,7 +531,7 @@ class TestEncoder(unittest.TestCase): ...@@ -531,7 +531,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.43 and actual[1] > 0.80 assert actual[0] < 0.39 and actual[1] > 0.83
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp_shardy(self): def test_te_delayed_scaling_fp8_with_sp_shardy(self):
...@@ -541,7 +541,7 @@ class TestEncoder(unittest.TestCase): ...@@ -541,7 +541,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.43 and actual[1] > 0.80 assert actual[0] < 0.39 and actual[1] > 0.83
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf( @unittest.skipIf(
...@@ -553,7 +553,7 @@ class TestEncoder(unittest.TestCase): ...@@ -553,7 +553,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.43 and actual[1] > 0.80 assert actual[0] < 0.39 and actual[1] > 0.83
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf( @unittest.skipIf(
...@@ -566,7 +566,7 @@ class TestEncoder(unittest.TestCase): ...@@ -566,7 +566,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.43 and actual[1] > 0.80 assert actual[0] < 0.39 and actual[1] > 0.83
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -435,13 +435,13 @@ class TestEncoder(unittest.TestCase): ...@@ -435,13 +435,13 @@ class TestEncoder(unittest.TestCase):
def setUp(self): def setUp(self):
"""Run 5 epochs for testing""" """Run 5 epochs for testing"""
self.args = encoder_parser(["--epochs", "6"]) self.args = encoder_parser(["--epochs", "5"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.75 assert actual[0] < 0.52 and actual[1] > 0.74
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self): def test_te_delayed_scaling_fp8(self):
...@@ -449,7 +449,7 @@ class TestEncoder(unittest.TestCase): ...@@ -449,7 +449,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.75 assert actual[0] < 0.52 and actual[1] > 0.74
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self): def test_te_current_scaling_fp8(self):
...@@ -457,7 +457,7 @@ class TestEncoder(unittest.TestCase): ...@@ -457,7 +457,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling" self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.75 assert actual[0] < 0.52 and actual[1] > 0.74
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self): def test_te_mxfp8(self):
...@@ -465,14 +465,14 @@ class TestEncoder(unittest.TestCase): ...@@ -465,14 +465,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.75 assert actual[0] < 0.52 and actual[1] > 0.74
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self): def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
self.args.enable_shardy = True self.args.enable_shardy = True
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.75 assert actual[0] < 0.52 and actual[1] > 0.74
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self): def test_te_delayed_scaling_fp8_shardy(self):
...@@ -481,7 +481,7 @@ class TestEncoder(unittest.TestCase): ...@@ -481,7 +481,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.75 assert actual[0] < 0.52 and actual[1] > 0.74
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8_shardy(self): def test_te_current_scaling_fp8_shardy(self):
...@@ -490,7 +490,7 @@ class TestEncoder(unittest.TestCase): ...@@ -490,7 +490,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling" self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.75 assert actual[0] < 0.52 and actual[1] > 0.74
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf( @unittest.skipIf(
...@@ -502,7 +502,7 @@ class TestEncoder(unittest.TestCase): ...@@ -502,7 +502,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.75 assert actual[0] < 0.52 and actual[1] > 0.74
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment