Commit da26bae6 authored by thomwolf's avatar thomwolf
Browse files

adding more tests on TF and pytorch serialization - updating configuration for better serialization

parent bb04edb4
...@@ -110,59 +110,49 @@ if is_tf_available(): ...@@ -110,59 +110,49 @@ if is_tf_available():
TFBertForMaskedLM, TFBertForNextSentencePrediction, TFBertForMaskedLM, TFBertForNextSentencePrediction,
TFBertForSequenceClassification, TFBertForMultipleChoice, TFBertForSequenceClassification, TFBertForMultipleChoice,
TFBertForTokenClassification, TFBertForQuestionAnswering, TFBertForTokenClassification, TFBertForQuestionAnswering,
load_bert_pt_weights_in_tf2,
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP) TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer, from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer,
TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel, TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel,
load_gpt2_pt_weights_in_tf2,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_openai import (TFOpenAIGPTPreTrainedModel, TFOpenAIGPTMainLayer, from .modeling_tf_openai import (TFOpenAIGPTPreTrainedModel, TFOpenAIGPTMainLayer,
TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel, TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel,
load_openai_gpt_pt_weights_in_tf2,
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP) TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_transfo_xl import (TFTransfoXLPreTrainedModel, TFTransfoXLMainLayer, from .modeling_tf_transfo_xl import (TFTransfoXLPreTrainedModel, TFTransfoXLMainLayer,
TFTransfoXLModel, TFTransfoXLLMHeadModel, TFTransfoXLModel, TFTransfoXLLMHeadModel,
load_transfo_xl_pt_weights_in_tf2,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP) TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_xlnet import (TFXLNetPreTrainedModel, TFXLNetMainLayer, from .modeling_tf_xlnet import (TFXLNetPreTrainedModel, TFXLNetMainLayer,
TFXLNetModel, TFXLNetLMHeadModel, TFXLNetModel, TFXLNetLMHeadModel,
TFXLNetForSequenceClassification, TFXLNetForSequenceClassification,
TFXLNetForQuestionAnsweringSimple, TFXLNetForQuestionAnsweringSimple,
load_xlnet_pt_weights_in_tf2,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_xlm import (TFXLMPreTrainedModel, TFXLMMainLayer, from .modeling_tf_xlm import (TFXLMPreTrainedModel, TFXLMMainLayer,
TFXLMModel, TFXLMWithLMHeadModel, TFXLMModel, TFXLMWithLMHeadModel,
TFXLMForSequenceClassification, TFXLMForSequenceClassification,
TFXLMForQuestionAnsweringSimple, TFXLMForQuestionAnsweringSimple,
load_xlm_pt_weights_in_tf2,
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP) TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_roberta import (TFRobertaPreTrainedModel, TFRobertaMainLayer, from .modeling_tf_roberta import (TFRobertaPreTrainedModel, TFRobertaMainLayer,
TFRobertaModel, TFRobertaForMaskedLM, TFRobertaModel, TFRobertaForMaskedLM,
TFRobertaForSequenceClassification, TFRobertaForSequenceClassification,
load_roberta_pt_weights_in_tf2,
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_distilbert import (TFDistilBertPreTrainedModel, TFDistilBertMainLayer, from .modeling_tf_distilbert import (TFDistilBertPreTrainedModel, TFDistilBertMainLayer,
TFDistilBertModel, TFDistilBertForMaskedLM, TFDistilBertModel, TFDistilBertForMaskedLM,
TFDistilBertForSequenceClassification, TFDistilBertForSequenceClassification,
TFDistilBertForQuestionAnswering, TFDistilBertForQuestionAnswering,
load_distilbert_pt_weights_in_tf2,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_ctrl import (TFCTRLPreTrainedModel, TFCTRLModel, from .modeling_tf_ctrl import (TFCTRLPreTrainedModel, TFCTRLModel,
TFCTRLLMHeadModel, TFCTRLLMHeadModel,
load_ctrl_pt_weights_in_tf2,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
# TF 2.0 <=> PyTorch conversion utilities # TF 2.0 <=> PyTorch conversion utilities
if is_tf_available() and is_torch_available(): from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
load_pytorch_checkpoint_in_tf2_model, load_pytorch_checkpoint_in_tf2_model,
load_pytorch_weights_in_tf2_model, load_pytorch_weights_in_tf2_model,
load_pytorch_model_in_tf2_model, load_pytorch_model_in_tf2_model,
......
...@@ -153,7 +153,7 @@ class PretrainedConfig(object): ...@@ -153,7 +153,7 @@ class PretrainedConfig(object):
config = cls.from_json_file(resolved_config_file) config = cls.from_json_file(resolved_config_file)
if hasattr(config, 'pruned_heads'): if hasattr(config, 'pruned_heads'):
config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items()) config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
# Update config with kwargs if needed # Update config with kwargs if needed
to_remove = [] to_remove = []
...@@ -164,7 +164,7 @@ class PretrainedConfig(object): ...@@ -164,7 +164,7 @@ class PretrainedConfig(object):
for key in to_remove: for key in to_remove:
kwargs.pop(key, None) kwargs.pop(key, None)
logger.info("Model config %s", config) logger.info("Model config %s", str(config))
if return_unused_kwargs: if return_unused_kwargs:
return config, kwargs return config, kwargs
else: else:
......
...@@ -24,15 +24,16 @@ import tensorflow as tf ...@@ -24,15 +24,16 @@ import tensorflow as tf
from transformers import is_torch_available, cached_path from transformers import is_torch_available, cached_path
from transformers import (BertConfig, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, load_bert_pt_weights_in_tf2, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, from transformers import (load_pytorch_checkpoint_in_tf2_model,
GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, TFGPT2LMHeadModel, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig, TFXLNetLMHeadModel, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, TFXLMWithLMHeadModel, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig, TFTransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRLConfig, TFCTRLLMHeadModel, load_ctrl_pt_weights_in_tf2, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP) DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP)
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -71,27 +72,27 @@ import logging ...@@ -71,27 +72,27 @@ import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'bert': (BertConfig, TFBertForPreTraining, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'bert-large-uncased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, load_bert_pt_weights_in_tf2, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'bert-large-uncased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'bert-large-cased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, load_bert_pt_weights_in_tf2, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'bert-large-cased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'bert-base-cased-finetuned-mrpc': (BertConfig, TFBertForSequenceClassification, load_bert_pt_weights_in_tf2, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'bert-base-cased-finetuned-mrpc': (BertConfig, TFBertForSequenceClassification, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP), 'gpt2': (GPT2Config, TFGPT2LMHeadModel, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP),
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP), 'xlnet': (XLNetConfig, TFXLNetLMHeadModel, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP),
'xlm': (XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP), 'xlm': (XLMConfig, TFXLMWithLMHeadModel, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP),
'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP), 'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP),
'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'roberta': (RobertaConfig, TFRobertaForMaskedLM, load_roberta_pt_weights_in_tf2, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), 'roberta': (RobertaConfig, TFRobertaForMaskedLM, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), 'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, load_ctrl_pt_weights_in_tf2, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP) 'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP)
} }
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True): def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True):
if model_type not in MODEL_CLASSES: if model_type not in MODEL_CLASSES:
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys()))) raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type] config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
# Initialise TF model # Initialise TF model
if config_file in aws_config_map: if config_file in aws_config_map:
...@@ -105,7 +106,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file ...@@ -105,7 +106,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
# Load weights from tf checkpoint # Load weights from tf checkpoint
if pytorch_checkpoint_path in aws_model_maps: if pytorch_checkpoint_path in aws_model_maps:
pytorch_checkpoint_path = cached_path(aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models) pytorch_checkpoint_path = cached_path(aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models)
tf_model = loading_fct(tf_model, pytorch_checkpoint_path) # Load PyTorch checkpoint in tf2 model:
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
if compare_with_pt_model: if compare_with_pt_model:
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
...@@ -147,7 +149,7 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc ...@@ -147,7 +149,7 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
if model_type not in MODEL_CLASSES: if model_type not in MODEL_CLASSES:
raise ValueError("Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys()))) raise ValueError("Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys())))
config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type] config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
if model_shortcut_names_or_path is None: if model_shortcut_names_or_path is None:
model_shortcut_names_or_path = list(aws_model_maps.keys()) model_shortcut_names_or_path = list(aws_model_maps.keys())
......
...@@ -30,7 +30,6 @@ import tensorflow as tf ...@@ -30,7 +30,6 @@ import tensorflow as tf
from .configuration_bert import BertConfig from .configuration_bert import BertConfig
from .modeling_tf_utils import TFPreTrainedModel, get_initializer from .modeling_tf_utils import TFPreTrainedModel, get_initializer
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -52,14 +51,6 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -52,14 +51,6 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
} }
def load_bert_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
# build the network
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
def gelu(x): def gelu(x):
""" Gaussian Error Linear Unit. """ Gaussian Error Linear Unit.
Original Implementation of the gelu activation function in Google Bert repo when initially created. Original Implementation of the gelu activation function in Google Bert repo when initially created.
...@@ -545,7 +536,6 @@ class TFBertPreTrainedModel(TFPreTrainedModel): ...@@ -545,7 +536,6 @@ class TFBertPreTrainedModel(TFPreTrainedModel):
""" """
config_class = BertConfig config_class = BertConfig
pretrained_model_archive_map = TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights = load_bert_pt_weights_in_tf2
base_model_prefix = "bert" base_model_prefix = "bert"
......
...@@ -27,20 +27,11 @@ import tensorflow as tf ...@@ -27,20 +27,11 @@ import tensorflow as tf
from .configuration_ctrl import CTRLConfig from .configuration_ctrl import CTRLConfig
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list, TFSharedEmbeddings from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list, TFSharedEmbeddings
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP = {"ctrl": "https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-tf_model.h5"} TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP = {"ctrl": "https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-tf_model.h5"}
def load_ctrl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
# build the network
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
def angle_defn(pos, i, d_model_size): def angle_defn(pos, i, d_model_size):
angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model_size)) angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model_size))
return pos * angle_rates return pos * angle_rates
...@@ -327,7 +318,6 @@ class TFCTRLPreTrainedModel(TFPreTrainedModel): ...@@ -327,7 +318,6 @@ class TFCTRLPreTrainedModel(TFPreTrainedModel):
config_class = CTRLConfig config_class = CTRLConfig
pretrained_model_archive_map = TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer" base_model_prefix = "transformer"
load_pt_weights = load_ctrl_pt_weights_in_tf2
CTRL_START_DOCSTRING = r""" CTRL model was proposed in CTRL_START_DOCSTRING = r""" CTRL model was proposed in
......
...@@ -31,7 +31,6 @@ import tensorflow as tf ...@@ -31,7 +31,6 @@ import tensorflow as tf
from .configuration_distilbert import DistilBertConfig from .configuration_distilbert import DistilBertConfig
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list, get_initializer from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list, get_initializer
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -66,14 +65,6 @@ def gelu_new(x): ...@@ -66,14 +65,6 @@ def gelu_new(x):
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf return x * cdf
def load_distilbert_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
# build the network
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
tf_inputs = [inputs_list, attns_list]
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
class TFEmbeddings(tf.keras.layers.Layer): class TFEmbeddings(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super(TFEmbeddings, self).__init__(**kwargs) super(TFEmbeddings, self).__init__(**kwargs)
...@@ -454,7 +445,6 @@ class TFDistilBertPreTrainedModel(TFPreTrainedModel): ...@@ -454,7 +445,6 @@ class TFDistilBertPreTrainedModel(TFPreTrainedModel):
""" """
config_class = DistilBertConfig config_class = DistilBertConfig
pretrained_model_archive_map = TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights = load_distilbert_pt_weights_in_tf2
base_model_prefix = "distilbert" base_model_prefix = "distilbert"
......
...@@ -32,7 +32,6 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings, ...@@ -32,7 +32,6 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
TFSequenceSummary, shape_list, get_initializer) TFSequenceSummary, shape_list, get_initializer)
from .configuration_gpt2 import GPT2Config from .configuration_gpt2 import GPT2Config
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -42,14 +41,6 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models ...@@ -42,14 +41,6 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-tf_model.h5",} "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-tf_model.h5",}
def load_gpt2_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
# build the network
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
def gelu(x): def gelu(x):
"""Gaussian Error Linear Unit. """Gaussian Error Linear Unit.
This is a smoother version of the RELU. This is a smoother version of the RELU.
...@@ -350,7 +341,6 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel): ...@@ -350,7 +341,6 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel):
""" """
config_class = GPT2Config config_class = GPT2Config
pretrained_model_archive_map = TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights = load_gpt2_pt_weights_in_tf2
base_model_prefix = "transformer" base_model_prefix = "transformer"
......
...@@ -32,21 +32,12 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings, ...@@ -32,21 +32,12 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
TFSequenceSummary, shape_list, get_initializer) TFSequenceSummary, shape_list, get_initializer)
from .configuration_openai import OpenAIGPTConfig from .configuration_openai import OpenAIGPTConfig
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-tf_model.h5"} TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-tf_model.h5"}
def load_openai_gpt_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
# build the network
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
def gelu(x): def gelu(x):
"""Gaussian Error Linear Unit. """Gaussian Error Linear Unit.
This is a smoother version of the RELU. This is a smoother version of the RELU.
...@@ -335,7 +326,6 @@ class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel): ...@@ -335,7 +326,6 @@ class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel):
""" """
config_class = OpenAIGPTConfig config_class = OpenAIGPTConfig
pretrained_model_archive_map = TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights = load_openai_gpt_pt_weights_in_tf2
base_model_prefix = "transformer" base_model_prefix = "transformer"
......
...@@ -25,8 +25,6 @@ import numpy ...@@ -25,8 +25,6 @@ import numpy
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=''): def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=''):
""" Convert a TF 2.0 model variable name in a pytorch model weight name. """ Convert a TF 2.0 model variable name in a pytorch model weight name.
...@@ -105,7 +103,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a ...@@ -105,7 +103,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
raise e raise e
if tf_inputs is None: if tf_inputs is None:
tf_inputs = tf.constant(DUMMY_INPUTS) tf_inputs = tf_model.dummy_inputs
if tf_inputs is not None: if tf_inputs is not None:
tfo = tf_model(tf_inputs, training=False) # Make sure model is built tfo = tf_model(tf_inputs, training=False) # Make sure model is built
......
...@@ -26,7 +26,6 @@ import tensorflow as tf ...@@ -26,7 +26,6 @@ import tensorflow as tf
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
from .modeling_tf_utils import TFPreTrainedModel, get_initializer from .modeling_tf_utils import TFPreTrainedModel, get_initializer
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu, gelu_new from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu, gelu_new
...@@ -38,14 +37,6 @@ TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -38,14 +37,6 @@ TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-tf_model.h5", 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-tf_model.h5",
} }
def load_roberta_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
# build the network
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
class TFRobertaEmbeddings(TFBertEmbeddings): class TFRobertaEmbeddings(TFBertEmbeddings):
""" """
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
...@@ -96,7 +87,6 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel): ...@@ -96,7 +87,6 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel):
""" """
config_class = RobertaConfig config_class = RobertaConfig
pretrained_model_archive_map = TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights = load_roberta_pt_weights_in_tf2
base_model_prefix = "roberta" base_model_prefix = "roberta"
......
...@@ -33,7 +33,6 @@ from .configuration_transfo_xl import TransfoXLConfig ...@@ -33,7 +33,6 @@ from .configuration_transfo_xl import TransfoXLConfig
from .modeling_tf_utils import TFPreTrainedModel, TFConv1D, TFSequenceSummary, shape_list, get_initializer from .modeling_tf_utils import TFPreTrainedModel, TFConv1D, TFSequenceSummary, shape_list, get_initializer
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -41,14 +40,6 @@ TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -41,14 +40,6 @@ TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-tf_model.h5", 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-tf_model.h5",
} }
def load_transfo_xl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
# build the network
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
class TFPositionalEmbedding(tf.keras.layers.Layer): class TFPositionalEmbedding(tf.keras.layers.Layer):
def __init__(self, demb, **kwargs): def __init__(self, demb, **kwargs):
super(TFPositionalEmbedding, self).__init__(**kwargs) super(TFPositionalEmbedding, self).__init__(**kwargs)
...@@ -577,7 +568,6 @@ class TFTransfoXLPreTrainedModel(TFPreTrainedModel): ...@@ -577,7 +568,6 @@ class TFTransfoXLPreTrainedModel(TFPreTrainedModel):
""" """
config_class = TransfoXLConfig config_class = TransfoXLConfig
pretrained_model_archive_map = TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights = load_transfo_xl_pt_weights_in_tf2
base_model_prefix = "transformer" base_model_prefix = "transformer"
......
...@@ -25,9 +25,11 @@ import tensorflow as tf ...@@ -25,9 +25,11 @@ import tensorflow as tf
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
class TFPreTrainedModel(tf.keras.Model): class TFPreTrainedModel(tf.keras.Model):
r""" Base class for all TF models. r""" Base class for all TF models.
...@@ -48,8 +50,8 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -48,8 +50,8 @@ class TFPreTrainedModel(tf.keras.Model):
""" """
config_class = None config_class = None
pretrained_model_archive_map = {} pretrained_model_archive_map = {}
load_pt_weights = lambda model, config, path: None
base_model_prefix = "" base_model_prefix = ""
dummy_inputs = tf.constant(DUMMY_INPUTS) # dummy inputs to build the network
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs) super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
...@@ -262,17 +264,16 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -262,17 +264,16 @@ class TFPreTrainedModel(tf.keras.Model):
if from_pt: if from_pt:
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
return cls.load_pt_weights(model, resolved_archive_file) return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file)
inputs = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]) ret = model(model.dummy_inputs, training=False) # build the network with dummy inputs
ret = model(inputs, training=False) # build the network with dummy inputs
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file) assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
# 'by_name' allow us to do transfer learning by skipping/adding layers # 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
model.load_weights(resolved_archive_file, by_name=True) model.load_weights(resolved_archive_file, by_name=True)
ret = model(inputs, training=False) # Make sure restore ops are run ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run
return model return model
......
...@@ -25,9 +25,8 @@ import numpy as np ...@@ -25,9 +25,8 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from .configuration_xlm import XLMConfig from .configuration_xlm import XLMConfig
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer, DUMMY_INPUTS
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -45,19 +44,6 @@ TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -45,19 +44,6 @@ TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
} }
def load_xlm_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
# build the network
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
if tf_model.config.use_lang_emb and tf_model.config.n_langs > 1:
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
else:
langs_list = None
tf_inputs = [inputs_list, attns_list, langs_list]
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
def create_sinusoidal_embeddings(n_pos, dim, out): def create_sinusoidal_embeddings(n_pos, dim, out):
position_enc = np.array([ position_enc = np.array([
[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
...@@ -441,9 +427,19 @@ class TFXLMPreTrainedModel(TFPreTrainedModel): ...@@ -441,9 +427,19 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
""" """
config_class = XLMConfig config_class = XLMConfig
pretrained_model_archive_map = TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights = load_xlm_pt_weights_in_tf2
base_model_prefix = "transformer" base_model_prefix = "transformer"
@property
def dummy_inputs(self):
# Sometimes XLM has language embeddings so don't forget to build them as well if needed
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
if self.config.use_lang_emb and self.config.n_langs > 1:
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
else:
langs_list = None
return [inputs_list, attns_list, langs_list]
XLM_START_DOCSTRING = r""" The XLM model was proposed in XLM_START_DOCSTRING = r""" The XLM model was proposed in
`Cross-lingual Language Model Pretraining`_ `Cross-lingual Language Model Pretraining`_
......
...@@ -30,7 +30,6 @@ import tensorflow as tf ...@@ -30,7 +30,6 @@ import tensorflow as tf
from .configuration_xlnet import XLNetConfig from .configuration_xlnet import XLNetConfig
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -41,13 +40,6 @@ TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -41,13 +40,6 @@ TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
} }
def load_xlnet_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False) # build the network
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
def gelu(x): def gelu(x):
""" Implementation of the gelu activation function. """ Implementation of the gelu activation function.
XLNet is using OpenAI GPT's gelu XLNet is using OpenAI GPT's gelu
...@@ -670,7 +662,6 @@ class TFXLNetPreTrainedModel(TFPreTrainedModel): ...@@ -670,7 +662,6 @@ class TFXLNetPreTrainedModel(TFPreTrainedModel):
""" """
config_class = XLNetConfig config_class = XLNetConfig
pretrained_model_archive_map = TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights = load_xlnet_pt_weights_in_tf2
base_model_prefix = "transformer" base_model_prefix = "transformer"
......
...@@ -17,8 +17,10 @@ from __future__ import division ...@@ -17,8 +17,10 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import copy import copy
import sys
import os import os
import shutil import shutil
import tempfile
import json import json
import random import random
import uuid import uuid
...@@ -31,6 +33,7 @@ from transformers import is_torch_available ...@@ -31,6 +33,7 @@ from transformers import is_torch_available
if is_torch_available(): if is_torch_available():
import torch import torch
import numpy as np
from transformers import (PretrainedConfig, PreTrainedModel, from transformers import (PretrainedConfig, PreTrainedModel,
BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
...@@ -38,6 +41,20 @@ if is_torch_available(): ...@@ -38,6 +41,20 @@ if is_torch_available():
else: else:
pytestmark = pytest.mark.skip("Require Torch") pytestmark = pytest.mark.skip("Require Torch")
if sys.version_info[0] == 2:
import cPickle as pickle
class TemporaryDirectory(object):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def __enter__(self):
self.name = tempfile.mkdtemp()
return self.name
def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self.name)
else:
import pickle
TemporaryDirectory = tempfile.TemporaryDirectory
unicode = str
def _config_zero_init(config): def _config_zero_init(config):
configs_no_init = copy.deepcopy(config) configs_no_init = copy.deepcopy(config)
...@@ -57,6 +74,23 @@ class CommonTestCases: ...@@ -57,6 +74,23 @@ class CommonTestCases:
test_resize_embeddings = True test_resize_embeddings = True
test_head_masking = True test_head_masking = True
def test_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.eval()
with torch.no_grad():
outputs = model(**inputs_dict)
with TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
with torch.no_grad():
after_outputs = model(**inputs_dict)
max_diff = np.amax(np.abs(after_outputs[0].numpy() - outputs[0].numpy()))
self.assertLessEqual(max_diff, 1e-5)
def test_initialization(self): def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment