Unverified Commit 22fe73c3 authored by Matt's avatar Matt Committed by GitHub
Browse files

TF safetensors reduced mem usage (#24404)

* Slight comment cleanup

* Reduce peak mem usage when loading TF-format safetensor weights

* Tweak the PyTorch loading code to support lazy loading from safetensors

* Pass safe_open objects to the PyTorch loading function

* Do GPU transposes for speed

* One more tweak to reduce peak usage further

* One-line hasattr

* Fix bug when there's a shape mismatch

* Rename state_dict in the loading code to be clearer

* Use TF format everywhere for consistency
parent 7e03e469
...@@ -248,7 +248,8 @@ def load_pytorch_state_dict_in_tf2_model( ...@@ -248,7 +248,8 @@ def load_pytorch_state_dict_in_tf2_model(
tf_to_pt_weight_rename=None, tf_to_pt_weight_rename=None,
ignore_mismatched_sizes=False, ignore_mismatched_sizes=False,
): ):
"""Load a pytorch state_dict in a TF 2.0 model.""" """Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
safetensors archive created with the safe_open() function."""
import tensorflow as tf import tensorflow as tf
from packaging.version import parse from packaging.version import parse
...@@ -262,13 +263,11 @@ def load_pytorch_state_dict_in_tf2_model( ...@@ -262,13 +263,11 @@ def load_pytorch_state_dict_in_tf2_model(
if _prefix is None: if _prefix is None:
_prefix = "" _prefix = ""
if tf_inputs is not None: if tf_inputs:
with tf.name_scope(_prefix): with tf.name_scope(_prefix):
tf_model(tf_inputs, training=False) # Make sure model is built tf_model(tf_inputs, training=False) # Make sure model is built
# Adapt state dict - TODO remove this and update the AWS weights files instead
# Convert old format to new format if needed from a PyTorch state_dict # Convert old format to new format if needed from a PyTorch state_dict
old_keys = [] tf_keys_to_pt_keys = {}
new_keys = []
for key in pt_state_dict.keys(): for key in pt_state_dict.keys():
new_key = None new_key = None
if "gamma" in key: if "gamma" in key:
...@@ -279,26 +278,24 @@ def load_pytorch_state_dict_in_tf2_model( ...@@ -279,26 +278,24 @@ def load_pytorch_state_dict_in_tf2_model(
new_key = key.replace("running_var", "moving_variance") new_key = key.replace("running_var", "moving_variance")
if "running_mean" in key: if "running_mean" in key:
new_key = key.replace("running_mean", "moving_mean") new_key = key.replace("running_mean", "moving_mean")
if new_key: if new_key is None:
old_keys.append(key) new_key = key
new_keys.append(new_key) tf_keys_to_pt_keys[new_key] = key
for old_key, new_key in zip(old_keys, new_keys):
pt_state_dict[new_key] = pt_state_dict.pop(old_key)
# Matt: All TF models store the actual model stem in a MainLayer class, including the base model. # Matt: All TF models store the actual model stem in a MainLayer class, including the base model.
# In PT, the derived models (with heads) use the base model class as the stem instead, and the base model # In PT, the derived models (with heads) use the base model class as the stem instead,
# just contains the stem itself, and there is no MainLayer class. This means that TF base classes have one # and there is no MainLayer class. This means that TF base classes have one
# extra layer in their weight names, corresponding to the MainLayer class. This code block compensates for that. # extra layer in their weight names, corresponding to the MainLayer class. This code block compensates for that.
start_prefix_to_remove = "" start_prefix_to_remove = ""
if not any(s.startswith(tf_model.base_model_prefix) for s in pt_state_dict.keys()): if not any(s.startswith(tf_model.base_model_prefix) for s in tf_keys_to_pt_keys.keys()):
start_prefix_to_remove = tf_model.base_model_prefix + "." start_prefix_to_remove = tf_model.base_model_prefix + "."
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
tf_loaded_numel = 0 tf_loaded_numel = 0
weight_value_tuples = [] all_pytorch_weights = set(tf_keys_to_pt_keys.keys())
all_pytorch_weights = set(pt_state_dict.keys())
missing_keys = [] missing_keys = []
mismatched_keys = [] mismatched_keys = []
is_safetensor_archive = hasattr(pt_state_dict, "get_tensor")
for symbolic_weight in symbolic_weights: for symbolic_weight in symbolic_weights:
sw_name = symbolic_weight.name sw_name = symbolic_weight.name
name, transpose = convert_tf_weight_name_to_pt_weight_name( name, transpose = convert_tf_weight_name_to_pt_weight_name(
...@@ -311,7 +308,7 @@ def load_pytorch_state_dict_in_tf2_model( ...@@ -311,7 +308,7 @@ def load_pytorch_state_dict_in_tf2_model(
name = tf_to_pt_weight_rename(name) name = tf_to_pt_weight_rename(name)
# Find associated numpy array in pytorch model state dict # Find associated numpy array in pytorch model state dict
if name not in pt_state_dict: if name not in tf_keys_to_pt_keys:
if allow_missing_keys: if allow_missing_keys:
missing_keys.append(name) missing_keys.append(name)
continue continue
...@@ -320,9 +317,13 @@ def load_pytorch_state_dict_in_tf2_model( ...@@ -320,9 +317,13 @@ def load_pytorch_state_dict_in_tf2_model(
if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing): if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing):
continue continue
raise AttributeError(f"{name} not found in PyTorch model") raise AttributeError(f"{name} not found in PyTorch model")
state_dict_name = tf_keys_to_pt_keys[name]
if is_safetensor_archive:
array = pt_state_dict.get_tensor(state_dict_name)
else:
array = pt_state_dict[state_dict_name]
try: try:
array = apply_transpose(transpose, pt_state_dict[name], symbolic_weight.shape) array = apply_transpose(transpose, array, symbolic_weight.shape)
except tf.errors.InvalidArgumentError as e: except tf.errors.InvalidArgumentError as e:
if not ignore_mismatched_sizes: if not ignore_mismatched_sizes:
error_msg = str(e) error_msg = str(e)
...@@ -331,16 +332,15 @@ def load_pytorch_state_dict_in_tf2_model( ...@@ -331,16 +332,15 @@ def load_pytorch_state_dict_in_tf2_model(
) )
raise tf.errors.InvalidArgumentError(error_msg) raise tf.errors.InvalidArgumentError(error_msg)
else: else:
mismatched_keys.append((name, pt_state_dict[name].shape, symbolic_weight.shape)) mismatched_keys.append((name, array.shape, symbolic_weight.shape))
continue continue
tf_loaded_numel += tensor_size(array) tf_loaded_numel += tensor_size(array)
weight_value_tuples.append((symbolic_weight, array)) K.set_value(symbolic_weight, array)
del array # Immediately free memory to keep peak usage as low as possible
all_pytorch_weights.discard(name) all_pytorch_weights.discard(name)
K.batch_set_value(weight_value_tuples)
logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.") logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.")
unexpected_keys = list(all_pytorch_weights) unexpected_keys = list(all_pytorch_weights)
......
...@@ -87,7 +87,6 @@ else: ...@@ -87,7 +87,6 @@ else:
if is_safetensors_available(): if is_safetensors_available():
from safetensors import safe_open from safetensors import safe_open
from safetensors.tensorflow import load_file as safe_load_file
from safetensors.tensorflow import save_file as safe_save_file from safetensors.tensorflow import save_file as safe_save_file
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -1000,42 +999,33 @@ def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_size ...@@ -1000,42 +999,33 @@ def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_size
def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
# Read the safetensors file # Read the safetensors file
state_dict = safe_load_file(resolved_archive_file) with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
mismatched_layers = []
weight_value_tuples = [] weight_names = [format_weight_name(w.name, _prefix=_prefix) for w in model.weights]
mismatched_layers = [] loaded_weight_names = list(safetensors_archive.keys())
# Find the missing layers from the high level list of layers
weight_names = [format_weight_name(w.name, _prefix=_prefix) for w in model.weights] missing_layers = list(set(weight_names) - set(loaded_weight_names))
loaded_weight_names = list(state_dict.keys()) # Find the unexpected layers from the high level list of layers
unexpected_layers = list(set(loaded_weight_names) - set(weight_names))
# Find the missing layers from the high level list of layers
missing_layers = list(set(weight_names) - set(loaded_weight_names)) for weight in model.weights:
# Find the unexpected layers from the high level list of layers weight_name = format_weight_name(weight.name, _prefix=_prefix)
unexpected_layers = list(set(loaded_weight_names) - set(weight_names)) if weight_name in loaded_weight_names:
weight_value = safetensors_archive.get_tensor(weight_name)
weight_value_tuples = [] # Check if the shape of the current weight and the one from the H5 file are different
for weight in model.weights: if K.int_shape(weight) != weight_value.shape:
weight_name = format_weight_name(weight.name, _prefix=_prefix) # If yes we reshape the weight from the H5 file accordingly to the current weight
if weight_name in state_dict: # If the two shapes are not compatible we raise an issue
weight_value = state_dict[weight_name] try:
# Check if the shape of the current weight and the one from the H5 file are different weight_value = tf.reshape(weight_value, K.int_shape(weight))
if K.int_shape(weight) != weight_value.shape: except ValueError as e:
# If yes we reshape the weight from the H5 file accordingly to the current weight if ignore_mismatched_sizes:
# If the two shapes are not compatible we raise an issue mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight)))
try: continue
weight_value = tf.reshape(weight_value, K.int_shape(weight)) else:
except ValueError as e: raise e
if ignore_mismatched_sizes:
mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight)))
continue
else:
raise e
weight_value_tuples.append((weight, weight_value))
# Load all the weights
K.batch_set_value(weight_value_tuples)
K.set_value(weight, weight_value) # weight.assign() might break if weight is a DTensor
return missing_layers, unexpected_layers, mismatched_layers return missing_layers, unexpected_layers, mismatched_layers
...@@ -2921,16 +2911,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2921,16 +2911,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if safetensors_from_pt: if safetensors_from_pt:
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
state_dict = safe_load_file(resolved_archive_file) with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
return load_pytorch_state_dict_in_tf2_model( # We load in TF format here because PT weights often need to be transposed, and this is much
model, # faster on GPU. Loading as numpy and transposing on CPU adds several seconds to load times.
state_dict, return load_pytorch_state_dict_in_tf2_model(
allow_missing_keys=True, model,
output_loading_info=output_loading_info, safetensors_archive,
_prefix=load_weight_prefix, tf_inputs=False, # No need to build the model again
ignore_mismatched_sizes=ignore_mismatched_sizes, allow_missing_keys=True,
) output_loading_info=output_loading_info,
_prefix=load_weight_prefix,
ignore_mismatched_sizes=ignore_mismatched_sizes,
)
# 'by_name' allow us to do transfer learning by skipping/adding layers # 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment