Unverified Commit 9cf7b23b authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Custom TF weights loading (#7422)



* First try

* Fix TF utils

* Handle authorized unexpected keys when loading weights

* Add several more authorized unexpected keys

* Apply style

* Fix test

* Address Patrick's comments.

* Update src/transformers/modeling_tf_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_tf_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply style

* Make return_dict the default behavior and display a warning message

* Revert

* Replace wrong keyword

* Revert code

* Add forgot key

* Fix bug in loading PT models from a TF one.

* Fix sort

* Add a test for custom load weights in BERT

* Apply style

* Remove unused import
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent d3adb985
...@@ -854,6 +854,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel): ...@@ -854,6 +854,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"pooler"] authorized_missing_keys = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
...@@ -939,6 +940,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -939,6 +940,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"pooler"] authorized_missing_keys = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
...@@ -1286,6 +1288,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1286,6 +1288,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
) )
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss): class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"pooler"] authorized_missing_keys = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
...@@ -1369,6 +1372,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ...@@ -1369,6 +1372,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
) )
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss): class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"pooler"] authorized_missing_keys = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
......
...@@ -177,6 +177,13 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a ...@@ -177,6 +177,13 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
elif len(symbolic_weight.shape) > len(array.shape): elif len(symbolic_weight.shape) > len(array.shape):
array = numpy.expand_dims(array, axis=0) array = numpy.expand_dims(array, axis=0)
if list(symbolic_weight.shape) != list(array.shape):
try:
array = numpy.reshape(array, symbolic_weight.shape)
except AssertionError as e:
e.args += (symbolic_weight.shape, array.shape)
raise e
try: try:
assert list(symbolic_weight.shape) == list(array.shape) assert list(symbolic_weight.shape) == list(array.shape)
except AssertionError as e: except AssertionError as e:
...@@ -251,6 +258,8 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs ...@@ -251,6 +258,8 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
import transformers import transformers
from .modeling_tf_utils import load_tf_weights
logger.info("Loading TensorFlow weights from {}".format(tf_checkpoint_path)) logger.info("Loading TensorFlow weights from {}".format(tf_checkpoint_path))
# Instantiate and load the associated TF 2.0 model # Instantiate and load the associated TF 2.0 model
...@@ -264,7 +273,7 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs ...@@ -264,7 +273,7 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
if tf_inputs is not None: if tf_inputs is not None:
tf_model(tf_inputs, training=False) # Make sure model is built tf_model(tf_inputs, training=False) # Make sure model is built
tf_model.load_weights(tf_checkpoint_path, by_name=True) load_tf_weights(tf_model, tf_checkpoint_path)
return load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=allow_missing_keys) return load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=allow_missing_keys)
...@@ -332,6 +341,13 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F ...@@ -332,6 +341,13 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
elif len(pt_weight.shape) > len(array.shape): elif len(pt_weight.shape) > len(array.shape):
array = numpy.expand_dims(array, axis=0) array = numpy.expand_dims(array, axis=0)
if list(pt_weight.shape) != list(array.shape):
try:
array = numpy.reshape(array, pt_weight.shape)
except AssertionError as e:
e.args += (pt_weight.shape, array.shape)
raise e
try: try:
assert list(pt_weight.shape) == list(array.shape) assert list(pt_weight.shape) == list(array.shape)
except AssertionError as e: except AssertionError as e:
......
...@@ -23,12 +23,12 @@ from typing import Dict, List, Optional, Union ...@@ -23,12 +23,12 @@ from typing import Dict, List, Optional, Union
import h5py import h5py
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.saving import hdf5_format from tensorflow.python.keras.saving import hdf5_format
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url from .file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
from .generation_tf_utils import TFGenerationMixin from .generation_tf_utils import TFGenerationMixin
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
from .utils import logging from .utils import logging
...@@ -216,6 +216,91 @@ class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss): ...@@ -216,6 +216,91 @@ class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
""" """
def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
"""
Detect missing and unexpected layers.
Args:
model (:obj:`tf.keras.models.Model`):
The model to load the weights into.
resolved_archive_file (:obj:`str`):
The location of the H5 file.
Returns:
Two lists, one for the missing layers, and another one for the unexpected layers.
"""
missing_layers = []
unexpected_layers = []
with h5py.File(resolved_archive_file, "r") as f:
saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
model_layer_names = set(layer.name for layer in model.layers)
missing_layers = list(model_layer_names - saved_layer_names)
unexpected_layers = list(saved_layer_names - model_layer_names)
for layer in model.layers:
if layer.name in saved_layer_names:
g = f[layer.name]
saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
saved_weight_names_set = set(
"/".join(weight_name.split("/")[2:]) for weight_name in saved_weight_names
)
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
symbolic_weights_names = set(
"/".join(symbolic_weight.name.split("/")[2:]) for symbolic_weight in symbolic_weights
)
missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))
return missing_layers, unexpected_layers
def load_tf_weights(model, resolved_archive_file):
"""
Load the TF weights from a H5 file.
Args:
model (:obj:`tf.keras.models.Model`):
The model to load the weights into.
resolved_archive_file (:obj:`str`):
The location of the H5 file.
"""
with h5py.File(resolved_archive_file, "r") as f:
saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
weight_value_tuples = []
for layer in model.layers:
if layer.name in saved_layer_names:
g = f[layer.name]
saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
saved_weight_names_values = {}
for weight_name in saved_weight_names:
name = "/".join(weight_name.split("/")[1:])
saved_weight_names_values[name] = np.asarray(g[weight_name])
for symbolic_weight in symbolic_weights:
splited_layers = symbolic_weight.name.split("/")[1:]
symbolic_weight_name = "/".join(splited_layers)
if symbolic_weight_name in saved_weight_names_values:
saved_weight_value = saved_weight_names_values[symbolic_weight_name]
if K.int_shape(symbolic_weight) != saved_weight_value.shape:
try:
array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
except AssertionError as e:
e.args += (K.int_shape(symbolic_weight), saved_weight_value.shape)
raise e
else:
array = saved_weight_value
weight_value_tuples.append((symbolic_weight, array))
K.batch_set_value(weight_value_tuples)
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
r""" r"""
Base class for all TF models. Base class for all TF models.
...@@ -231,10 +316,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -231,10 +316,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model. derived classes of the same architecture adding modules on top of the base model.
- **authorized_missing_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore
from the model when loading the model weights (and avoid unnecessary warnings).
- **authorized_unexpected_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore
from the weights when loading the model weights (and avoid unnecessary warnings).
""" """
config_class = None config_class = None
base_model_prefix = "" base_model_prefix = ""
authorized_missing_keys = None authorized_missing_keys = None
authorized_unexpected_keys = None
@property @property
def dummy_inputs(self) -> Dict[str, tf.Tensor]: def dummy_inputs(self) -> Dict[str, tf.Tensor]:
...@@ -604,6 +694,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -604,6 +694,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
model = cls(config, *model_args, **model_kwargs) model = cls(config, *model_args, **model_kwargs)
if from_pt: if from_pt:
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True) return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True)
...@@ -613,7 +705,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -613,7 +705,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
# 'by_name' allow us to do transfer learning by skipping/adding layers # 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
try: try:
model.load_weights(resolved_archive_file, by_name=True) load_tf_weights(model, resolved_archive_file)
except OSError: except OSError:
raise OSError( raise OSError(
"Unable to load weights from h5 file. " "Unable to load weights from h5 file. "
...@@ -622,23 +714,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -622,23 +714,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
model(model.dummy_inputs, training=False) # Make sure restore ops are run model(model.dummy_inputs, training=False) # Make sure restore ops are run
# Check if the models are the same to output loading informations missing_keys, unexpected_keys = detect_tf_missing_unexpected_layers(model, resolved_archive_file)
with h5py.File(resolved_archive_file, "r") as f:
if "layer_names" not in f.attrs and "model_weights" in f:
f = f["model_weights"]
hdf5_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
model_layer_names = set(layer.name for layer in model.layers)
missing_keys = list(model_layer_names - hdf5_layer_names)
unexpected_keys = list(hdf5_layer_names - model_layer_names)
error_msgs = []
if cls.authorized_missing_keys is not None: if cls.authorized_missing_keys is not None:
for pat in cls.authorized_missing_keys: for pat in cls.authorized_missing_keys:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None] missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if cls.authorized_unexpected_keys is not None:
for pat in cls.authorized_unexpected_keys:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
logger.warning( logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " f"Some layers from the model checkpoint at {pretrained_model_name_or_path} were not used when "
f"initializing {model.__class__.__name__}: {unexpected_keys}\n" f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n" f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
...@@ -646,25 +734,24 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -646,25 +734,24 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
) )
else: else:
logger.warning(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") logger.warning(f"All model checkpoint layers were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0: if len(missing_keys) > 0:
logger.warning( logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " f"Some layers of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
f"and are newly initialized: {missing_keys}\n" f"and are newly initialized: {missing_keys}\n"
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
) )
else: else:
logger.warning( logger.warning(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
f"If your task is similar to the task the model of the checkpoint was trained on, " f"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {model.__class__.__name__} for predictions without further training." f"you can already use {model.__class__.__name__} for predictions without further training."
) )
if len(error_msgs) > 0:
raise RuntimeError(
"Error(s) in loading weights for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
)
if output_loading_info: if output_loading_info:
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs} loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
return model, loading_info return model, loading_info
return model return model
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import unittest import unittest
from transformers import BertConfig, is_tf_available from transformers import BertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow from transformers.testing_utils import require_tf
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
...@@ -317,9 +317,14 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -317,9 +317,14 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs) self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
# for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: model = TFBertModel.from_pretrained("jplu/tiny-tf-bert-random")
for model_name in ["bert-base-uncased"]: self.assertIsNotNone(model)
model = TFBertModel.from_pretrained(model_name)
self.assertIsNotNone(model) def test_custom_load_tf_weights(self):
model, output_loading_info = TFBertForTokenClassification.from_pretrained(
"jplu/tiny-tf-bert-random", use_cdn=False, output_loading_info=True
)
self.assertEqual(sorted(output_loading_info["unexpected_keys"]), ["mlm___cls", "nsp___cls"])
for layer in output_loading_info["missing_keys"]:
self.assertTrue(layer.split("_")[0] in ["dropout", "classifier"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment