Commit f03c0c14 authored by thomwolf's avatar thomwolf
Browse files

adding models in readme and auto classes

parent 4321c541
...@@ -122,7 +122,8 @@ At some point in the future, you'll be able to seamlessly move from pre-training ...@@ -122,7 +122,8 @@ At some point in the future, you'll be able to seamlessly move from pre-training
7. **[RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. 7. **[RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
8. **[DistilBERT](https://github.com/huggingface/transformers/tree/master/examples/distillation)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/master/examples/distillation). 8. **[DistilBERT](https://github.com/huggingface/transformers/tree/master/examples/distillation)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/master/examples/distillation).
9. **[CTRL](https://github.com/salesforce/ctrl/)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. 9. **[CTRL](https://github.com/salesforce/ctrl/)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.
10. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR. 10. **[T5](https://github.com/google-research/text-to-text-transfer-transformer)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
11. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html). These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html).
......
...@@ -144,5 +144,25 @@ Here is the full list of the currently provided pretrained models together with ...@@ -144,5 +144,25 @@ Here is the full list of the currently provided pretrained models together with
| CTRL | ``ctrl`` | | 48-layer, 1280-hidden, 16-heads, 1.6B parameters | | CTRL | ``ctrl`` | | 48-layer, 1280-hidden, 16-heads, 1.6B parameters |
| | | | Salesforce's Large-sized CTRL English model | | | | | Salesforce's Large-sized CTRL English model |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| T5 | ``t5-small`` | | 6-layer, 768-hidden, 12-heads, 66M parameters |
| | | | The DistilBERT model distilled from the BERT model `bert-base-uncased` checkpoint |
| | | (see `details <https://github.com/huggingface/transformers/tree/master/examples/distillation>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``t5-base`` | | 6-layer, 768-hidden, 12-heads, 66M parameters |
| | | | The DistilBERT model distilled from the BERT model `bert-base-uncased` checkpoint, with an additional linear layer. |
| | | (see `details <https://github.com/huggingface/transformers/tree/master/examples/distillation>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``t5-large`` | | 6-layer, 768-hidden, 12-heads, 82M parameters |
| | | | The DistilGPT2 model distilled from the GPT2 model `gpt2` checkpoint. |
| | | (see `details <https://github.com/huggingface/transformers/tree/master/examples/distillation>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``t5-3b`` | | 6-layer, 768-hidden, 12-heads, 82M parameters |
| | | | The DistilRoBERTa model distilled from the RoBERTa model `roberta-base` checkpoint. |
| | | (see `details <https://github.com/huggingface/transformers/tree/master/examples/distillation>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``t5-11b`` | | 6-layer, 768-hidden, 12-heads, 82M parameters |
| | | | The DistilRoBERTa model distilled from the RoBERTa model `roberta-base` checkpoint. |
| | | (see `details <https://github.com/huggingface/transformers/tree/master/examples/distillation>`__) |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
.. <https://huggingface.co/transformers/examples.html>`__ .. <https://huggingface.co/transformers/examples.html>`__
\ No newline at end of file
...@@ -6,6 +6,7 @@ def main(): ...@@ -6,6 +6,7 @@ def main():
"This command line utility let you convert original (author released) model checkpoint to pytorch.\n" "This command line utility let you convert original (author released) model checkpoint to pytorch.\n"
"It should be used as one of: \n" "It should be used as one of: \n"
">> transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n" ">> transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n"
">> transformers t5 TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n"
">> transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n" ">> transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n"
">> transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n" ">> transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n"
">> transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG] or \n" ">> transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG] or \n"
...@@ -21,6 +22,23 @@ def main(): ...@@ -21,6 +22,23 @@ def main():
"https://www.tensorflow.org/install/ for installation instructions.") "https://www.tensorflow.org/install/ for installation instructions.")
raise raise
if len(sys.argv) != 5:
# pylint: disable=line-too-long
print("Should be used as `transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
else:
PYTORCH_DUMP_OUTPUT = sys.argv.pop()
TF_CONFIG = sys.argv.pop()
TF_CHECKPOINT = sys.argv.pop()
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
elif sys.argv[1] == "t5":
try:
from .convert_t5_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
except ImportError:
print("transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
if len(sys.argv) != 5: if len(sys.argv) != 5:
# pylint: disable=line-too-long # pylint: disable=line-too-long
print("Should be used as `transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") print("Should be used as `transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
......
...@@ -33,7 +33,8 @@ from transformers import (load_pytorch_checkpoint_in_tf2_model, ...@@ -33,7 +33,8 @@ from transformers import (load_pytorch_checkpoint_in_tf2_model,
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP) CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
T5Config, TFT5WithLMHeadModel, T5_PRETRAINED_CONFIG_ARCHIVE_MAP)
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -46,7 +47,8 @@ if is_torch_available(): ...@@ -46,7 +47,8 @@ if is_torch_available():
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP)
else: else:
(BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, (BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
...@@ -56,7 +58,8 @@ else: ...@@ -56,7 +58,8 @@ else:
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) = ( CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP) = (
None, None, None, None, None, None, None, None,
None, None, None, None,
None, None, None, None,
...@@ -65,6 +68,7 @@ else: ...@@ -65,6 +68,7 @@ else:
None, None, None, None,
None, None, None, None, None, None,
None, None, None, None, None, None,
None, None,
None, None) None, None)
...@@ -85,7 +89,8 @@ MODEL_CLASSES = { ...@@ -85,7 +89,8 @@ MODEL_CLASSES = {
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), 'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP) 'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP),
't5': (T5Config, TFT5WithLMHeadModel, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP),
} }
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True): def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True):
......
...@@ -27,6 +27,7 @@ from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassi ...@@ -27,6 +27,7 @@ from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassi
from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification
from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification
from .modeling_t5 import T5Model, T5WithLMHeadModel
from .modeling_utils import PreTrainedModel, SequenceSummary from .modeling_utils import PreTrainedModel, SequenceSummary
...@@ -47,6 +48,7 @@ class AutoModel(object): ...@@ -47,6 +48,7 @@ class AutoModel(object):
The base model class to instantiate is selected as the first pattern matching The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: T5Model (T5 model)
- contains `distilbert`: DistilBertModel (DistilBERT model) - contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `roberta`: RobertaModel (RoBERTa model) - contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model) - contains `bert`: BertModel (Bert model)
...@@ -70,6 +72,7 @@ class AutoModel(object): ...@@ -70,6 +72,7 @@ class AutoModel(object):
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: T5Model (T5 model)
- contains `distilbert`: DistilBertModel (DistilBERT model) - contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `roberta`: RobertaModel (RoBERTa model) - contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model) - contains `bert`: BertModel (Bert model)
...@@ -136,7 +139,9 @@ class AutoModel(object): ...@@ -136,7 +139,9 @@ class AutoModel(object):
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
if 'distilbert' in pretrained_model_name_or_path: if 't5' in pretrained_model_name_or_path:
return T5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path:
return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path: elif 'roberta' in pretrained_model_name_or_path:
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
...@@ -171,6 +176,7 @@ class AutoModelWithLMHead(object): ...@@ -171,6 +176,7 @@ class AutoModelWithLMHead(object):
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: T5ModelWithLMHead (T5 model)
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model) - contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
- contains `roberta`: RobertaForMaskedLM (RoBERTa model) - contains `roberta`: RobertaForMaskedLM (RoBERTa model)
- contains `bert`: BertForMaskedLM (Bert model) - contains `bert`: BertForMaskedLM (Bert model)
...@@ -197,6 +203,7 @@ class AutoModelWithLMHead(object): ...@@ -197,6 +203,7 @@ class AutoModelWithLMHead(object):
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: T5ModelWithLMHead (T5 model)
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model) - contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
- contains `roberta`: RobertaForMaskedLM (RoBERTa model) - contains `roberta`: RobertaForMaskedLM (RoBERTa model)
- contains `bert`: BertForMaskedLM (Bert model) - contains `bert`: BertForMaskedLM (Bert model)
...@@ -262,7 +269,9 @@ class AutoModelWithLMHead(object): ...@@ -262,7 +269,9 @@ class AutoModelWithLMHead(object):
model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
if 'distilbert' in pretrained_model_name_or_path: if 't5' in pretrained_model_name_or_path:
return T5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path:
return DistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return DistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path: elif 'roberta' in pretrained_model_name_or_path:
return RobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return RobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
......
...@@ -27,6 +27,7 @@ from .modeling_tf_xlm import TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceC ...@@ -27,6 +27,7 @@ from .modeling_tf_xlm import TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceC
from .modeling_tf_roberta import TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification from .modeling_tf_roberta import TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification
from .modeling_tf_distilbert import TFDistilBertModel, TFDistilBertForQuestionAnswering, TFDistilBertForMaskedLM, TFDistilBertForSequenceClassification from .modeling_tf_distilbert import TFDistilBertModel, TFDistilBertForQuestionAnswering, TFDistilBertForMaskedLM, TFDistilBertForSequenceClassification
from .modeling_tf_ctrl import TFCTRLModel, TFCTRLLMHeadModel from .modeling_tf_ctrl import TFCTRLModel, TFCTRLLMHeadModel
from .modeling_tf_t5 import TFT5Model, TFT5WithLMHeadModel
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
...@@ -45,6 +46,7 @@ class TFAutoModel(object): ...@@ -45,6 +46,7 @@ class TFAutoModel(object):
The base model class to instantiate is selected as the first pattern matching The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: TFT5Model (T5 model)
- contains `distilbert`: TFDistilBertModel (DistilBERT model) - contains `distilbert`: TFDistilBertModel (DistilBERT model)
- contains `roberta`: TFRobertaModel (RoBERTa model) - contains `roberta`: TFRobertaModel (RoBERTa model)
- contains `bert`: TFBertModel (Bert model) - contains `bert`: TFBertModel (Bert model)
...@@ -68,6 +70,7 @@ class TFAutoModel(object): ...@@ -68,6 +70,7 @@ class TFAutoModel(object):
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: TFT5Model (T5 model)
- contains `distilbert`: TFDistilBertModel (DistilBERT model) - contains `distilbert`: TFDistilBertModel (DistilBERT model)
- contains `roberta`: TFRobertaModel (RoBERTa model) - contains `roberta`: TFRobertaModel (RoBERTa model)
- contains `bert`: TFTFBertModel (Bert model) - contains `bert`: TFTFBertModel (Bert model)
...@@ -133,7 +136,9 @@ class TFAutoModel(object): ...@@ -133,7 +136,9 @@ class TFAutoModel(object):
model = TFAutoModel.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) model = TFAutoModel.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
""" """
if 'distilbert' in pretrained_model_name_or_path: if 't5' in pretrained_model_name_or_path:
return TFT5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path:
return TFDistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TFDistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path: elif 'roberta' in pretrained_model_name_or_path:
return TFRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TFRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
...@@ -169,6 +174,7 @@ class TFAutoModelWithLMHead(object): ...@@ -169,6 +174,7 @@ class TFAutoModelWithLMHead(object):
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: TFT5WithLMHeadModel (T5 model)
- contains `distilbert`: TFDistilBertForMaskedLM (DistilBERT model) - contains `distilbert`: TFDistilBertForMaskedLM (DistilBERT model)
- contains `roberta`: TFRobertaForMaskedLM (RoBERTa model) - contains `roberta`: TFRobertaForMaskedLM (RoBERTa model)
- contains `bert`: TFBertForMaskedLM (Bert model) - contains `bert`: TFBertForMaskedLM (Bert model)
...@@ -195,6 +201,7 @@ class TFAutoModelWithLMHead(object): ...@@ -195,6 +201,7 @@ class TFAutoModelWithLMHead(object):
The model class to instantiate is selected as the first pattern matching The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: TFT5WithLMHeadModel (T5 model)
- contains `distilbert`: TFDistilBertForMaskedLM (DistilBERT model) - contains `distilbert`: TFDistilBertForMaskedLM (DistilBERT model)
- contains `roberta`: TFRobertaForMaskedLM (RoBERTa model) - contains `roberta`: TFRobertaForMaskedLM (RoBERTa model)
- contains `bert`: TFBertForMaskedLM (Bert model) - contains `bert`: TFBertForMaskedLM (Bert model)
...@@ -261,7 +268,9 @@ class TFAutoModelWithLMHead(object): ...@@ -261,7 +268,9 @@ class TFAutoModelWithLMHead(object):
model = TFAutoModelWithLMHead.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) model = TFAutoModelWithLMHead.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
""" """
if 'distilbert' in pretrained_model_name_or_path: if 't5' in pretrained_model_name_or_path:
return TFT5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path:
return TFDistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TFDistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path: elif 'roberta' in pretrained_model_name_or_path:
return TFRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return TFRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
......
...@@ -27,6 +27,7 @@ from .tokenization_xlnet import XLNetTokenizer ...@@ -27,6 +27,7 @@ from .tokenization_xlnet import XLNetTokenizer
from .tokenization_xlm import XLMTokenizer from .tokenization_xlm import XLMTokenizer
from .tokenization_roberta import RobertaTokenizer from .tokenization_roberta import RobertaTokenizer
from .tokenization_distilbert import DistilBertTokenizer from .tokenization_distilbert import DistilBertTokenizer
from .tokenization_t5 import T5Tokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -41,6 +42,7 @@ class AutoTokenizer(object): ...@@ -41,6 +42,7 @@ class AutoTokenizer(object):
The tokenizer class to instantiate is selected as the first pattern matching The tokenizer class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: T5Tokenizer (T5 model)
- contains `distilbert`: DistilBertTokenizer (DistilBert model) - contains `distilbert`: DistilBertTokenizer (DistilBert model)
- contains `roberta`: RobertaTokenizer (RoBERTa model) - contains `roberta`: RobertaTokenizer (RoBERTa model)
- contains `bert`: BertTokenizer (Bert model) - contains `bert`: BertTokenizer (Bert model)
...@@ -64,6 +66,7 @@ class AutoTokenizer(object): ...@@ -64,6 +66,7 @@ class AutoTokenizer(object):
The tokenizer class to instantiate is selected as the first pattern matching The tokenizer class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: T5Tokenizer (T5 model)
- contains `distilbert`: DistilBertTokenizer (DistilBert model) - contains `distilbert`: DistilBertTokenizer (DistilBert model)
- contains `roberta`: RobertaTokenizer (XLM model) - contains `roberta`: RobertaTokenizer (XLM model)
- contains `bert`: BertTokenizer (Bert model) - contains `bert`: BertTokenizer (Bert model)
...@@ -101,7 +104,9 @@ class AutoTokenizer(object): ...@@ -101,7 +104,9 @@ class AutoTokenizer(object):
tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/') # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')` tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/') # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`
""" """
if 'distilbert' in pretrained_model_name_or_path: if 't5' in pretrained_model_name_or_path:
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path:
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'roberta' in pretrained_model_name_or_path: elif 'roberta' in pretrained_model_name_or_path:
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment