Commit f56c495c authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Support variable reshape to make TF1 checkpoint compatible with a Bert without...

Support variable reshape to make TF1 checkpoint compatible with a Bert without reshape in einsum layers.

PiperOrigin-RevId: 271613961
parent 2708db70
......@@ -29,6 +29,7 @@ from __future__ import division
from __future__ import print_function
from absl import app
import numpy as np
import tensorflow as tf # TF 1.x
flags = tf.flags
......@@ -44,6 +45,11 @@ flags.DEFINE_string(
"exclude_patterns", None,
"Comma-delimited string of a list of patterns to exclude"
" variables from source checkpoint.")
flags.DEFINE_integer(
"num_heads", -1,
"The number of attention heads, used to reshape variables. If it is -1, "
"we do not reshape variables."
)
# Mapping between old <=> new names. The source pattern in original variable
# name will be replaced by destination pattern.
......@@ -82,6 +88,25 @@ def _has_exclude_patterns(name, exclude_patterns):
return False
def _get_new_shape(name, shape, num_heads):
"""Checks whether a variable requires reshape by pattern matching."""
if "attention/output/dense/kernel" in name:
return tuple([num_heads, shape[0] // num_heads, shape[1]])
if "attention/output/dense/bias" in name:
return shape
patterns = [
"attention/self/query", "attention/self/value", "attention/self/key"
]
for pattern in patterns:
if pattern in name:
if "kernel" in name:
return tuple([shape[0], num_heads, shape[1] // num_heads])
if "bias" in name:
return tuple([num_heads, shape[0] // num_heads])
return None
def convert_names(checkpoint_from_path,
checkpoint_to_path,
exclude_patterns=None):
......@@ -108,6 +133,14 @@ def convert_names(checkpoint_from_path,
continue
new_var_name = _bert_name_replacement(var_name)
tensor = reader.get_tensor(var_name)
new_shape = None
if FLAGS.num_heads > 0:
new_shape = _get_new_shape(var_name, tensor.shape, FLAGS.num_heads)
if new_shape:
tf.logging.info("Veriable %s has a shape change from %s to %s",
var_name, tensor.shape, new_shape)
tensor = np.reshape(tensor, new_shape)
var = tf.Variable(tensor, name=var_name)
new_variable_map[new_var_name] = var
if new_var_name != var_name:
......
......@@ -73,9 +73,6 @@ def create_bert_model(bert_config):
def convert_checkpoint():
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
# Sets backward_compatible to true to convert TF1 BERT checkpoints.
bert_config.backward_compatible = True
core_model = create_bert_model(bert_config)
# Uses streaming-restore in eager model to read V1 name-based checkpoints.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment