Unverified Commit d194d639 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Remove datasets requirement (#14795)

parent bef1e3e4
...@@ -31,7 +31,6 @@ from transformers import logging as transformers_logging ...@@ -31,7 +31,6 @@ from transformers import logging as transformers_logging
from .deepspeed import is_deepspeed_available from .deepspeed import is_deepspeed_available
from .file_utils import ( from .file_utils import (
is_datasets_available,
is_detectron2_available, is_detectron2_available,
is_faiss_available, is_faiss_available,
is_flax_available, is_flax_available,
...@@ -513,15 +512,6 @@ def require_torch_tf32(test_case): ...@@ -513,15 +512,6 @@ def require_torch_tf32(test_case):
return test_case return test_case
def require_datasets(test_case):
"""Decorator marking a test that requires datasets."""
if not is_datasets_available():
return unittest.skip("test requires `datasets`")(test_case)
else:
return test_case
def require_detectron2(test_case): def require_detectron2(test_case):
"""Decorator marking a test that requires detectron2.""" """Decorator marking a test that requires detectron2."""
if not is_detectron2_available(): if not is_detectron2_available():
......
...@@ -23,7 +23,6 @@ from transformers import Wav2Vec2Config, is_flax_available ...@@ -23,7 +23,6 @@ from transformers import Wav2Vec2Config, is_flax_available
from transformers.testing_utils import ( from transformers.testing_utils import (
is_librosa_available, is_librosa_available,
is_pyctcdecode_available, is_pyctcdecode_available,
require_datasets,
require_flax, require_flax,
require_librosa, require_librosa,
require_pyctcdecode, require_pyctcdecode,
...@@ -367,7 +366,6 @@ class FlaxWav2Vec2UtilsTest(unittest.TestCase): ...@@ -367,7 +366,6 @@ class FlaxWav2Vec2UtilsTest(unittest.TestCase):
@require_flax @require_flax
@require_datasets
@require_soundfile @require_soundfile
@slow @slow
class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase): class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
......
...@@ -22,7 +22,7 @@ import pytest ...@@ -22,7 +22,7 @@ import pytest
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from transformers import HubertConfig, is_torch_available from transformers import HubertConfig, is_torch_available
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, _config_zero_init from .test_modeling_common import ModelTesterMixin, _config_zero_init
...@@ -606,7 +606,6 @@ class HubertUtilsTest(unittest.TestCase): ...@@ -606,7 +606,6 @@ class HubertUtilsTest(unittest.TestCase):
@require_torch @require_torch
@require_datasets
@require_soundfile @require_soundfile
@slow @slow
class HubertModelIntegrationTest(unittest.TestCase): class HubertModelIntegrationTest(unittest.TestCase):
......
...@@ -22,7 +22,7 @@ import pytest ...@@ -22,7 +22,7 @@ import pytest
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from transformers import SEWConfig, is_torch_available from transformers import SEWConfig, is_torch_available
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, _config_zero_init from .test_modeling_common import ModelTesterMixin, _config_zero_init
...@@ -462,7 +462,6 @@ class SEWUtilsTest(unittest.TestCase): ...@@ -462,7 +462,6 @@ class SEWUtilsTest(unittest.TestCase):
@require_torch @require_torch
@require_datasets
@require_soundfile @require_soundfile
@slow @slow
class SEWModelIntegrationTest(unittest.TestCase): class SEWModelIntegrationTest(unittest.TestCase):
......
...@@ -22,7 +22,7 @@ import pytest ...@@ -22,7 +22,7 @@ import pytest
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from transformers import SEWDConfig, is_torch_available from transformers import SEWDConfig, is_torch_available
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, _config_zero_init from .test_modeling_common import ModelTesterMixin, _config_zero_init
...@@ -475,7 +475,6 @@ class SEWDUtilsTest(unittest.TestCase): ...@@ -475,7 +475,6 @@ class SEWDUtilsTest(unittest.TestCase):
@require_torch @require_torch
@require_datasets
@require_soundfile @require_soundfile
@slow @slow
class SEWDModelIntegrationTest(unittest.TestCase): class SEWDModelIntegrationTest(unittest.TestCase):
......
...@@ -23,7 +23,7 @@ import numpy as np ...@@ -23,7 +23,7 @@ import numpy as np
import pytest import pytest
from transformers import is_tf_available from transformers import is_tf_available
from transformers.testing_utils import require_datasets, require_soundfile, require_tf, slow from transformers.testing_utils import require_soundfile, require_tf, slow
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
...@@ -473,7 +473,6 @@ class TFHubertUtilsTest(unittest.TestCase): ...@@ -473,7 +473,6 @@ class TFHubertUtilsTest(unittest.TestCase):
@require_tf @require_tf
@slow @slow
@require_datasets
@require_soundfile @require_soundfile
class TFHubertModelIntegrationTest(unittest.TestCase): class TFHubertModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples): def _load_datasamples(self, num_samples):
......
...@@ -25,7 +25,7 @@ from datasets import load_dataset ...@@ -25,7 +25,7 @@ from datasets import load_dataset
from transformers import Wav2Vec2Config, is_tf_available from transformers import Wav2Vec2Config, is_tf_available
from transformers.file_utils import is_librosa_available, is_pyctcdecode_available from transformers.file_utils import is_librosa_available, is_pyctcdecode_available
from transformers.testing_utils import require_datasets, require_librosa, require_pyctcdecode, require_tf, slow from transformers.testing_utils import require_librosa, require_pyctcdecode, require_tf, slow
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
...@@ -483,7 +483,6 @@ class TFWav2Vec2UtilsTest(unittest.TestCase): ...@@ -483,7 +483,6 @@ class TFWav2Vec2UtilsTest(unittest.TestCase):
@require_tf @require_tf
@slow @slow
@require_datasets
class TFWav2Vec2ModelIntegrationTest(unittest.TestCase): class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples): def _load_datasamples(self, num_samples):
from datasets import load_dataset from datasets import load_dataset
......
...@@ -23,7 +23,7 @@ from datasets import load_dataset ...@@ -23,7 +23,7 @@ from datasets import load_dataset
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from transformers import UniSpeechConfig, is_torch_available from transformers import UniSpeechConfig, is_torch_available
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, _config_zero_init from .test_modeling_common import ModelTesterMixin, _config_zero_init
...@@ -525,7 +525,6 @@ class UniSpeechRobustModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -525,7 +525,6 @@ class UniSpeechRobustModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch @require_torch
@require_datasets
@require_soundfile @require_soundfile
@slow @slow
class UniSpeechModelIntegrationTest(unittest.TestCase): class UniSpeechModelIntegrationTest(unittest.TestCase):
......
...@@ -23,7 +23,7 @@ from datasets import load_dataset ...@@ -23,7 +23,7 @@ from datasets import load_dataset
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from transformers import UniSpeechSatConfig, is_torch_available from transformers import UniSpeechSatConfig, is_torch_available
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, _config_zero_init from .test_modeling_common import ModelTesterMixin, _config_zero_init
...@@ -783,7 +783,6 @@ class UniSpeechSatRobustModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -783,7 +783,6 @@ class UniSpeechSatRobustModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch @require_torch
@require_datasets
@require_soundfile @require_soundfile
@slow @slow
class UniSpeechSatModelIntegrationTest(unittest.TestCase): class UniSpeechSatModelIntegrationTest(unittest.TestCase):
......
...@@ -26,7 +26,6 @@ from transformers.testing_utils import ( ...@@ -26,7 +26,6 @@ from transformers.testing_utils import (
is_pt_flax_cross_test, is_pt_flax_cross_test,
is_pyctcdecode_available, is_pyctcdecode_available,
is_torchaudio_available, is_torchaudio_available,
require_datasets,
require_pyctcdecode, require_pyctcdecode,
require_soundfile, require_soundfile,
require_torch, require_torch,
...@@ -1060,7 +1059,6 @@ class Wav2Vec2UtilsTest(unittest.TestCase): ...@@ -1060,7 +1059,6 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
@require_torch @require_torch
@require_datasets
@require_soundfile @require_soundfile
@slow @slow
class Wav2Vec2ModelIntegrationTest(unittest.TestCase): class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
......
...@@ -21,7 +21,6 @@ from transformers.pipelines import AudioClassificationPipeline, pipeline ...@@ -21,7 +21,6 @@ from transformers.pipelines import AudioClassificationPipeline, pipeline
from transformers.testing_utils import ( from transformers.testing_utils import (
is_pipeline_test, is_pipeline_test,
nested_simplify, nested_simplify,
require_datasets,
require_tf, require_tf,
require_torch, require_torch,
require_torchaudio, require_torchaudio,
...@@ -65,7 +64,6 @@ class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ...@@ -65,7 +64,6 @@ class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
self.run_torchaudio(audio_classifier) self.run_torchaudio(audio_classifier)
@require_datasets
@require_torchaudio @require_torchaudio
def run_torchaudio(self, audio_classifier): def run_torchaudio(self, audio_classifier):
import datasets import datasets
...@@ -101,7 +99,6 @@ class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ...@@ -101,7 +99,6 @@ class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
) )
@require_torch @require_torch
@require_datasets
@slow @slow
def test_large_model_pt(self): def test_large_model_pt(self):
import datasets import datasets
......
...@@ -26,14 +26,7 @@ from transformers import ( ...@@ -26,14 +26,7 @@ from transformers import (
Wav2Vec2ForCTC, Wav2Vec2ForCTC,
) )
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
from transformers.testing_utils import ( from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, require_torchaudio, slow
is_pipeline_test,
require_datasets,
require_tf,
require_torch,
require_torchaudio,
slow,
)
from .test_pipelines_common import ANY, PipelineTestCaseMeta from .test_pipelines_common import ANY, PipelineTestCaseMeta
...@@ -105,7 +98,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -105,7 +98,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
framework="pt", framework="pt",
) )
@require_datasets
@require_torch @require_torch
@slow @slow
def test_torch_large(self): def test_torch_large(self):
...@@ -128,7 +120,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -128,7 +120,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output = speech_recognizer(filename) output = speech_recognizer(filename)
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"}) self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
@require_datasets
@require_torch @require_torch
@slow @slow
def test_torch_speech_encoder_decoder(self): def test_torch_speech_encoder_decoder(self):
...@@ -148,7 +139,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -148,7 +139,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
@slow @slow
@require_torch @require_torch
@require_datasets
def test_simple_wav2vec2(self): def test_simple_wav2vec2(self):
import numpy as np import numpy as np
from datasets import load_dataset from datasets import load_dataset
...@@ -177,7 +167,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -177,7 +167,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
@slow @slow
@require_torch @require_torch
@require_torchaudio @require_torchaudio
@require_datasets
def test_simple_s2t(self): def test_simple_s2t(self):
import numpy as np import numpy as np
from datasets import load_dataset from datasets import load_dataset
...@@ -207,7 +196,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -207,7 +196,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
@slow @slow
@require_torch @require_torch
@require_torchaudio @require_torchaudio
@require_datasets
def test_xls_r_to_en(self): def test_xls_r_to_en(self):
speech_recognizer = pipeline( speech_recognizer = pipeline(
task="automatic-speech-recognition", task="automatic-speech-recognition",
...@@ -226,7 +214,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -226,7 +214,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
@slow @slow
@require_torch @require_torch
@require_torchaudio @require_torchaudio
@require_datasets
def test_xls_r_from_en(self): def test_xls_r_from_en(self):
speech_recognizer = pipeline( speech_recognizer = pipeline(
task="automatic-speech-recognition", task="automatic-speech-recognition",
......
...@@ -19,7 +19,6 @@ from transformers.pipelines import ImageClassificationPipeline, pipeline ...@@ -19,7 +19,6 @@ from transformers.pipelines import ImageClassificationPipeline, pipeline
from transformers.testing_utils import ( from transformers.testing_utils import (
is_pipeline_test, is_pipeline_test,
nested_simplify, nested_simplify,
require_datasets,
require_tf, require_tf,
require_torch, require_torch,
require_vision, require_vision,
...@@ -53,7 +52,6 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ...@@ -53,7 +52,6 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
] ]
return image_classifier, examples return image_classifier, examples
@require_datasets
def run_pipeline_test(self, image_classifier, examples): def run_pipeline_test(self, image_classifier, examples):
outputs = image_classifier("./tests/fixtures/tests_samples/COCO/000000039769.png") outputs = image_classifier("./tests/fixtures/tests_samples/COCO/000000039769.png")
......
...@@ -26,7 +26,6 @@ from transformers import ( ...@@ -26,7 +26,6 @@ from transformers import (
from transformers.testing_utils import ( from transformers.testing_utils import (
is_pipeline_test, is_pipeline_test,
nested_simplify, nested_simplify,
require_datasets,
require_tf, require_tf,
require_timm, require_timm,
require_torch, require_torch,
...@@ -61,7 +60,6 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa ...@@ -61,7 +60,6 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
"./tests/fixtures/tests_samples/COCO/000000039769.png", "./tests/fixtures/tests_samples/COCO/000000039769.png",
] ]
@require_datasets
def run_pipeline_test(self, image_segmenter, examples): def run_pipeline_test(self, image_segmenter, examples):
outputs = image_segmenter("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0) outputs = image_segmenter("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0)
self.assertEqual(outputs, [{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12) self.assertEqual(outputs, [{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12)
......
...@@ -25,7 +25,6 @@ from transformers import ( ...@@ -25,7 +25,6 @@ from transformers import (
from transformers.testing_utils import ( from transformers.testing_utils import (
is_pipeline_test, is_pipeline_test,
nested_simplify, nested_simplify,
require_datasets,
require_tf, require_tf,
require_timm, require_timm,
require_torch, require_torch,
...@@ -57,7 +56,6 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase ...@@ -57,7 +56,6 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor) object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor)
return object_detector, ["./tests/fixtures/tests_samples/COCO/000000039769.png"] return object_detector, ["./tests/fixtures/tests_samples/COCO/000000039769.png"]
@require_datasets
def run_pipeline_test(self, object_detector, examples): def run_pipeline_test(self, object_detector, examples):
outputs = object_detector("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0) outputs = object_detector("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0)
......
...@@ -32,13 +32,7 @@ from transformers.models.dpr.tokenization_dpr import DPRContextEncoderTokenizer, ...@@ -32,13 +32,7 @@ from transformers.models.dpr.tokenization_dpr import DPRContextEncoderTokenizer,
from transformers.models.rag.configuration_rag import RagConfig from transformers.models.rag.configuration_rag import RagConfig
from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
from transformers.testing_utils import ( from transformers.testing_utils import require_faiss, require_sentencepiece, require_tokenizers, require_torch
require_datasets,
require_faiss,
require_sentencepiece,
require_tokenizers,
require_torch,
)
if is_faiss_available(): if is_faiss_available():
...@@ -46,7 +40,6 @@ if is_faiss_available(): ...@@ -46,7 +40,6 @@ if is_faiss_available():
@require_faiss @require_faiss
@require_datasets
class RagRetrieverTest(TestCase): class RagRetrieverTest(TestCase):
def setUp(self): def setUp(self):
self.tmpdirname = tempfile.mkdtemp() self.tmpdirname = tempfile.mkdtemp()
......
...@@ -24,7 +24,7 @@ from transformers.models.bart.configuration_bart import BartConfig ...@@ -24,7 +24,7 @@ from transformers.models.bart.configuration_bart import BartConfig
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.models.dpr.configuration_dpr import DPRConfig from transformers.models.dpr.configuration_dpr import DPRConfig
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
from transformers.testing_utils import require_datasets, require_faiss, require_tokenizers, require_torch, slow from transformers.testing_utils import require_faiss, require_tokenizers, require_torch, slow
if is_torch_available() and is_datasets_available() and is_faiss_available(): if is_torch_available() and is_datasets_available() and is_faiss_available():
...@@ -33,7 +33,6 @@ if is_torch_available() and is_datasets_available() and is_faiss_available(): ...@@ -33,7 +33,6 @@ if is_torch_available() and is_datasets_available() and is_faiss_available():
@require_faiss @require_faiss
@require_datasets
@require_torch @require_torch
class RagTokenizerTest(TestCase): class RagTokenizerTest(TestCase):
def setUp(self): def setUp(self):
......
...@@ -46,7 +46,6 @@ from transformers.testing_utils import ( ...@@ -46,7 +46,6 @@ from transformers.testing_utils import (
get_gpu_count, get_gpu_count,
get_tests_dir, get_tests_dir,
is_staging_test, is_staging_test,
require_datasets,
require_optuna, require_optuna,
require_ray, require_ray,
require_sentencepiece, require_sentencepiece,
...@@ -391,7 +390,6 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -391,7 +390,6 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train() trainer.train()
self.check_trained_model(trainer.model, alternate_seed=True) self.check_trained_model(trainer.model, alternate_seed=True)
@require_datasets
def test_trainer_with_datasets(self): def test_trainer_with_datasets(self):
import datasets import datasets
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from transformers import BertTokenizer, EncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments from transformers import BertTokenizer, EncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers.file_utils import is_datasets_available from transformers.file_utils import is_datasets_available
from transformers.testing_utils import TestCasePlus, require_datasets, require_torch, slow from transformers.testing_utils import TestCasePlus, require_torch, slow
if is_datasets_available(): if is_datasets_available():
...@@ -25,7 +25,6 @@ if is_datasets_available(): ...@@ -25,7 +25,6 @@ if is_datasets_available():
class Seq2seqTrainerTester(TestCasePlus): class Seq2seqTrainerTester(TestCasePlus):
@slow @slow
@require_torch @require_torch
@require_datasets
def test_finetune_bert2bert(self): def test_finetune_bert2bert(self):
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny") bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment