Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
16c6eb7c
Unverified
Commit
16c6eb7c
authored
Jun 22, 2022
by
Arthur
Committed by
GitHub
Jun 22, 2022
Browse files
Flax sharded (#17760)
parent
3b00b623
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
305 additions
and
22 deletions
+305
-22
src/transformers/modeling_flax_utils.py
src/transformers/modeling_flax_utils.py
+246
-17
src/transformers/utils/__init__.py
src/transformers/utils/__init__.py
+1
-0
src/transformers/utils/hub.py
src/transformers/utils/hub.py
+2
-5
tests/test_modeling_flax_common.py
tests/test_modeling_flax_common.py
+56
-0
No files found.
src/transformers/modeling_flax_utils.py
View file @
16c6eb7c
...
...
@@ -13,11 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
gc
import
json
import
os
import
re
from
functools
import
partial
from
pickle
import
UnpicklingError
from
typing
import
Any
,
Dict
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
flax.linen
as
nn
import
jax
import
jax.numpy
as
jnp
...
...
@@ -33,6 +39,7 @@ from .dynamic_module_utils import custom_object_save
from
.generation_flax_utils
import
FlaxGenerationMixin
from
.modeling_flax_pytorch_utils
import
load_pytorch_checkpoint_in_flax_state_dict
from
.utils
import
(
FLAX_WEIGHTS_INDEX_NAME
,
FLAX_WEIGHTS_NAME
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
WEIGHTS_NAME
,
...
...
@@ -51,6 +58,7 @@ from .utils import (
logging
,
replace_return_docstrings
,
)
from
.utils.hub
import
convert_file_size_to_int
,
get_checkpoint_shard_files
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -70,6 +78,88 @@ ACT2FN = {
}
def
dtype_byte_size
(
dtype
):
"""
Returns the size (in bytes) occupied by one parameter of type `dtype`. Example:
```py
>>> dtype_byte_size(np.float32)
4
```
"""
if
dtype
==
np
.
bool
:
return
1
/
8
bit_search
=
re
.
search
(
"[^\d](\d+)$"
,
dtype
.
name
)
if
bit_search
is
None
:
raise
ValueError
(
f
"`dtype` is not a valid dtype:
{
dtype
}
."
)
bit_size
=
int
(
bit_search
.
groups
()[
0
])
return
bit_size
//
8
def
flax_shard_checkpoint
(
params
,
max_shard_size
=
"10GB"
):
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size. The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so
there is no optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For
example, if the limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as
[6GB], [6+2GB], [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
<Tip warning={true}>
If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will
have a size greater than `max_shard_size`.
</Tip>
Args:
params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters.
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
(like `"5MB"`).
"""
max_shard_size
=
convert_file_size_to_int
(
max_shard_size
)
sharded_state_dicts
=
[]
current_block
=
{}
current_block_size
=
0
total_size
=
0
# flatten the weights to chunk
weights
=
flatten_dict
(
params
,
sep
=
"/"
)
for
item
in
weights
:
weight_size
=
weights
[
item
].
size
*
dtype_byte_size
(
weights
[
item
].
dtype
)
# If this weight is going to tip up over the maximal size, we split.
if
current_block_size
+
weight_size
>
max_shard_size
:
sharded_state_dicts
.
append
(
current_block
)
current_block
=
{}
current_block_size
=
0
current_block
[
item
]
=
weights
[
item
]
current_block_size
+=
weight_size
total_size
+=
weight_size
# Add the last block
sharded_state_dicts
.
append
(
current_block
)
# If we only have one shard, we return it
if
len
(
sharded_state_dicts
)
==
1
:
return
{
FLAX_WEIGHTS_NAME
:
sharded_state_dicts
[
0
]},
None
# Otherwise, let's build the index
weight_map
=
{}
shards
=
{}
for
idx
,
shard
in
enumerate
(
sharded_state_dicts
):
shard_file
=
FLAX_WEIGHTS_NAME
.
replace
(
".msgpack"
,
f
"-
{
idx
+
1
:
05
d
}
-of-
{
len
(
sharded_state_dicts
):
05
d
}
.msgpack"
)
shards
[
shard_file
]
=
shard
for
weight_name
in
shard
.
keys
():
weight_map
[
weight_name
]
=
shard_file
# Add the metadata
metadata
=
{
"total_size"
:
total_size
}
index
=
{
"metadata"
:
metadata
,
"weight_map"
:
weight_map
}
return
shards
,
index
class
FlaxPreTrainedModel
(
PushToHubMixin
,
FlaxGenerationMixin
):
r
"""
Base class for all models.
...
...
@@ -333,6 +423,53 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
```"""
return
self
.
_cast_floating_to
(
params
,
jnp
.
float16
,
mask
)
@
classmethod
def
load_flax_sharded_weights
(
cls
,
shard_files
):
"""
This is the same as [`flax.serialization.from_bytes`]
(https:lax.readthedocs.io/en/latest/_modules/flax/serialization.html#from_bytes) but for a sharded checkpoint.
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
loaded in the model.
Args:
shard_files (`List[str]`:
The list of shard files to load.
Returns:
`Dict`: A nested dictionary of the model parameters, in the expected format for flax models : `{'model':
{'params': {'...'}}}`.
"""
# Load the index
state_sharded_dict
=
dict
()
for
shard_file
in
shard_files
:
# load using msgpack utils
try
:
with
open
(
shard_file
,
"rb"
)
as
state_f
:
state
=
from_bytes
(
cls
,
state_f
.
read
())
except
(
UnpicklingError
,
msgpack
.
exceptions
.
ExtraData
)
as
e
:
with
open
(
shard_file
)
as
f
:
if
f
.
read
().
startswith
(
"version"
):
raise
OSError
(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else
:
raise
ValueError
from
e
except
(
UnicodeDecodeError
,
ValueError
):
raise
EnvironmentError
(
f
"Unable to convert
{
shard_file
}
to Flax deserializable object. "
)
state
=
flatten_dict
(
state
,
sep
=
"/"
)
state_sharded_dict
.
update
(
state
)
del
state
gc
.
collect
()
# the state dict is unflattened to the match the format of model.params
return
unflatten_dict
(
state_sharded_dict
,
sep
=
"/"
)
@
classmethod
def
from_pretrained
(
cls
,
...
...
@@ -489,6 +626,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# Add the dtype to model_kwargs
model_kwargs
[
"dtype"
]
=
dtype
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
is_sharded
=
False
# Load model
if
pretrained_model_name_or_path
is
not
None
:
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
...
...
@@ -498,6 +639,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
elif
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
FLAX_WEIGHTS_NAME
)):
# Load from a Flax checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
FLAX_WEIGHTS_NAME
)
elif
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
FLAX_WEIGHTS_INDEX_NAME
)):
# Load from a sharded Flax checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
FLAX_WEIGHTS_INDEX_NAME
)
is_sharded
=
True
# At this stage we don't have a weight file so we will raise an error.
elif
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
):
raise
EnvironmentError
(
...
...
@@ -521,6 +666,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
...
...
@@ -548,18 +694,37 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
)
except
EntryNotFoundError
:
if
filename
==
FLAX_WEIGHTS_NAME
:
has_file_kwargs
=
{
"revision"
:
revision
,
"proxies"
:
proxies
,
"use_auth_token"
:
use_auth_token
}
if
has_file
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
,
**
has_file_kwargs
):
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named"
f
"
{
FLAX_WEIGHTS_
NAME
}
but there is a file for PyTorch weights. Use `from_pt=True` to load"
" this model from those weights."
try
:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
archive_file
=
hf_bucket_url
(
pretrained_model_name_or_path
,
f
ilename
=
FLAX_WEIGHTS_
INDEX_NAME
,
revision
=
revision
,
)
else
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named"
f
"
{
FLAX_WEIGHTS_NAME
}
or
{
WEIGHTS_NAME
}
."
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
)
is_sharded
=
True
except
EntryNotFoundError
:
has_file_kwargs
=
{
"revision"
:
revision
,
"proxies"
:
proxies
,
"use_auth_token"
:
use_auth_token
}
if
has_file
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
,
**
has_file_kwargs
):
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named"
f
"
{
FLAX_WEIGHTS_NAME
}
but there is a file for PyTorch weights. Use `from_pt=True` to"
" load this model from those weights."
)
else
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named"
f
"
{
FLAX_WEIGHTS_NAME
}
or
{
WEIGHTS_NAME
}
."
)
else
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
filename
}
."
...
...
@@ -592,15 +757,35 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
else
:
resolved_archive_file
=
None
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
if
is_sharded
:
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_archive_file
,
_
=
get_checkpoint_shard_files
(
pretrained_model_name_or_path
,
resolved_archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
revision
=
revision
,
)
# init random models
model
=
cls
(
config
,
*
model_args
,
_do_init
=
_do_init
,
**
model_kwargs
)
if
from_pt
:
state
=
load_pytorch_checkpoint_in_flax_state_dict
(
model
,
resolved_archive_file
)
else
:
with
open
(
resolved_archive_file
,
"rb"
)
as
state_f
:
if
is_sharded
:
state
=
cls
.
load_flax_sharded_weights
(
resolved_archive_file
)
else
:
try
:
state
=
from_bytes
(
cls
,
state_f
.
read
())
with
open
(
resolved_archive_file
,
"rb"
)
as
state_f
:
state
=
from_bytes
(
cls
,
state_f
.
read
())
except
(
UnpicklingError
,
msgpack
.
exceptions
.
ExtraData
)
as
e
:
try
:
with
open
(
resolved_archive_file
)
as
f
:
...
...
@@ -742,7 +927,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
else
:
return
model
,
unflatten_dict
(
state
)
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
params
=
None
,
push_to_hub
=
False
,
**
kwargs
):
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
params
=
None
,
push_to_hub
=
False
,
max_shard_size
=
"10GB"
,
**
kwargs
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`[`~FlaxPreTrainedModel.from_pretrained`]` class method
...
...
@@ -761,6 +948,17 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
</Tip>
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
<Tip warning={true}>
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
which will be bigger than `max_shard_size`.
</Tip>
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
...
...
@@ -788,10 +986,41 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# save model
output_model_file
=
os
.
path
.
join
(
save_directory
,
FLAX_WEIGHTS_NAME
)
with
open
(
output_model_file
,
"wb"
)
as
f
:
params
=
params
if
params
is
not
None
else
self
.
params
model_bytes
=
to_bytes
(
params
)
f
.
write
(
model_bytes
)
shards
,
index
=
flax_shard_checkpoint
(
params
if
params
is
not
None
else
self
.
params
,
max_shard_size
)
# Clean the folder from a previous save
for
filename
in
os
.
listdir
(
save_directory
):
full_filename
=
os
.
path
.
join
(
save_directory
,
filename
)
if
(
filename
.
startswith
(
FLAX_WEIGHTS_NAME
[:
-
4
])
and
os
.
path
.
isfile
(
full_filename
)
and
filename
not
in
shards
.
keys
()
):
os
.
remove
(
full_filename
)
if
index
is
None
:
with
open
(
output_model_file
,
"wb"
)
as
f
:
params
=
params
if
params
is
not
None
else
self
.
params
model_bytes
=
to_bytes
(
params
)
f
.
write
(
model_bytes
)
else
:
save_index_file
=
os
.
path
.
join
(
save_directory
,
FLAX_WEIGHTS_INDEX_NAME
)
# Save the index as well
with
open
(
save_index_file
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
content
=
json
.
dumps
(
index
,
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
f
.
write
(
content
)
logger
.
info
(
f
"The model is bigger than the maximum size per checkpoint (
{
max_shard_size
}
) and is going to be "
f
"split in
{
len
(
shards
)
}
checkpoint shards. You can find where each parameters has been saved in the "
f
"index located at
{
save_index_file
}
."
)
for
shard_file
,
shard
in
shards
.
items
():
# the shard item are unflattened, to save them we need to flatten them again
with
open
(
os
.
path
.
join
(
save_directory
,
shard_file
),
mode
=
"wb"
)
as
f
:
params
=
unflatten_dict
(
shard
,
sep
=
"/"
)
shard_bytes
=
to_bytes
(
params
)
f
.
write
(
shard_bytes
)
logger
.
info
(
f
"Model weights saved in
{
output_model_file
}
"
)
...
...
src/transformers/utils/__init__.py
View file @
16c6eb7c
...
...
@@ -151,6 +151,7 @@ TF2_WEIGHTS_NAME = "tf_model.h5"
TF2_WEIGHTS_INDEX_NAME
=
"tf_model.h5.index.json"
TF_WEIGHTS_NAME
=
"model.ckpt"
FLAX_WEIGHTS_NAME
=
"flax_model.msgpack"
FLAX_WEIGHTS_INDEX_NAME
=
"flax_model.msgpack.index.json"
CONFIG_NAME
=
"config.json"
FEATURE_EXTRACTOR_NAME
=
"preprocessor_config.json"
MODEL_CARD_NAME
=
"modelcard.json"
...
...
src/transformers/utils/hub.py
View file @
16c6eb7c
...
...
@@ -937,7 +937,7 @@ class PushToHubMixin:
use_auth_token
=
use_auth_token
,
)
# Save the files in the cloned repo
self
.
save_pretrained
(
repo_path_or_name
,
max_shard_size
=
max_shard_size
)
if
hasattr
(
self
,
"history"
)
and
hasattr
(
self
,
"create_model_card"
):
self
.
save_pretrained
(
repo_path_or_name
,
max_shard_size
=
max_shard_size
)
# This is a Keras model and we might be able to fish out its History and make a model card out of it
...
...
@@ -947,9 +947,7 @@ class PushToHubMixin:
}
base_model_card_args
.
update
(
model_card_kwargs
)
self
.
create_model_card
(
**
base_model_card_args
)
else
:
# FLAX does not support sharding yet, will come in next PR
self
.
save_pretrained
(
repo_path_or_name
)
# Commit and push!
url
=
self
.
_push_to_hub
(
repo
,
commit_message
=
commit_message
)
...
...
@@ -1090,7 +1088,6 @@ def convert_file_size_to_int(size: Union[int, str]):
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
Example:
```py
>>> convert_file_size_to_int("1MiB")
1048576
...
...
tests/test_modeling_flax_common.py
View file @
16c6eb7c
...
...
@@ -14,6 +14,7 @@
import
copy
import
inspect
import
json
import
random
import
tempfile
import
unittest
...
...
@@ -45,6 +46,7 @@ if is_flax_available():
import
jax
import
jax.numpy
as
jnp
from
flax.core.frozen_dict
import
FrozenDict
,
freeze
,
unfreeze
from
flax.serialization
import
from_bytes
from
flax.traverse_util
import
flatten_dict
,
unflatten_dict
from
transformers
import
(
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
,
...
...
@@ -58,6 +60,7 @@ if is_flax_available():
convert_pytorch_state_dict_to_flax
,
load_flax_weights_in_pytorch_model
,
)
from
transformers.modeling_flax_utils
import
FLAX_WEIGHTS_INDEX_NAME
,
FLAX_WEIGHTS_NAME
os
.
environ
[
"XLA_PYTHON_CLIENT_MEM_FRACTION"
]
=
"0.12"
# assumed parallelism: 8
...
...
@@ -1043,6 +1046,59 @@ class FlaxModelTesterMixin:
# Check if all required parmas are loaded
_assert_all_params_initialised
(
model
,
params
)
def
test_checkpoint_sharding_from_hub
(
self
):
model
=
FlaxBertModel
.
from_pretrained
(
"ArthurZ/flax-tiny-random-bert-sharded"
)
# the model above is the same as the model below, just a sharded version.
ref_model
=
FlaxBertModel
.
from_pretrained
(
"hf-internal-testing/tiny-bert-flax-only"
)
for
p1
,
p2
in
zip
(
flatten_dict
(
model
.
params
).
values
(),
flatten_dict
(
ref_model
.
params
).
values
()):
assert
np
.
allclose
(
np
.
array
(
p1
),
np
.
array
(
p2
))
def
test_checkpoint_sharding_local
(
self
):
model
=
FlaxBertModel
.
from_pretrained
(
"hf-internal-testing/tiny-bert-flax-only"
)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
for
max_size
in
[
"150kB"
,
"150kiB"
,
"200kB"
,
"200kiB"
]:
model
.
save_pretrained
(
tmp_dir
,
max_shard_size
=
max_size
)
# Get each shard file and its size
shard_to_size
=
{}
for
shard
in
os
.
listdir
(
tmp_dir
):
if
shard
.
endswith
(
".msgpack"
):
shard_file
=
os
.
path
.
join
(
tmp_dir
,
shard
)
shard_to_size
[
shard_file
]
=
os
.
path
.
getsize
(
shard_file
)
index_file
=
os
.
path
.
join
(
tmp_dir
,
FLAX_WEIGHTS_INDEX_NAME
)
# Check there is an index but no regular weight file
self
.
assertTrue
(
os
.
path
.
isfile
(
index_file
))
self
.
assertFalse
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmp_dir
,
FLAX_WEIGHTS_NAME
)))
# Check a file is bigger than max_size only when it has a single weight
for
shard_file
,
size
in
shard_to_size
.
items
():
if
max_size
.
endswith
(
"kiB"
):
max_size_int
=
int
(
max_size
[:
-
3
])
*
2
**
10
else
:
max_size_int
=
int
(
max_size
[:
-
2
])
*
10
**
3
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
# the size asked for (since we count parameters)
if
size
>=
max_size_int
+
50000
:
with
open
(
shard_file
,
"rb"
)
as
state_f
:
state_file
=
from_bytes
(
FlaxBertModel
,
state_f
.
read
())
self
.
assertEqual
(
len
(
state_file
),
1
)
# Check the index and the shard files found match
with
open
(
index_file
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
index
=
json
.
loads
(
f
.
read
())
all_shards
=
set
(
index
[
"weight_map"
].
values
())
shards_found
=
set
(
f
for
f
in
os
.
listdir
(
tmp_dir
)
if
f
.
endswith
(
".msgpack"
))
self
.
assertSetEqual
(
all_shards
,
shards_found
)
# Finally, check the model can be reloaded
new_model
=
FlaxBertModel
.
from_pretrained
(
tmp_dir
)
for
p1
,
p2
in
zip
(
flatten_dict
(
model
.
params
).
values
(),
flatten_dict
(
new_model
.
params
).
values
()):
self
.
assertTrue
(
np
.
allclose
(
np
.
array
(
p1
),
np
.
array
(
p2
)))
@
require_flax
@
is_staging_test
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment