Unverified Commit 9cf7b23b authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Custom TF weights loading (#7422)



* First try

* Fix TF utils

* Handle authorized unexpected keys when loading weights

* Add several more authorized unexpected keys

* Apply style

* Fix test

* Address Patrick's comments.

* Update src/transformers/modeling_tf_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_tf_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply style

* Make return_dict the default behavior and display a warning message

* Revert

* Replace wrong keyword

* Revert code

* Add forgot key

* Fix bug in loading PT models from a TF one.

* Fix sort

* Add a test for custom load weights in BERT

* Apply style

* Remove unused import
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent d3adb985
......@@ -854,6 +854,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
......@@ -939,6 +940,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
......@@ -1286,6 +1288,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
)
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
......@@ -1369,6 +1372,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
)
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
......
......@@ -177,6 +177,13 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
elif len(symbolic_weight.shape) > len(array.shape):
array = numpy.expand_dims(array, axis=0)
if list(symbolic_weight.shape) != list(array.shape):
try:
array = numpy.reshape(array, symbolic_weight.shape)
except AssertionError as e:
e.args += (symbolic_weight.shape, array.shape)
raise e
try:
assert list(symbolic_weight.shape) == list(array.shape)
except AssertionError as e:
......@@ -251,6 +258,8 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
import transformers
from .modeling_tf_utils import load_tf_weights
logger.info("Loading TensorFlow weights from {}".format(tf_checkpoint_path))
# Instantiate and load the associated TF 2.0 model
......@@ -264,7 +273,7 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
if tf_inputs is not None:
tf_model(tf_inputs, training=False) # Make sure model is built
tf_model.load_weights(tf_checkpoint_path, by_name=True)
load_tf_weights(tf_model, tf_checkpoint_path)
return load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=allow_missing_keys)
......@@ -332,6 +341,13 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
elif len(pt_weight.shape) > len(array.shape):
array = numpy.expand_dims(array, axis=0)
if list(pt_weight.shape) != list(array.shape):
try:
array = numpy.reshape(array, pt_weight.shape)
except AssertionError as e:
e.args += (pt_weight.shape, array.shape)
raise e
try:
assert list(pt_weight.shape) == list(array.shape)
except AssertionError as e:
......
......@@ -23,12 +23,12 @@ from typing import Dict, List, Optional, Union
import h5py
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.saving import hdf5_format
from .configuration_utils import PretrainedConfig
from .file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
from .generation_tf_utils import TFGenerationMixin
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
from .utils import logging
......@@ -216,6 +216,91 @@ class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
"""
def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
"""
Detect missing and unexpected layers.
Args:
model (:obj:`tf.keras.models.Model`):
The model to load the weights into.
resolved_archive_file (:obj:`str`):
The location of the H5 file.
Returns:
Two lists, one for the missing layers, and another one for the unexpected layers.
"""
missing_layers = []
unexpected_layers = []
with h5py.File(resolved_archive_file, "r") as f:
saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
model_layer_names = set(layer.name for layer in model.layers)
missing_layers = list(model_layer_names - saved_layer_names)
unexpected_layers = list(saved_layer_names - model_layer_names)
for layer in model.layers:
if layer.name in saved_layer_names:
g = f[layer.name]
saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
saved_weight_names_set = set(
"/".join(weight_name.split("/")[2:]) for weight_name in saved_weight_names
)
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
symbolic_weights_names = set(
"/".join(symbolic_weight.name.split("/")[2:]) for symbolic_weight in symbolic_weights
)
missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))
return missing_layers, unexpected_layers
def load_tf_weights(model, resolved_archive_file):
"""
Load the TF weights from a H5 file.
Args:
model (:obj:`tf.keras.models.Model`):
The model to load the weights into.
resolved_archive_file (:obj:`str`):
The location of the H5 file.
"""
with h5py.File(resolved_archive_file, "r") as f:
saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
weight_value_tuples = []
for layer in model.layers:
if layer.name in saved_layer_names:
g = f[layer.name]
saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
saved_weight_names_values = {}
for weight_name in saved_weight_names:
name = "/".join(weight_name.split("/")[1:])
saved_weight_names_values[name] = np.asarray(g[weight_name])
for symbolic_weight in symbolic_weights:
splited_layers = symbolic_weight.name.split("/")[1:]
symbolic_weight_name = "/".join(splited_layers)
if symbolic_weight_name in saved_weight_names_values:
saved_weight_value = saved_weight_names_values[symbolic_weight_name]
if K.int_shape(symbolic_weight) != saved_weight_value.shape:
try:
array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
except AssertionError as e:
e.args += (K.int_shape(symbolic_weight), saved_weight_value.shape)
raise e
else:
array = saved_weight_value
weight_value_tuples.append((symbolic_weight, array))
K.batch_set_value(weight_value_tuples)
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
r"""
Base class for all TF models.
......@@ -231,10 +316,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model.
- **authorized_missing_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore
from the model when loading the model weights (and avoid unnecessary warnings).
- **authorized_unexpected_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore
from the weights when loading the model weights (and avoid unnecessary warnings).
"""
config_class = None
base_model_prefix = ""
authorized_missing_keys = None
authorized_unexpected_keys = None
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
......@@ -604,6 +694,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
model = cls(config, *model_args, **model_kwargs)
if from_pt:
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
# Load from a PyTorch checkpoint
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True)
......@@ -613,7 +705,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
# 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
try:
model.load_weights(resolved_archive_file, by_name=True)
load_tf_weights(model, resolved_archive_file)
except OSError:
raise OSError(
"Unable to load weights from h5 file. "
......@@ -622,23 +714,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
model(model.dummy_inputs, training=False) # Make sure restore ops are run
# Check if the models are the same to output loading informations
with h5py.File(resolved_archive_file, "r") as f:
if "layer_names" not in f.attrs and "model_weights" in f:
f = f["model_weights"]
hdf5_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
model_layer_names = set(layer.name for layer in model.layers)
missing_keys = list(model_layer_names - hdf5_layer_names)
unexpected_keys = list(hdf5_layer_names - model_layer_names)
error_msgs = []
missing_keys, unexpected_keys = detect_tf_missing_unexpected_layers(model, resolved_archive_file)
if cls.authorized_missing_keys is not None:
for pat in cls.authorized_missing_keys:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if cls.authorized_unexpected_keys is not None:
for pat in cls.authorized_unexpected_keys:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
f"Some layers from the model checkpoint at {pretrained_model_name_or_path} were not used when "
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
......@@ -646,25 +734,24 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else:
logger.warning(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
logger.warning(f"All model checkpoint layers were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
f"Some layers of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
f"and are newly initialized: {missing_keys}\n"
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
logger.warning(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
f"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {model.__class__.__name__} for predictions without further training."
)
if len(error_msgs) > 0:
raise RuntimeError(
"Error(s) in loading weights for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
)
if output_loading_info:
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
return model, loading_info
return model
......
......@@ -17,7 +17,7 @@
import unittest
from transformers import BertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from transformers.testing_utils import require_tf
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
......@@ -317,9 +317,14 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
# for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
for model_name in ["bert-base-uncased"]:
model = TFBertModel.from_pretrained(model_name)
self.assertIsNotNone(model)
model = TFBertModel.from_pretrained("jplu/tiny-tf-bert-random")
self.assertIsNotNone(model)
def test_custom_load_tf_weights(self):
model, output_loading_info = TFBertForTokenClassification.from_pretrained(
"jplu/tiny-tf-bert-random", use_cdn=False, output_loading_info=True
)
self.assertEqual(sorted(output_loading_info["unexpected_keys"]), ["mlm___cls", "nsp___cls"])
for layer in output_loading_info["missing_keys"]:
self.assertTrue(layer.split("_")[0] in ["dropout", "classifier"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment