Commit 31c23bd5 authored by thomwolf's avatar thomwolf
Browse files

[BIG] pytorch-transformers => transformers

parent 2f071fcb
...@@ -24,11 +24,11 @@ import pytest ...@@ -24,11 +24,11 @@ import pytest
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from pytorch_transformers import TransfoXLConfig, is_tf_available from transformers import TransfoXLConfig, is_tf_available
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from pytorch_transformers.modeling_tf_transfo_xl import (TFTransfoXLModel, from transformers.modeling_tf_transfo_xl import (TFTransfoXLModel,
TFTransfoXLLMHeadModel, TFTransfoXLLMHeadModel,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP) TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
else: else:
...@@ -206,7 +206,7 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -206,7 +206,7 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFTransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir) model = TFTransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
......
...@@ -20,11 +20,11 @@ import unittest ...@@ -20,11 +20,11 @@ import unittest
import shutil import shutil
import pytest import pytest
from pytorch_transformers import is_tf_available from transformers import is_tf_available
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from pytorch_transformers import (XLMConfig, TFXLMModel, from transformers import (XLMConfig, TFXLMModel,
TFXLMWithLMHeadModel, TFXLMWithLMHeadModel,
TFXLMForSequenceClassification, TFXLMForSequenceClassification,
TFXLMForQuestionAnsweringSimple, TFXLMForQuestionAnsweringSimple,
...@@ -253,7 +253,7 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -253,7 +253,7 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir) model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
......
...@@ -23,12 +23,12 @@ import random ...@@ -23,12 +23,12 @@ import random
import shutil import shutil
import pytest import pytest
from pytorch_transformers import XLNetConfig, is_tf_available from transformers import XLNetConfig, is_tf_available
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from pytorch_transformers.modeling_tf_xlnet import (TFXLNetModel, TFXLNetLMHeadModel, from transformers.modeling_tf_xlnet import (TFXLNetModel, TFXLNetLMHeadModel,
TFXLNetForSequenceClassification, TFXLNetForSequenceClassification,
TFXLNetForQuestionAnsweringSimple, TFXLNetForQuestionAnsweringSimple,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
...@@ -291,7 +291,7 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -291,7 +291,7 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFXLNetModel.from_pretrained(model_name, cache_dir=cache_dir) model = TFXLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
......
...@@ -21,12 +21,12 @@ import random ...@@ -21,12 +21,12 @@ import random
import shutil import shutil
import pytest import pytest
from pytorch_transformers import is_torch_available from transformers import is_torch_available
if is_torch_available(): if is_torch_available():
import torch import torch
from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel) from transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
from pytorch_transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP from transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
else: else:
pytestmark = pytest.mark.skip("Require Torch") pytestmark = pytest.mark.skip("Require Torch")
...@@ -206,7 +206,7 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester): ...@@ -206,7 +206,7 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir) model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
......
...@@ -20,12 +20,12 @@ import unittest ...@@ -20,12 +20,12 @@ import unittest
import shutil import shutil
import pytest import pytest
from pytorch_transformers import is_torch_available from transformers import is_torch_available
if is_torch_available(): if is_torch_available():
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, from transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering,
XLMForSequenceClassification, XLMForQuestionAnsweringSimple) XLMForSequenceClassification, XLMForQuestionAnsweringSimple)
from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP from transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
else: else:
pytestmark = pytest.mark.skip("Require Torch") pytestmark = pytest.mark.skip("Require Torch")
...@@ -314,7 +314,7 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -314,7 +314,7 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir) model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
......
...@@ -23,13 +23,13 @@ import random ...@@ -23,13 +23,13 @@ import random
import shutil import shutil
import pytest import pytest
from pytorch_transformers import is_torch_available from transformers import is_torch_available
if is_torch_available(): if is_torch_available():
import torch import torch
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering) from transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
from pytorch_transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP from transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
else: else:
pytestmark = pytest.mark.skip("Require Torch") pytestmark = pytest.mark.skip("Require Torch")
...@@ -317,7 +317,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -317,7 +317,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir) model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
......
...@@ -20,12 +20,12 @@ import unittest ...@@ -20,12 +20,12 @@ import unittest
import os import os
import pytest import pytest
from pytorch_transformers import is_torch_available from transformers import is_torch_available
if is_torch_available(): if is_torch_available():
import torch import torch
from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, from transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
else: else:
pytestmark = pytest.mark.skip("Require Torch") pytestmark = pytest.mark.skip("Require Torch")
......
...@@ -21,8 +21,8 @@ import shutil ...@@ -21,8 +21,8 @@ import shutil
import pytest import pytest
import logging import logging
from pytorch_transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer from transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer
from pytorch_transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
class AutoTokenizerTest(unittest.TestCase): class AutoTokenizerTest(unittest.TestCase):
......
...@@ -18,7 +18,7 @@ import os ...@@ -18,7 +18,7 @@ import os
import unittest import unittest
from io import open from io import open
from pytorch_transformers.tokenization_bert import (BasicTokenizer, from transformers.tokenization_bert import (BasicTokenizer,
BertTokenizer, BertTokenizer,
WordpieceTokenizer, WordpieceTokenizer,
_is_control, _is_punctuation, _is_control, _is_punctuation,
......
...@@ -18,7 +18,7 @@ import os ...@@ -18,7 +18,7 @@ import os
import unittest import unittest
from io import open from io import open
from pytorch_transformers.tokenization_distilbert import (DistilBertTokenizer) from transformers.tokenization_distilbert import (DistilBertTokenizer)
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
from .tokenization_bert_test import BertTokenizationTest from .tokenization_bert_test import BertTokenizationTest
......
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
import json import json
from io import open from io import open
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES from transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
......
...@@ -18,7 +18,7 @@ import os ...@@ -18,7 +18,7 @@ import os
import unittest import unittest
import json import json
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES from transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
......
...@@ -19,7 +19,7 @@ import json ...@@ -19,7 +19,7 @@ import json
import unittest import unittest
from io import open from io import open
from pytorch_transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES from transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
......
...@@ -19,11 +19,11 @@ import unittest ...@@ -19,11 +19,11 @@ import unittest
import pytest import pytest
from io import open from io import open
from pytorch_transformers import is_torch_available from transformers import is_torch_available
if is_torch_available(): if is_torch_available():
import torch import torch
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES from transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
else: else:
pytestmark = pytest.mark.skip("Require Torch") # TODO: untangle Transfo-XL tokenizer from torch.load and torch.save pytestmark = pytest.mark.skip("Require Torch") # TODO: untangle Transfo-XL tokenizer from torch.load and torch.save
......
...@@ -19,8 +19,8 @@ from __future__ import print_function ...@@ -19,8 +19,8 @@ from __future__ import print_function
import unittest import unittest
import six import six
from pytorch_transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer from transformers.tokenization_gpt2 import GPT2Tokenizer
class TokenizerUtilsTest(unittest.TestCase): class TokenizerUtilsTest(unittest.TestCase):
def check_tokenizer_from_pretrained(self, tokenizer_class): def check_tokenizer_from_pretrained(self, tokenizer_class):
......
...@@ -18,7 +18,7 @@ import os ...@@ -18,7 +18,7 @@ import os
import unittest import unittest
import json import json
from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES from transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
......
...@@ -17,7 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,7 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) from transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
......
...@@ -30,7 +30,7 @@ from .tokenization_distilbert import DistilBertTokenizer ...@@ -30,7 +30,7 @@ from .tokenization_distilbert import DistilBertTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AutoTokenizer(object): class AutoTokenizer(object):
r""":class:`~pytorch_transformers.AutoTokenizer` is a generic tokenizer class r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
that will be instantiated as one of the tokenizer classes of the library that will be instantiated as one of the tokenizer classes of the library
when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)`
class method. class method.
...@@ -75,7 +75,7 @@ class AutoTokenizer(object): ...@@ -75,7 +75,7 @@ class AutoTokenizer(object):
pretrained_model_name_or_path: either: pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``. - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
- a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~pytorch_transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``. - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
- (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``. - (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.
cache_dir: (`optional`) string: cache_dir: (`optional`) string:
...@@ -90,7 +90,7 @@ class AutoTokenizer(object): ...@@ -90,7 +90,7 @@ class AutoTokenizer(object):
inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method. inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.
kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details. kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~transformers.PreTrainedTokenizer` for details.
Examples:: Examples::
......
...@@ -103,7 +103,7 @@ def whitespace_tokenize(text): ...@@ -103,7 +103,7 @@ def whitespace_tokenize(text):
class BertTokenizer(PreTrainedTokenizer): class BertTokenizer(PreTrainedTokenizer):
r""" r"""
Constructs a BertTokenizer. Constructs a BertTokenizer.
:class:`~pytorch_transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece :class:`~transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
Args: Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file vocab_file: Path to a one-wordpiece-per-line vocabulary file
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment