Commit 48b438ff authored by thomwolf's avatar thomwolf
Browse files

doc and conversion

parent c19b8e4a
...@@ -119,5 +119,8 @@ Here is the full list of the currently provided pretrained models together with ...@@ -119,5 +119,8 @@ Here is the full list of the currently provided pretrained models together with
| | | | The DistilBERT model distilled from the BERT model `bert-base-uncased` checkpoint, with an additional linear layer. | | | | | The DistilBERT model distilled from the BERT model `bert-base-uncased` checkpoint, with an additional linear layer. |
| | | (see `details <https://medium.com/huggingface/distilbert-8cf3380435b5>`__) | | | | (see `details <https://medium.com/huggingface/distilbert-8cf3380435b5>`__) |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| CTRL | ``ctrl`` | | 48-layer, 1280-hidden, 16-heads, 1.6B parameters |
| | | | Salesforce's Large-sized CTRL English model |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
.. <https://huggingface.co/transformers/examples.html>`__ .. <https://huggingface.co/transformers/examples.html>`__
\ No newline at end of file
...@@ -155,6 +155,11 @@ if is_tf_available(): ...@@ -155,6 +155,11 @@ if is_tf_available():
load_distilbert_pt_weights_in_tf2, load_distilbert_pt_weights_in_tf2,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_ctrl import (TFCTRLPreTrainedModel, TFCTRLModel,
TFCTRLLMHeadModel,
load_ctrl_pt_weights_in_tf2,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
# TF 2.0 <=> PyTorch conversion utilities # TF 2.0 <=> PyTorch conversion utilities
if is_tf_available() and is_torch_available(): if is_tf_available() and is_torch_available():
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name, from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
......
...@@ -31,7 +31,8 @@ from transformers import (BertConfig, TFBertForPreTraining, TFBertForQuestionAns ...@@ -31,7 +31,8 @@ from transformers import (BertConfig, TFBertForPreTraining, TFBertForQuestionAns
TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP) DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRLConfig, TFCTRLLMHeadModel, load_ctrl_pt_weights_in_tf2, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP)
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -43,7 +44,8 @@ if is_torch_available(): ...@@ -43,7 +44,8 @@ if is_torch_available():
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
else: else:
(BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, (BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
...@@ -52,7 +54,8 @@ else: ...@@ -52,7 +54,8 @@ else:
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,) = ( DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) = (
None, None, None, None, None, None, None, None,
None, None, None, None,
None, None, None, None,
...@@ -60,7 +63,8 @@ else: ...@@ -60,7 +63,8 @@ else:
None, None, None, None,
None, None, None, None,
None, None, None, None, None, None,
None, None, None,) None, None, None,
None, None)
import logging import logging
...@@ -80,6 +84,7 @@ MODEL_CLASSES = { ...@@ -80,6 +84,7 @@ MODEL_CLASSES = {
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), 'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, load_ctrl_pt_weights_in_tf2, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP)
} }
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True): def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True):
...@@ -228,6 +233,7 @@ if __name__ == "__main__": ...@@ -228,6 +233,7 @@ if __name__ == "__main__":
convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None, convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None,
args.tf_dump_path, args.tf_dump_path,
model_shortcut_names_or_path=[args.pytorch_checkpoint_path] if args.pytorch_checkpoint_path is not None else None, model_shortcut_names_or_path=[args.pytorch_checkpoint_path] if args.pytorch_checkpoint_path is not None else None,
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
compare_with_pt_model=args.compare_with_pt_model, compare_with_pt_model=args.compare_with_pt_model,
use_cached_models=args.use_cached_models, use_cached_models=args.use_cached_models,
only_convert_finetuned_models=args.only_convert_finetuned_models) only_convert_finetuned_models=args.only_convert_finetuned_models)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment