Commit 01ffc65e authored by thomwolf's avatar thomwolf
Browse files

update tests to remove unittest.patch

parent 825697ca
...@@ -48,16 +48,19 @@ if is_torch_available(): ...@@ -48,16 +48,19 @@ if is_torch_available():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_framework(model=None): def get_framework(model=None):
""" Select framework (TensorFlow/PyTorch) to use.
If both frameworks are installed and no specific model is provided, defaults to using TensorFlow.
"""
if is_tf_available() and is_torch_available() and model is not None and not isinstance(model, str): if is_tf_available() and is_torch_available() and model is not None and not isinstance(model, str):
# Both framework are available but the use supplied a model class instance. # Both framework are available but the use supplied a model class instance.
# Try to guess which framework to use from the model classname # Try to guess which framework to use from the model classname
framework = 'tf' if model.__class__.__name__.startswith('TF') else 'pt' framework = 'tf' if model.__class__.__name__.startswith('TF') else 'pt'
else: elif not is_tf_available() and not is_torch_available():
framework = 'tf' if is_tf_available() else ('pt' if is_torch_available() else None)
if framework is None:
raise ImportError("At least one of TensorFlow 2.0 or PyTorch should be installed. " raise ImportError("At least one of TensorFlow 2.0 or PyTorch should be installed. "
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/.") "To install PyTorch, read the instructions at https://pytorch.org/.")
else:
framework = 'tf' if is_tf_available() else 'pt'
return framework return framework
class ArgumentHandler(ABC): class ArgumentHandler(ABC):
......
import unittest import unittest
from unittest.mock import patch
from typing import Iterable from typing import Iterable
...@@ -35,16 +34,6 @@ TEXT_CLASSIF_FINETUNED_MODELS = { ...@@ -35,16 +34,6 @@ TEXT_CLASSIF_FINETUNED_MODELS = {
} }
@require_tf
def tf_pipeline(*args, **kwargs):
return pipeline(**kwargs)
@require_torch
def torch_pipeline(*args, **kwargs):
return pipeline(**kwargs)
class MonoColumnInputTestCase(unittest.TestCase): class MonoColumnInputTestCase(unittest.TestCase):
def _test_mono_column_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]): def _test_mono_column_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
self.assertIsNotNone(nlp) self.assertIsNotNone(nlp)
...@@ -72,43 +61,57 @@ class MonoColumnInputTestCase(unittest.TestCase): ...@@ -72,43 +61,57 @@ class MonoColumnInputTestCase(unittest.TestCase):
self.assertRaises(Exception, nlp, invalid_inputs) self.assertRaises(Exception, nlp, invalid_inputs)
@require_torch
def test_ner(self): def test_ner(self):
mandatory_keys = {'entity', 'word', 'score'} mandatory_keys = {'entity', 'word', 'score'}
valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris'] valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris']
invalid_inputs = [None] invalid_inputs = [None]
for tokenizer, model, config in NER_FINETUNED_MODELS: for tokenizer, model, config in NER_FINETUNED_MODELS:
with patch('transformers.pipelines.is_torch_available', return_value=False): nlp = pipeline(task='ner', model=model, config=config, tokenizer=tokenizer)
nlp = tf_pipeline(task='ner', model=model, config=config, tokenizer=tokenizer) self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
with patch('transformers.pipelines.is_tf_available', return_value=False): @require_tf
nlp = torch_pipeline(task='ner', model=model, config=config, tokenizer=tokenizer) def test_tf_ner(self):
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys) mandatory_keys = {'entity', 'word', 'score'}
valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris']
invalid_inputs = [None]
for tokenizer, model, config in NER_FINETUNED_MODELS:
nlp = pipeline(task='ner', model=model, config=config, tokenizer=tokenizer)
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
@require_torch
def test_sentiment_analysis(self): def test_sentiment_analysis(self):
mandatory_keys = {'label'} mandatory_keys = {'label'}
valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris'] valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris']
invalid_inputs = [None] invalid_inputs = [None]
for tokenizer, model, config in TEXT_CLASSIF_FINETUNED_MODELS: for tokenizer, model, config in TEXT_CLASSIF_FINETUNED_MODELS:
with patch('transformers.pipelines.is_torch_available', return_value=False): nlp = pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer)
nlp = tf_pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer) self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
with patch('transformers.pipelines.is_tf_available', return_value=False): @require_tf
nlp = torch_pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer) def test_tf_sentiment_analysis(self):
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys) mandatory_keys = {'label'}
valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris']
invalid_inputs = [None]
for tokenizer, model, config in TEXT_CLASSIF_FINETUNED_MODELS:
nlp = pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer)
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
@require_torch
def test_features_extraction(self): def test_features_extraction(self):
valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris'] valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris']
invalid_inputs = [None] invalid_inputs = [None]
for tokenizer, model, config in FEATURE_EXTRACT_FINETUNED_MODELS: for tokenizer, model, config in FEATURE_EXTRACT_FINETUNED_MODELS:
with patch('transformers.pipelines.is_torch_available', return_value=False): nlp = pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer)
nlp = tf_pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer) self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, {})
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, {})
with patch('transformers.pipelines.is_tf_available', return_value=False): @require_tf
nlp = torch_pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer) def test_tf_features_extraction(self):
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, {}) valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris']
invalid_inputs = [None]
for tokenizer, model, config in FEATURE_EXTRACT_FINETUNED_MODELS:
nlp = pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer)
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, {})
class MultiColumnInputTestCase(unittest.TestCase): class MultiColumnInputTestCase(unittest.TestCase):
...@@ -132,6 +135,7 @@ class MultiColumnInputTestCase(unittest.TestCase): ...@@ -132,6 +135,7 @@ class MultiColumnInputTestCase(unittest.TestCase):
self.assertRaises(Exception, nlp, invalid_inputs[0]) self.assertRaises(Exception, nlp, invalid_inputs[0])
self.assertRaises(Exception, nlp, invalid_inputs) self.assertRaises(Exception, nlp, invalid_inputs)
@require_torch
def test_question_answering(self): def test_question_answering(self):
mandatory_output_keys = {'score', 'answer', 'start', 'end'} mandatory_output_keys = {'score', 'answer', 'start', 'end'}
valid_samples = [ valid_samples = [
...@@ -149,16 +153,29 @@ class MultiColumnInputTestCase(unittest.TestCase): ...@@ -149,16 +153,29 @@ class MultiColumnInputTestCase(unittest.TestCase):
] ]
for tokenizer, model, config in QA_FINETUNED_MODELS: for tokenizer, model, config in QA_FINETUNED_MODELS:
nlp = pipeline(task='question-answering', model=model, config=config, tokenizer=tokenizer)
self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys)
# Test for Tensorflow @require_tf
with patch('transformers.pipelines.is_torch_available', return_value=False): def test_tf_question_answering(self):
nlp = pipeline(task='question-answering', model=model, config=config, tokenizer=tokenizer) mandatory_output_keys = {'score', 'answer', 'start', 'end'}
self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys) valid_samples = [
{'question': 'Where was HuggingFace founded ?', 'context': 'HuggingFace was founded in Paris.'},
{
'question': 'In what field is HuggingFace working ?',
'context': 'HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.'
}
]
invalid_samples = [
{'question': '', 'context': 'This is a test to try empty question edge case'},
{'question': None, 'context': 'This is a test to try empty question edge case'},
{'question': 'What is does with empty context ?', 'context': ''},
{'question': 'What is does with empty context ?', 'context': None},
]
# Test for PyTorch for tokenizer, model, config in QA_FINETUNED_MODELS:
with patch('transformers.pipelines.is_tf_available', return_value=False): nlp = pipeline(task='question-answering', model=model, config=config, tokenizer=tokenizer)
nlp = pipeline(task='question-answering', model=model, config=config, tokenizer=tokenizer) self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys)
self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment