"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b1019d2a8e5725f4f72fc8abb4085fef8a60c7e4"
Unverified Commit acfb714b authored by Matt's avatar Matt Committed by GitHub
Browse files

Improve TF weight loading, especially PT crossloading (#21792)

* First commit for the improved PT-TF weight loading

* Remove workarounds from TFEncoderDecoder tests

* Allow a custom weight renaming function in from_pretrained and use that to clean up EncoderDecoder

* make fixup

* First attempt at visionencoderdecoder

* Disable tensorfloat32 in tests to get consistent outputs

* Quick fix to tf_vision_encoder_decoder tests

* make fixup

* Update Blenderbot tests

* Remove unused arg in modeling_tf_opt

* load_tf_sharded_weights had strict=True! This meant transfer learning was impossible, so I'm setting it to False.

* Support prefixes when loading sharded TF checkpoints

* make fixup

* Add test to load sharded models with a weight prefix

* Fix sharded weight loading test

* Add a test for transfer from a sharded checkpoint

* make fixup

* Add test to check that crossloading from PT with a prefix works

* Refactor from_pretrained in the encoderdecoder classes

* Refactor from_pretrained in the encoderdecoder classes

* missmatched -> mismatched

* Explicitly check for None

* No comments showing my very impressive and attractive knowledge of Py3.9+

* Disable TF32 across all TF tests
parent 871c31a6
...@@ -39,7 +39,9 @@ class TransposeType(ExplicitEnum): ...@@ -39,7 +39,9 @@ class TransposeType(ExplicitEnum):
CONV2D = "conv2d" CONV2D = "conv2d"
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="", tf_weight_shape=None): def convert_tf_weight_name_to_pt_weight_name(
tf_name, start_prefix_to_remove="", tf_weight_shape=None, name_scope=None
):
""" """
Convert a TF 2.0 model variable name in a pytorch model weight name. Convert a TF 2.0 model variable name in a pytorch model weight name.
...@@ -54,6 +56,14 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="", ...@@ -54,6 +56,14 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="",
- transpose: `TransposeType` member indicating whether and how TF2.0 and PyTorch weights matrices should be - transpose: `TransposeType` member indicating whether and how TF2.0 and PyTorch weights matrices should be
transposed with regards to each other transposed with regards to each other
""" """
if name_scope is not None:
if not tf_name.startswith(name_scope):
raise ValueError(
f"Weight name {tf_name} does not start with name_scope {name_scope}. This is an internal error "
"in Transformers, so (unless you were doing something really evil) please open an issue to report it!"
)
tf_name = tf_name[len(name_scope) :]
tf_name = tf_name.lstrip("/")
tf_name = tf_name.replace(":0", "") # device ids tf_name = tf_name.replace(":0", "") # device ids
tf_name = re.sub( tf_name = re.sub(
r"/[^/]*___([^/]*)/", r"/\1/", tf_name r"/[^/]*___([^/]*)/", r"/\1/", tf_name
...@@ -144,7 +154,13 @@ def apply_transpose(transpose: TransposeType, weight, match_shape=None, pt_to_tf ...@@ -144,7 +154,13 @@ def apply_transpose(transpose: TransposeType, weight, match_shape=None, pt_to_tf
def load_pytorch_checkpoint_in_tf2_model( def load_pytorch_checkpoint_in_tf2_model(
tf_model, pytorch_checkpoint_path, tf_inputs=None, allow_missing_keys=False, output_loading_info=False tf_model,
pytorch_checkpoint_path,
tf_inputs=None,
allow_missing_keys=False,
output_loading_info=False,
_prefix=None,
tf_to_pt_weight_rename=None,
): ):
"""Load pytorch checkpoints in a TF 2.0 model""" """Load pytorch checkpoints in a TF 2.0 model"""
try: try:
...@@ -176,6 +192,8 @@ def load_pytorch_checkpoint_in_tf2_model( ...@@ -176,6 +192,8 @@ def load_pytorch_checkpoint_in_tf2_model(
tf_inputs=tf_inputs, tf_inputs=tf_inputs,
allow_missing_keys=allow_missing_keys, allow_missing_keys=allow_missing_keys,
output_loading_info=output_loading_info, output_loading_info=output_loading_info,
_prefix=_prefix,
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
) )
...@@ -189,7 +207,13 @@ def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_mi ...@@ -189,7 +207,13 @@ def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_mi
def load_pytorch_weights_in_tf2_model( def load_pytorch_weights_in_tf2_model(
tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False, output_loading_info=False tf_model,
pt_state_dict,
tf_inputs=None,
allow_missing_keys=False,
output_loading_info=False,
_prefix=None,
tf_to_pt_weight_rename=None,
): ):
"""Load pytorch state_dict in a TF 2.0 model.""" """Load pytorch state_dict in a TF 2.0 model."""
try: try:
...@@ -209,11 +233,19 @@ def load_pytorch_weights_in_tf2_model( ...@@ -209,11 +233,19 @@ def load_pytorch_weights_in_tf2_model(
tf_inputs=tf_inputs, tf_inputs=tf_inputs,
allow_missing_keys=allow_missing_keys, allow_missing_keys=allow_missing_keys,
output_loading_info=output_loading_info, output_loading_info=output_loading_info,
_prefix=_prefix,
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
) )
def load_pytorch_state_dict_in_tf2_model( def load_pytorch_state_dict_in_tf2_model(
tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False, output_loading_info=False tf_model,
pt_state_dict,
tf_inputs=None,
allow_missing_keys=False,
output_loading_info=False,
_prefix=None,
tf_to_pt_weight_rename=None,
): ):
"""Load a pytorch state_dict in a TF 2.0 model.""" """Load a pytorch state_dict in a TF 2.0 model."""
import tensorflow as tf import tensorflow as tf
...@@ -227,8 +259,11 @@ def load_pytorch_state_dict_in_tf2_model( ...@@ -227,8 +259,11 @@ def load_pytorch_state_dict_in_tf2_model(
if tf_inputs is None: if tf_inputs is None:
tf_inputs = tf_model.dummy_inputs tf_inputs = tf_model.dummy_inputs
if _prefix is None:
_prefix = ""
if tf_inputs is not None: if tf_inputs is not None:
tf_model(tf_inputs, training=False) # Make sure model is built with tf.name_scope(_prefix):
tf_model(tf_inputs, training=False) # Make sure model is built
# Adapt state dict - TODO remove this and update the AWS weights files instead # Adapt state dict - TODO remove this and update the AWS weights files instead
# Convert old format to new format if needed from a PyTorch state_dict # Convert old format to new format if needed from a PyTorch state_dict
old_keys = [] old_keys = []
...@@ -249,8 +284,10 @@ def load_pytorch_state_dict_in_tf2_model( ...@@ -249,8 +284,10 @@ def load_pytorch_state_dict_in_tf2_model(
for old_key, new_key in zip(old_keys, new_keys): for old_key, new_key in zip(old_keys, new_keys):
pt_state_dict[new_key] = pt_state_dict.pop(old_key) pt_state_dict[new_key] = pt_state_dict.pop(old_key)
# Make sure we are able to load PyTorch base models as well as derived models (with heads) # Matt: All TF models store the actual model stem in a MainLayer class, including the base model.
# TF models always have a prefix, some of PyTorch models (base ones) don't # In PT, the derived models (with heads) use the base model class as the stem instead, and the base model
# just contains the stem itself, and there is no MainLayer class. This means that TF base classes have one
# extra layer in their weight names, corresponding to the MainLayer class. This code block compensates for that.
start_prefix_to_remove = "" start_prefix_to_remove = ""
if not any(s.startswith(tf_model.base_model_prefix) for s in pt_state_dict.keys()): if not any(s.startswith(tf_model.base_model_prefix) for s in pt_state_dict.keys()):
start_prefix_to_remove = tf_model.base_model_prefix + "." start_prefix_to_remove = tf_model.base_model_prefix + "."
...@@ -263,8 +300,13 @@ def load_pytorch_state_dict_in_tf2_model( ...@@ -263,8 +300,13 @@ def load_pytorch_state_dict_in_tf2_model(
for symbolic_weight in symbolic_weights: for symbolic_weight in symbolic_weights:
sw_name = symbolic_weight.name sw_name = symbolic_weight.name
name, transpose = convert_tf_weight_name_to_pt_weight_name( name, transpose = convert_tf_weight_name_to_pt_weight_name(
sw_name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=symbolic_weight.shape sw_name,
start_prefix_to_remove=start_prefix_to_remove,
tf_weight_shape=symbolic_weight.shape,
name_scope=_prefix,
) )
if tf_to_pt_weight_rename is not None:
name = tf_to_pt_weight_rename(name)
# Find associated numpy array in pytorch model state dict # Find associated numpy array in pytorch model state dict
if name not in pt_state_dict: if name not in pt_state_dict:
......
...@@ -707,7 +707,7 @@ def tf_shard_checkpoint(weights, max_shard_size="10GB"): ...@@ -707,7 +707,7 @@ def tf_shard_checkpoint(weights, max_shard_size="10GB"):
return shards, index return shards, index
def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=True): def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None):
""" """
This is the same as `load_tf_weights` but for a sharded checkpoint. Detect missing and unexpected layers and load This is the same as `load_tf_weights` but for a sharded checkpoint. Detect missing and unexpected layers and load
the TF weights from the shard file accordingly to their names and shapes. the TF weights from the shard file accordingly to their names and shapes.
...@@ -729,32 +729,35 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s ...@@ -729,32 +729,35 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s
""" """
# Load the index # Load the index
missing_keys = []
unexpected_keys = set() unexpected_keys = set()
saved_keys = set() saved_keys = set()
missmatched_keys = set() mismatched_keys = set()
# Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load # Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
# the weight, we have to get rid of the first prefix of the name of the layer. # the weight, we have to get rid of the first prefix of the name of the layer.
model_keys = set() model_keys = set()
model_layer_map = {} model_layer_map = {}
for i, k in enumerate(model.weights): for i, k in enumerate(model.weights):
if "model." in k.name or len(k.name.split("/")) == 1: layer_name = k.name
layer_name = k.name if _prefix is not None and layer_name.startswith(_prefix):
else: layer_name = layer_name[len(_prefix) :]
layer_name = "/".join(k.name.split("/")[1:]) layer_name = layer_name.lstrip("/")
if not ("model." in layer_name or len(layer_name.split("/")) == 1):
layer_name = "/".join(layer_name.split("/")[1:])
model_keys.add(layer_name) model_keys.add(layer_name)
model_layer_map[layer_name] = i model_layer_map[layer_name] = i
for shard_file in shard_files: for shard_file in shard_files:
state_dict = tf.io.read_file(shard_file) saved_weight_names_set, unexpected_keys_set, mismatched_keys_set = load_tf_shard(
saved_weight_names_set, unexpected_keys_set, missmatched_keys_set = load_tf_shard( model,
model, model_layer_map, shard_file, ignore_mismatched_sizes=ignore_mismatched_sizes model_layer_map,
shard_file,
ignore_mismatched_sizes=ignore_mismatched_sizes,
_prefix=_prefix,
) )
saved_keys.update(saved_weight_names_set) saved_keys.update(saved_weight_names_set)
unexpected_keys.update(unexpected_keys_set) unexpected_keys.update(unexpected_keys_set)
missmatched_keys.update(missmatched_keys_set) mismatched_keys.update(mismatched_keys_set)
del state_dict
gc.collect() gc.collect()
missing_keys = model_keys - saved_keys missing_keys = model_keys - saved_keys
...@@ -768,10 +771,10 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s ...@@ -768,10 +771,10 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s
error_message += f"\nMissing key(s): {str_unexpected_keys}." error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message) raise RuntimeError(error_message)
return missing_keys, unexpected_keys, missmatched_keys return missing_keys, unexpected_keys, mismatched_keys
def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False): def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
""" """
Loads a shard from a sharded checkpoint file. Handles the missing keys and unexpected keys. Loads a shard from a sharded checkpoint file. Handles the missing keys and unexpected keys.
...@@ -783,11 +786,11 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch ...@@ -783,11 +786,11 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch
Returns: Returns:
`tf.keras.models.Model`: Three lists, one for the layers that were found and succesfully restored (from the `tf.keras.models.Model`: Three lists, one for the layers that were found and succesfully restored (from the
shard file), one for the missmatched layers, and another one for the unexpected layers. shard file), one for the mismatched layers, and another one for the unexpected layers.
""" """
saved_weight_names_set = set() saved_weight_names_set = set()
saved_weights = {} saved_weights = {}
missmatched_keys = set() mismatched_keys = set()
unexpected_keys = set() unexpected_keys = set()
# Read the H5 file # Read the H5 file
try: try:
...@@ -822,7 +825,7 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch ...@@ -822,7 +825,7 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch
array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight)) array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
except ValueError as e: except ValueError as e:
if ignore_mismatched_sizes: if ignore_mismatched_sizes:
missmatched_keys.add( mismatched_keys.add(
(layer_name, saved_weight_value.shape, K.int_shape(symbolic_weight)) (layer_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
) )
continue continue
...@@ -836,7 +839,7 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch ...@@ -836,7 +839,7 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch
K.batch_set_value(weight_value_tuples) K.batch_set_value(weight_value_tuples)
return saved_weight_names_set, unexpected_keys, missmatched_keys return saved_weight_names_set, unexpected_keys, mismatched_keys
except Exception as e: except Exception as e:
try: try:
...@@ -2458,6 +2461,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2458,6 +2461,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here. specify the folder name here.
tf_to_pt_weight_rename (`Callable`, *optional*):
A function that is called to transform the names of weights during the PyTorch to TensorFlow
crossloading process. This is not necessary for most models, but is useful to allow composite models to
be crossloaded correctly.
kwargs (remaining dictionary of keyword arguments, *optional*): kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
...@@ -2506,6 +2513,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2506,6 +2513,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
subfolder = kwargs.pop("subfolder", "") subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None) commit_hash = kwargs.pop("_commit_hash", None)
tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None)
if trust_remote_code is True: if trust_remote_code is True:
logger.warning( logger.warning(
...@@ -2745,7 +2753,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2745,7 +2753,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
return load_pytorch_checkpoint_in_tf2_model( return load_pytorch_checkpoint_in_tf2_model(
model, resolved_archive_file, allow_missing_keys=True, output_loading_info=output_loading_info model,
resolved_archive_file,
allow_missing_keys=True,
output_loading_info=output_loading_info,
_prefix=load_weight_prefix,
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
) )
# we might need to extend the variable scope for composite models # we might need to extend the variable scope for composite models
...@@ -2761,7 +2774,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2761,7 +2774,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
state_dict = safe_load_file(resolved_archive_file) state_dict = safe_load_file(resolved_archive_file)
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
return load_pytorch_state_dict_in_tf2_model( return load_pytorch_state_dict_in_tf2_model(
model, state_dict, allow_missing_keys=True, output_loading_info=output_loading_info model,
state_dict,
allow_missing_keys=True,
output_loading_info=output_loading_info,
_prefix=load_weight_prefix,
) )
# 'by_name' allow us to do transfer learning by skipping/adding layers # 'by_name' allow us to do transfer learning by skipping/adding layers
...@@ -2775,6 +2792,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2775,6 +2792,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
model, model,
resolved_archive_file, resolved_archive_file,
ignore_mismatched_sizes=ignore_mismatched_sizes, ignore_mismatched_sizes=ignore_mismatched_sizes,
_prefix=load_weight_prefix,
) )
else: else:
missing_keys, unexpected_keys, mismatched_keys = load_tf_weights( missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
......
...@@ -15,9 +15,7 @@ ...@@ -15,9 +15,7 @@
""" Classes to support TF Encoder-Decoder architectures""" """ Classes to support TF Encoder-Decoder architectures"""
import gc import re
import os
import tempfile
import warnings import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -306,46 +304,23 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -306,46 +304,23 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
>>> model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16") >>> model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16")
```""" ```"""
# Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models
from_pt = kwargs.pop("from_pt", False) # (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal.
if from_pt: # However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption
import torch # here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's
# not the case, and I wasn't sure how else to go from the config to the correct MainLayer name!
from transformers import EncoderDecoderModel
if kwargs.get("from_pt", False):
# a workaround to load from pytorch checkpoint config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
_model = EncoderDecoderModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) encoder_model_type = config.encoder.model_type
config = _model.config
def tf_to_pt_weight_rename(tf_weight):
with tempfile.TemporaryDirectory() as tmpdirname: if "encoder" in tf_weight and "decoder" not in tf_weight:
encoder_dir = os.path.join(tmpdirname, "encoder") return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight)
decoder_dir = os.path.join(tmpdirname, "decoder") else:
_model.encoder.save_pretrained(encoder_dir) return tf_weight
_model.decoder.save_pretrained(decoder_dir)
kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename
if hasattr(_model, "enc_to_dec_proj"):
enc_to_dec_proj_kernel = tf.transpose(
tf.constant(_model.enc_to_dec_proj.weight.detach().to("cpu").numpy()), perm=(1, 0)
)
enc_to_dec_proj_bias = tf.constant(_model.enc_to_dec_proj.bias.detach().to("cpu").numpy())
del _model
gc.collect()
torch.cuda.empty_cache()
model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_dir, decoder_dir, encoder_from_pt=True, decoder_from_pt=True
)
# This is only for copying some specific attributes of this particular model.
model.config = config
if hasattr(model, "enc_to_dec_proj"):
model(model.dummy_inputs)
model.enc_to_dec_proj.kernel.assign(enc_to_dec_proj_kernel)
model.enc_to_dec_proj.bias.assign(enc_to_dec_proj_bias)
return model
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
@classmethod @classmethod
...@@ -451,14 +426,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -451,14 +426,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix
encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
# This is necessary to make `from_pretrained` following `save_pretrained` work correctly
if kwargs_encoder.get("from_pt", None):
del kwargs_encoder["from_pt"]
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder.save_pretrained(tmp_dirname)
del encoder
encoder = TFAutoModel.from_pretrained(tmp_dirname, *model_args, **kwargs_encoder)
decoder = kwargs_decoder.pop("model", None) decoder = kwargs_decoder.pop("model", None)
if decoder is None: if decoder is None:
if decoder_pretrained_model_name_or_path is None: if decoder_pretrained_model_name_or_path is None:
...@@ -493,14 +460,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -493,14 +460,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix
decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
# This is necessary to make `from_pretrained` following `save_pretrained` work correctly
if kwargs_decoder.get("from_pt", None):
del kwargs_decoder["from_pt"]
with tempfile.TemporaryDirectory() as tmp_dirname:
decoder.save_pretrained(tmp_dirname)
del decoder
decoder = TFAutoModelForCausalLM.from_pretrained(tmp_dirname, **kwargs_decoder)
# Make sure these 2 `tf.keras.Model` have fixed names so `from_pretrained` could load model weights correctly. # Make sure these 2 `tf.keras.Model` have fixed names so `from_pretrained` could load model weights correctly.
if encoder.name != "encoder": if encoder.name != "encoder":
raise ValueError("encoder model must be created with the name `encoder`.") raise ValueError("encoder model must be created with the name `encoder`.")
......
...@@ -486,7 +486,7 @@ OPT_INPUTS_DOCSTRING = r""" ...@@ -486,7 +486,7 @@ OPT_INPUTS_DOCSTRING = r"""
class TFOPTDecoder(tf.keras.layers.Layer): class TFOPTDecoder(tf.keras.layers.Layer):
config_class = OPTConfig config_class = OPTConfig
def __init__(self, config: OPTConfig, load_weight_prefix=None, **kwargs): def __init__(self, config: OPTConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
......
...@@ -15,9 +15,7 @@ ...@@ -15,9 +15,7 @@
""" Classes to support TF Vision-Encoder-Text-Decoder architectures""" """ Classes to support TF Vision-Encoder-Text-Decoder architectures"""
import gc import re
import os
import tempfile
import warnings import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -320,46 +318,23 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -320,46 +318,23 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
>>> assert preds == ["a cat laying on top of a couch next to another cat"] >>> assert preds == ["a cat laying on top of a couch next to another cat"]
```""" ```"""
# Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models
from_pt = kwargs.pop("from_pt", False) # (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal.
if from_pt: # However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption
import torch # here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's
# not the case, and I wasn't sure how else to go from the config to the correct MainLayer name!
from transformers import VisionEncoderDecoderModel
if kwargs.get("from_pt", False):
# a workaround to load from pytorch checkpoint config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
_model = VisionEncoderDecoderModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) encoder_model_type = config.encoder.model_type
config = _model.config
def tf_to_pt_weight_rename(tf_weight):
with tempfile.TemporaryDirectory() as tmpdirname: if "encoder" in tf_weight and "decoder" not in tf_weight:
encoder_dir = os.path.join(tmpdirname, "encoder") return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight)
decoder_dir = os.path.join(tmpdirname, "decoder") else:
_model.encoder.save_pretrained(encoder_dir) return tf_weight
_model.decoder.save_pretrained(decoder_dir)
kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename
if hasattr(_model, "enc_to_dec_proj"):
enc_to_dec_proj_kernel = tf.transpose(
tf.constant(_model.enc_to_dec_proj.weight.detach().to("cpu").numpy()), perm=(1, 0)
)
enc_to_dec_proj_bias = tf.constant(_model.enc_to_dec_proj.bias.detach().to("cpu").numpy())
del _model
gc.collect()
torch.cuda.empty_cache()
model = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_dir, decoder_dir, encoder_from_pt=True, decoder_from_pt=True
)
# This is only for copying some specific attributes of this particular model.
model.config = config
if hasattr(model, "enc_to_dec_proj"):
model(model.dummy_inputs)
model.enc_to_dec_proj.kernel.assign(enc_to_dec_proj_kernel)
model.enc_to_dec_proj.bias.assign(enc_to_dec_proj_bias)
return model
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
@classmethod @classmethod
...@@ -466,15 +441,6 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -466,15 +441,6 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix
encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
# Necessary to make `save_pretrained -> from_pretrained` work correctly for the converted PT -> TF model.
# See https://github.com/huggingface/transformers/pull/14016#issuecomment-944046313
if kwargs_encoder.get("from_pt", None):
del kwargs_encoder["from_pt"]
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder.save_pretrained(tmp_dirname)
del encoder
encoder = TFAutoModel.from_pretrained(tmp_dirname, *model_args, **kwargs_encoder)
decoder = kwargs_decoder.pop("model", None) decoder = kwargs_decoder.pop("model", None)
if decoder is None: if decoder is None:
if decoder_pretrained_model_name_or_path is None: if decoder_pretrained_model_name_or_path is None:
...@@ -509,15 +475,6 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -509,15 +475,6 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix
decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
# Necessary to make `save_pretrained -> from_pretrained` work correctly for the converted PT -> TF model.
# See https://github.com/huggingface/transformers/pull/14016#issuecomment-944046313
if kwargs_decoder.get("from_pt", None):
del kwargs_decoder["from_pt"]
with tempfile.TemporaryDirectory() as tmp_dirname:
decoder.save_pretrained(tmp_dirname)
del decoder
decoder = TFAutoModelForCausalLM.from_pretrained(tmp_dirname, **kwargs_decoder)
# Make sure these 2 `tf.keras.Model` have fixed names so `from_pretrained` could load model weights correctly. # Make sure these 2 `tf.keras.Model` have fixed names so `from_pretrained` could load model weights correctly.
if encoder.name != "encoder": if encoder.name != "encoder":
raise ValueError("encoder model must be created with the name `encoder`.") raise ValueError("encoder model must be created with the name `encoder`.")
......
...@@ -925,16 +925,14 @@ class TFViT2GPT2ModelIntegrationTest(unittest.TestCase): ...@@ -925,16 +925,14 @@ class TFViT2GPT2ModelIntegrationTest(unittest.TestCase):
self.assertLessEqual(max_diff, 1e-4) self.assertLessEqual(max_diff, 1e-4)
def generate_step(pixel_values): def generate_step(pixel_values):
outputs = model.generate( outputs = model.generate(pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True)
pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True, output_scores=True
)
output_ids = outputs.sequences output_ids = outputs.sequences
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds] preds = [pred.strip() for pred in preds]
return preds, outputs.scores.numpy() return preds
preds, scores = generate_step(pixel_values) preds = generate_step(pixel_values)
# should produce # should produce
# ["a cat laying on top of a couch next to another cat"] # ["a cat laying on top of a couch next to another cat"]
......
...@@ -90,6 +90,7 @@ if is_tf_available(): ...@@ -90,6 +90,7 @@ if is_tf_available():
TFAutoModel, TFAutoModel,
TFAutoModelForSequenceClassification, TFAutoModelForSequenceClassification,
TFBertForMaskedLM, TFBertForMaskedLM,
TFBertForSequenceClassification,
TFBertModel, TFBertModel,
TFRagModel, TFRagModel,
TFSharedEmbeddings, TFSharedEmbeddings,
...@@ -107,6 +108,8 @@ if is_tf_available(): ...@@ -107,6 +108,8 @@ if is_tf_available():
from transformers.modeling_tf_utils import tf_shard_checkpoint, unpack_inputs from transformers.modeling_tf_utils import tf_shard_checkpoint, unpack_inputs
from transformers.tf_utils import stable_softmax from transformers.tf_utils import stable_softmax
tf.config.experimental.enable_tensor_float_32_execution(False)
if _tf_gpu_memory_limit is not None: if _tf_gpu_memory_limit is not None:
gpus = tf.config.list_physical_devices("GPU") gpus = tf.config.list_physical_devices("GPU")
for gpu in gpus: for gpu in gpus:
...@@ -2140,6 +2143,18 @@ class UtilsFunctionsTest(unittest.TestCase): ...@@ -2140,6 +2143,18 @@ class UtilsFunctionsTest(unittest.TestCase):
for p1, p2 in zip(model.weights, ref_model.weights): for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy()) assert np.allclose(p1.numpy(), p2.numpy())
def test_sharded_checkpoint_with_prefix(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", load_weight_prefix="a/b")
sharded_model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded", load_weight_prefix="a/b")
for p1, p2 in zip(model.weights, sharded_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
self.assertTrue(p1.name.startswith("a/b/"))
self.assertTrue(p2.name.startswith("a/b/"))
def test_sharded_checkpoint_transfer(self):
# If this doesn't throw an error then the test passes
TFBertForSequenceClassification.from_pretrained("ArthurZ/tiny-random-bert-sharded")
@is_pt_tf_cross_test @is_pt_tf_cross_test
def test_checkpoint_sharding_local_from_pt(self): def test_checkpoint_sharding_local_from_pt(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
...@@ -2150,6 +2165,16 @@ class UtilsFunctionsTest(unittest.TestCase): ...@@ -2150,6 +2165,16 @@ class UtilsFunctionsTest(unittest.TestCase):
for p1, p2 in zip(model.weights, ref_model.weights): for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy()) assert np.allclose(p1.numpy(), p2.numpy())
@is_pt_tf_cross_test
def test_checkpoint_loading_with_prefix_from_pt(self):
model = TFBertModel.from_pretrained(
"hf-internal-testing/tiny-random-bert", from_pt=True, load_weight_prefix="a/b"
)
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
for p1, p2 in zip(model.weights, ref_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
self.assertTrue(p1.name.startswith("a/b/"))
@is_pt_tf_cross_test @is_pt_tf_cross_test
def test_checkpoint_sharding_hub_from_pt(self): def test_checkpoint_sharding_hub_from_pt(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment