"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "207594be81b8e5a8589c8b11c3b236924555d806"
Unverified Commit 9f8619c6 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Flax testing should not run the full torch test suite (#10725)

* make flax tests pytorch independent

* fix typo

* finish

* improve circle ci

* fix return tensors

* correct flax test

* re-add sentencepiece

* last tokenizer fixes

* finish maybe now
parent 87d685b8
...@@ -91,6 +91,34 @@ jobs: ...@@ -91,6 +91,34 @@ jobs:
- store_artifacts: - store_artifacts:
path: ~/transformers/reports path: ~/transformers/reports
run_tests_torch_and_flax:
working_directory: ~/transformers
docker:
- image: circleci/python:3.6
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
parallelism: 1
steps:
- checkout
- restore_cache:
keys:
- v0.4-torch_and_flax-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: pip install .[sklearn,flax,torch,testing,sentencepiece,speech]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- save_cache:
key: v0.4-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: RUN_PT_FLAX_CROSS_TESTS=1 python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_torch_and_flax ./tests/ -m is_pt_flax_cross_test --durations=0 | tee tests_output.txt
- store_artifacts:
path: ~/transformers/tests_output.txt
- store_artifacts:
path: ~/transformers/reports
run_tests_torch: run_tests_torch:
working_directory: ~/transformers working_directory: ~/transformers
docker: docker:
...@@ -159,9 +187,8 @@ jobs: ...@@ -159,9 +187,8 @@ jobs:
keys: keys:
- v0.4-flax-{{ checksum "setup.py" }} - v0.4-flax-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip - run: pip install --upgrade pip
- run: sudo pip install .[flax,sklearn,torch,testing,sentencepiece,speech] - run: sudo pip install .[flax,testing,sentencepiece]
- save_cache: - save_cache:
key: v0.4-flax-{{ checksum "setup.py" }} key: v0.4-flax-{{ checksum "setup.py" }}
paths: paths:
...@@ -418,6 +445,7 @@ workflows: ...@@ -418,6 +445,7 @@ workflows:
- run_examples_torch - run_examples_torch
- run_tests_custom_tokenizers - run_tests_custom_tokenizers
- run_tests_torch_and_tf - run_tests_torch_and_tf
- run_tests_torch_and_flax
- run_tests_torch - run_tests_torch
- run_tests_tf - run_tests_tf
- run_tests_flax - run_tests_flax
......
...@@ -97,7 +97,7 @@ _deps = [ ...@@ -97,7 +97,7 @@ _deps = [
"fastapi", "fastapi",
"filelock", "filelock",
"flake8>=3.8.3", "flake8>=3.8.3",
"flax>=0.2.2", "flax>=0.3.2",
"fugashi>=1.0", "fugashi>=1.0",
"importlib_metadata", "importlib_metadata",
"ipadic>=1.0.0,<2.0", "ipadic>=1.0.0,<2.0",
......
...@@ -10,7 +10,7 @@ deps = { ...@@ -10,7 +10,7 @@ deps = {
"fastapi": "fastapi", "fastapi": "fastapi",
"filelock": "filelock", "filelock": "filelock",
"flake8": "flake8>=3.8.3", "flake8": "flake8>=3.8.3",
"flax": "flax>=0.2.2", "flax": "flax>=0.3.2",
"fugashi": "fugashi>=1.0", "fugashi": "fugashi>=1.0",
"importlib_metadata": "importlib_metadata", "importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0", "ipadic": "ipadic>=1.0.0,<2.0",
......
...@@ -80,6 +80,7 @@ def parse_int_from_env(key, default=None): ...@@ -80,6 +80,7 @@ def parse_int_from_env(key, default=None):
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
_run_pt_tf_cross_tests = parse_flag_from_env("RUN_PT_TF_CROSS_TESTS", default=False) _run_pt_tf_cross_tests = parse_flag_from_env("RUN_PT_TF_CROSS_TESTS", default=False)
_run_pt_flax_cross_tests = parse_flag_from_env("RUN_PT_FLAX_CROSS_TESTS", default=False)
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False) _run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=False) _run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=False)
_run_git_lfs_tests = parse_flag_from_env("RUN_GIT_LFS_TESTS", default=False) _run_git_lfs_tests = parse_flag_from_env("RUN_GIT_LFS_TESTS", default=False)
...@@ -105,6 +106,25 @@ def is_pt_tf_cross_test(test_case): ...@@ -105,6 +106,25 @@ def is_pt_tf_cross_test(test_case):
return pytest.mark.is_pt_tf_cross_test()(test_case) return pytest.mark.is_pt_tf_cross_test()(test_case)
def is_pt_flax_cross_test(test_case):
"""
Decorator marking a test as a test that control interactions between PyTorch and Flax
PT+FLAX tests are skipped by default and we can run only them by setting RUN_PT_FLAX_CROSS_TESTS environment
variable to a truthy value and selecting the is_pt_flax_cross_test pytest mark.
"""
if not _run_pt_flax_cross_tests or not is_torch_available() or not is_flax_available():
return unittest.skip("test is PT+FLAX test")(test_case)
else:
try:
import pytest # We don't need a hard dependency on pytest in the main library
except ImportError:
return test_case
else:
return pytest.mark.is_pt_flax_cross_test()(test_case)
def is_pipeline_test(test_case): def is_pipeline_test(test_case):
""" """
Decorator marking a test as a pipeline test. Decorator marking a test as a pipeline test.
......
...@@ -35,6 +35,9 @@ def pytest_configure(config): ...@@ -35,6 +35,9 @@ def pytest_configure(config):
config.addinivalue_line( config.addinivalue_line(
"markers", "is_pt_tf_cross_test: mark test to run only when PT and TF interactions are tested" "markers", "is_pt_tf_cross_test: mark test to run only when PT and TF interactions are tested"
) )
config.addinivalue_line(
"markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested"
)
def pytest_addoption(parser): def pytest_addoption(parser):
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import transformers import transformers
from transformers import is_flax_available, is_torch_available from transformers import is_flax_available, is_torch_available
from transformers.testing_utils import require_flax, require_torch from transformers.testing_utils import is_pt_flax_cross_test, require_flax
if is_flax_available(): if is_flax_available():
...@@ -60,7 +60,6 @@ def random_attention_mask(shape, rng=None): ...@@ -60,7 +60,6 @@ def random_attention_mask(shape, rng=None):
return attn_mask return attn_mask
@require_flax
class FlaxModelTesterMixin: class FlaxModelTesterMixin:
model_tester = None model_tester = None
all_model_classes = () all_model_classes = ()
...@@ -69,7 +68,7 @@ class FlaxModelTesterMixin: ...@@ -69,7 +68,7 @@ class FlaxModelTesterMixin:
diff = np.abs((a - b)).max() diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
@require_torch @is_pt_flax_cross_test
def test_equivalence_flax_pytorch(self): def test_equivalence_flax_pytorch(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -104,6 +103,7 @@ class FlaxModelTesterMixin: ...@@ -104,6 +103,7 @@ class FlaxModelTesterMixin:
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3) self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
@require_flax
def test_from_pretrained_save_pretrained(self): def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -121,6 +121,7 @@ class FlaxModelTesterMixin: ...@@ -121,6 +121,7 @@ class FlaxModelTesterMixin:
for output_loaded, output in zip(outputs_loaded, outputs): for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 5e-3) self.assert_almost_equals(output_loaded, output, 5e-3)
@require_flax
def test_jit_compilation(self): def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -143,6 +144,7 @@ class FlaxModelTesterMixin: ...@@ -143,6 +144,7 @@ class FlaxModelTesterMixin:
for jitted_output, output in zip(jitted_outputs, outputs): for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape) self.assertEqual(jitted_output.shape, output.shape)
@require_flax
def test_naming_convention(self): def test_naming_convention(self):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model_class_name = model_class.__name__ model_class_name = model_class.__name__
......
...@@ -24,7 +24,13 @@ from collections import OrderedDict ...@@ -24,7 +24,13 @@ from collections import OrderedDict
from itertools import takewhile from itertools import takewhile
from typing import TYPE_CHECKING, Dict, List, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast, is_torch_available from transformers import (
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
is_tf_available,
is_torch_available,
)
from transformers.testing_utils import ( from transformers.testing_utils import (
get_tests_dir, get_tests_dir,
is_pt_tf_cross_test, is_pt_tf_cross_test,
...@@ -2283,7 +2289,12 @@ class TokenizerTesterMixin: ...@@ -2283,7 +2289,12 @@ class TokenizerTesterMixin:
"{} ({}, {})".format(tokenizer.__class__.__name__, pretrained_name, tokenizer.__class__.__name__) "{} ({}, {})".format(tokenizer.__class__.__name__, pretrained_name, tokenizer.__class__.__name__)
): ):
returned_tensor = "pt" if is_torch_available() else "tf" if is_torch_available():
returned_tensor = "pt"
elif is_tf_available():
returned_tensor = "tf"
else:
returned_tensor = "jax"
if not tokenizer.pad_token or tokenizer.pad_token_id < 0: if not tokenizer.pad_token or tokenizer.pad_token_id < 0:
return return
......
...@@ -21,7 +21,7 @@ from pathlib import Path ...@@ -21,7 +21,7 @@ from pathlib import Path
from shutil import copyfile from shutil import copyfile
from transformers import BatchEncoding, MarianTokenizer from transformers import BatchEncoding, MarianTokenizer
from transformers.file_utils import is_sentencepiece_available, is_torch_available from transformers.file_utils import is_sentencepiece_available, is_tf_available, is_torch_available
from transformers.testing_utils import require_sentencepiece from transformers.testing_utils import require_sentencepiece
...@@ -36,7 +36,13 @@ SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/t ...@@ -36,7 +36,13 @@ SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/t
mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"} mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"}
zh_code = ">>zh<<" zh_code = ">>zh<<"
ORG_NAME = "Helsinki-NLP/" ORG_NAME = "Helsinki-NLP/"
FRAMEWORK = "pt" if is_torch_available() else "tf"
if is_torch_available():
FRAMEWORK = "pt"
elif is_tf_available():
FRAMEWORK = "tf"
else:
FRAMEWORK = "jax"
@require_sentencepiece @require_sentencepiece
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import unittest import unittest
from transformers import SPIECE_UNDERLINE, BatchEncoding, T5Tokenizer, T5TokenizerFast from transformers import SPIECE_UNDERLINE, BatchEncoding, T5Tokenizer, T5TokenizerFast
from transformers.file_utils import cached_property, is_torch_available from transformers.file_utils import cached_property, is_tf_available, is_torch_available
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers
from .test_tokenization_common import TokenizerTesterMixin from .test_tokenization_common import TokenizerTesterMixin
...@@ -25,7 +25,12 @@ from .test_tokenization_common import TokenizerTesterMixin ...@@ -25,7 +25,12 @@ from .test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
FRAMEWORK = "pt" if is_torch_available() else "tf" if is_torch_available():
FRAMEWORK = "pt"
elif is_tf_available():
FRAMEWORK = "tf"
else:
FRAMEWORK = "jax"
@require_sentencepiece @require_sentencepiece
...@@ -157,7 +162,12 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -157,7 +162,12 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id] expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id]
batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
self.assertIsInstance(batch, BatchEncoding) self.assertIsInstance(batch, BatchEncoding)
if FRAMEWORK != "jax":
result = list(batch.input_ids.numpy()[0]) result = list(batch.input_ids.numpy()[0])
else:
result = list(batch.input_ids.tolist()[0])
self.assertListEqual(expected_src_tokens, result) self.assertListEqual(expected_src_tokens, result)
self.assertEqual((2, 9), batch.input_ids.shape) self.assertEqual((2, 9), batch.input_ids.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment