Unverified Commit 16c6eb7c authored by Arthur's avatar Arthur Committed by GitHub
Browse files

Flax sharded (#17760)

parent 3b00b623
......@@ -13,11 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import json
import os
import re
from functools import partial
from pickle import UnpicklingError
from typing import Any, Dict, Set, Tuple, Union
import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
......@@ -33,6 +39,7 @@ from .dynamic_module_utils import custom_object_save
from .generation_flax_utils import FlaxGenerationMixin
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
from .utils import (
FLAX_WEIGHTS_INDEX_NAME,
FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_NAME,
......@@ -51,6 +58,7 @@ from .utils import (
logging,
replace_return_docstrings,
)
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
logger = logging.get_logger(__name__)
......@@ -70,6 +78,88 @@ ACT2FN = {
}
def dtype_byte_size(dtype):
"""
Returns the size (in bytes) occupied by one parameter of type `dtype`. Example:
```py
>>> dtype_byte_size(np.float32)
4
```
"""
if dtype == np.bool:
return 1 / 8
bit_search = re.search("[^\d](\d+)$", dtype.name)
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8
def flax_shard_checkpoint(params, max_shard_size="10GB"):
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size. The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so
there is no optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For
example, if the limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as
[6GB], [6+2GB], [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
<Tip warning={true}>
If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will
have a size greater than `max_shard_size`.
</Tip>
Args:
params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters.
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
(like `"5MB"`).
"""
max_shard_size = convert_file_size_to_int(max_shard_size)
sharded_state_dicts = []
current_block = {}
current_block_size = 0
total_size = 0
# flatten the weights to chunk
weights = flatten_dict(params, sep="/")
for item in weights:
weight_size = weights[item].size * dtype_byte_size(weights[item].dtype)
# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size:
sharded_state_dicts.append(current_block)
current_block = {}
current_block_size = 0
current_block[item] = weights[item]
current_block_size += weight_size
total_size += weight_size
# Add the last block
sharded_state_dicts.append(current_block)
# If we only have one shard, we return it
if len(sharded_state_dicts) == 1:
return {FLAX_WEIGHTS_NAME: sharded_state_dicts[0]}, None
# Otherwise, let's build the index
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = FLAX_WEIGHTS_NAME.replace(".msgpack", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.msgpack")
shards[shard_file] = shard
for weight_name in shard.keys():
weight_map[weight_name] = shard_file
# Add the metadata
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
return shards, index
class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
r"""
Base class for all models.
......@@ -333,6 +423,53 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
```"""
return self._cast_floating_to(params, jnp.float16, mask)
@classmethod
def load_flax_sharded_weights(cls, shard_files):
"""
This is the same as [`flax.serialization.from_bytes`]
(https:lax.readthedocs.io/en/latest/_modules/flax/serialization.html#from_bytes) but for a sharded checkpoint.
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
loaded in the model.
Args:
shard_files (`List[str]`:
The list of shard files to load.
Returns:
`Dict`: A nested dictionary of the model parameters, in the expected format for flax models : `{'model':
{'params': {'...'}}}`.
"""
# Load the index
state_sharded_dict = dict()
for shard_file in shard_files:
# load using msgpack utils
try:
with open(shard_file, "rb") as state_f:
state = from_bytes(cls, state_f.read())
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
with open(shard_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {shard_file} to Flax deserializable object. ")
state = flatten_dict(state, sep="/")
state_sharded_dict.update(state)
del state
gc.collect()
# the state dict is unflattened to the match the format of model.params
return unflatten_dict(state_sharded_dict, sep="/")
@classmethod
def from_pretrained(
cls,
......@@ -489,6 +626,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# Add the dtype to model_kwargs
model_kwargs["dtype"] = dtype
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
is_sharded = False
# Load model
if pretrained_model_name_or_path is not None:
if os.path.isdir(pretrained_model_name_or_path):
......@@ -498,6 +639,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME)):
# Load from a sharded Flax checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME)
is_sharded = True
# At this stage we don't have a weight file so we will raise an error.
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
raise EnvironmentError(
......@@ -521,6 +666,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(
archive_file,
......@@ -548,18 +694,37 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
)
except EntryNotFoundError:
if filename == FLAX_WEIGHTS_NAME:
has_file_kwargs = {"revision": revision, "proxies": proxies, "use_auth_token": use_auth_token}
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to load"
" this model from those weights."
try:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=FLAX_WEIGHTS_INDEX_NAME,
revision=revision,
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
is_sharded = True
except EntryNotFoundError:
has_file_kwargs = {"revision": revision, "proxies": proxies, "use_auth_token": use_auth_token}
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
" load this model from those weights."
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
......@@ -592,15 +757,35 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
else:
resolved_archive_file = None
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
if is_sharded:
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_archive_file, _ = get_checkpoint_shard_files(
pretrained_model_name_or_path,
resolved_archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
)
# init random models
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
if from_pt:
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
else:
with open(resolved_archive_file, "rb") as state_f:
if is_sharded:
state = cls.load_flax_sharded_weights(resolved_archive_file)
else:
try:
state = from_bytes(cls, state_f.read())
with open(resolved_archive_file, "rb") as state_f:
state = from_bytes(cls, state_f.read())
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
try:
with open(resolved_archive_file) as f:
......@@ -742,7 +927,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
else:
return model, unflatten_dict(state)
def save_pretrained(self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs):
def save_pretrained(
self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, max_shard_size="10GB", **kwargs
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`[`~FlaxPreTrainedModel.from_pretrained`]` class method
......@@ -761,6 +948,17 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
</Tip>
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
<Tip warning={true}>
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
which will be bigger than `max_shard_size`.
</Tip>
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
......@@ -788,10 +986,41 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# save model
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
with open(output_model_file, "wb") as f:
params = params if params is not None else self.params
model_bytes = to_bytes(params)
f.write(model_bytes)
shards, index = flax_shard_checkpoint(params if params is not None else self.params, max_shard_size)
# Clean the folder from a previous save
for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename)
if (
filename.startswith(FLAX_WEIGHTS_NAME[:-4])
and os.path.isfile(full_filename)
and filename not in shards.keys()
):
os.remove(full_filename)
if index is None:
with open(output_model_file, "wb") as f:
params = params if params is not None else self.params
model_bytes = to_bytes(params)
f.write(model_bytes)
else:
save_index_file = os.path.join(save_directory, FLAX_WEIGHTS_INDEX_NAME)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
for shard_file, shard in shards.items():
# the shard item are unflattened, to save them we need to flatten them again
with open(os.path.join(save_directory, shard_file), mode="wb") as f:
params = unflatten_dict(shard, sep="/")
shard_bytes = to_bytes(params)
f.write(shard_bytes)
logger.info(f"Model weights saved in {output_model_file}")
......
......@@ -151,6 +151,7 @@ TF2_WEIGHTS_NAME = "tf_model.h5"
TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
TF_WEIGHTS_NAME = "model.ckpt"
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json"
CONFIG_NAME = "config.json"
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
MODEL_CARD_NAME = "modelcard.json"
......
......@@ -937,7 +937,7 @@ class PushToHubMixin:
use_auth_token=use_auth_token,
)
# Save the files in the cloned repo
self.save_pretrained(repo_path_or_name, max_shard_size=max_shard_size)
if hasattr(self, "history") and hasattr(self, "create_model_card"):
self.save_pretrained(repo_path_or_name, max_shard_size=max_shard_size)
# This is a Keras model and we might be able to fish out its History and make a model card out of it
......@@ -947,9 +947,7 @@ class PushToHubMixin:
}
base_model_card_args.update(model_card_kwargs)
self.create_model_card(**base_model_card_args)
else:
# FLAX does not support sharding yet, will come in next PR
self.save_pretrained(repo_path_or_name)
# Commit and push!
url = self._push_to_hub(repo, commit_message=commit_message)
......@@ -1090,7 +1088,6 @@ def convert_file_size_to_int(size: Union[int, str]):
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
Example:
```py
>>> convert_file_size_to_int("1MiB")
1048576
......
......@@ -14,6 +14,7 @@
import copy
import inspect
import json
import random
import tempfile
import unittest
......@@ -45,6 +46,7 @@ if is_flax_available():
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.serialization import from_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers import (
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
......@@ -58,6 +60,7 @@ if is_flax_available():
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
from transformers.modeling_flax_utils import FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
......@@ -1043,6 +1046,59 @@ class FlaxModelTesterMixin:
# Check if all required parmas are loaded
_assert_all_params_initialised(model, params)
def test_checkpoint_sharding_from_hub(self):
model = FlaxBertModel.from_pretrained("ArthurZ/flax-tiny-random-bert-sharded")
# the model above is the same as the model below, just a sharded version.
ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()):
assert np.allclose(np.array(p1), np.array(p2))
def test_checkpoint_sharding_local(self):
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
with tempfile.TemporaryDirectory() as tmp_dir:
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
model.save_pretrained(tmp_dir, max_shard_size=max_size)
# Get each shard file and its size
shard_to_size = {}
for shard in os.listdir(tmp_dir):
if shard.endswith(".msgpack"):
shard_file = os.path.join(tmp_dir, shard)
shard_to_size[shard_file] = os.path.getsize(shard_file)
index_file = os.path.join(tmp_dir, FLAX_WEIGHTS_INDEX_NAME)
# Check there is an index but no regular weight file
self.assertTrue(os.path.isfile(index_file))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, FLAX_WEIGHTS_NAME)))
# Check a file is bigger than max_size only when it has a single weight
for shard_file, size in shard_to_size.items():
if max_size.endswith("kiB"):
max_size_int = int(max_size[:-3]) * 2**10
else:
max_size_int = int(max_size[:-2]) * 10**3
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
# the size asked for (since we count parameters)
if size >= max_size_int + 50000:
with open(shard_file, "rb") as state_f:
state_file = from_bytes(FlaxBertModel, state_f.read())
self.assertEqual(len(state_file), 1)
# Check the index and the shard files found match
with open(index_file, "r", encoding="utf-8") as f:
index = json.loads(f.read())
all_shards = set(index["weight_map"].values())
shards_found = set(f for f in os.listdir(tmp_dir) if f.endswith(".msgpack"))
self.assertSetEqual(all_shards, shards_found)
# Finally, check the model can be reloaded
new_model = FlaxBertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
self.assertTrue(np.allclose(np.array(p1), np.array(p2)))
@require_flax
@is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment