Unverified Commit 16c22401 authored by Martin Müller's avatar Martin Müller Committed by GitHub
Browse files

Add script to convert tf2.x checkpoint to PyTorch (#5791)

* Add script to convert tf2.x checkpoint to pytorch

The script converts the newer TF2.x checkpoints (as published on their official GitHub: https://github.com/tensorflow/models/tree/master/official/nlp/bert) to Pytorch.

* rename file in order to stay consistent with naming convention
parent 82a0e2b6
"""
This script can be used to convert a head-less TF2.x Bert model to PyTorch,
as published on the official GitHub: https://github.com/tensorflow/models/tree/master/official/nlp/bert
TF2.x uses different variable names from the original BERT (TF 1.4) implementation.
The script re-maps the TF2.x Bert weight names to the original names, so the model can be imported with Huggingface/transformer.
You may adapt this script to include classification/MLM/NSP/etc. heads.
"""
import argparse
import logging
import os
import re
import tensorflow as tf
import torch
from transformers import BertConfig, BertModel
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def load_tf2_weights_in_bert(model, tf_checkpoint_path, config):
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
layer_depth = []
for full_name, shape in init_vars:
# logger.info("Loading TF weight {} with shape {}".format(name, shape))
name = full_name.split("/")
if full_name == "_CHECKPOINTABLE_OBJECT_GRAPH" or name[0] in ["global_step", "save_counter"]:
logger.info(f"Skipping non-model layer {full_name}")
continue
if "optimizer" in full_name:
logger.info(f"Skipping optimization layer {full_name}")
continue
if name[0] == "model":
# ignore initial 'model'
name = name[1:]
# figure out how many levels deep the name is
depth = 0
for _name in name:
if _name.startswith("layer_with_weights"):
depth += 1
else:
break
layer_depth.append(depth)
# read data
array = tf.train.load_variable(tf_path, full_name)
names.append("/".join(name))
arrays.append(array)
logger.info(f"Read a total of {len(arrays):,} layers")
# Sanity check
if len(set(layer_depth)) != 1:
raise ValueError(f"Found layer names with different depths (layer depth {list(set(layer_depth))})")
layer_depth = list(set(layer_depth))[0]
if layer_depth != 1:
raise ValueError(
"The model contains more than just the embedding/encoder layers. This script does not handle MLM/NSP heads."
)
# convert layers
logger.info("Converting weights...")
for full_name, array in zip(names, arrays):
name = full_name.split("/")
pointer = model
trace = []
for i, m_name in enumerate(name):
if m_name == ".ATTRIBUTES":
# variable names end with .ATTRIBUTES/VARIABLE_VALUE
break
if m_name.startswith("layer_with_weights"):
layer_num = int(m_name.split("-")[-1])
if layer_num <= 2:
# embedding layers
# layer_num 0: word_embeddings
# layer_num 1: position_embeddings
# layer_num 2: token_type_embeddings
continue
elif layer_num == 3:
# embedding LayerNorm
trace.extend(["embeddings", "LayerNorm"])
pointer = getattr(pointer, "embeddings")
pointer = getattr(pointer, "LayerNorm")
elif layer_num > 3 and layer_num < config.num_hidden_layers + 4:
# encoder layers
trace.extend(["encoder", "layer", str(layer_num - 4)])
pointer = getattr(pointer, "encoder")
pointer = getattr(pointer, "layer")
pointer = pointer[layer_num - 4]
elif layer_num == config.num_hidden_layers + 4:
# pooler layer
trace.extend(["pooler", "dense"])
pointer = getattr(pointer, "pooler")
pointer = getattr(pointer, "dense")
elif m_name == "embeddings":
trace.append("embeddings")
pointer = getattr(pointer, "embeddings")
if layer_num == 0:
trace.append("word_embeddings")
pointer = getattr(pointer, "word_embeddings")
elif layer_num == 1:
trace.append("position_embeddings")
pointer = getattr(pointer, "position_embeddings")
elif layer_num == 2:
trace.append("token_type_embeddings")
pointer = getattr(pointer, "token_type_embeddings")
else:
raise ValueError("Unknown embedding layer with name {full_name}")
trace.append("weight")
pointer = getattr(pointer, "weight")
elif m_name == "_attention_layer":
# self-attention layer
trace.extend(["attention", "self"])
pointer = getattr(pointer, "attention")
pointer = getattr(pointer, "self")
elif m_name == "_attention_layer_norm":
# output attention norm
trace.extend(["attention", "output", "LayerNorm"])
pointer = getattr(pointer, "attention")
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "LayerNorm")
elif m_name == "_attention_output_dense":
# output attention dense
trace.extend(["attention", "output", "dense"])
pointer = getattr(pointer, "attention")
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "dense")
elif m_name == "_output_dense":
# output dense
trace.extend(["output", "dense"])
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "dense")
elif m_name == "_output_layer_norm":
# output dense
trace.extend(["output", "LayerNorm"])
pointer = getattr(pointer, "output")
pointer = getattr(pointer, "LayerNorm")
elif m_name == "_key_dense":
# attention key
trace.append("key")
pointer = getattr(pointer, "key")
elif m_name == "_query_dense":
# attention query
trace.append("query")
pointer = getattr(pointer, "query")
elif m_name == "_value_dense":
# attention value
trace.append("value")
pointer = getattr(pointer, "value")
elif m_name == "_intermediate_dense":
# attention intermediate dense
trace.extend(["intermediate", "dense"])
pointer = getattr(pointer, "intermediate")
pointer = getattr(pointer, "dense")
elif m_name == "_output_layer_norm":
# output layer norm
trace.append("output")
pointer = getattr(pointer, "output")
# weights & biases
elif m_name in ["bias", "beta"]:
trace.append("bias")
pointer = getattr(pointer, "bias")
elif m_name in ["kernel", "gamma"]:
trace.append("weight")
pointer = getattr(pointer, "weight")
else:
logger.warning(f"Ignored {m_name}")
# for certain layers reshape is necessary
trace = ".".join(trace)
if re.match(r"(\S+)\.attention\.self\.(key|value|query)\.(bias|weight)", trace) or re.match(
r"(\S+)\.attention\.output\.dense\.weight", trace
):
array = array.reshape(pointer.data.shape)
if "kernel" in full_name:
array = array.transpose()
if pointer.shape == array.shape:
pointer.data = torch.from_numpy(array)
else:
raise ValueError(
f"Shape mismatch in layer {full_name}: Model expects shape {pointer.shape} but layer contains shape: {array.shape}"
)
logger.info(f"Successfully set variable {full_name} to PyTorch layer {trace}")
return model
def convert_tf2_checkpoint_to_pytorch(tf_checkpoint_path, config_path, pytorch_dump_path):
# Instantiate model
logger.info(f"Loading model based on config from {config_path}...")
config = BertConfig.from_json_file(config_path)
model = BertModel(config)
# Load weights from checkpoint
logger.info(f"Loading weights from checkpoint {tf_checkpoint_path}...")
load_tf2_weights_in_bert(model, tf_checkpoint_path, config)
# Save pytorch-model
logger.info(f"Saving PyTorch model to {pytorch_dump_path}...")
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow 2.x checkpoint path."
)
parser.add_argument(
"--bert_config_file",
type=str,
required=True,
help="The config json file corresponding to the BERT model. This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path",
type=str,
required=True,
help="Path to the output PyTorch model (must include filename).",
)
args = parser.parse_args()
convert_tf2_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment