Commit dc4e7c4f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 274460885
parent 9a833e2c
...@@ -86,6 +86,10 @@ BERT_V2_NAME_REPLACEMENTS = [ ...@@ -86,6 +86,10 @@ BERT_V2_NAME_REPLACEMENTS = [
("output/dense", "output"), ("output/dense", "output"),
("output/LayerNorm", "output_layer_norm"), ("output/LayerNorm", "output_layer_norm"),
("pooler/dense", "pooler_transform"), ("pooler/dense", "pooler_transform"),
("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
("cls/seq_relationship/output_weights",
"predictions/transform/logits/kernel"),
] ]
...@@ -111,6 +115,17 @@ def _has_exclude_patterns(name, exclude_patterns): ...@@ -111,6 +115,17 @@ def _has_exclude_patterns(name, exclude_patterns):
return False return False
def _get_permutation(name):
"""Checks whether a variable requires transposition by pattern matching."""
if not FLAGS.use_v2_names:
return None
if "cls/seq_relationship/output_weights" in name:
return (1, 0)
return None
def _get_new_shape(name, shape, num_heads): def _get_new_shape(name, shape, num_heads):
"""Checks whether a variable requires reshape by pattern matching.""" """Checks whether a variable requires reshape by pattern matching."""
if "attention/output/dense/kernel" in name: if "attention/output/dense/kernel" in name:
...@@ -154,8 +169,13 @@ def convert_names(checkpoint_from_path, ...@@ -154,8 +169,13 @@ def convert_names(checkpoint_from_path,
for var_name in name_shape_map: for var_name in name_shape_map:
if exclude_patterns and _has_exclude_patterns(var_name, exclude_patterns): if exclude_patterns and _has_exclude_patterns(var_name, exclude_patterns):
continue continue
new_var_name = _bert_name_replacement(var_name) # Get the original tensor data.
tensor = reader.get_tensor(var_name) tensor = reader.get_tensor(var_name)
# Look up the new variable name, if any.
new_var_name = _bert_name_replacement(var_name)
# See if we need to reshape the underlying tensor.
new_shape = None new_shape = None
if FLAGS.num_heads > 0: if FLAGS.num_heads > 0:
new_shape = _get_new_shape(var_name, tensor.shape, FLAGS.num_heads) new_shape = _get_new_shape(var_name, tensor.shape, FLAGS.num_heads)
...@@ -164,8 +184,19 @@ def convert_names(checkpoint_from_path, ...@@ -164,8 +184,19 @@ def convert_names(checkpoint_from_path,
var_name, tensor.shape, new_shape) var_name, tensor.shape, new_shape)
tensor = np.reshape(tensor, new_shape) tensor = np.reshape(tensor, new_shape)
# See if we need to permute the underlying tensor.
permutation = _get_permutation(var_name)
if permutation:
tensor = np.transpose(tensor, permutation)
# Create a new variable with the possibly-reshaped or transposed tensor.
var = tf.Variable(tensor, name=var_name) var = tf.Variable(tensor, name=var_name)
# Save the variable into the new variable map.
new_variable_map[new_var_name] = var new_variable_map[new_var_name] = var
# Keep a list of converter variables for sanity checking.
if new_var_name != var_name: if new_var_name != var_name:
conversion_map[var_name] = new_var_name conversion_map[var_name] = new_var_name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment