Commit 7d03c537 authored by thomwolf's avatar thomwolf
Browse files

conversion working

parent 3a9c8837
...@@ -119,4 +119,7 @@ dmypy.json ...@@ -119,4 +119,7 @@ dmypy.json
.vscode .vscode
# TF code # TF code
tensorflow_code tensorflow_code
\ No newline at end of file
# models
models
\ No newline at end of file
...@@ -3,9 +3,14 @@ def main(): ...@@ -3,9 +3,14 @@ def main():
import sys import sys
if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [ if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [
"convert_tf_checkpoint_to_pytorch", "convert_tf_checkpoint_to_pytorch",
"convert_openai_checkpoint" "convert_openai_checkpoint",
"convert_transfo_xl_checkpoint"
]: ]:
print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT` \n or `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`") print(
"Should be used as"
"`pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n"
"`pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]` or \n"
"`pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
else: else:
if sys.argv[1] == "convert_tf_checkpoint_to_pytorch": if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
try: try:
...@@ -24,7 +29,7 @@ def main(): ...@@ -24,7 +29,7 @@ def main():
TF_CONFIG = sys.argv.pop() TF_CONFIG = sys.argv.pop()
TF_CHECKPOINT = sys.argv.pop() TF_CHECKPOINT = sys.argv.pop()
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
else: elif sys.argv[1] == "convert_openai_checkpoint":
from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2]
PYTORCH_DUMP_OUTPUT = sys.argv[3] PYTORCH_DUMP_OUTPUT = sys.argv[3]
...@@ -35,6 +40,22 @@ def main(): ...@@ -35,6 +40,22 @@ def main():
convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
OPENAI_GPT_CONFIG, OPENAI_GPT_CONFIG,
PYTORCH_DUMP_OUTPUT) PYTORCH_DUMP_OUTPUT)
else:
try:
from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
except ModuleNotFoundError:
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
TF_CHECKPOINT = sys.argv[2]
PYTORCH_DUMP_OUTPUT = sys.argv[3]
if len(sys.argv) == 5:
TF_CONFIG = sys.argv[4]
else:
TF_CONFIG = ""
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -18,11 +18,9 @@ from __future__ import absolute_import ...@@ -18,11 +18,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import re import re
import json import json
import argparse import argparse
import tensorflow as tf
import torch import torch
import numpy as np import numpy as np
......
...@@ -25,7 +25,72 @@ import tensorflow as tf ...@@ -25,7 +25,72 @@ import tensorflow as tf
import torch import torch
import numpy as np import numpy as np
from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, CONFIG_NAME, WEIGHTS_NAME from pytorch_pretrained_bert.modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, CONFIG_NAME, WEIGHTS_NAME
def build_tf_to_pytorch_map(model, config):
""" A map of modules from TF to PyTorch """
tf_to_pt_map = {}
# Embeddings cutoffs
for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
layer_str = "transformer/adaptive_embed/cutoff_%d/" % i
tf_to_pt_map.update({
layer_str + 'lookup_table': embed_l.weight,
layer_str + 'proj_W': proj_l
})
# Transformer blocks
for i, b in enumerate(model.layers):
layer_str = "transformer/layer_%d/" % i
tf_to_pt_map.update({
layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight,
layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias,
layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight,
layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight,
layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight,
layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight,
layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias,
layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight,
layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias,
layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight,
layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
})
# Softmax cutoffs
for i, (out_l, proj_l, tie_proj) in enumerate(zip(
model.crit.out_layers,
model.crit.out_projs,
config.tie_projs)):
layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i
if config.tie_weight:
tf_to_pt_map.update({
layer_str + 'b': out_l.bias})
else:
raise NotImplementedError
# I don't think this is implemented in the TF code
tf_to_pt_map.update({
layer_str + 'lookup_table': out_l.weight,
layer_str + 'b': out_l.bias})
if not tie_proj:
tf_to_pt_map.update({
layer_str + 'proj': proj_l
})
# Relative positioning biases
if config.untie_r:
layer_str = "transformer/r_r_bias"
layer_str_2 = "transformer/r_w_bias"
r_r_list = []
r_w_list = []
for b in model.layers:
r_r_list.append(b.dec_attn.r_r_bias)
r_w_list.append(b.dec_attn.r_w_bias)
else:
r_r_list = [model.r_r_bias]
r_w_list = [model.r_w_bias]
tf_to_pt_map.update({
'transformer/r_r_bias': r_r_list,
'transformer/r_w_bias': r_w_list})
return tf_to_pt_map
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
...@@ -35,16 +100,6 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, ...@@ -35,16 +100,6 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
tf_path = os.path.abspath(tf_checkpoint_path) tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
# Initialise PyTorch model # Initialise PyTorch model
# Construct model # Construct model
if transfo_xl_config_file == "": if transfo_xl_config_file == "":
...@@ -54,34 +109,37 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, ...@@ -54,34 +109,37 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
print("Building PyTorch model from configuration: {}".format(str(config))) print("Building PyTorch model from configuration: {}".format(str(config)))
model = TransfoXLModel(config) model = TransfoXLModel(config)
for name, array in zip(names, arrays): # Build TF to PyTorch weights loading map
name = name.split('/') tf_to_pt_map = build_tf_to_pytorch_map(model.transformer, config)
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
tf_weights = {}
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
tf_weights[name] = array
for name, pointer in tf_to_pt_map.items():
assert name in tf_weights
array = tf_weights[name]
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model # which are not required for using pretrained model
if any(n in ["adam_v", "adam_m"] for n in name): if 'kernel' in name or 'proj_W' in name:
print("Skipping {}".format("/".join(name)))
continue
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
l = re.split(r'_(\d+)', m_name)
else:
l = [m_name]
if l[0] == 'kernel' or l[0] == 'gamma':
pointer = getattr(pointer, 'weight')
elif l[0] == 'output_bias' or l[0] == 'beta':
pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights':
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, l[0])
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
array = np.transpose(array) array = np.transpose(array)
if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
# Here we will split the TF weigths
assert len(pointer) == array.shape[0]
for i, p_i in enumerate(pointer):
arr_i = array[i, ...]
try:
assert p_i.shape == arr_i.shape
except AssertionError as e:
e.args += (p_i.shape, arr_i.shape)
raise
print("Initialize PyTorch weight {} for layer {}".format(name, i))
p_i.data = torch.from_numpy(arr_i)
continue
try: try:
assert pointer.shape == array.shape assert pointer.shape == array.shape
except AssertionError as e: except AssertionError as e:
...@@ -108,17 +166,16 @@ if __name__ == "__main__": ...@@ -108,17 +166,16 @@ if __name__ == "__main__":
type = str, type = str,
required = True, required = True,
help = "Path the TensorFlow checkpoint path.") help = "Path the TensorFlow checkpoint path.")
parser.add_argument("--transfo_xl_config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_folder_path", parser.add_argument("--pytorch_dump_folder_path",
default = None, default = None,
type = str, type = str,
required = True, required = True,
help = "Path to the output PyTorch model.") help = "Path to the output PyTorch model.")
parser.add_argument("--transfo_xl_config_file",
default = "",
type = str,
help = "The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
args = parser.parse_args() args = parser.parse_args()
convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path,
args.transfo_xl_config_file, args.transfo_xl_config_file,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment