"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "f4323dbf8c29952b1ae55b979120969a9aeb730e"
Unverified Commit 3b43b018 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1482 from huggingface/tf2_integration_tests

Integration of TF 2.0 models with other Keras modules
parents 700331b5 4b8f3e8f
...@@ -9,7 +9,7 @@ jobs: ...@@ -9,7 +9,7 @@ jobs:
steps: steps:
- checkout - checkout
- run: sudo pip install torch - run: sudo pip install torch
- run: sudo pip install tensorflow==2.0.0-rc0 - run: sudo pip install tensorflow
- run: sudo pip install --progress-bar off . - run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov - run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install tensorboardX scikit-learn - run: sudo pip install tensorboardX scikit-learn
...@@ -38,7 +38,7 @@ jobs: ...@@ -38,7 +38,7 @@ jobs:
parallelism: 1 parallelism: 1
steps: steps:
- checkout - checkout
- run: sudo pip install tensorflow==2.0.0-rc0 - run: sudo pip install tensorflow
- run: sudo pip install --progress-bar off . - run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov - run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install tensorboardX scikit-learn - run: sudo pip install tensorboardX scikit-learn
...@@ -65,7 +65,7 @@ jobs: ...@@ -65,7 +65,7 @@ jobs:
- image: circleci/python:2.7 - image: circleci/python:2.7
steps: steps:
- checkout - checkout
- run: sudo pip install tensorflow==2.0.0-rc0 - run: sudo pip install tensorflow
- run: sudo pip install --progress-bar off . - run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov - run: sudo pip install pytest codecov pytest-cov
- run: python -m pytest -sv ./transformers/tests/ --cov - run: python -m pytest -sv ./transformers/tests/ --cov
......
...@@ -545,4 +545,14 @@ for batch in train_data: ...@@ -545,4 +545,14 @@ for batch in train_data:
## Citation ## Citation
At the moment, there is no paper associated with Transformers but we are working on preparing one. In the meantime, please include a mention of the library and a link to the present repository if you use this work in a published or open-source project. We now have a paper you can cite for the 🤗 Transformers library:
```
@misc{wolf2019transformers,
title={Transformers: State-of-the-art Natural Language Processing},
author={Thomas Wolf and Lysandre Debut and Victor Sanh and Julien Chaumond and Clement Delangue and Anthony Moi and Pierric Cistac and Tim Rault and Rémi Louf and Morgan Funtowicz and Jamie Brew},
year={2019},
eprint={1910.03771},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
...@@ -110,59 +110,49 @@ if is_tf_available(): ...@@ -110,59 +110,49 @@ if is_tf_available():
TFBertForMaskedLM, TFBertForNextSentencePrediction, TFBertForMaskedLM, TFBertForNextSentencePrediction,
TFBertForSequenceClassification, TFBertForMultipleChoice, TFBertForSequenceClassification, TFBertForMultipleChoice,
TFBertForTokenClassification, TFBertForQuestionAnswering, TFBertForTokenClassification, TFBertForQuestionAnswering,
load_bert_pt_weights_in_tf2,
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP) TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer, from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer,
TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel, TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel,
load_gpt2_pt_weights_in_tf2,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_openai import (TFOpenAIGPTPreTrainedModel, TFOpenAIGPTMainLayer, from .modeling_tf_openai import (TFOpenAIGPTPreTrainedModel, TFOpenAIGPTMainLayer,
TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel, TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel,
load_openai_gpt_pt_weights_in_tf2,
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP) TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_transfo_xl import (TFTransfoXLPreTrainedModel, TFTransfoXLMainLayer, from .modeling_tf_transfo_xl import (TFTransfoXLPreTrainedModel, TFTransfoXLMainLayer,
TFTransfoXLModel, TFTransfoXLLMHeadModel, TFTransfoXLModel, TFTransfoXLLMHeadModel,
load_transfo_xl_pt_weights_in_tf2,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP) TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_xlnet import (TFXLNetPreTrainedModel, TFXLNetMainLayer, from .modeling_tf_xlnet import (TFXLNetPreTrainedModel, TFXLNetMainLayer,
TFXLNetModel, TFXLNetLMHeadModel, TFXLNetModel, TFXLNetLMHeadModel,
TFXLNetForSequenceClassification, TFXLNetForSequenceClassification,
TFXLNetForQuestionAnsweringSimple, TFXLNetForQuestionAnsweringSimple,
load_xlnet_pt_weights_in_tf2,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_xlm import (TFXLMPreTrainedModel, TFXLMMainLayer, from .modeling_tf_xlm import (TFXLMPreTrainedModel, TFXLMMainLayer,
TFXLMModel, TFXLMWithLMHeadModel, TFXLMModel, TFXLMWithLMHeadModel,
TFXLMForSequenceClassification, TFXLMForSequenceClassification,
TFXLMForQuestionAnsweringSimple, TFXLMForQuestionAnsweringSimple,
load_xlm_pt_weights_in_tf2,
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP) TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_roberta import (TFRobertaPreTrainedModel, TFRobertaMainLayer, from .modeling_tf_roberta import (TFRobertaPreTrainedModel, TFRobertaMainLayer,
TFRobertaModel, TFRobertaForMaskedLM, TFRobertaModel, TFRobertaForMaskedLM,
TFRobertaForSequenceClassification, TFRobertaForSequenceClassification,
load_roberta_pt_weights_in_tf2,
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_distilbert import (TFDistilBertPreTrainedModel, TFDistilBertMainLayer, from .modeling_tf_distilbert import (TFDistilBertPreTrainedModel, TFDistilBertMainLayer,
TFDistilBertModel, TFDistilBertForMaskedLM, TFDistilBertModel, TFDistilBertForMaskedLM,
TFDistilBertForSequenceClassification, TFDistilBertForSequenceClassification,
TFDistilBertForQuestionAnswering, TFDistilBertForQuestionAnswering,
load_distilbert_pt_weights_in_tf2,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_ctrl import (TFCTRLPreTrainedModel, TFCTRLModel, from .modeling_tf_ctrl import (TFCTRLPreTrainedModel, TFCTRLModel,
TFCTRLLMHeadModel, TFCTRLLMHeadModel,
load_ctrl_pt_weights_in_tf2,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
# TF 2.0 <=> PyTorch conversion utilities # TF 2.0 <=> PyTorch conversion utilities
if is_tf_available() and is_torch_available(): from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
load_pytorch_checkpoint_in_tf2_model, load_pytorch_checkpoint_in_tf2_model,
load_pytorch_weights_in_tf2_model, load_pytorch_weights_in_tf2_model,
load_pytorch_model_in_tf2_model, load_pytorch_model_in_tf2_model,
......
...@@ -53,7 +53,8 @@ class PretrainedConfig(object): ...@@ -53,7 +53,8 @@ class PretrainedConfig(object):
self.num_labels = kwargs.pop('num_labels', 2) self.num_labels = kwargs.pop('num_labels', 2)
self.output_attentions = kwargs.pop('output_attentions', False) self.output_attentions = kwargs.pop('output_attentions', False)
self.output_hidden_states = kwargs.pop('output_hidden_states', False) self.output_hidden_states = kwargs.pop('output_hidden_states', False)
self.torchscript = kwargs.pop('torchscript', False) self.output_past = kwargs.pop('output_past', True) # Not used by all models
self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop('use_bfloat16', False) self.use_bfloat16 = kwargs.pop('use_bfloat16', False)
self.pruned_heads = kwargs.pop('pruned_heads', {}) self.pruned_heads = kwargs.pop('pruned_heads', {})
...@@ -153,7 +154,7 @@ class PretrainedConfig(object): ...@@ -153,7 +154,7 @@ class PretrainedConfig(object):
config = cls.from_json_file(resolved_config_file) config = cls.from_json_file(resolved_config_file)
if hasattr(config, 'pruned_heads'): if hasattr(config, 'pruned_heads'):
config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items()) config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
# Update config with kwargs if needed # Update config with kwargs if needed
to_remove = [] to_remove = []
...@@ -164,7 +165,7 @@ class PretrainedConfig(object): ...@@ -164,7 +165,7 @@ class PretrainedConfig(object):
for key in to_remove: for key in to_remove:
kwargs.pop(key, None) kwargs.pop(key, None)
logger.info("Model config %s", config) logger.info("Model config %s", str(config))
if return_unused_kwargs: if return_unused_kwargs:
return config, kwargs return config, kwargs
else: else:
......
...@@ -24,15 +24,16 @@ import tensorflow as tf ...@@ -24,15 +24,16 @@ import tensorflow as tf
from transformers import is_torch_available, cached_path from transformers import is_torch_available, cached_path
from transformers import (BertConfig, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, load_bert_pt_weights_in_tf2, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, from transformers import (load_pytorch_checkpoint_in_tf2_model,
GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, TFGPT2LMHeadModel, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig, TFXLNetLMHeadModel, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, TFXLMWithLMHeadModel, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig, TFTransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRLConfig, TFCTRLLMHeadModel, load_ctrl_pt_weights_in_tf2, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP) DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP)
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -71,27 +72,27 @@ import logging ...@@ -71,27 +72,27 @@ import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'bert': (BertConfig, TFBertForPreTraining, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'bert-large-uncased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, load_bert_pt_weights_in_tf2, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'bert-large-uncased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'bert-large-cased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, load_bert_pt_weights_in_tf2, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'bert-large-cased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'bert-base-cased-finetuned-mrpc': (BertConfig, TFBertForSequenceClassification, load_bert_pt_weights_in_tf2, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'bert-base-cased-finetuned-mrpc': (BertConfig, TFBertForSequenceClassification, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP), 'gpt2': (GPT2Config, TFGPT2LMHeadModel, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP),
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP), 'xlnet': (XLNetConfig, TFXLNetLMHeadModel, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP),
'xlm': (XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP), 'xlm': (XLMConfig, TFXLMWithLMHeadModel, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP),
'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP), 'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP),
'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'roberta': (RobertaConfig, TFRobertaForMaskedLM, load_roberta_pt_weights_in_tf2, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), 'roberta': (RobertaConfig, TFRobertaForMaskedLM, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), 'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, load_ctrl_pt_weights_in_tf2, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP) 'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP)
} }
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True): def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True):
if model_type not in MODEL_CLASSES: if model_type not in MODEL_CLASSES:
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys()))) raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type] config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
# Initialise TF model # Initialise TF model
if config_file in aws_config_map: if config_file in aws_config_map:
...@@ -105,7 +106,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file ...@@ -105,7 +106,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
# Load weights from tf checkpoint # Load weights from tf checkpoint
if pytorch_checkpoint_path in aws_model_maps: if pytorch_checkpoint_path in aws_model_maps:
pytorch_checkpoint_path = cached_path(aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models) pytorch_checkpoint_path = cached_path(aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models)
tf_model = loading_fct(tf_model, pytorch_checkpoint_path) # Load PyTorch checkpoint in tf2 model:
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
if compare_with_pt_model: if compare_with_pt_model:
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
...@@ -147,7 +149,7 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc ...@@ -147,7 +149,7 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
if model_type not in MODEL_CLASSES: if model_type not in MODEL_CLASSES:
raise ValueError("Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys()))) raise ValueError("Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys())))
config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type] config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
if model_shortcut_names_or_path is None: if model_shortcut_names_or_path is None:
model_shortcut_names_or_path = list(aws_model_maps.keys()) model_shortcut_names_or_path = list(aws_model_maps.keys())
......
...@@ -269,16 +269,16 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -269,16 +269,16 @@ class CTRLModel(CTRLPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super(CTRLModel, self).__init__(config) super(CTRLModel, self).__init__(config)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.output_past = config.output_past
self.d_model_size = config.n_embd self.d_model_size = config.n_embd
self.num_layers = config.n_layer self.num_layers = config.n_layer
self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size, torch.float) self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size, torch.float)
self.output_attentions = config.output_attentions
self.w = nn.Embedding(config.vocab_size, config.n_embd) self.w = nn.Embedding(config.vocab_size, config.n_embd)
self.dropout = nn.Dropout(config.embd_pdrop) self.dropout = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([EncoderLayer(config.n_embd, self.h = nn.ModuleList([EncoderLayer(config.n_embd,
config.n_head, config.n_head,
...@@ -378,6 +378,7 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -378,6 +378,7 @@ class CTRLModel(CTRLPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask[i]) head_mask=head_mask[i])
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
if self.output_past:
presents = presents + (present,) presents = presents + (present,)
if self.output_attentions: if self.output_attentions:
...@@ -388,7 +389,9 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -388,7 +389,9 @@ class CTRLModel(CTRLPreTrainedModel):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states, presents) outputs = (hidden_states,)
if self.output_past:
outputs = outputs + (presents,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if self.output_attentions:
......
...@@ -347,6 +347,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -347,6 +347,7 @@ class GPT2Model(GPT2PreTrainedModel):
super(GPT2Model, self).__init__(config) super(GPT2Model, self).__init__(config)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_past = config.output_past
self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd)
...@@ -440,6 +441,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -440,6 +441,7 @@ class GPT2Model(GPT2PreTrainedModel):
head_mask=head_mask[i]) head_mask=head_mask[i])
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
if self.output_past:
presents = presents + (present,) presents = presents + (present,)
if self.output_attentions: if self.output_attentions:
...@@ -452,7 +454,9 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -452,7 +454,9 @@ class GPT2Model(GPT2PreTrainedModel):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states, presents) outputs = (hidden_states,)
if self.output_past:
outputs = outputs + (presents,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if self.output_attentions:
...@@ -460,7 +464,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -460,7 +464,7 @@ class GPT2Model(GPT2PreTrainedModel):
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:] attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions) all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
return outputs # last hidden state, presents, (all hidden_states), (attentions) return outputs # last hidden state, (presents), (all hidden_states), (attentions)
@add_start_docstrings("""The GPT2 Model transformer with a language modeling head on top @add_start_docstrings("""The GPT2 Model transformer with a language modeling head on top
......
...@@ -30,7 +30,6 @@ import tensorflow as tf ...@@ -30,7 +30,6 @@ import tensorflow as tf
from .configuration_bert import BertConfig from .configuration_bert import BertConfig
from .modeling_tf_utils import TFPreTrainedModel, get_initializer from .modeling_tf_utils import TFPreTrainedModel, get_initializer
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -52,14 +51,6 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -52,14 +51,6 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
} }
def load_bert_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
# build the network
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
def gelu(x): def gelu(x):
""" Gaussian Error Linear Unit. """ Gaussian Error Linear Unit.
Original Implementation of the gelu activation function in Google Bert repo when initially created. Original Implementation of the gelu activation function in Google Bert repo when initially created.
...@@ -545,7 +536,6 @@ class TFBertPreTrainedModel(TFPreTrainedModel): ...@@ -545,7 +536,6 @@ class TFBertPreTrainedModel(TFPreTrainedModel):
""" """
config_class = BertConfig config_class = BertConfig
pretrained_model_archive_map = TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights = load_bert_pt_weights_in_tf2
base_model_prefix = "bert" base_model_prefix = "bert"
......
...@@ -27,20 +27,11 @@ import tensorflow as tf ...@@ -27,20 +27,11 @@ import tensorflow as tf
from .configuration_ctrl import CTRLConfig from .configuration_ctrl import CTRLConfig
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list, TFSharedEmbeddings from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list, TFSharedEmbeddings
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP = {"ctrl": "https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-tf_model.h5"} TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP = {"ctrl": "https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-tf_model.h5"}
def load_ctrl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
# build the network
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
def angle_defn(pos, i, d_model_size): def angle_defn(pos, i, d_model_size):
angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model_size)) angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model_size))
return pos * angle_rates return pos * angle_rates
...@@ -177,12 +168,14 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -177,12 +168,14 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super(TFCTRLMainLayer, self).__init__(**kwargs) super(TFCTRLMainLayer, self).__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.output_past = config.output_past
self.d_model_size = config.n_embd self.d_model_size = config.n_embd
self.num_layers = config.n_layer self.num_layers = config.n_layer
self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size) self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size)
self.output_attentions = config.output_attentions
self.w = TFSharedEmbeddings(config.vocab_size, self.w = TFSharedEmbeddings(config.vocab_size,
config.n_embd, config.n_embd,
...@@ -299,6 +292,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -299,6 +292,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i]], training=training) outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i]], training=training)
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
if self.output_past:
presents = presents + (present,) presents = presents + (present,)
if self.output_attentions: if self.output_attentions:
...@@ -309,7 +304,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -309,7 +304,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states, presents) outputs = (hidden_states,)
if self.output_past:
outputs = outputs + (presents,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if self.output_attentions:
...@@ -327,7 +324,6 @@ class TFCTRLPreTrainedModel(TFPreTrainedModel): ...@@ -327,7 +324,6 @@ class TFCTRLPreTrainedModel(TFPreTrainedModel):
config_class = CTRLConfig config_class = CTRLConfig
pretrained_model_archive_map = TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer" base_model_prefix = "transformer"
load_pt_weights = load_ctrl_pt_weights_in_tf2
CTRL_START_DOCSTRING = r""" CTRL model was proposed in CTRL_START_DOCSTRING = r""" CTRL model was proposed in
......
...@@ -31,7 +31,6 @@ import tensorflow as tf ...@@ -31,7 +31,6 @@ import tensorflow as tf
from .configuration_distilbert import DistilBertConfig from .configuration_distilbert import DistilBertConfig
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list, get_initializer from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list, get_initializer
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -66,14 +65,6 @@ def gelu_new(x): ...@@ -66,14 +65,6 @@ def gelu_new(x):
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf return x * cdf
def load_distilbert_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
# build the network
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
tf_inputs = [inputs_list, attns_list]
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
class TFEmbeddings(tf.keras.layers.Layer): class TFEmbeddings(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super(TFEmbeddings, self).__init__(**kwargs) super(TFEmbeddings, self).__init__(**kwargs)
...@@ -454,7 +445,6 @@ class TFDistilBertPreTrainedModel(TFPreTrainedModel): ...@@ -454,7 +445,6 @@ class TFDistilBertPreTrainedModel(TFPreTrainedModel):
""" """
config_class = DistilBertConfig config_class = DistilBertConfig
pretrained_model_archive_map = TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights = load_distilbert_pt_weights_in_tf2
base_model_prefix = "distilbert" base_model_prefix = "distilbert"
......
...@@ -32,7 +32,6 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings, ...@@ -32,7 +32,6 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
TFSequenceSummary, shape_list, get_initializer) TFSequenceSummary, shape_list, get_initializer)
from .configuration_gpt2 import GPT2Config from .configuration_gpt2 import GPT2Config
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -42,14 +41,6 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models ...@@ -42,14 +41,6 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-tf_model.h5",} "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-tf_model.h5",}
def load_gpt2_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
# build the network
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
def gelu(x): def gelu(x):
"""Gaussian Error Linear Unit. """Gaussian Error Linear Unit.
This is a smoother version of the RELU. This is a smoother version of the RELU.
...@@ -350,7 +341,6 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel): ...@@ -350,7 +341,6 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel):
""" """
config_class = GPT2Config config_class = GPT2Config
pretrained_model_archive_map = TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights = load_gpt2_pt_weights_in_tf2
base_model_prefix = "transformer" base_model_prefix = "transformer"
......
...@@ -32,21 +32,12 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings, ...@@ -32,21 +32,12 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
TFSequenceSummary, shape_list, get_initializer) TFSequenceSummary, shape_list, get_initializer)
from .configuration_openai import OpenAIGPTConfig from .configuration_openai import OpenAIGPTConfig
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-tf_model.h5"} TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-tf_model.h5"}
def load_openai_gpt_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
# build the network
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
def gelu(x): def gelu(x):
"""Gaussian Error Linear Unit. """Gaussian Error Linear Unit.
This is a smoother version of the RELU. This is a smoother version of the RELU.
...@@ -335,7 +326,6 @@ class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel): ...@@ -335,7 +326,6 @@ class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel):
""" """
config_class = OpenAIGPTConfig config_class = OpenAIGPTConfig
pretrained_model_archive_map = TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights = load_openai_gpt_pt_weights_in_tf2
base_model_prefix = "transformer" base_model_prefix = "transformer"
......
...@@ -25,8 +25,6 @@ import numpy ...@@ -25,8 +25,6 @@ import numpy
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=''): def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=''):
""" Convert a TF 2.0 model variable name in a pytorch model weight name. """ Convert a TF 2.0 model variable name in a pytorch model weight name.
...@@ -105,7 +103,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a ...@@ -105,7 +103,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
raise e raise e
if tf_inputs is None: if tf_inputs is None:
tf_inputs = tf.constant(DUMMY_INPUTS) tf_inputs = tf_model.dummy_inputs
if tf_inputs is not None: if tf_inputs is not None:
tfo = tf_model(tf_inputs, training=False) # Make sure model is built tfo = tf_model(tf_inputs, training=False) # Make sure model is built
......
...@@ -26,7 +26,6 @@ import tensorflow as tf ...@@ -26,7 +26,6 @@ import tensorflow as tf
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
from .modeling_tf_utils import TFPreTrainedModel, get_initializer from .modeling_tf_utils import TFPreTrainedModel, get_initializer
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu, gelu_new from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu, gelu_new
...@@ -38,14 +37,6 @@ TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -38,14 +37,6 @@ TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-tf_model.h5", 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-tf_model.h5",
} }
def load_roberta_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
# build the network
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
class TFRobertaEmbeddings(TFBertEmbeddings): class TFRobertaEmbeddings(TFBertEmbeddings):
""" """
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
...@@ -83,7 +74,7 @@ class TFRobertaMainLayer(TFBertMainLayer): ...@@ -83,7 +74,7 @@ class TFRobertaMainLayer(TFBertMainLayer):
input_ids = inputs input_ids = inputs
if tf.not_equal(tf.reduce_sum(input_ids[:, 0]), 0): if tf.not_equal(tf.reduce_sum(input_ids[:, 0]), 0):
logger.warning("A sequence with no special tokens has been passed to the RoBERTa model. " tf.print("A sequence with no special tokens has been passed to the RoBERTa model. "
"This model requires special tokens in order to work. " "This model requires special tokens in order to work. "
"Please specify add_special_tokens=True in your encoding.") "Please specify add_special_tokens=True in your encoding.")
...@@ -96,7 +87,6 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel): ...@@ -96,7 +87,6 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel):
""" """
config_class = RobertaConfig config_class = RobertaConfig
pretrained_model_archive_map = TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights = load_roberta_pt_weights_in_tf2
base_model_prefix = "roberta" base_model_prefix = "roberta"
......
...@@ -33,7 +33,6 @@ from .configuration_transfo_xl import TransfoXLConfig ...@@ -33,7 +33,6 @@ from .configuration_transfo_xl import TransfoXLConfig
from .modeling_tf_utils import TFPreTrainedModel, TFConv1D, TFSequenceSummary, shape_list, get_initializer from .modeling_tf_utils import TFPreTrainedModel, TFConv1D, TFSequenceSummary, shape_list, get_initializer
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -41,14 +40,6 @@ TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -41,14 +40,6 @@ TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-tf_model.h5", 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-tf_model.h5",
} }
def load_transfo_xl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
# build the network
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
class TFPositionalEmbedding(tf.keras.layers.Layer): class TFPositionalEmbedding(tf.keras.layers.Layer):
def __init__(self, demb, **kwargs): def __init__(self, demb, **kwargs):
super(TFPositionalEmbedding, self).__init__(**kwargs) super(TFPositionalEmbedding, self).__init__(**kwargs)
...@@ -577,7 +568,6 @@ class TFTransfoXLPreTrainedModel(TFPreTrainedModel): ...@@ -577,7 +568,6 @@ class TFTransfoXLPreTrainedModel(TFPreTrainedModel):
""" """
config_class = TransfoXLConfig config_class = TransfoXLConfig
pretrained_model_archive_map = TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights = load_transfo_xl_pt_weights_in_tf2
base_model_prefix = "transformer" base_model_prefix = "transformer"
......
...@@ -25,9 +25,11 @@ import tensorflow as tf ...@@ -25,9 +25,11 @@ import tensorflow as tf
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
class TFPreTrainedModel(tf.keras.Model): class TFPreTrainedModel(tf.keras.Model):
r""" Base class for all TF models. r""" Base class for all TF models.
...@@ -48,8 +50,8 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -48,8 +50,8 @@ class TFPreTrainedModel(tf.keras.Model):
""" """
config_class = None config_class = None
pretrained_model_archive_map = {} pretrained_model_archive_map = {}
load_pt_weights = lambda model, config, path: None
base_model_prefix = "" base_model_prefix = ""
dummy_inputs = tf.constant(DUMMY_INPUTS) # dummy inputs to build the network
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs) super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
...@@ -262,17 +264,16 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -262,17 +264,16 @@ class TFPreTrainedModel(tf.keras.Model):
if from_pt: if from_pt:
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
return cls.load_pt_weights(model, resolved_archive_file) return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file)
inputs = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]) ret = model(model.dummy_inputs, training=False) # build the network with dummy inputs
ret = model(inputs, training=False) # build the network with dummy inputs
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file) assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
# 'by_name' allow us to do transfer learning by skipping/adding layers # 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
model.load_weights(resolved_archive_file, by_name=True) model.load_weights(resolved_archive_file, by_name=True)
ret = model(inputs, training=False) # Make sure restore ops are run ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run
return model return model
...@@ -393,8 +394,8 @@ class TFSequenceSummary(tf.keras.layers.Layer): ...@@ -393,8 +394,8 @@ class TFSequenceSummary(tf.keras.layers.Layer):
# We can probably just use the multi-head attention module of PyTorch >=1.1.0 # We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise NotImplementedError raise NotImplementedError
self.summary = None self.has_summary = hasattr(config, 'summary_use_proj') and config.summary_use_proj
if hasattr(config, 'summary_use_proj') and config.summary_use_proj: if self.has_summary:
if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0: if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
num_classes = config.num_labels num_classes = config.num_labels
else: else:
...@@ -403,16 +404,16 @@ class TFSequenceSummary(tf.keras.layers.Layer): ...@@ -403,16 +404,16 @@ class TFSequenceSummary(tf.keras.layers.Layer):
kernel_initializer=get_initializer(initializer_range), kernel_initializer=get_initializer(initializer_range),
name='summary') name='summary')
self.activation = None self.has_activation = hasattr(config, 'summary_activation') and config.summary_activation == 'tanh'
if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': if self.has_activation:
self.activation = tf.keras.activations.tanh self.activation = tf.keras.activations.tanh
self.first_dropout = None self.has_first_dropout = hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0
if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0: if self.has_first_dropout:
self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout) self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout)
self.last_dropout = None self.has_last_dropout = hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0
if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: if self.has_last_dropout:
self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout) self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)
def call(self, inputs, training=False): def call(self, inputs, training=False):
...@@ -455,17 +456,17 @@ class TFSequenceSummary(tf.keras.layers.Layer): ...@@ -455,17 +456,17 @@ class TFSequenceSummary(tf.keras.layers.Layer):
elif self.summary_type == 'attn': elif self.summary_type == 'attn':
raise NotImplementedError raise NotImplementedError
if training and self.first_dropout is not None: if self.has_first_dropout:
output = self.first_dropout(output) output = self.first_dropout(output, training=training)
if self.summary is not None: if self.has_summary:
output = self.summary(output) output = self.summary(output)
if self.activation is not None: if self.has_activation:
output = self.activation(output) output = self.activation(output)
if training and self.last_dropout is not None: if self.has_last_dropout:
output = self.last_dropout(output) output = self.last_dropout(output, training=training)
return output return output
......
...@@ -25,9 +25,8 @@ import numpy as np ...@@ -25,9 +25,8 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from .configuration_xlm import XLMConfig from .configuration_xlm import XLMConfig
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer, DUMMY_INPUTS
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -45,19 +44,6 @@ TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -45,19 +44,6 @@ TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
} }
def load_xlm_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
# build the network
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
if tf_model.config.use_lang_emb and tf_model.config.n_langs > 1:
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
else:
langs_list = None
tf_inputs = [inputs_list, attns_list, langs_list]
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
def create_sinusoidal_embeddings(n_pos, dim, out): def create_sinusoidal_embeddings(n_pos, dim, out):
position_enc = np.array([ position_enc = np.array([
[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
...@@ -441,9 +427,19 @@ class TFXLMPreTrainedModel(TFPreTrainedModel): ...@@ -441,9 +427,19 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
""" """
config_class = XLMConfig config_class = XLMConfig
pretrained_model_archive_map = TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights = load_xlm_pt_weights_in_tf2
base_model_prefix = "transformer" base_model_prefix = "transformer"
@property
def dummy_inputs(self):
# Sometimes XLM has language embeddings so don't forget to build them as well if needed
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
if self.config.use_lang_emb and self.config.n_langs > 1:
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
else:
langs_list = None
return [inputs_list, attns_list, langs_list]
XLM_START_DOCSTRING = r""" The XLM model was proposed in XLM_START_DOCSTRING = r""" The XLM model was proposed in
`Cross-lingual Language Model Pretraining`_ `Cross-lingual Language Model Pretraining`_
......
...@@ -30,7 +30,6 @@ import tensorflow as tf ...@@ -30,7 +30,6 @@ import tensorflow as tf
from .configuration_xlnet import XLNetConfig from .configuration_xlnet import XLNetConfig
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -41,13 +40,6 @@ TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -41,13 +40,6 @@ TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
} }
def load_xlnet_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False) # build the network
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
def gelu(x): def gelu(x):
""" Implementation of the gelu activation function. """ Implementation of the gelu activation function.
XLNet is using OpenAI GPT's gelu XLNet is using OpenAI GPT's gelu
...@@ -362,6 +354,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -362,6 +354,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
super(TFXLNetMainLayer, self).__init__(**kwargs) super(TFXLNetMainLayer, self).__init__(**kwargs)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_past = config.output_past
self.mem_len = config.mem_len self.mem_len = config.mem_len
self.reuse_len = config.reuse_len self.reuse_len = config.reuse_len
...@@ -421,9 +414,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -421,9 +414,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
def cache_mem(self, curr_out, prev_mem): def cache_mem(self, curr_out, prev_mem):
"""cache hidden states into memory.""" """cache hidden states into memory."""
if self.mem_len is None or self.mem_len == 0:
return None
else:
if self.reuse_len is not None and self.reuse_len > 0: if self.reuse_len is not None and self.reuse_len > 0:
curr_out = curr_out[:self.reuse_len] curr_out = curr_out[:self.reuse_len]
...@@ -546,7 +536,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -546,7 +536,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
raise ValueError('Unsupported attention type: {}'.format(self.attn_type)) raise ValueError('Unsupported attention type: {}'.format(self.attn_type))
# data mask: input mask & perm mask # data mask: input mask & perm mask
assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) " assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) " \
"or attention_mask (uses 0 for padding, added for compatbility with BERT). Please choose one." "or attention_mask (uses 0 for padding, added for compatbility with BERT). Please choose one."
if input_mask is None and attention_mask is not None: if input_mask is None and attention_mask is not None:
input_mask = 1.0 - attention_mask input_mask = 1.0 - attention_mask
...@@ -632,6 +622,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -632,6 +622,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
hidden_states = [] hidden_states = []
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
# cache new mems # cache new mems
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),) new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if self.output_hidden_states: if self.output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h) hidden_states.append((output_h, output_g) if output_g is not None else output_h)
...@@ -650,7 +641,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -650,7 +641,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
output = self.dropout(output_g if output_g is not None else output_h, training=training) output = self.dropout(output_g if output_g is not None else output_h, training=training)
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
outputs = (tf.transpose(output, perm=(1, 0, 2)), new_mems) outputs = (tf.transpose(output, perm=(1, 0, 2)),)
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
outputs = outputs + (new_mems,)
if self.output_hidden_states: if self.output_hidden_states:
if output_g is not None: if output_g is not None:
hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs) hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs)
...@@ -661,7 +656,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -661,7 +656,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions) attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
outputs = outputs + (attentions,) outputs = outputs + (attentions,)
return outputs # outputs, new_mems, (hidden_states), (attentions) return outputs # outputs, (new_mems), (hidden_states), (attentions)
class TFXLNetPreTrainedModel(TFPreTrainedModel): class TFXLNetPreTrainedModel(TFPreTrainedModel):
...@@ -670,7 +665,6 @@ class TFXLNetPreTrainedModel(TFPreTrainedModel): ...@@ -670,7 +665,6 @@ class TFXLNetPreTrainedModel(TFPreTrainedModel):
""" """
config_class = XLNetConfig config_class = XLNetConfig
pretrained_model_archive_map = TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights = load_xlnet_pt_weights_in_tf2
base_model_prefix = "transformer" base_model_prefix = "transformer"
...@@ -777,7 +771,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel): ...@@ -777,7 +771,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, hidden_size)`` **last_hidden_state**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model. Sequence of hidden-states at the last layer of the model.
**mems**: **mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``tf.Tensor`` (one for each layer): list of ``tf.Tensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
...@@ -819,7 +813,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel): ...@@ -819,7 +813,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**prediction_scores**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` **prediction_scores**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**mems**: **mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``tf.Tensor`` (one for each layer): list of ``tf.Tensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
...@@ -863,7 +857,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel): ...@@ -863,7 +857,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
return outputs # return logits, mems, (hidden states), (attentions) return outputs # return logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of @add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
...@@ -874,7 +868,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel): ...@@ -874,7 +868,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**logits**: ``tf.Tensor`` of shape ``(batch_size, config.num_labels)`` **logits**: ``tf.Tensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax). Classification (or regression if config.num_labels==1) scores (before SoftMax).
**mems**: **mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``tf.Tensor`` (one for each layer): list of ``tf.Tensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
...@@ -918,7 +912,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel): ...@@ -918,7 +912,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
return outputs # return logits, mems, (hidden states), (attentions) return outputs # return logits, (mems), (hidden states), (attentions)
# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of # @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
...@@ -932,6 +926,11 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel): ...@@ -932,6 +926,11 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel):
Span-start scores (before SoftMax). Span-start scores (before SoftMax).
**end_scores**: ``tf.Tensor`` of shape ``(batch_size, sequence_length,)`` **end_scores**: ``tf.Tensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax). Span-end scores (before SoftMax).
**mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``tf.Tensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings) list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
...@@ -971,7 +970,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel): ...@@ -971,7 +970,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel):
outputs = (start_logits, end_logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it outputs = (start_logits, end_logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
return outputs # start_logits, end_logits, (hidden_states), (attentions) return outputs # start_logits, end_logits, (mems), (hidden_states), (attentions)
# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of # @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
# the hidden-states output to compute `span start logits` and `span end logits`). """, # the hidden-states output to compute `span start logits` and `span end logits`). """,
......
...@@ -570,7 +570,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -570,7 +570,7 @@ class XLNetModel(XLNetPreTrainedModel):
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model. Sequence of hidden-states at the last layer of the model.
**mems**: **mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``torch.FloatTensor`` (one for each layer): list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
...@@ -596,6 +596,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -596,6 +596,7 @@ class XLNetModel(XLNetPreTrainedModel):
super(XLNetModel, self).__init__(config) super(XLNetModel, self).__init__(config)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_past = config.output_past
self.mem_len = config.mem_len self.mem_len = config.mem_len
self.reuse_len = config.reuse_len self.reuse_len = config.reuse_len
...@@ -652,9 +653,6 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -652,9 +653,6 @@ class XLNetModel(XLNetPreTrainedModel):
def cache_mem(self, curr_out, prev_mem): def cache_mem(self, curr_out, prev_mem):
"""cache hidden states into memory.""" """cache hidden states into memory."""
if self.mem_len is None or self.mem_len == 0:
return None
else:
if self.reuse_len is not None and self.reuse_len > 0: if self.reuse_len is not None and self.reuse_len > 0:
curr_out = curr_out[:self.reuse_len] curr_out = curr_out[:self.reuse_len]
...@@ -832,6 +830,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -832,6 +830,7 @@ class XLNetModel(XLNetPreTrainedModel):
attentions = [] attentions = []
hidden_states = [] hidden_states = []
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
# cache new mems # cache new mems
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),) new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if self.output_hidden_states: if self.output_hidden_states:
...@@ -851,7 +850,11 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -851,7 +850,11 @@ class XLNetModel(XLNetPreTrainedModel):
output = self.dropout(output_g if output_g is not None else output_h) output = self.dropout(output_g if output_g is not None else output_h)
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
outputs = (output.permute(1, 0, 2).contiguous(), new_mems) outputs = (output.permute(1, 0, 2).contiguous(),)
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
outputs = outputs + (new_mems,)
if self.output_hidden_states: if self.output_hidden_states:
if output_g is not None: if output_g is not None:
hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs) hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs)
...@@ -862,7 +865,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -862,7 +865,7 @@ class XLNetModel(XLNetPreTrainedModel):
attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions) attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
outputs = outputs + (attentions,) outputs = outputs + (attentions,)
return outputs # outputs, new_mems, (hidden_states), (attentions) return outputs # outputs, (new_mems), (hidden_states), (attentions)
@add_start_docstrings("""XLNet Model with a language modeling head on top @add_start_docstrings("""XLNet Model with a language modeling head on top
...@@ -882,7 +885,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -882,7 +885,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
Language modeling loss. Language modeling loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**mems**: **mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``torch.FloatTensor`` (one for each layer): list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
...@@ -947,7 +950,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -947,7 +950,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
labels.view(-1)) labels.view(-1))
outputs = (loss,) + outputs outputs = (loss,) + outputs
return outputs # return (loss), logits, mems, (hidden states), (attentions) return outputs # return (loss), logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of @add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
...@@ -966,7 +969,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -966,7 +969,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Classification (or regression if config.num_labels==1) loss. Classification (or regression if config.num_labels==1) loss.
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)`` **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax). Classification (or regression if config.num_labels==1) scores (before SoftMax).
**mems**: **mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``torch.FloatTensor`` (one for each layer): list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
...@@ -1026,7 +1029,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1026,7 +1029,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs outputs = (loss,) + outputs
return outputs # return (loss), logits, mems, (hidden states), (attentions) return outputs # return (loss), logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a multiple choice classification head on top (a linear layer on top of @add_start_docstrings("""XLNet Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RACE/SWAG tasks. """, the pooled output and a softmax) e.g. for RACE/SWAG tasks. """,
...@@ -1061,6 +1064,11 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): ...@@ -1061,6 +1064,11 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
**classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension **classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above). of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax). Classification scores (before SoftMax).
**mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
...@@ -1117,7 +1125,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): ...@@ -1117,7 +1125,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
loss = loss_fct(reshaped_logits, labels.view(-1)) loss = loss_fct(reshaped_logits, labels.view(-1))
outputs = (loss,) + outputs outputs = (loss,) + outputs
return outputs # return (loss), logits, mems, (hidden states), (attentions) return outputs # return (loss), logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
...@@ -1141,7 +1149,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): ...@@ -1141,7 +1149,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
Span-start scores (before SoftMax). Span-start scores (before SoftMax).
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax). Span-end scores (before SoftMax).
**mems**: **mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``torch.FloatTensor`` (one for each layer): list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
...@@ -1212,7 +1220,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): ...@@ -1212,7 +1220,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
total_loss = (start_loss + end_loss) / 2 total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) return outputs # (loss), start_logits, end_logits, (mems), (hidden_states), (attentions)
@add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
...@@ -1254,7 +1262,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1254,7 +1262,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
**cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.FloatTensor`` of shape ``(batch_size,)`` ``torch.FloatTensor`` of shape ``(batch_size,)``
Log probabilities for the ``is_impossible`` label of the answers. Log probabilities for the ``is_impossible`` label of the answers.
**mems**: **mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``torch.FloatTensor`` (one for each layer): list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
......
...@@ -17,8 +17,10 @@ from __future__ import division ...@@ -17,8 +17,10 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import copy import copy
import sys
import os import os
import shutil import shutil
import tempfile
import json import json
import random import random
import uuid import uuid
...@@ -31,6 +33,7 @@ from transformers import is_torch_available ...@@ -31,6 +33,7 @@ from transformers import is_torch_available
if is_torch_available(): if is_torch_available():
import torch import torch
import numpy as np
from transformers import (PretrainedConfig, PreTrainedModel, from transformers import (PretrainedConfig, PreTrainedModel,
BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
...@@ -38,6 +41,20 @@ if is_torch_available(): ...@@ -38,6 +41,20 @@ if is_torch_available():
else: else:
pytestmark = pytest.mark.skip("Require Torch") pytestmark = pytest.mark.skip("Require Torch")
if sys.version_info[0] == 2:
import cPickle as pickle
class TemporaryDirectory(object):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def __enter__(self):
self.name = tempfile.mkdtemp()
return self.name
def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self.name)
else:
import pickle
TemporaryDirectory = tempfile.TemporaryDirectory
unicode = str
def _config_zero_init(config): def _config_zero_init(config):
configs_no_init = copy.deepcopy(config) configs_no_init = copy.deepcopy(config)
...@@ -57,6 +74,29 @@ class CommonTestCases: ...@@ -57,6 +74,29 @@ class CommonTestCases:
test_resize_embeddings = True test_resize_embeddings = True
test_head_masking = True test_head_masking = True
def test_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.eval()
with torch.no_grad():
outputs = model(**inputs_dict)
with TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
with torch.no_grad():
after_outputs = model(**inputs_dict)
# Make sure we don't have nans
out_1 = after_outputs[0].numpy()
out_2 = outputs[0].numpy()
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def test_initialization(self): def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment