Unverified Commit 7cced021 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

TF Sharded (#17713)



* initial commit

* update modeeling tf utils

* quality

* clean and update args

* update

* remove potential bug

* code quality

* update

* update max shard

* update tests for sharding from pretrained

* fix remaining test

* make style

* h5py if tf available

* update and fix test

* fix test

* style

* modified push to hub to support shard for TF

* quick fix

* update code

* merge branch main and style

* Apply suggestions from code review
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* update based on reviews

* update doc

* update and style

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update based on reviews

* fix typo

* style
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent f47afefb
......@@ -16,7 +16,9 @@
"""TF general model utils."""
import functools
import gc
import inspect
import json
import os
import pickle
import re
......@@ -33,7 +35,9 @@ from tensorflow.python.keras.engine.keras_tensor import KerasTensor
from tensorflow.python.keras.saving import hdf5_format
from huggingface_hub import Repository, list_repo_files
from keras.saving.hdf5_format import save_attributes_to_hdf5_group
from requests import HTTPError
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from . import DataCollatorWithPadding, DefaultDataCollator
from .activations_tf import get_tf_activation
......@@ -44,6 +48,7 @@ from .tf_utils import shape_list
from .utils import (
DUMMY_INPUTS,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
TF2_WEIGHTS_INDEX_NAME,
TF2_WEIGHTS_NAME,
WEIGHTS_NAME,
EntryNotFoundError,
......@@ -554,9 +559,243 @@ def input_processing(func, config, input_ids, **kwargs):
return output
def dtype_byte_size(dtype):
"""
Returns the size (in bytes) occupied by one parameter of type `dtype`.
Example:
```py
>>> dtype_byte_size(tf.float32)
4
```
"""
if dtype == tf.bool:
return 1 / 8
bit_search = re.search("[^\d](\d+)$", dtype.name)
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8
def tf_shard_checkpoint(weights, max_shard_size="10GB"):
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
[6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
<Tip warning={true}>
If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will
have a size greater than `max_shard_size`.
</Tip>
Args:
weights (`Dict[str, tf.RessourceVariable]`): The list of tf.RessourceVariable of a model to save.
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
(like `"5MB"`).
"""
max_shard_size = convert_file_size_to_int(max_shard_size)
sharded_state_dicts = []
current_block = []
current_block_size = 0
total_size = 0
for item in weights:
weight_size = item.numpy().size * dtype_byte_size(item.dtype)
# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size:
sharded_state_dicts.append(current_block)
current_block = []
current_block_size = 0
current_block.append(item)
current_block_size += weight_size
total_size += weight_size
# Add the last block
sharded_state_dicts.append(current_block)
# If we only have one shard, we return it
if len(sharded_state_dicts) == 1:
return {TF2_WEIGHTS_NAME: sharded_state_dicts[0]}, None
# Otherwise, let's build the index
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = TF2_WEIGHTS_NAME.replace(".h5", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.h5")
shards[shard_file] = shard
for weight in shard:
weight_name = weight.name
weight_map[weight_name] = shard_file
# Add the metadata
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
return shards, index
def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=True):
"""
This is the same as `load_tf_weights` but for a sharded checkpoint. Detect missing and unexpected layers and load
the TF weights from the shard file accordingly to their names and shapes.
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
loaded in the model.
Args:
model (`tf.keras.models.Model`): The model in which to load the checkpoint.
shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names.
ignore_mismatched_sizes`bool`, *optional`, defaults to `True`):
Whether or not to ignore the mismatch between the sizes
strict (`bool`, *optional*, defaults to `True`):
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
Returns:
Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
mismatched layers.
"""
# Load the index
missing_keys = []
unexpected_keys = set()
saved_keys = set()
missmatched_keys = set()
# Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
# the weight, we have to get rid of the first prefix of the name of the layer.
model_keys = set("/".join(k.name.split("/")[1:]) for k in model.weights)
model_layer_map = {"/".join(k.name.split("/")[1:]): i for i, k in enumerate(model.weights)}
for shard_file in shard_files:
state_dict = tf.io.read_file(shard_file)
saved_weight_names_set, unexpected_keys_set, missmatched_keys_set = load_tf_shard(
model, model_layer_map, shard_file, ignore_mismatched_sizes=ignore_mismatched_sizes
)
saved_keys.update(saved_weight_names_set)
unexpected_keys.update(unexpected_keys_set)
missmatched_keys.update(missmatched_keys_set)
del state_dict
gc.collect()
missing_keys = model_keys - saved_keys
if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
if len(missing_keys) > 0:
str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
error_message += f"\nMissing key(s): {str_missing_keys}."
if len(unexpected_keys) > 0:
str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)
return missing_keys, unexpected_keys, missmatched_keys
def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False):
"""
Loads a shard from a sharded checkpoint file. Handles the missing keys and unexpected keys.
Args:
model (`tf.keras.models.Model`): Model in which the weights are loaded
model_layer_map (`Dict`): A dictionnary mapping the layer name to the index of the layer in the model.
resolved_archive_file (`str`): Path to the checkpoint file from which the weights will be loaded
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): Whether to ignore the mismatched keys
Returns:
`tf.keras.models.Model`: Three lists, one for the layers that were found and succesfully restored (from the
shard file), one for the missmatched layers, and another one for the unexpected layers.
"""
saved_weight_names_set = set()
saved_weights = {}
missmatched_keys = set()
unexpected_keys = set()
# Read the H5 file
try:
with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
# Retrieve the name of each layer from the H5 file
saved_h5_model_layers_name = set(
hdf5_format.load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names")
)
weight_value_tuples = []
# Compute missing and unexpected sub layers
# Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...]
for layer_name in saved_h5_model_layers_name:
h5_layer_object = sharded_checkpoint_file[layer_name]
saved_weights[layer_name] = np.asarray(h5_layer_object)
saved_weight_names_set.add(layer_name)
if layer_name not in model_layer_map:
unexpected_keys.add(layer_name)
else:
symbolic_weight = model.weights[model_layer_map[layer_name]]
saved_weight_value = saved_weights[layer_name]
# If the current weight is found
if saved_weight_value is not None:
# Check if the shape of the current weight and the one from the H5 file are different
if K.int_shape(symbolic_weight) != saved_weight_value.shape:
# If yes we reshape the weight from the H5 file accordingly to the current weight
# If the two shapes are not compatible we raise an issue
try:
array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
except ValueError as e:
if ignore_mismatched_sizes:
missmatched_keys.add(
(layer_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
)
continue
else:
raise e
else:
array = saved_weight_value
# We create the tuple that will be loaded and add it to the final list
weight_value_tuples.append((symbolic_weight, array))
K.batch_set_value(weight_value_tuples)
return saved_weight_names_set, unexpected_keys, missmatched_keys
except Exception as e:
try:
with open(resolved_archive_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please install "
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
"you cloned."
)
else:
raise ValueError(
f"Unable to locate the file {resolved_archive_file} which is necessary to load this pretrained"
" model. Make sure you have saved the model properly."
) from e
except (UnicodeDecodeError, ValueError):
raise OSError(
f"Unable to load weights from TF checkpoint file for '{resolved_archive_file}' "
f"at '{resolved_archive_file}'. "
"If you tried to load a TF model from a sharded checkpoint, you should try converting the model"
"by loading it in pytorch and saving it localy. A convertion script should be realeased soon."
)
def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
"""
Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes.
Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
shapes.
Args:
model (`tf.keras.models.Model`):
......@@ -575,9 +814,11 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False,
mismatched_layers = []
# Read the H5 file
with h5py.File(resolved_archive_file, "r") as f:
with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
# Retrieve the name of each layer from the H5 file
saved_h5_model_layers_name = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
saved_h5_model_layers_name = set(
hdf5_format.load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names")
)
# Find the missing layers from the high level list of layers
missing_layers = list(set([layer.name for layer in model.layers]) - saved_h5_model_layers_name)
......@@ -594,7 +835,7 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False,
# if layer_name from the H5 file belongs to the layers from the instantiated model
if layer.name in saved_h5_model_layers_name:
# Get the H5 layer object from its name
h5_layer_object = f[layer.name]
h5_layer_object = sharded_checkpoint_file[layer.name]
# Get all the weights as a list from the layer object
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
saved_weights = {}
......@@ -1624,7 +1865,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"""
raise NotImplementedError
def save_pretrained(self, save_directory, saved_model=False, version=1, push_to_hub=False, **kwargs):
def save_pretrained(
self,
save_directory,
saved_model=False,
version=1,
push_to_hub=False,
max_shard_size: Union[int, str] = "10GB",
**kwargs
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
[`~TFPreTrainedModel.from_pretrained`] class method.
......@@ -1649,6 +1898,17 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
</Tip>
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
<Tip warning={true}>
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
which will be bigger than `max_shard_size`.
</Tip>
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
......@@ -1679,8 +1939,48 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME)
self.save_weights(output_model_file)
logger.info(f"Model weights saved in {output_model_file}")
shards, index = tf_shard_checkpoint(self.weights, max_shard_size)
# Clean the folder from a previous save
for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename)
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
# in distributed settings to avoid race conditions.
if (
filename.startswith(TF2_WEIGHTS_NAME[:-4])
and os.path.isfile(full_filename)
and filename not in shards.keys()
):
os.remove(full_filename)
if index is None:
self.save_weights(output_model_file)
logger.info(f"Model weights saved in {output_model_file}")
else:
save_index_file = os.path.join(save_directory, TF2_WEIGHTS_INDEX_NAME)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as index_file:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
index_file.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
for shard_file, shard in shards.items():
with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file:
save_attributes_to_hdf5_group(
shard_file,
"layer_names",
["/".join(layer.name.split("/")[1:]).encode("utf8") for layer in shard],
)
for layer in sorted(shard, key=lambda x: x.name):
param_dset = shard_file.create_dataset(
"/".join(layer.name.split("/")[1:]), layer.numpy().shape, dtype=layer.numpy().dtype
)
param_dset[:] = layer.numpy()
if push_to_hub:
url = self._push_to_hub(repo, commit_message=commit_message)
......@@ -1844,6 +2144,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
else:
model_kwargs = kwargs
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
is_sharded = False
sharded_metadata = None
# Load model
if pretrained_model_name_or_path is not None:
if os.path.isdir(pretrained_model_name_or_path):
......@@ -1853,6 +2157,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
# Load from a TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
is_sharded = True
# At this stage we don't have a weight file so we will raise an error.
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
raise EnvironmentError(
......@@ -1906,23 +2214,45 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
)
except EntryNotFoundError:
if filename == TF2_WEIGHTS_NAME:
has_file_kwargs = {
"revision": revision,
"mirror": mirror,
"proxies": proxies,
"use_auth_token": use_auth_token,
}
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {TF2_WEIGHTS_NAME} "
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from "
"those weights."
try:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=TF2_WEIGHTS_INDEX_NAME,
revision=revision,
mirror=mirror,
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {TF2_WEIGHTS_NAME} "
f"or {WEIGHTS_NAME}."
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
is_sharded = True
except EntryNotFoundError:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
has_file_kwargs = {
"revision": revision,
"mirror": mirror,
"proxies": proxies,
"use_auth_token": use_auth_token,
}
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {TF2_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
" load this model from those weights."
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
......@@ -1955,6 +2285,23 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
else:
resolved_archive_file = None
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
if is_sharded:
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
pretrained_model_name_or_path,
resolved_archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
mirror=mirror,
)
config.name_or_path = pretrained_model_name_or_path
# composed models, *e.g.* TFRag, require special treatment when it comes to loading
......@@ -1978,16 +2325,25 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
else:
model(model.dummy_inputs) # build the network with dummy inputs
assert os.path.isfile(resolved_archive_file), f"Error retrieving file {resolved_archive_file}"
# 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
try:
missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
model,
resolved_archive_file,
ignore_mismatched_sizes=ignore_mismatched_sizes,
_prefix=load_weight_prefix,
)
if is_sharded:
for file in resolved_archive_file:
os.path.isfile(file), f"Error retrieving files {file}"
missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights(
model,
resolved_archive_file,
ignore_mismatched_sizes=ignore_mismatched_sizes,
)
else:
missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
model,
resolved_archive_file,
ignore_mismatched_sizes=ignore_mismatched_sizes,
_prefix=load_weight_prefix,
)
except OSError as e:
try:
with open(resolved_archive_file) as f:
......
......@@ -148,6 +148,7 @@ from .import_utils import (
WEIGHTS_NAME = "pytorch_model.bin"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
TF2_WEIGHTS_NAME = "tf_model.h5"
TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
TF_WEIGHTS_NAME = "model.ckpt"
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
CONFIG_NAME = "config.json"
......
......@@ -861,6 +861,7 @@ class PushToHubMixin:
organization: Optional[str] = None,
private: Optional[bool] = None,
use_auth_token: Optional[Union[bool, str]] = None,
max_shard_size: Optional[Union[int, str]] = "10GB",
**model_card_kwargs
) -> str:
"""
......@@ -936,8 +937,9 @@ class PushToHubMixin:
use_auth_token=use_auth_token,
)
# Save the files in the cloned repo
self.save_pretrained(repo_path_or_name)
if hasattr(self, "history") and hasattr(self, "create_model_card"):
self.save_pretrained(repo_path_or_name, max_shard_size=max_shard_size)
# This is a Keras model and we might be able to fish out its History and make a model card out of it
base_model_card_args = {
"output_dir": repo_path_or_name,
......@@ -945,6 +947,9 @@ class PushToHubMixin:
}
base_model_card_args.update(model_card_kwargs)
self.create_model_card(**base_model_card_args)
else:
# FLAX does not support sharding yet, will come in next PR
self.save_pretrained(repo_path_or_name)
# Commit and push!
url = self._push_to_hub(repo, commit_message=commit_message)
......@@ -1075,3 +1080,114 @@ def send_example_telemetry(example_name, *example_args, framework="pytorch"):
except Exception:
# We don't want to error in case of connection errors of any kind.
pass
def convert_file_size_to_int(size: Union[int, str]):
"""
Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
Args:
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
Example:
```py
>>> convert_file_size_to_int("1MiB")
1048576
```
"""
if isinstance(size, int):
return size
if size.upper().endswith("GIB"):
return int(size[:-3]) * (2**30)
if size.upper().endswith("MIB"):
return int(size[:-3]) * (2**20)
if size.upper().endswith("KIB"):
return int(size[:-3]) * (2**10)
if size.upper().endswith("GB"):
int_size = int(size[:-2]) * (10**9)
return int_size // 8 if size.endswith("b") else int_size
if size.upper().endswith("MB"):
int_size = int(size[:-2]) * (10**6)
return int_size // 8 if size.endswith("b") else int_size
if size.upper().endswith("KB"):
int_size = int(size[:-2]) * (10**3)
return int_size // 8 if size.endswith("b") else int_size
raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.")
def get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_filename,
cache_dir=None,
force_download=False,
proxies=None,
resume_download=False,
local_files_only=False,
use_auth_token=None,
user_agent=None,
revision=None,
mirror=None,
):
"""
For a given model:
- download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
Hub
- returns the list of paths to all the shards, as well as some metadata.
For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
"""
import json
if not os.path.isfile(index_filename):
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
with open(index_filename, "r") as f:
index = json.loads(f.read())
shard_filenames = sorted(list(set(index["weight_map"].values())))
sharded_metadata = index["metadata"]
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
# First, let's deal with local folder.
if os.path.isdir(pretrained_model_name_or_path):
shard_filenames = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames]
return shard_filenames, sharded_metadata
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
cached_filenames = []
for shard_filename in shard_filenames:
shard_url = hf_bucket_url(
pretrained_model_name_or_path, filename=shard_filename, revision=revision, mirror=mirror
)
try:
# Load from URL
cached_filename = cached_path(
shard_url,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here.
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is "
"required according to the checkpoint index."
)
except HTTPError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {shard_filename}. You should try"
" again after checking your internet connection."
)
cached_filenames.append(cached_filename)
return cached_filenames, sharded_metadata
......@@ -53,6 +53,7 @@ logger = logging.get_logger(__name__)
if is_tf_available():
import h5py
import numpy as np
import tensorflow as tf
......@@ -85,7 +86,12 @@ if is_tf_available():
TFSampleDecoderOnlyOutput,
TFSampleEncoderDecoderOutput,
)
from transformers.modeling_tf_utils import unpack_inputs
from transformers.modeling_tf_utils import (
TF2_WEIGHTS_INDEX_NAME,
TF2_WEIGHTS_NAME,
tf_shard_checkpoint,
unpack_inputs,
)
from transformers.tf_utils import stable_softmax
if _tf_gpu_memory_limit is not None:
......@@ -1867,6 +1873,129 @@ class UtilsFunctionsTest(unittest.TestCase):
out = masked_softmax(x, boolean_mask)
assert tf.experimental.numpy.allclose(xla_out, out)
def test_checkpoint_sharding_from_hub(self):
model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
# the model above is the same as the model below, just a sharded version.
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())
def test_shard_checkpoint(self):
# This is the model we will use, total size 340,000 bytes.
model = tf.keras.Sequential(
[
tf.keras.layers.Dense(200, use_bias=False), # size 80,000
tf.keras.layers.Dense(200, use_bias=False), # size 160,000
tf.keras.layers.Dense(100, use_bias=False), # size 80,000
tf.keras.layers.Dense(50, use_bias=False), # size 20,000
]
)
inputs = tf.zeros((1, 100), dtype=tf.float32)
model(inputs)
weights = model.weights
weights_dict = {w.name: w for w in weights}
with self.subTest("No shard when max size is bigger than model size"):
shards, index = tf_shard_checkpoint(weights)
self.assertIsNone(index)
self.assertDictEqual(shards, {TF2_WEIGHTS_NAME: weights})
with self.subTest("Test sharding, no weights bigger than max size"):
shards, index = tf_shard_checkpoint(weights, max_shard_size="300kB")
# Split is first two layers then last two.
self.assertDictEqual(
index,
{
"metadata": {"total_size": 340000},
"weight_map": {
"dense/kernel:0": "tf_model-00001-of-00002.h5",
"dense_1/kernel:0": "tf_model-00001-of-00002.h5",
"dense_2/kernel:0": "tf_model-00002-of-00002.h5",
"dense_3/kernel:0": "tf_model-00002-of-00002.h5",
},
},
)
shard1 = [weights_dict["dense/kernel:0"], weights_dict["dense_1/kernel:0"]]
shard2 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
self.assertDictEqual(shards, {"tf_model-00001-of-00002.h5": shard1, "tf_model-00002-of-00002.h5": shard2})
with self.subTest("Test sharding with weights bigger than max size"):
shards, index = tf_shard_checkpoint(weights, max_shard_size="100kB")
# Split is first layer, second layer then last 2.
self.assertDictEqual(
index,
{
"metadata": {"total_size": 340000},
"weight_map": {
"dense/kernel:0": "tf_model-00001-of-00003.h5",
"dense_1/kernel:0": "tf_model-00002-of-00003.h5",
"dense_2/kernel:0": "tf_model-00003-of-00003.h5",
"dense_3/kernel:0": "tf_model-00003-of-00003.h5",
},
},
)
shard1 = [weights_dict["dense/kernel:0"]]
shard2 = [weights_dict["dense_1/kernel:0"]]
shard3 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
self.assertDictEqual(
shards,
{
"tf_model-00001-of-00003.h5": shard1,
"tf_model-00002-of-00003.h5": shard2,
"tf_model-00003-of-00003.h5": shard3,
},
)
def test_checkpoint_sharding_local(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
model.save_pretrained(tmp_dir, max_shard_size=max_size)
# Get each shard file and its size
shard_to_size = {}
for shard in os.listdir(tmp_dir):
if shard.endswith(".h5"):
shard_file = os.path.join(tmp_dir, shard)
shard_to_size[shard_file] = os.path.getsize(shard_file)
index_file = os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)
# Check there is an index but no regular weight file
self.assertTrue(os.path.isfile(index_file))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
# Check a file is bigger than max_size only when it has a single weight
for shard_file, size in shard_to_size.items():
if max_size.endswith("kiB"):
max_size_int = int(max_size[:-3]) * 2**10
else:
max_size_int = int(max_size[:-2]) * 10**3
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
# the size asked for (since we count parameters)
if size >= max_size_int + 50000:
with h5py.File(shard_file, "r") as state_file:
self.assertEqual(len(state_file), 1)
# Check the index and the shard files found match
with open(index_file, "r", encoding="utf-8") as f:
index = json.loads(f.read())
all_shards = set(index["weight_map"].values())
shards_found = set(f for f in os.listdir(tmp_dir) if f.endswith(".h5"))
self.assertSetEqual(all_shards, shards_found)
# Finally, check the model can be reloaded
new_model = TFBertModel.from_pretrained(tmp_dir)
model(model.dummy_inputs)
new_model(model.dummy_inputs)
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_tf
@is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment