Unverified Commit 35f7d262 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Skip V100 encoder tests (#1262)



* Skip encoder tests on V100

* Fix mulitprocessing jax.distributed.init

* Remove XLA xla_gpu_deterministic_ops which causes segfault

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 7b18f235
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Shared functions for the encoder tests"""
from functools import lru_cache
from transformer_engine.transformer_engine_jax import get_device_compute_capability
@lru_cache
def is_bf16_supported():
"""Return if BF16 has hardware supported"""
gpu_arch = get_device_compute_capability(0)
return gpu_arch >= 80
...@@ -22,6 +22,8 @@ from jax.sharding import PartitionSpec, NamedSharding ...@@ -22,6 +22,8 @@ from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from common import is_bf16_supported
DEVICE_DP_AXIS = "data" DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = "model" DEVICE_TP_AXIS = "model"
NAMED_BROADCAST_AXIS = "my_broadcast_axis" NAMED_BROADCAST_AXIS = "my_broadcast_axis"
...@@ -434,6 +436,7 @@ class TestEncoder(unittest.TestCase): ...@@ -434,6 +436,7 @@ class TestEncoder(unittest.TestCase):
"""Run 3 epochs for testing""" """Run 3 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"]) cls.args = encoder_parser(["--epochs", "3"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
...@@ -446,6 +449,7 @@ class TestEncoder(unittest.TestCase): ...@@ -446,6 +449,7 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79 assert actual[0] < 0.45 and actual[1] > 0.79
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_sp(self): def test_te_bf16_sp(self):
"""Test Transformer Engine with BF16 + SP""" """Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True self.args.enable_sp = True
......
...@@ -22,6 +22,8 @@ from jax.sharding import PartitionSpec, NamedSharding ...@@ -22,6 +22,8 @@ from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from common import is_bf16_supported
DEVICE_DP_AXIS = "data" DEVICE_DP_AXIS = "data"
PARAMS_KEY = "params" PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes" PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
...@@ -402,6 +404,7 @@ class TestEncoder(unittest.TestCase): ...@@ -402,6 +404,7 @@ class TestEncoder(unittest.TestCase):
"""Run 3 epochs for testing""" """Run 3 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"]) cls.args = encoder_parser(["--epochs", "3"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
......
...@@ -24,6 +24,8 @@ from jax.sharding import PartitionSpec, NamedSharding ...@@ -24,6 +24,8 @@ from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from common import is_bf16_supported
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
DEVICE_DP_AXIS = "data" DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = "model" DEVICE_TP_AXIS = "model"
...@@ -552,8 +554,9 @@ def encoder_parser(args): ...@@ -552,8 +554,9 @@ def encoder_parser(args):
def query_gpu(q): def query_gpu(q):
"""Query GPU info on the system""" """Query GPU info on the system"""
gpu_has_fp8, reason = te.fp8.is_fp8_available() gpu_has_fp8, reason = te.fp8.is_fp8_available()
gpu_has_bf16 = is_bf16_supported()
num_gpu = len(jax.devices()) num_gpu = len(jax.devices())
q.put([num_gpu, gpu_has_fp8, reason]) q.put([num_gpu, gpu_has_fp8, gpu_has_bf16, reason])
def unittest_query_gpu(): def unittest_query_gpu():
...@@ -566,15 +569,15 @@ def unittest_query_gpu(): ...@@ -566,15 +569,15 @@ def unittest_query_gpu():
q = mp.Queue() q = mp.Queue()
p = mp.Process(target=query_gpu, args=(q,)) p = mp.Process(target=query_gpu, args=(q,))
p.start() p.start()
num_gpu, gpu_has_fp8, reason = q.get() num_gpu, gpu_has_fp8, gpu_has_bf16, reason = q.get()
p.join() p.join()
return num_gpu, gpu_has_fp8, reason return num_gpu, gpu_has_fp8, gpu_has_bf16, reason
class TestEncoder(unittest.TestCase): class TestEncoder(unittest.TestCase):
"""Encoder unittests""" """Encoder unittests"""
num_gpu, gpu_has_fp8, reason = unittest_query_gpu() num_gpu, gpu_has_fp8, gpu_has_bf16, reason = unittest_query_gpu()
def exec(self, use_fp8): def exec(self, use_fp8):
"""Run 3 epochs for testing""" """Run 3 epochs for testing"""
...@@ -598,6 +601,7 @@ class TestEncoder(unittest.TestCase): ...@@ -598,6 +601,7 @@ class TestEncoder(unittest.TestCase):
return results return results
@unittest.skipIf(not gpu_has_bf16, "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
results = self.exec(False) results = self.exec(False)
......
...@@ -19,6 +19,8 @@ from flax.training import train_state ...@@ -19,6 +19,8 @@ from flax.training import train_state
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from common import is_bf16_supported
PARAMS_KEY = "params" PARAMS_KEY = "params"
DROPOUT_KEY = "dropout" DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng" INPUT_KEY = "input_rng"
...@@ -321,6 +323,7 @@ class TestEncoder(unittest.TestCase): ...@@ -321,6 +323,7 @@ class TestEncoder(unittest.TestCase):
"""Run 4 epochs for testing""" """Run 4 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"]) cls.args = encoder_parser(["--epochs", "3"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
......
...@@ -18,7 +18,5 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt ...@@ -18,7 +18,5 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment