Unverified Commit 3bc15400 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

New TF loading weights (#8490)

* New TF loading weights

* apply style

* Better naming

* Largely comment the loading method

* Apply style

* Address Patrick's comments

* Remove useless line of code

* Update Docstring

* Address Sylvain's and Lysandre's comments

* Simplify the names computation

* Typos
parent 0df91ee4
......@@ -236,9 +236,9 @@ class TFNextSentencePredictionLoss:
return loss_fn(next_sentence_label, next_sentence_reduced_logits)
def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
def load_tf_weights(model, resolved_archive_file):
"""
Detect missing and unexpected layers.
Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes.
Args:
model (:obj:`tf.keras.models.Model`):
......@@ -252,62 +252,60 @@ def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
missing_layers = []
unexpected_layers = []
# Read the H5 file
with h5py.File(resolved_archive_file, "r") as f:
saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
model_layer_names = set(layer.name for layer in model.layers)
missing_layers = list(model_layer_names - saved_layer_names)
unexpected_layers = list(saved_layer_names - model_layer_names)
for layer in model.layers:
if layer.name in saved_layer_names:
g = f[layer.name]
saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
saved_weight_names_set = set(
"/".join(weight_name.split("/")[2:]) for weight_name in saved_weight_names
)
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
symbolic_weights_names = set(
"/".join(symbolic_weight.name.split("/")[2:]) for symbolic_weight in symbolic_weights
)
missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))
# Retrieve the name of each layer from the H5 file
saved_h5_model_layers_name = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
return missing_layers, unexpected_layers
def load_tf_weights(model, resolved_archive_file):
"""
Load the TF weights from a H5 file.
# Find the missing layers from the high level list of layers
missing_layers = list(set([layer.name for layer in model.layers]) - saved_h5_model_layers_name)
Args:
model (:obj:`tf.keras.models.Model`):
The model to load the weights into.
resolved_archive_file (:obj:`str`):
The location of the H5 file.
"""
with h5py.File(resolved_archive_file, "r") as f:
saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
# Find the unexpected layers from the high level list of layers
unexpected_layers = list(saved_h5_model_layers_name - set([layer.name for layer in model.layers]))
saved_weight_names_set = set()
symbolic_weights_names = set()
weight_value_tuples = []
# Compute missing and unexpected sub layers
# Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...]
for layer in model.layers:
if layer.name in saved_layer_names:
g = f[layer.name]
saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
# if layer_name from the H5 file belongs to the layers from the instantiated model
if layer.name in saved_h5_model_layers_name:
# Get the H5 layer object from its name
h5_layer_object = f[layer.name]
# Get all the weights as a list from the layer object
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
saved_weight_names_values = {}
saved_weights = {}
for weight_name in saved_weight_names:
# Create a dict from the H5 saved model that looks like {"weight_name": weight_value}
# And a set with only the names
for weight_name in hdf5_format.load_attributes_from_hdf5_group(h5_layer_object, "weight_names"):
# TF names always start with the model name so we ignore it
name = "/".join(weight_name.split("/")[1:])
saved_weight_names_values[name] = np.asarray(g[weight_name])
saved_weights[name] = np.asarray(h5_layer_object[weight_name])
# Add the updated name to the final list for computing missing/unexpected values
saved_weight_names_set.add(name)
# Loop over each weights from the instantiated model and compare with the weights from the H5 file
for symbolic_weight in symbolic_weights:
splited_layers = symbolic_weight.name.split("/")[1:]
symbolic_weight_name = "/".join(splited_layers)
# TF names always start with the model name so we ignore it
symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:])
# here we check if the current weight is among the weights from the H5 file
# If yes, get the weight_value of the corresponding weight from the H5 file
# If not, make the value to None
saved_weight_value = saved_weights.get(symbolic_weight_name, None)
if symbolic_weight_name in saved_weight_names_values:
saved_weight_value = saved_weight_names_values[symbolic_weight_name]
# Add the updated name to the final list for computing missing/unexpected values
symbolic_weights_names.add(symbolic_weight_name)
# If the current weight is found
if saved_weight_value is not None:
# Check if the shape of the current weight and the one from the H5 file are different
if K.int_shape(symbolic_weight) != saved_weight_value.shape:
# If yes we reshape the weight from the H5 file accordingly to the current weight
# If the two shapes are not compatible we raise an issue
try:
array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
except AssertionError as e:
......@@ -316,10 +314,18 @@ def load_tf_weights(model, resolved_archive_file):
else:
array = saved_weight_value
# We create the tuple that will be loaded and add it to the final list
weight_value_tuples.append((symbolic_weight, array))
# Load all the weights
K.batch_set_value(weight_value_tuples)
# Compute the missing and unexpected layers
missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))
return missing_layers, unexpected_layers
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
r"""
......@@ -727,7 +733,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
# 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
try:
load_tf_weights(model, resolved_archive_file)
missing_keys, unexpected_keys = load_tf_weights(model, resolved_archive_file)
except OSError:
raise OSError(
"Unable to load weights from h5 file. "
......@@ -736,8 +742,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
model(model.dummy_inputs, training=False) # Make sure restore ops are run
missing_keys, unexpected_keys = detect_tf_missing_unexpected_layers(model, resolved_archive_file)
if cls.authorized_missing_keys is not None:
for pat in cls.authorized_missing_keys:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
......@@ -1033,18 +1037,18 @@ class TFSequenceSummary(tf.keras.layers.Layer):
return output
def shape_list(x: tf.Tensor) -> List[int]:
def shape_list(tensor: tf.Tensor) -> List[int]:
"""
Deal with dynamic shape in tensorflow cleanly.
Args:
x (:obj:`tf.Tensor`): The tensor we want the shape of.
tensor (:obj:`tf.Tensor`): The tensor we want the shape of.
Returns:
:obj:`List[int]`: The shape of the tensor as a list.
"""
static = x.shape.as_list()
dynamic = tf.shape(x)
static = tensor.shape.as_list()
dynamic = tf.shape(tensor)
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment