Commit 31c23bd5 authored by thomwolf's avatar thomwolf
Browse files

[BIG] pytorch-transformers => transformers

parent 2f071fcb
...@@ -21,15 +21,15 @@ import shutil ...@@ -21,15 +21,15 @@ import shutil
import pytest import pytest
import logging import logging
from pytorch_transformers import is_torch_available from transformers import is_torch_available
if is_torch_available(): if is_torch_available():
from pytorch_transformers import (AutoConfig, BertConfig, from transformers import (AutoConfig, BertConfig,
AutoModel, BertModel, AutoModel, BertModel,
AutoModelWithLMHead, BertForMaskedLM, AutoModelWithLMHead, BertForMaskedLM,
AutoModelForSequenceClassification, BertForSequenceClassification, AutoModelForSequenceClassification, BertForSequenceClassification,
AutoModelForQuestionAnswering, BertForQuestionAnswering) AutoModelForQuestionAnswering, BertForQuestionAnswering)
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
......
...@@ -20,17 +20,17 @@ import unittest ...@@ -20,17 +20,17 @@ import unittest
import shutil import shutil
import pytest import pytest
from pytorch_transformers import is_torch_available from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
if is_torch_available(): if is_torch_available():
from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM, from transformers import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction, BertForPreTraining, BertForNextSentencePrediction, BertForPreTraining,
BertForQuestionAnswering, BertForSequenceClassification, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification, BertForMultipleChoice) BertForTokenClassification, BertForMultipleChoice)
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
else: else:
pytestmark = pytest.mark.skip("Require Torch") pytestmark = pytest.mark.skip("Require Torch")
...@@ -310,7 +310,7 @@ class BertModelTest(CommonTestCases.CommonModelTester): ...@@ -310,7 +310,7 @@ class BertModelTest(CommonTestCases.CommonModelTester):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = BertModel.from_pretrained(model_name, cache_dir=cache_dir) model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
......
...@@ -27,12 +27,12 @@ import unittest ...@@ -27,12 +27,12 @@ import unittest
import logging import logging
import pytest import pytest
from pytorch_transformers import is_torch_available from transformers import is_torch_available
if is_torch_available(): if is_torch_available():
import torch import torch
from pytorch_transformers import (PretrainedConfig, PreTrainedModel, from transformers import (PretrainedConfig, PreTrainedModel,
BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
else: else:
...@@ -621,7 +621,7 @@ class CommonTestCases: ...@@ -621,7 +621,7 @@ class CommonTestCases:
[[], []]) [[], []])
def create_and_check_model_from_pretrained(self): def create_and_check_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]: for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]:
model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir) model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
......
...@@ -19,10 +19,10 @@ from __future__ import print_function ...@@ -19,10 +19,10 @@ from __future__ import print_function
import unittest import unittest
import pytest import pytest
from pytorch_transformers import is_torch_available from transformers import is_torch_available
if is_torch_available(): if is_torch_available():
from pytorch_transformers import (DistilBertConfig, DistilBertModel, DistilBertForMaskedLM, from transformers import (DistilBertConfig, DistilBertModel, DistilBertForMaskedLM,
DistilBertForQuestionAnswering, DistilBertForSequenceClassification) DistilBertForQuestionAnswering, DistilBertForSequenceClassification)
else: else:
pytestmark = pytest.mark.skip("Require Torch") pytestmark = pytest.mark.skip("Require Torch")
...@@ -211,7 +211,7 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester): ...@@ -211,7 +211,7 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
# @pytest.mark.slow # @pytest.mark.slow
# def test_model_from_pretrained(self): # def test_model_from_pretrained(self):
# cache_dir = "/tmp/pytorch_transformers_test/" # cache_dir = "/tmp/transformers_test/"
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir) # model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
# shutil.rmtree(cache_dir) # shutil.rmtree(cache_dir)
......
...@@ -20,10 +20,10 @@ import unittest ...@@ -20,10 +20,10 @@ import unittest
import pytest import pytest
import shutil import shutil
from pytorch_transformers import is_torch_available from transformers import is_torch_available
if is_torch_available(): if is_torch_available():
from pytorch_transformers import (GPT2Config, GPT2Model, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, from transformers import (GPT2Config, GPT2Model, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2DoubleHeadsModel) GPT2LMHeadModel, GPT2DoubleHeadsModel)
else: else:
pytestmark = pytest.mark.skip("Require Torch") pytestmark = pytest.mark.skip("Require Torch")
...@@ -237,7 +237,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester): ...@@ -237,7 +237,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = GPT2Model.from_pretrained(model_name, cache_dir=cache_dir) model = GPT2Model.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
......
...@@ -20,10 +20,10 @@ import unittest ...@@ -20,10 +20,10 @@ import unittest
import pytest import pytest
import shutil import shutil
from pytorch_transformers import is_torch_available from transformers import is_torch_available
if is_torch_available(): if is_torch_available():
from pytorch_transformers import (OpenAIGPTConfig, OpenAIGPTModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, from transformers import (OpenAIGPTConfig, OpenAIGPTModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
else: else:
pytestmark = pytest.mark.skip("Require Torch") pytestmark = pytest.mark.skip("Require Torch")
...@@ -205,7 +205,7 @@ class OpenAIGPTModelTest(CommonTestCases.CommonModelTester): ...@@ -205,7 +205,7 @@ class OpenAIGPTModelTest(CommonTestCases.CommonModelTester):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir) model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
......
...@@ -20,12 +20,12 @@ import unittest ...@@ -20,12 +20,12 @@ import unittest
import shutil import shutil
import pytest import pytest
from pytorch_transformers import is_torch_available from transformers import is_torch_available
if is_torch_available(): if is_torch_available():
import torch import torch
from pytorch_transformers import (RobertaConfig, RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification) from transformers import (RobertaConfig, RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification)
from pytorch_transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP from transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
else: else:
pytestmark = pytest.mark.skip("Require Torch") pytestmark = pytest.mark.skip("Require Torch")
...@@ -180,7 +180,7 @@ class RobertaModelTest(CommonTestCases.CommonModelTester): ...@@ -180,7 +180,7 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = RobertaModel.from_pretrained(model_name, cache_dir=cache_dir) model = RobertaModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
......
...@@ -21,15 +21,15 @@ import shutil ...@@ -21,15 +21,15 @@ import shutil
import pytest import pytest
import logging import logging
from pytorch_transformers import is_tf_available from transformers import is_tf_available
if is_tf_available(): if is_tf_available():
from pytorch_transformers import (AutoConfig, BertConfig, from transformers import (AutoConfig, BertConfig,
TFAutoModel, TFBertModel, TFAutoModel, TFBertModel,
TFAutoModelWithLMHead, TFBertForMaskedLM, TFAutoModelWithLMHead, TFBertForMaskedLM,
TFAutoModelForSequenceClassification, TFBertForSequenceClassification, TFAutoModelForSequenceClassification, TFBertForSequenceClassification,
TFAutoModelForQuestionAnswering, TFBertForQuestionAnswering) TFAutoModelForQuestionAnswering, TFBertForQuestionAnswering)
from pytorch_transformers.modeling_tf_bert import TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP from transformers.modeling_tf_bert import TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
......
...@@ -24,11 +24,11 @@ import sys ...@@ -24,11 +24,11 @@ import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from pytorch_transformers import BertConfig, is_tf_available from transformers import BertConfig, is_tf_available
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from pytorch_transformers.modeling_tf_bert import (TFBertModel, TFBertForMaskedLM, from transformers.modeling_tf_bert import (TFBertModel, TFBertForMaskedLM,
TFBertForNextSentencePrediction, TFBertForNextSentencePrediction,
TFBertForPreTraining, TFBertForPreTraining,
TFBertForSequenceClassification, TFBertForSequenceClassification,
...@@ -315,7 +315,7 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -315,7 +315,7 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/transformers_test/"
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']: for model_name in ['bert-base-uncased']:
model = TFBertModel.from_pretrained(model_name, cache_dir=cache_dir) model = TFBertModel.from_pretrained(model_name, cache_dir=cache_dir)
......
...@@ -26,13 +26,13 @@ import uuid ...@@ -26,13 +26,13 @@ import uuid
import pytest import pytest
import sys import sys
from pytorch_transformers import is_tf_available, is_torch_available from transformers import is_tf_available, is_torch_available
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from pytorch_transformers import TFPreTrainedModel from transformers import TFPreTrainedModel
# from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP # from transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
else: else:
pytestmark = pytest.mark.skip("Require TensorFlow") pytestmark = pytest.mark.skip("Require TensorFlow")
...@@ -71,19 +71,19 @@ class TFCommonTestCases: ...@@ -71,19 +71,19 @@ class TFCommonTestCases:
if not is_torch_available(): if not is_torch_available():
return return
import pytorch_transformers import transformers
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beggining pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beggining
pt_model_class = getattr(pytorch_transformers, pt_model_class_name) pt_model_class = getattr(transformers, pt_model_class_name)
tf_model = model_class(config) tf_model = model_class(config)
pt_model = pt_model_class(config) pt_model = pt_model_class(config)
tf_model = pytorch_transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict) tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
pt_model = pytorch_transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model) pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
def test_keyword_and_dict_args(self): def test_keyword_and_dict_args(self):
......
...@@ -22,11 +22,11 @@ import pytest ...@@ -22,11 +22,11 @@ import pytest
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from pytorch_transformers import DistilBertConfig, is_tf_available from transformers import DistilBertConfig, is_tf_available
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from pytorch_transformers.modeling_tf_distilbert import (TFDistilBertModel, from transformers.modeling_tf_distilbert import (TFDistilBertModel,
TFDistilBertForMaskedLM, TFDistilBertForMaskedLM,
TFDistilBertForQuestionAnswering, TFDistilBertForQuestionAnswering,
TFDistilBertForSequenceClassification) TFDistilBertForSequenceClassification)
...@@ -212,7 +212,7 @@ class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -212,7 +212,7 @@ class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester):
# @pytest.mark.slow # @pytest.mark.slow
# def test_model_from_pretrained(self): # def test_model_from_pretrained(self):
# cache_dir = "/tmp/pytorch_transformers_test/" # cache_dir = "/tmp/transformers_test/"
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir) # model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
# shutil.rmtree(cache_dir) # shutil.rmtree(cache_dir)
......
...@@ -24,11 +24,11 @@ import sys ...@@ -24,11 +24,11 @@ import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from pytorch_transformers import GPT2Config, is_tf_available from transformers import GPT2Config, is_tf_available
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from pytorch_transformers.modeling_tf_gpt2 import (TFGPT2Model, TFGPT2LMHeadModel, from transformers.modeling_tf_gpt2 import (TFGPT2Model, TFGPT2LMHeadModel,
TFGPT2DoubleHeadsModel, TFGPT2DoubleHeadsModel,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
else: else:
...@@ -221,7 +221,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -221,7 +221,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_gpt2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_gpt2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFGPT2Model.from_pretrained(model_name, cache_dir=cache_dir) model = TFGPT2Model.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
......
...@@ -24,11 +24,11 @@ import sys ...@@ -24,11 +24,11 @@ import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from pytorch_transformers import OpenAIGPTConfig, is_tf_available from transformers import OpenAIGPTConfig, is_tf_available
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from pytorch_transformers.modeling_tf_openai import (TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, from transformers.modeling_tf_openai import (TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel,
TFOpenAIGPTDoubleHeadsModel, TFOpenAIGPTDoubleHeadsModel,
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP) TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
else: else:
...@@ -220,7 +220,7 @@ class TFOpenAIGPTModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -220,7 +220,7 @@ class TFOpenAIGPTModelTest(TFCommonTestCases.TFCommonModelTester):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFOpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir) model = TFOpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
......
...@@ -23,12 +23,12 @@ import pytest ...@@ -23,12 +23,12 @@ import pytest
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from pytorch_transformers import RobertaConfig, is_tf_available from transformers import RobertaConfig, is_tf_available
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
import numpy import numpy
from pytorch_transformers.modeling_tf_roberta import (TFRobertaModel, TFRobertaForMaskedLM, from transformers.modeling_tf_roberta import (TFRobertaModel, TFRobertaForMaskedLM,
TFRobertaForSequenceClassification, TFRobertaForSequenceClassification,
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
else: else:
...@@ -178,7 +178,7 @@ class TFRobertaModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -178,7 +178,7 @@ class TFRobertaModelTest(TFCommonTestCases.TFCommonModelTester):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFRobertaModel.from_pretrained(model_name, cache_dir=cache_dir) model = TFRobertaModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment