Commit 7d03c537 authored by thomwolf's avatar thomwolf
Browse files

conversion working

parent 3a9c8837
......@@ -119,4 +119,7 @@ dmypy.json
.vscode
# TF code
tensorflow_code
\ No newline at end of file
tensorflow_code
# models
models
\ No newline at end of file
......@@ -3,9 +3,14 @@ def main():
import sys
if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [
"convert_tf_checkpoint_to_pytorch",
"convert_openai_checkpoint"
"convert_openai_checkpoint",
"convert_transfo_xl_checkpoint"
]:
print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT` \n or `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`")
print(
"Should be used as"
"`pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n"
"`pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]` or \n"
"`pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
else:
if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
try:
......@@ -24,7 +29,7 @@ def main():
TF_CONFIG = sys.argv.pop()
TF_CHECKPOINT = sys.argv.pop()
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
else:
elif sys.argv[1] == "convert_openai_checkpoint":
from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2]
PYTORCH_DUMP_OUTPUT = sys.argv[3]
......@@ -35,6 +40,22 @@ def main():
convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
OPENAI_GPT_CONFIG,
PYTORCH_DUMP_OUTPUT)
else:
try:
from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
except ModuleNotFoundError:
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
TF_CHECKPOINT = sys.argv[2]
PYTORCH_DUMP_OUTPUT = sys.argv[3]
if len(sys.argv) == 5:
TF_CONFIG = sys.argv[4]
else:
TF_CONFIG = ""
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
if __name__ == '__main__':
main()
......@@ -18,11 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import re
import json
import argparse
import tensorflow as tf
import torch
import numpy as np
......
......@@ -25,7 +25,72 @@ import tensorflow as tf
import torch
import numpy as np
from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, CONFIG_NAME, WEIGHTS_NAME
from pytorch_pretrained_bert.modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, CONFIG_NAME, WEIGHTS_NAME
def build_tf_to_pytorch_map(model, config):
""" A map of modules from TF to PyTorch """
tf_to_pt_map = {}
# Embeddings cutoffs
for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
layer_str = "transformer/adaptive_embed/cutoff_%d/" % i
tf_to_pt_map.update({
layer_str + 'lookup_table': embed_l.weight,
layer_str + 'proj_W': proj_l
})
# Transformer blocks
for i, b in enumerate(model.layers):
layer_str = "transformer/layer_%d/" % i
tf_to_pt_map.update({
layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight,
layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias,
layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight,
layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight,
layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight,
layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight,
layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias,
layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight,
layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias,
layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight,
layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
})
# Softmax cutoffs
for i, (out_l, proj_l, tie_proj) in enumerate(zip(
model.crit.out_layers,
model.crit.out_projs,
config.tie_projs)):
layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i
if config.tie_weight:
tf_to_pt_map.update({
layer_str + 'b': out_l.bias})
else:
raise NotImplementedError
# I don't think this is implemented in the TF code
tf_to_pt_map.update({
layer_str + 'lookup_table': out_l.weight,
layer_str + 'b': out_l.bias})
if not tie_proj:
tf_to_pt_map.update({
layer_str + 'proj': proj_l
})
# Relative positioning biases
if config.untie_r:
layer_str = "transformer/r_r_bias"
layer_str_2 = "transformer/r_w_bias"
r_r_list = []
r_w_list = []
for b in model.layers:
r_r_list.append(b.dec_attn.r_r_bias)
r_w_list.append(b.dec_attn.r_w_bias)
else:
r_r_list = [model.r_r_bias]
r_w_list = [model.r_w_bias]
tf_to_pt_map.update({
'transformer/r_r_bias': r_r_list,
'transformer/r_w_bias': r_w_list})
return tf_to_pt_map
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
......@@ -35,16 +100,6 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
# Initialise PyTorch model
# Construct model
if transfo_xl_config_file == "":
......@@ -54,34 +109,37 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
print("Building PyTorch model from configuration: {}".format(str(config)))
model = TransfoXLModel(config)
for name, array in zip(names, arrays):
name = name.split('/')
# Build TF to PyTorch weights loading map
tf_to_pt_map = build_tf_to_pytorch_map(model.transformer, config)
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
tf_weights = {}
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
tf_weights[name] = array
for name, pointer in tf_to_pt_map.items():
assert name in tf_weights
array = tf_weights[name]
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m"] for n in name):
print("Skipping {}".format("/".join(name)))
continue
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
l = re.split(r'_(\d+)', m_name)
else:
l = [m_name]
if l[0] == 'kernel' or l[0] == 'gamma':
pointer = getattr(pointer, 'weight')
elif l[0] == 'output_bias' or l[0] == 'beta':
pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights':
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, l[0])
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
if 'kernel' in name or 'proj_W' in name:
array = np.transpose(array)
if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
# Here we will split the TF weigths
assert len(pointer) == array.shape[0]
for i, p_i in enumerate(pointer):
arr_i = array[i, ...]
try:
assert p_i.shape == arr_i.shape
except AssertionError as e:
e.args += (p_i.shape, arr_i.shape)
raise
print("Initialize PyTorch weight {} for layer {}".format(name, i))
p_i.data = torch.from_numpy(arr_i)
continue
try:
assert pointer.shape == array.shape
except AssertionError as e:
......@@ -108,17 +166,16 @@ if __name__ == "__main__":
type = str,
required = True,
help = "Path the TensorFlow checkpoint path.")
parser.add_argument("--transfo_xl_config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_folder_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
parser.add_argument("--transfo_xl_config_file",
default = "",
type = str,
help = "The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
args = parser.parse_args()
convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path,
args.transfo_xl_config_file,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment