Commit f56c495c authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Support variable reshape to make TF1 checkpoint compatible with a Bert without...

Support variable reshape to make TF1 checkpoint compatible with a Bert without reshape in einsum layers.

PiperOrigin-RevId: 271613961
parent 2708db70
...@@ -29,6 +29,7 @@ from __future__ import division ...@@ -29,6 +29,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import app from absl import app
import numpy as np
import tensorflow as tf # TF 1.x import tensorflow as tf # TF 1.x
flags = tf.flags flags = tf.flags
...@@ -44,6 +45,11 @@ flags.DEFINE_string( ...@@ -44,6 +45,11 @@ flags.DEFINE_string(
"exclude_patterns", None, "exclude_patterns", None,
"Comma-delimited string of a list of patterns to exclude" "Comma-delimited string of a list of patterns to exclude"
" variables from source checkpoint.") " variables from source checkpoint.")
flags.DEFINE_integer(
"num_heads", -1,
"The number of attention heads, used to reshape variables. If it is -1, "
"we do not reshape variables."
)
# Mapping between old <=> new names. The source pattern in original variable # Mapping between old <=> new names. The source pattern in original variable
# name will be replaced by destination pattern. # name will be replaced by destination pattern.
...@@ -82,6 +88,25 @@ def _has_exclude_patterns(name, exclude_patterns): ...@@ -82,6 +88,25 @@ def _has_exclude_patterns(name, exclude_patterns):
return False return False
def _get_new_shape(name, shape, num_heads):
"""Checks whether a variable requires reshape by pattern matching."""
if "attention/output/dense/kernel" in name:
return tuple([num_heads, shape[0] // num_heads, shape[1]])
if "attention/output/dense/bias" in name:
return shape
patterns = [
"attention/self/query", "attention/self/value", "attention/self/key"
]
for pattern in patterns:
if pattern in name:
if "kernel" in name:
return tuple([shape[0], num_heads, shape[1] // num_heads])
if "bias" in name:
return tuple([num_heads, shape[0] // num_heads])
return None
def convert_names(checkpoint_from_path, def convert_names(checkpoint_from_path,
checkpoint_to_path, checkpoint_to_path,
exclude_patterns=None): exclude_patterns=None):
...@@ -108,6 +133,14 @@ def convert_names(checkpoint_from_path, ...@@ -108,6 +133,14 @@ def convert_names(checkpoint_from_path,
continue continue
new_var_name = _bert_name_replacement(var_name) new_var_name = _bert_name_replacement(var_name)
tensor = reader.get_tensor(var_name) tensor = reader.get_tensor(var_name)
new_shape = None
if FLAGS.num_heads > 0:
new_shape = _get_new_shape(var_name, tensor.shape, FLAGS.num_heads)
if new_shape:
tf.logging.info("Veriable %s has a shape change from %s to %s",
var_name, tensor.shape, new_shape)
tensor = np.reshape(tensor, new_shape)
var = tf.Variable(tensor, name=var_name) var = tf.Variable(tensor, name=var_name)
new_variable_map[new_var_name] = var new_variable_map[new_var_name] = var
if new_var_name != var_name: if new_var_name != var_name:
......
...@@ -73,9 +73,6 @@ def create_bert_model(bert_config): ...@@ -73,9 +73,6 @@ def create_bert_model(bert_config):
def convert_checkpoint(): def convert_checkpoint():
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint.""" """Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
# Sets backward_compatible to true to convert TF1 BERT checkpoints.
bert_config.backward_compatible = True
core_model = create_bert_model(bert_config) core_model = create_bert_model(bert_config)
# Uses streaming-restore in eager model to read V1 name-based checkpoints. # Uses streaming-restore in eager model to read V1 name-based checkpoints.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment