"vscode:/vscode.git/clone" did not exist on "3854bd0fa02bf9afb5249d9eaaa6c827442105e0"
Commit 646711e1 authored by thomwolf's avatar thomwolf
Browse files

standardize scopes names - add conversion methods

parent 4356f791
...@@ -25,12 +25,13 @@ from pytorch_transformers import is_torch_available ...@@ -25,12 +25,13 @@ from pytorch_transformers import is_torch_available
from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2,
GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2,
XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2) XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2,
XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2,)
if is_torch_available(): if is_torch_available():
import torch import torch
import numpy as np import numpy as np
from pytorch_transformers import BertForPreTraining, GPT2LMHeadModel, XLNetLMHeadModel from pytorch_transformers import BertForPreTraining, GPT2LMHeadModel, XLNetLMHeadModel, XLMWithLMHeadModel
else: else:
BertForPreTraining, GPT2LMHeadModel = None, None BertForPreTraining, GPT2LMHeadModel = None, None
...@@ -42,6 +43,7 @@ MODEL_CLASSES = { ...@@ -42,6 +43,7 @@ MODEL_CLASSES = {
'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BertForPreTraining), 'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BertForPreTraining),
'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2LMHeadModel), 'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2LMHeadModel),
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNetLMHeadModel), 'xlnet': (XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNetLMHeadModel),
'xlm': (XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLMWithLMHeadModel),
} }
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False): def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False):
...@@ -58,7 +60,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file ...@@ -58,7 +60,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tf_model = model_class(config) tf_model = model_class(config)
# Load weights from tf checkpoint # Load weights from tf checkpoint
tf_model = loading_fct(tf_model, config, pytorch_checkpoint_path) tf_model = loading_fct(tf_model, pytorch_checkpoint_path)
if compare_with_pt_model: if compare_with_pt_model:
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
......
...@@ -33,7 +33,15 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p ...@@ -33,7 +33,15 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
# Load checkpoint # Load checkpoint
chkpt = torch.load(xlm_checkpoint_path, map_location='cpu') chkpt = torch.load(xlm_checkpoint_path, map_location='cpu')
model = chkpt['model'] state_dict = chkpt['model']
# We have the base model one level deeper than the original XLM repository
two_levels_state_dict = {}
for k, v in state_dict.items():
if 'pred_layer' in k:
two_levels_state_dict[k] = v
else:
two_levels_state_dict['transformer.' + k] = v
config = chkpt['params'] config = chkpt['params']
config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray))) config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray)))
...@@ -47,7 +55,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p ...@@ -47,7 +55,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file'] pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file']
print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(model, pytorch_weights_dump_path) torch.save(two_levels_state_dict, pytorch_weights_dump_path)
print("Save configuration file to {}".format(pytorch_config_dump_path)) print("Save configuration file to {}".format(pytorch_config_dump_path))
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
......
...@@ -30,6 +30,7 @@ import tensorflow as tf ...@@ -30,6 +30,7 @@ import tensorflow as tf
from .configuration_bert import BertConfig from .configuration_bert import BertConfig
from .modeling_tf_utils import TFPreTrainedModel from .modeling_tf_utils import TFPreTrainedModel
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -51,71 +52,12 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -51,71 +52,12 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
} }
def load_bert_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path): def load_bert_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
""" Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format # build the network
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
"""
try:
import re
import torch
import numpy
from tensorflow.python.keras import backend as K
except ImportError:
logger.error("Loading a PyTorch model in TensorFlow, requires PyTorch to be installed. Please see "
"https://pytorch.org/ for installation instructions.")
raise
pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info("Loading PyTorch weights from {}".format(pt_path))
# Load pytorch model
state_dict = torch.load(pt_path, map_location='cpu')
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list) tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False) # build the network tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
weight_value_tuples = []
all_pytorch_weights = set(list(state_dict.keys()))
for symbolic_weight in symbolic_weights:
name = symbolic_weight.name
name = name.replace('cls_mlm', 'cls') # We had to split this layer in two in the TF model to be
name = name.replace('cls_nsp', 'cls') # able to do transfer learning (Keras only allow to remove full layers)
name = name.replace(':0', '')
name = name.replace('__', '/')
name = name.split('/')
name = name[1:]
transpose = bool(name[-1] == 'kernel')
if name[-1] == 'kernel' or name[-1] == 'embeddings':
name[-1] = 'weight'
name = '.'.join(name)
assert name in state_dict, "{} not found in PyTorch model".format(name)
array = state_dict[name].numpy()
if transpose:
array = numpy.transpose(array)
try:
assert list(symbolic_weight.shape) == list(array.shape)
except AssertionError as e:
e.args += (symbolic_weight.shape, array.shape)
raise e
logger.info("Initialize TF weight {}".format(symbolic_weight.name))
weight_value_tuples.append((symbolic_weight, array))
all_pytorch_weights.discard(name)
K.batch_set_value(weight_value_tuples)
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights))
return tf_model
def gelu(x): def gelu(x):
...@@ -391,7 +333,7 @@ class TFBertEncoder(tf.keras.layers.Layer): ...@@ -391,7 +333,7 @@ class TFBertEncoder(tf.keras.layers.Layer):
super(TFBertEncoder, self).__init__(**kwargs) super(TFBertEncoder, self).__init__(**kwargs)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.layer = [TFBertLayer(config, name='layer__{}'.format(i)) for i in range(config.num_hidden_layers)] self.layer = [TFBertLayer(config, name='layer_._{}'.format(i)) for i in range(config.num_hidden_layers)]
def call(self, inputs, training=False): def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs hidden_states, attention_mask, head_mask = inputs
...@@ -730,15 +672,15 @@ class TFBertForPreTraining(TFBertPreTrainedModel): ...@@ -730,15 +672,15 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
super(TFBertForPreTraining, self).__init__(config, *inputs, **kwargs) super(TFBertForPreTraining, self).__init__(config, *inputs, **kwargs)
self.bert = TFBertMainLayer(config, name='bert') self.bert = TFBertMainLayer(config, name='bert')
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp') self.nsp = TFBertNSPHead(config, name='nsp___cls')
self.cls_mlm = TFBertMLMHead(config, self.bert.embeddings, name='cls_mlm') self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
def call(self, inputs, training=False): def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training) outputs = self.bert(inputs, training=training)
sequence_output, pooled_output = outputs[:2] sequence_output, pooled_output = outputs[:2]
prediction_scores = self.cls_mlm(sequence_output, training=training) prediction_scores = self.mlm(sequence_output, training=training)
seq_relationship_score = self.cls_nsp(pooled_output) seq_relationship_score = self.nsp(pooled_output)
outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
...@@ -773,13 +715,13 @@ class TFBertForMaskedLM(TFBertPreTrainedModel): ...@@ -773,13 +715,13 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
super(TFBertForMaskedLM, self).__init__(config, *inputs, **kwargs) super(TFBertForMaskedLM, self).__init__(config, *inputs, **kwargs)
self.bert = TFBertMainLayer(config, name='bert') self.bert = TFBertMainLayer(config, name='bert')
self.cls_mlm = TFBertMLMHead(config, self.bert.embeddings, name='cls_mlm') self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
def call(self, inputs, training=False): def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training) outputs = self.bert(inputs, training=training)
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.cls_mlm(sequence_output, training=training) prediction_scores = self.mlm(sequence_output, training=training)
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
...@@ -816,13 +758,13 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel): ...@@ -816,13 +758,13 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
super(TFBertForNextSentencePrediction, self).__init__(config, *inputs, **kwargs) super(TFBertForNextSentencePrediction, self).__init__(config, *inputs, **kwargs)
self.bert = TFBertMainLayer(config, name='bert') self.bert = TFBertMainLayer(config, name='bert')
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp') self.nsp = TFBertNSPHead(config, name='nsp___cls')
def call(self, inputs, training=False): def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training) outputs = self.bert(inputs, training=training)
pooled_output = outputs[1] pooled_output = outputs[1]
seq_relationship_score = self.cls_nsp(pooled_output) seq_relationship_score = self.nsp(pooled_output)
outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
......
...@@ -32,6 +32,7 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings, ...@@ -32,6 +32,7 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
TFSequenceSummary, shape_list) TFSequenceSummary, shape_list)
from .configuration_gpt2 import GPT2Config from .configuration_gpt2 import GPT2Config
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -40,77 +41,12 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models ...@@ -40,77 +41,12 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-tf_model.h5"} "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-tf_model.h5"}
def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path): def load_gpt2_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
""" Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format # build the network
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
"""
try:
import re
import torch
import numpy
from tensorflow.python.keras import backend as K
except ImportError:
logger.error("Loading a PyTorch model in TensorFlow, requires PyTorch to be installed. Please see "
"https://pytorch.org/ for installation instructions.")
raise
pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info("Loading PyTorch weights from {}".format(pt_path))
# Load pytorch model
state_dict = torch.load(pt_path, map_location='cpu')
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list) tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False) # build the network tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
weight_value_tuples = []
all_pytorch_weights = set(list(state_dict.keys()))
for symbolic_weight in symbolic_weights:
name = symbolic_weight.name
name = name.replace(':0', '')
name = name.replace('__', '/')
name = name.split('/')
name = name[2:]
transpose = bool(name[-1] == 'kernel')
if name[-1] == 'kernel' or name[-1] == 'embeddings' or name[-1] == 'gamma':
name[-1] = 'weight'
if name[-1] == 'beta':
name[-1] = 'bias'
name = '.'.join(name)
assert name in state_dict, "Weight {} not in PyTorch model".format(name)
array = state_dict[name].numpy()
if transpose:
array = numpy.transpose(array)
if len(symbolic_weight.shape) > len(array.shape):
array = array[None, ...]
if len(symbolic_weight.shape) < len(array.shape):
array = np.squeeze(array)
try:
assert list(symbolic_weight.shape) == list(array.shape)
except AssertionError as e:
e.args += (symbolic_weight.shape, array.shape)
raise e
logger.info("Initialize TF weight {}".format(symbolic_weight.name))
weight_value_tuples.append((symbolic_weight, array))
all_pytorch_weights.discard(name)
K.batch_set_value(weight_value_tuples)
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights))
return tf_model
def gelu(x): def gelu(x):
...@@ -282,7 +218,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -282,7 +218,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
self.h = [TFBlock(config.n_ctx, self.h = [TFBlock(config.n_ctx,
config, config,
scale=True, scale=True,
name='h__{}'.format(i)) for i in range(config.n_layer)] name='h_._{}'.format(i)) for i in range(config.n_layer)]
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f') self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f')
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
......
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch - TF 2.0 general utilities."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import logging
from pytorch_transformers import is_tf_available, is_torch_available
logger = logging.getLogger(__name__)
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path):
""" Load pytorch checkpoints in a TF 2.0 model
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
- '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
"""
if not is_tf_available() or not is_torch_available():
logger.error("Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
raise ImportError
import torch
pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info("Loading PyTorch weights from {}".format(pt_path))
pt_state_dict = torch.load(pt_path, map_location='cpu')
return load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict)
def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict):
""" Load pytorch state_dict in a TF 2.0 model.
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
- '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
"""
try:
import re
import torch
import numpy
from tensorflow.python.keras import backend as K
except ImportError as e:
logger.error("Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
raise e
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
weight_value_tuples = []
all_pytorch_weights = set(list(pt_state_dict.keys()))
for symbolic_weight in symbolic_weights:
name = symbolic_weight.name
name = name.replace(':0', '') # device ids
name = re.sub(r'/[^/]*___([^/]*)/', r'/\1/', name) # '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
name = name.replace('_._', '/') # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
name = re.sub(r'//+', '/', name) # Remove empty levels at the end
name = name.split('/') # Convert from TF2.0 '/' separators to PyTorch '.' separators
name = name[1:] # Remove level zero
# Convert standard TF2.0 names in PyTorch names
transpose = bool(name[-1] == 'kernel')
if name[-1] == 'kernel' or name[-1] == 'embeddings' or name[-1] == 'gamma':
name[-1] = 'weight'
if name[-1] == 'beta':
name[-1] = 'bias'
name = '.'.join(name)
assert name in pt_state_dict, "{} not found in PyTorch model".format(name)
array = pt_state_dict[name].numpy()
if transpose:
array = numpy.transpose(array)
try:
assert list(symbolic_weight.shape) == list(array.shape)
except AssertionError as e:
e.args += (symbolic_weight.shape, array.shape)
raise e
logger.info("Initialize TF weight {}".format(symbolic_weight.name))
weight_value_tuples.append((symbolic_weight, array))
all_pytorch_weights.discard(name)
K.batch_set_value(weight_value_tuples)
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights))
return tf_model
def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path):
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
"""
raise NotImplementedError
def load_tf2_weights_in_pytorch_model(pt_model, tf_model):
""" Load TF2.0 symbolic weights in a PyTorch model
"""
raise NotImplementedError
...@@ -18,6 +18,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -18,6 +18,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import logging import logging
import math import math
import os
import itertools import itertools
import numpy as np import numpy as np
...@@ -26,6 +27,7 @@ import tensorflow as tf ...@@ -26,6 +27,7 @@ import tensorflow as tf
from .configuration_xlm import XLMConfig from .configuration_xlm import XLMConfig
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -43,71 +45,16 @@ TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -43,71 +45,16 @@ TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
} }
def load_xlm_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path): def load_xlm_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
""" Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format # build the network
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
"""
try:
import re
import torch
import numpy
from tensorflow.python.keras import backend as K
except ImportError:
logger.error("Loading a PyTorch model in TensorFlow, requires PyTorch to be installed. Please see "
"https://pytorch.org/ for installation instructions.")
raise
pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info("Loading PyTorch weights from {}".format(pt_path))
# Load pytorch model
state_dict = torch.load(pt_path, map_location='cpu')
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
attns_list = [[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]
langs_list = [[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]
tf_inputs = tf.constant(inputs_list) tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False) # build the network tf_attns = tf.constant(attns_list)
tf_langs = tf.constant(langs_list)
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights tfo = tf_model([tf_inputs, tf_attns, tf_langs], training=False)
weight_value_tuples = [] return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
all_pytorch_weights = set(list(state_dict.keys()))
for symbolic_weight in symbolic_weights:
name = symbolic_weight.name
name = name.replace(':0', '')
name = name.replace('__', '/')
name = name.split('/')
name = name[1:]
transpose = bool(name[-1] == 'kernel')
if name[-1] == 'kernel' or name[-1] == 'embeddings' or name[-1] == 'gamma':
name[-1] = 'weight'
if name[-1] == 'beta':
name[-1] = 'bias'
name = '.'.join(name)
assert name in state_dict, "{} not found in PyTorch model".format(name)
array = state_dict[name].numpy()
if transpose:
array = numpy.transpose(array)
try:
assert list(symbolic_weight.shape) == list(array.shape)
except AssertionError as e:
e.args += (symbolic_weight.shape, array.shape)
raise e
logger.info("Initialize TF weight {}".format(symbolic_weight.name))
weight_value_tuples.append((symbolic_weight, array))
all_pytorch_weights.discard(name)
K.batch_set_value(weight_value_tuples)
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights))
return tf_model
def create_sinusoidal_embeddings(n_pos, dim, out): def create_sinusoidal_embeddings(n_pos, dim, out):
...@@ -320,13 +267,13 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -320,13 +267,13 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# self.encoder_attn = tf.keras.layers.LayerList() # self.encoder_attn = tf.keras.layers.LayerList()
for i in range(self.n_layers): for i in range(self.n_layers):
self.attentions.append(TFMultiHeadAttention(self.n_heads, self.dim, config=config, name='attentions__{}'.format(i))) self.attentions.append(TFMultiHeadAttention(self.n_heads, self.dim, config=config, name='attentions_._{}'.format(i)))
self.layer_norm1.append(tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='layer_norm1__{}'.format(i))) self.layer_norm1.append(tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='layer_norm1_._{}'.format(i)))
# if self.is_decoder: # if self.is_decoder:
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps)) # self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
# self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout)) # self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
self.ffns.append(TFTransformerFFN(self.dim, self.hidden_dim, self.dim, config=config, name='ffns__{}'.format(i))) self.ffns.append(TFTransformerFFN(self.dim, self.hidden_dim, self.dim, config=config, name='ffns_._{}'.format(i)))
self.layer_norm2.append(tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='layer_norm2__{}'.format(i))) self.layer_norm2.append(tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='layer_norm2_._{}'.format(i)))
if hasattr(config, "pruned_heads"): if hasattr(config, "pruned_heads"):
pruned_heads = config.pruned_heads.copy().items() pruned_heads = config.pruned_heads.copy().items()
...@@ -667,8 +614,8 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel): ...@@ -667,8 +614,8 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
""" """
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super(TFXLMWithLMHeadModel, self).__init__(config, *inputs, **kwargs) super(TFXLMWithLMHeadModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFXLMMainLayer(config, name='transformer') self.transformer = TFXLMMainLayer(config, name='transformer___')
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name='pred_layer') self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name='pred_layer_._proj')
def call(self, inputs, training=False): def call(self, inputs, training=False):
......
...@@ -30,6 +30,7 @@ import tensorflow as tf ...@@ -30,6 +30,7 @@ import tensorflow as tf
from .configuration_xlnet import XLNetConfig from .configuration_xlnet import XLNetConfig
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -40,71 +41,11 @@ TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -40,71 +41,11 @@ TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
} }
def load_xlnet_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path): def load_xlnet_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
""" Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
"""
try:
import re
import torch
import numpy
from tensorflow.python.keras import backend as K
except ImportError:
logger.error("Loading a PyTorch model in TensorFlow, requires PyTorch to be installed. Please see "
"https://pytorch.org/ for installation instructions.")
raise
pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info("Loading PyTorch weights from {}".format(pt_path))
# Load pytorch model
state_dict = torch.load(pt_path, map_location='cpu')
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list) tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False) # build the network tfo = tf_model(tf_inputs, training=False) # build the network
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
weight_value_tuples = []
all_pytorch_weights = set(list(state_dict.keys()))
for symbolic_weight in symbolic_weights:
name = symbolic_weight.name
name = name.replace(':0', '')
name = name.replace('__', '/')
name = name.split('/')
name = name[1:]
transpose = bool(name[-1] == 'kernel')
if name[-1] == 'kernel' or name[-1] == 'embeddings' or name[-1] == 'gamma':
name[-1] = 'weight'
if name[-1] == 'beta':
name[-1] = 'bias'
name = '.'.join(name)
assert name in state_dict, "{} not found in PyTorch model".format(name)
array = state_dict[name].numpy()
if transpose:
array = numpy.transpose(array)
try:
assert list(symbolic_weight.shape) == list(array.shape)
except AssertionError as e:
e.args += (symbolic_weight.shape, array.shape)
raise e
logger.info("Initialize TF weight {}".format(symbolic_weight.name))
weight_value_tuples.append((symbolic_weight, array))
all_pytorch_weights.discard(name)
K.batch_set_value(weight_value_tuples)
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights))
return tf_model
def gelu(x): def gelu(x):
...@@ -430,7 +371,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -430,7 +371,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
self.initializer_range = config.initializer_range self.initializer_range = config.initializer_range
self.word_embedding = TFSharedEmbeddings(config.n_token, config.d_model, initializer_range=config.initializer_range, name='word_embedding') self.word_embedding = TFSharedEmbeddings(config.n_token, config.d_model, initializer_range=config.initializer_range, name='word_embedding')
self.layer = [TFXLNetLayer(config, name='layer__{}'.format(i)) for i in range(config.n_layer)] self.layer = [TFXLNetLayer(config, name='layer_._{}'.format(i)) for i in range(config.n_layer)]
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
def build(self, input_shape): def build(self, input_shape):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment