Commit 98c96fb1 authored by thomwolf's avatar thomwolf
Browse files

splitting position and tokens embeddings in OpenAI GPT - updating tf imports - tests

parent 5456d823
......@@ -14,7 +14,7 @@ def main():
else:
if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
try:
from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
import tensorflow as tf
except ModuleNotFoundError:
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
......@@ -42,7 +42,7 @@ def main():
PYTORCH_DUMP_OUTPUT)
else:
try:
from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
import tensorflow as tf
except ModuleNotFoundError:
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
......
......@@ -18,13 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import json
import argparse
import torch
import numpy as np
from .modeling_openai import load_tf_weights_in_openai_gpt, OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME
from pytorch_pretrained_bert.modeling_openai import load_tf_weights_in_openai_gpt, OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
# Construct model
......@@ -67,5 +64,5 @@ if __name__ == "__main__":
"This specifies the model architecture.")
args = parser.parse_args()
convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path,
args.pytorch_dump_folder_path,
args.openai_config_file)
args.openai_config_file,
args.pytorch_dump_folder_path)
......@@ -25,7 +25,7 @@ import tensorflow as tf
import torch
import numpy as np
from .modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert
from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
# Initialise PyTorch model
......
......@@ -52,6 +52,14 @@ TF_WEIGHTS_NAME = 'model.ckpt'
def load_tf_weights_in_bert(model, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model
"""
try:
import re
import numpy as np
import tensorflow as tf
except ModuleNotFoundError:
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
......
......@@ -15,23 +15,23 @@
# limitations under the License.
"""PyTorch OpenAI GPT model."""
import os
import collections
import copy
import json
import math
import logging
import math
import os
import shutil
import tarfile
import tempfile
import shutil
import collections
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
from .modeling import BertLayerNorm as LayerNorm
from .file_utils import cached_path
from .modeling import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__)
......@@ -42,6 +42,8 @@ WEIGHTS_NAME = "pytorch_model.bin"
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
"""
import re
import numpy as np
print("Loading weights...")
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
......@@ -50,18 +52,24 @@ def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
del init_params[1]
# Thsi as used when we had a single embedding matrix for positions and tokens
# init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
# del init_params[1]
init_params = [arr.squeeze() for arr in init_params]
try:
assert model.embed.weight.shape == init_params[0].shape
assert model.tokens_embed.weight.shape == init_params[1].shape
assert model.positions_embed.weight.shape == init_params[0].shape
except AssertionError as e:
e.args += (model.embed.weight.shape, init_params[0].shape)
e.args += (model.tokens_embed.weight.shape, init_params[1].shape)
e.args += (model.positions_embed.weight.shape, init_params[0].shape)
raise
model.embed.weight.data = torch.from_numpy(init_params[0])
model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
model.positions_embed.weight.data = torch.from_numpy(init_params[0])
names.pop(0)
# Pop position and token embedding arrays
init_params.pop(0)
init_params.pop(0)
for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
......@@ -584,8 +592,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
def __init__(self, config):
super(OpenAIGPTModel, self).__init__(config)
total_embeddings_size = config.vocab_size + config.n_special + config.n_positions
self.embed = nn.Embedding(total_embeddings_size, config.n_embd)
num_tokens = config.vocab_size + config.n_special
self.tokens_embed = nn.Embedding(num_tokens, config.n_embd)
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
......@@ -598,30 +607,32 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
# Update config
self.config.n_special = num_special_tokens
# # Build new embeddings and initialize
old_embed = self.embed
self.embed = nn.Embedding(self.config.total_num_embeddings, self.config.n_embd)
old_embed = self.tokens_embed
self.tokens_embed = nn.Embedding(self.config.total_num_embeddings, self.config.n_embd)
# Initialize all new embeddings (in particular the special tokens)
self.init_weights(self.embed)
self.init_weights(self.tokens_embed)
# Copy word and positional embeddings from the previous weights
self.embed.weight.data[: self.config.vocab_size, :] = old_embed.weight.data[: self.config.vocab_size, :]
self.embed.weight.data[-self.config.n_positions :, :] = old_embed.weight.data[-self.config.n_positions :, :]
self.tokens_embed.weight.data[: self.config.vocab_size, :] = old_embed.weight.data[: self.config.vocab_size, :]
self.tokens_embed.weight.data[-self.config.n_positions :, :] = old_embed.weight.data[-self.config.n_positions :, :]
def forward(self, input_ids, position_ids=None, token_type_ids=None):
if position_ids is None:
start = self.config.vocab_size + self.config.n_special
end = start + input_ids.size(-1)
position_ids = torch.arange(start, end, dtype=torch.long, device=input_ids.device)
# This was used when we had a single embedding matrice from position and token embeddings
# start = self.config.vocab_size + self.config.n_special
# end = start + input_ids.size(-1)
# position_ids = torch.arange(start, end, dtype=torch.long, device=input_ids.device)
position_ids = torch.arange(input_ids.size(-1), dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_ids.size(-1))
position_ids = position_ids.view(-1, position_ids.size(-1))
inputs_embeds = self.embed(input_ids)
position_embeds = self.embed(position_ids)
inputs_embeds = self.tokens_embed(input_ids)
position_embeds = self.positions_embed(position_ids)
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
token_type_embeds = self.embed(token_type_ids)
token_type_embeds = self.tokens_embed(token_type_ids)
else:
token_type_embeds = 0
# Add the position information to the input embeddings
......@@ -694,13 +705,13 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
def __init__(self, config):
super(OpenAIGPTLMHeadModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config)
self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, config)
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens):
" Update input and output embeddings with new embedding matrice "
self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.embed.weight)
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
......@@ -780,14 +791,14 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def __init__(self, config):
super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config)
self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, config)
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config)
self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens):
" Update input and output embeddings with new embedding matrice "
self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.embed.weight)
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight)
def forward(self, input_ids, mc_token_mask, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
......
......@@ -121,6 +121,13 @@ def build_tf_to_pytorch_map(model, config):
def load_tf_weights_in_transfo_xl(model, config, tf_path):
""" Load tf checkpoints in a pytorch model
"""
try:
import numpy as np
import tensorflow as tf
except ModuleNotFoundError:
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
# Build TF to PyTorch weights loading map
tf_to_pt_map = build_tf_to_pytorch_map(model, config)
......
......@@ -39,7 +39,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
use_labels=True,
vocab_size=99,
n_special=1,
n_ctx=33,
n_positions=33,
n_embd=32,
n_layer=5,
n_head=4,
......@@ -61,7 +61,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
self.use_labels = use_labels
self.vocab_size = vocab_size
self.n_special = n_special
self.n_ctx = n_ctx
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
......@@ -80,12 +80,11 @@ class OpenAIGPTModelTest(unittest.TestCase):
position_ids = None
if self.use_position_ids:
position_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.n_ctx)
position_ids = position_ids + self.n_special + self.vocab_size
position_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.n_positions)
token_type_ids = None
if self.use_token_type_ids:
total_voc = self.n_ctx + self.n_special + self.vocab_size
total_voc = self.vocab_size + self.n_special
token_type_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_voc)
mc_labels = None
......@@ -98,7 +97,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
config = OpenAIGPTConfig(
vocab_size_or_config_json_file=self.vocab_size,
n_ctx=self.n_ctx,
n_positions=self.n_positions,
n_special=self.n_special,
n_embd=self.n_embd,
n_layer=self.n_layer,
......@@ -139,7 +138,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
return outputs
def check_openai_lm_head_output(self, result):
total_voc = self.n_ctx + self.n_special + self.vocab_size
total_voc = self.n_special + self.vocab_size
self.parent.assertListEqual(
list(result["lm_logits"].size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc])
......@@ -164,7 +163,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
return outputs
def check_openai_double_heads_output(self, result):
total_voc = self.n_ctx + self.n_special + self.vocab_size
total_voc = self.n_special + self.vocab_size
self.parent.assertListEqual(
list(result["lm_logits"].size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment