Commit e9a103c1 authored by thomwolf's avatar thomwolf
Browse files

bidirectional conversion TF <=> PT - extended tests

parent a7e01a24
version: 2
jobs:
build_py3_torch_and_tf:
working_directory: ~/pytorch-transformers
docker:
- image: circleci/python:3.5
resource_class: xlarge
parallelism: 1
steps:
- checkout
- run: sudo pip install torch
- run: sudo pip install tensorflow==2.0.0-rc0
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install tensorboardX scikit-learn
- run: python -m pytest -sv ./pytorch_transformers/tests/ --cov
- run: codecov
build_py3_torch:
working_directory: ~/pytorch-transformers
docker:
......
......@@ -73,7 +73,8 @@ if _torch_available:
load_tf_weights_in_xlnet, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_xlm import (XLMPreTrainedModel , XLMModel,
XLMWithLMHeadModel, XLMForSequenceClassification,
XLMForQuestionAnswering, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
XLMForQuestionAnswering, XLMForQuestionAnsweringSimple,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
......@@ -150,6 +151,15 @@ if _tf_available:
load_distilbert_pt_weights_in_tf2,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
if _tf_available and _torch_available:
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
load_pytorch_checkpoint_in_tf2_model,
load_pytorch_weights_in_tf2_model,
load_pytorch_model_in_tf2_model,
load_tf2_checkpoint_in_pytorch_model,
load_tf2_weights_in_pytorch_model,
load_tf2_model_in_pytorch_model)
# Files and general utilities
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
cached_path, add_start_docstrings, add_end_docstrings,
......
......@@ -20,15 +20,49 @@ from __future__ import (absolute_import, division, print_function,
import logging
import os
import re
import numpy
logger = logging.getLogger(__name__)
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None):
""" Load pytorch checkpoints in a TF 2.0 model
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=''):
""" Convert a TF 2.0 model variable name in a pytorch model weight name.
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
- '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
return tuple with:
- pytorch model weight name
- transpose: boolean indicating weither TF2.0 and PyTorch weights matrices are transposed with regards to each other
"""
tf_name = tf_name.replace(':0', '') # device ids
tf_name = re.sub(r'/[^/]*___([^/]*)/', r'/\1/', tf_name) # '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
tf_name = tf_name.replace('_._', '/') # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
tf_name = re.sub(r'//+', '/', tf_name) # Remove empty levels at the end
tf_name = tf_name.split('/') # Convert from TF2.0 '/' separators to PyTorch '.' separators
tf_name = tf_name[1:] # Remove level zero
# When should we transpose the weights
transpose = bool(tf_name[-1] == 'kernel' or 'emb_projs' in tf_name or 'out_projs' in tf_name)
# Convert standard TF2.0 names in PyTorch names
if tf_name[-1] == 'kernel' or tf_name[-1] == 'embeddings' or tf_name[-1] == 'gamma':
tf_name[-1] = 'weight'
if tf_name[-1] == 'beta':
tf_name[-1] = 'bias'
# Remove prefix if needed
tf_name = '.'.join(tf_name)
if start_prefix_to_remove:
tf_name = tf_name.replace(start_prefix_to_remove, '', 1)
return tf_name, transpose
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None):
""" Load pytorch checkpoints in a TF 2.0 model
"""
try:
import tensorflow as tf
......@@ -43,25 +77,31 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
pt_state_dict = torch.load(pt_path, map_location='cpu')
return load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs)
return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs)
def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None):
""" Load pytorch checkpoints in a TF 2.0 model
"""
pt_state_dict = pt_model.state_dict()
return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs)
def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None):
def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None):
""" Load pytorch state_dict in a TF 2.0 model.
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
- '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
"""
try:
import re
import torch
import numpy
from tensorflow.python.keras import backend as K
except ImportError as e:
logger.error("Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
raise e
if tf_inputs is not None:
tfo = tf_model(tf_inputs, training=False) # Make sure model is built
# Adapt state dict - TODO remove this and update the AWS weights files instead
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
......@@ -89,27 +129,8 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None
weight_value_tuples = []
all_pytorch_weights = set(list(pt_state_dict.keys()))
for symbolic_weight in symbolic_weights:
name = symbolic_weight.name
name = name.replace(':0', '') # device ids
name = re.sub(r'/[^/]*___([^/]*)/', r'/\1/', name) # '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
name = name.replace('_._', '/') # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
name = re.sub(r'//+', '/', name) # Remove empty levels at the end
name = name.split('/') # Convert from TF2.0 '/' separators to PyTorch '.' separators
name = name[1:] # Remove level zero
# When should we transpose the weights
transpose = bool(name[-1] == 'kernel' or 'emb_projs' in name or 'out_projs' in name)
# Convert standard TF2.0 names in PyTorch names
if name[-1] == 'kernel' or name[-1] == 'embeddings' or name[-1] == 'gamma':
name[-1] = 'weight'
if name[-1] == 'beta':
name[-1] = 'bias'
# Remove prefix if needed
name = '.'.join(name)
if start_prefix_to_remove:
name = name.replace(start_prefix_to_remove, '', 1)
sw_name = symbolic_weight.name
name, transpose = convert_tf_weight_name_to_pt_weight_name(sw_name, start_prefix_to_remove=start_prefix_to_remove)
# Find associated numpy array in pytorch model state dict
assert name in pt_state_dict, "{} not found in PyTorch model".format(name)
......@@ -144,13 +165,10 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None
return tf_model
def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path):
def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=None):
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
- '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
"""
try:
import tensorflow as tf
......@@ -161,13 +179,97 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path):
raise e
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info("Loading TensorFlow weights from {}".format(tf_path))
logger.info("Loading TensorFlow weights from {}".format(tf_checkpoint_path))
# Instantiate and load the associated TF 2.0 model
tf_model_class_name = "TF" + model_class.__name__ # Add "TF" at the beggining
tf_model_class = getattr(pytorch_transformers, tf_model_class_name)
tf_model = tf_model_class(pt_model.config)
if tf_inputs is not None:
tfo = tf_model(tf_inputs, training=False) # Make sure model is built
tf_model.load_weights(tf_checkpoint_path, by_name=True)
return load_tf2_model_in_pytorch_model(pt_model, tf_model)
def load_tf2_model_in_pytorch_model(pt_model, tf_model):
""" Load TF 2.0 model in a pytorch model
"""
weights = tf_model.weights
tf_state_dict = torch.load(tf_path, map_location='cpu')
return load_tf2_weights_in_pytorch_model(pt_model, weights)
return load_tf2_weights_in_pytorch_model(pt_model, tf_state_dict)
def load_tf2_weights_in_pytorch_model(pt_model, tf_model):
def load_tf2_weights_in_pytorch_model(pt_model, tf_weights):
""" Load TF2.0 symbolic weights in a PyTorch model
"""
raise NotImplementedError
try:
import tensorflow as tf
import torch
except ImportError as e:
logger.error("Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
raise e
new_pt_params_dict = {}
current_pt_params_dict = dict(pt_model.named_parameters())
# Make sure we are able to load PyTorch base models as well as derived models (with heads)
# TF models always have a prefix, some of PyTorch models (base ones) don't
start_prefix_to_remove = ''
if not any(s.startswith(pt_model.base_model_prefix) for s in current_pt_params_dict.keys()):
start_prefix_to_remove = pt_model.base_model_prefix + '.'
# Build a map from potential PyTorch weight names to TF 2.0 Variables
tf_weights_map = {}
for tf_weight in tf_weights:
pt_name, transpose = convert_tf_weight_name_to_pt_weight_name(tf_weight.name, start_prefix_to_remove=start_prefix_to_remove)
tf_weights_map[pt_name] = (tf_weight.numpy(), transpose)
all_tf_weights = set(list(tf_weights_map.keys()))
loaded_pt_weights_data_ptr = {}
for pt_weight_name, pt_weight in current_pt_params_dict.items():
# Handle PyTorch shared weight ()not duplicated in TF 2.0
if pt_weight.data_ptr() in loaded_pt_weights_data_ptr:
new_pt_params_dict[pt_weight_name] = loaded_pt_weights_data_ptr[pt_weight.data_ptr()]
continue
# Find associated numpy array in pytorch model state dict
if pt_weight_name not in tf_weights_map:
raise ValueError("{} not found in TF 2.0 model".format(pt_weight_name))
array, transpose = tf_weights_map[pt_weight_name]
if transpose:
array = numpy.transpose(array)
if len(pt_weight.shape) < len(array.shape):
array = numpy.squeeze(array)
elif len(pt_weight.shape) > len(array.shape):
array = numpy.expand_dims(array, axis=0)
try:
assert list(pt_weight.shape) == list(array.shape)
except AssertionError as e:
e.args += (pt_weight.shape, array.shape)
raise e
logger.info("Initialize PyTorch weight {}".format(pt_weight_name))
new_pt_params_dict[pt_weight_name] = torch.from_numpy(array)
loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = torch.from_numpy(array)
all_tf_weights.discard(pt_weight_name)
missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from TF 2.0 model: {}".format(
pt_model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from TF 2.0 model not used in {}: {}".format(
pt_model.__class__.__name__, unexpected_keys))
logger.info("Weights or buffers not loaded from TF 2.0 model: {}".format(all_tf_weights))
return pt_model
......@@ -718,6 +718,101 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
@add_start_docstrings("""XLM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
**is_impossible**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels whether a question has an answer or no answer (SQuAD 2.0)
**cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the classification token to use as input for computing plausibility of the answer.
**p_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...)
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-start scores (before SoftMax).
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
model = XLMForQuestionAnsweringSimple.from_pretrained('xlm-mlm-en-2048')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
start_positions = torch.tensor([1])
end_positions = torch.tensor([3])
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
loss, start_scores, end_scores = outputs[:2]
"""
def __init__(self, config):
super(XLMForQuestionAnsweringSimple, self).__init__(config)
self.transformer = XLMModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, start_positions=None, end_positions=None):
transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask)
sequence_output = transformer_outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = (start_logits, end_logits,)
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
outputs = outputs + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
return outputs
@add_start_docstrings("""XLM Model with a beam-search span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
class XLMForQuestionAnswering(XLMPreTrainedModel):
r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
......
......@@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function
import copy
import json
import logging
import importlib
import random
import shutil
import unittest
......@@ -25,7 +26,7 @@ import uuid
import pytest
import sys
from pytorch_transformers import is_tf_available
from pytorch_transformers import is_tf_available, is_torch_available
if is_tf_available():
import tensorflow as tf
......@@ -66,6 +67,24 @@ class TFCommonTestCases:
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
def test_pt_tf_model_equivalence(self):
if not is_torch_available():
pass
import pytorch_transformers
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beggining
pt_model_class = getattr(pytorch_transformers, pt_model_class_name)
tf_model = model_class(config)
pt_model = pt_model_class(config)
tf_model = pytorch_transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
pt_model = pytorch_transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
def test_keyword_and_dict_args(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
......@@ -225,7 +225,7 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, token_type_ids, input_lengths,
sequence_labels, token_labels, is_impossible_labels, input_mask) = config_and_inputs
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'lengths': input_lengths}
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'langs': token_type_ids, 'lengths': input_lengths}
return config, inputs_dict
def setUp(self):
......
......@@ -24,7 +24,7 @@ from pytorch_transformers import is_torch_available
if is_torch_available():
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering,
XLMForSequenceClassification)
XLMForSequenceClassification, XLMForQuestionAnsweringSimple)
from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
else:
pytestmark = pytest.mark.skip("Require Torch")
......@@ -36,7 +36,7 @@ from .configuration_common_test import ConfigTester
class XLMModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering,
XLMForSequenceClassification) if is_torch_available() else ()
XLMForSequenceClassification, XLMForQuestionAnsweringSimple) if is_torch_available() else ()
class XLMModelTester(object):
......@@ -180,6 +180,30 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
[self.batch_size, self.seq_length, self.vocab_size])
def create_and_check_xlm_simple_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
model = XLMForQuestionAnsweringSimple(config)
model.eval()
outputs = model(input_ids)
outputs = model(input_ids, start_positions=sequence_labels,
end_positions=sequence_labels)
loss, start_logits, end_logits = outputs
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(
list(result["start_logits"].size()),
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["end_logits"].size()),
[self.batch_size, self.seq_length])
self.check_loss_output(result)
def create_and_check_xlm_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
model = XLMForQuestionAnswering(config)
model.eval()
......@@ -276,6 +300,10 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_lm_head(*config_and_inputs)
def test_xlm_simple_qa(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_simple_qa(*config_and_inputs)
def test_xlm_qa(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_qa(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment