Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
16c6eb7c
Unverified
Commit
16c6eb7c
authored
Jun 22, 2022
by
Arthur
Committed by
GitHub
Jun 22, 2022
Browse files
Flax sharded (#17760)
parent
3b00b623
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
305 additions
and
22 deletions
+305
-22
src/transformers/modeling_flax_utils.py
src/transformers/modeling_flax_utils.py
+246
-17
src/transformers/utils/__init__.py
src/transformers/utils/__init__.py
+1
-0
src/transformers/utils/hub.py
src/transformers/utils/hub.py
+2
-5
tests/test_modeling_flax_common.py
tests/test_modeling_flax_common.py
+56
-0
No files found.
src/transformers/modeling_flax_utils.py
View file @
16c6eb7c
...
@@ -13,11 +13,17 @@
...
@@ -13,11 +13,17 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
gc
import
json
import
os
import
os
import
re
from
functools
import
partial
from
functools
import
partial
from
pickle
import
UnpicklingError
from
pickle
import
UnpicklingError
from
typing
import
Any
,
Dict
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
flax.linen
as
nn
import
flax.linen
as
nn
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
...
@@ -33,6 +39,7 @@ from .dynamic_module_utils import custom_object_save
...
@@ -33,6 +39,7 @@ from .dynamic_module_utils import custom_object_save
from
.generation_flax_utils
import
FlaxGenerationMixin
from
.generation_flax_utils
import
FlaxGenerationMixin
from
.modeling_flax_pytorch_utils
import
load_pytorch_checkpoint_in_flax_state_dict
from
.modeling_flax_pytorch_utils
import
load_pytorch_checkpoint_in_flax_state_dict
from
.utils
import
(
from
.utils
import
(
FLAX_WEIGHTS_INDEX_NAME
,
FLAX_WEIGHTS_NAME
,
FLAX_WEIGHTS_NAME
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
WEIGHTS_NAME
,
WEIGHTS_NAME
,
...
@@ -51,6 +58,7 @@ from .utils import (
...
@@ -51,6 +58,7 @@ from .utils import (
logging
,
logging
,
replace_return_docstrings
,
replace_return_docstrings
,
)
)
from
.utils.hub
import
convert_file_size_to_int
,
get_checkpoint_shard_files
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
...
@@ -70,6 +78,88 @@ ACT2FN = {
...
@@ -70,6 +78,88 @@ ACT2FN = {
}
}
def
dtype_byte_size
(
dtype
):
"""
Returns the size (in bytes) occupied by one parameter of type `dtype`. Example:
```py
>>> dtype_byte_size(np.float32)
4
```
"""
if
dtype
==
np
.
bool
:
return
1
/
8
bit_search
=
re
.
search
(
"[^\d](\d+)$"
,
dtype
.
name
)
if
bit_search
is
None
:
raise
ValueError
(
f
"`dtype` is not a valid dtype:
{
dtype
}
."
)
bit_size
=
int
(
bit_search
.
groups
()[
0
])
return
bit_size
//
8
def
flax_shard_checkpoint
(
params
,
max_shard_size
=
"10GB"
):
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size. The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so
there is no optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For
example, if the limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as
[6GB], [6+2GB], [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
<Tip warning={true}>
If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will
have a size greater than `max_shard_size`.
</Tip>
Args:
params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters.
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
(like `"5MB"`).
"""
max_shard_size
=
convert_file_size_to_int
(
max_shard_size
)
sharded_state_dicts
=
[]
current_block
=
{}
current_block_size
=
0
total_size
=
0
# flatten the weights to chunk
weights
=
flatten_dict
(
params
,
sep
=
"/"
)
for
item
in
weights
:
weight_size
=
weights
[
item
].
size
*
dtype_byte_size
(
weights
[
item
].
dtype
)
# If this weight is going to tip up over the maximal size, we split.
if
current_block_size
+
weight_size
>
max_shard_size
:
sharded_state_dicts
.
append
(
current_block
)
current_block
=
{}
current_block_size
=
0
current_block
[
item
]
=
weights
[
item
]
current_block_size
+=
weight_size
total_size
+=
weight_size
# Add the last block
sharded_state_dicts
.
append
(
current_block
)
# If we only have one shard, we return it
if
len
(
sharded_state_dicts
)
==
1
:
return
{
FLAX_WEIGHTS_NAME
:
sharded_state_dicts
[
0
]},
None
# Otherwise, let's build the index
weight_map
=
{}
shards
=
{}
for
idx
,
shard
in
enumerate
(
sharded_state_dicts
):
shard_file
=
FLAX_WEIGHTS_NAME
.
replace
(
".msgpack"
,
f
"-
{
idx
+
1
:
05
d
}
-of-
{
len
(
sharded_state_dicts
):
05
d
}
.msgpack"
)
shards
[
shard_file
]
=
shard
for
weight_name
in
shard
.
keys
():
weight_map
[
weight_name
]
=
shard_file
# Add the metadata
metadata
=
{
"total_size"
:
total_size
}
index
=
{
"metadata"
:
metadata
,
"weight_map"
:
weight_map
}
return
shards
,
index
class
FlaxPreTrainedModel
(
PushToHubMixin
,
FlaxGenerationMixin
):
class
FlaxPreTrainedModel
(
PushToHubMixin
,
FlaxGenerationMixin
):
r
"""
r
"""
Base class for all models.
Base class for all models.
...
@@ -333,6 +423,53 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
...
@@ -333,6 +423,53 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
```"""
```"""
return
self
.
_cast_floating_to
(
params
,
jnp
.
float16
,
mask
)
return
self
.
_cast_floating_to
(
params
,
jnp
.
float16
,
mask
)
@
classmethod
def
load_flax_sharded_weights
(
cls
,
shard_files
):
"""
This is the same as [`flax.serialization.from_bytes`]
(https:lax.readthedocs.io/en/latest/_modules/flax/serialization.html#from_bytes) but for a sharded checkpoint.
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
loaded in the model.
Args:
shard_files (`List[str]`:
The list of shard files to load.
Returns:
`Dict`: A nested dictionary of the model parameters, in the expected format for flax models : `{'model':
{'params': {'...'}}}`.
"""
# Load the index
state_sharded_dict
=
dict
()
for
shard_file
in
shard_files
:
# load using msgpack utils
try
:
with
open
(
shard_file
,
"rb"
)
as
state_f
:
state
=
from_bytes
(
cls
,
state_f
.
read
())
except
(
UnpicklingError
,
msgpack
.
exceptions
.
ExtraData
)
as
e
:
with
open
(
shard_file
)
as
f
:
if
f
.
read
().
startswith
(
"version"
):
raise
OSError
(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else
:
raise
ValueError
from
e
except
(
UnicodeDecodeError
,
ValueError
):
raise
EnvironmentError
(
f
"Unable to convert
{
shard_file
}
to Flax deserializable object. "
)
state
=
flatten_dict
(
state
,
sep
=
"/"
)
state_sharded_dict
.
update
(
state
)
del
state
gc
.
collect
()
# the state dict is unflattened to the match the format of model.params
return
unflatten_dict
(
state_sharded_dict
,
sep
=
"/"
)
@
classmethod
@
classmethod
def
from_pretrained
(
def
from_pretrained
(
cls
,
cls
,
...
@@ -489,6 +626,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
...
@@ -489,6 +626,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# Add the dtype to model_kwargs
# Add the dtype to model_kwargs
model_kwargs
[
"dtype"
]
=
dtype
model_kwargs
[
"dtype"
]
=
dtype
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
is_sharded
=
False
# Load model
# Load model
if
pretrained_model_name_or_path
is
not
None
:
if
pretrained_model_name_or_path
is
not
None
:
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
...
@@ -498,6 +639,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
...
@@ -498,6 +639,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
elif
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
FLAX_WEIGHTS_NAME
)):
elif
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
FLAX_WEIGHTS_NAME
)):
# Load from a Flax checkpoint
# Load from a Flax checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
FLAX_WEIGHTS_NAME
)
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
FLAX_WEIGHTS_NAME
)
elif
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
FLAX_WEIGHTS_INDEX_NAME
)):
# Load from a sharded Flax checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
FLAX_WEIGHTS_INDEX_NAME
)
is_sharded
=
True
# At this stage we don't have a weight file so we will raise an error.
# At this stage we don't have a weight file so we will raise an error.
elif
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
):
elif
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
):
raise
EnvironmentError
(
raise
EnvironmentError
(
...
@@ -521,6 +666,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
...
@@ -521,6 +666,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
)
)
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_archive_file
=
cached_path
(
resolved_archive_file
=
cached_path
(
archive_file
,
archive_file
,
...
@@ -548,18 +694,37 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
...
@@ -548,18 +694,37 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
)
)
except
EntryNotFoundError
:
except
EntryNotFoundError
:
if
filename
==
FLAX_WEIGHTS_NAME
:
if
filename
==
FLAX_WEIGHTS_NAME
:
has_file_kwargs
=
{
"revision"
:
revision
,
"proxies"
:
proxies
,
"use_auth_token"
:
use_auth_token
}
try
:
if
has_file
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
,
**
has_file_kwargs
):
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
raise
EnvironmentError
(
archive_file
=
hf_bucket_url
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named"
pretrained_model_name_or_path
,
f
"
{
FLAX_WEIGHTS_
NAME
}
but there is a file for PyTorch weights. Use `from_pt=True` to load"
f
ilename
=
FLAX_WEIGHTS_
INDEX_NAME
,
" this model from those weights."
revision
=
revision
,
)
)
else
:
resolved_archive_file
=
cached_path
(
raise
EnvironmentError
(
archive_file
,
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named"
cache_dir
=
cache_dir
,
f
"
{
FLAX_WEIGHTS_NAME
}
or
{
WEIGHTS_NAME
}
."
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
)
)
is_sharded
=
True
except
EntryNotFoundError
:
has_file_kwargs
=
{
"revision"
:
revision
,
"proxies"
:
proxies
,
"use_auth_token"
:
use_auth_token
}
if
has_file
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
,
**
has_file_kwargs
):
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named"
f
"
{
FLAX_WEIGHTS_NAME
}
but there is a file for PyTorch weights. Use `from_pt=True` to"
" load this model from those weights."
)
else
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named"
f
"
{
FLAX_WEIGHTS_NAME
}
or
{
WEIGHTS_NAME
}
."
)
else
:
else
:
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
filename
}
."
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
filename
}
."
...
@@ -592,15 +757,35 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
...
@@ -592,15 +757,35 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
else
:
else
:
resolved_archive_file
=
None
resolved_archive_file
=
None
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
if
is_sharded
:
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_archive_file
,
_
=
get_checkpoint_shard_files
(
pretrained_model_name_or_path
,
resolved_archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
revision
=
revision
,
)
# init random models
# init random models
model
=
cls
(
config
,
*
model_args
,
_do_init
=
_do_init
,
**
model_kwargs
)
model
=
cls
(
config
,
*
model_args
,
_do_init
=
_do_init
,
**
model_kwargs
)
if
from_pt
:
if
from_pt
:
state
=
load_pytorch_checkpoint_in_flax_state_dict
(
model
,
resolved_archive_file
)
state
=
load_pytorch_checkpoint_in_flax_state_dict
(
model
,
resolved_archive_file
)
else
:
else
:
with
open
(
resolved_archive_file
,
"rb"
)
as
state_f
:
if
is_sharded
:
state
=
cls
.
load_flax_sharded_weights
(
resolved_archive_file
)
else
:
try
:
try
:
state
=
from_bytes
(
cls
,
state_f
.
read
())
with
open
(
resolved_archive_file
,
"rb"
)
as
state_f
:
state
=
from_bytes
(
cls
,
state_f
.
read
())
except
(
UnpicklingError
,
msgpack
.
exceptions
.
ExtraData
)
as
e
:
except
(
UnpicklingError
,
msgpack
.
exceptions
.
ExtraData
)
as
e
:
try
:
try
:
with
open
(
resolved_archive_file
)
as
f
:
with
open
(
resolved_archive_file
)
as
f
:
...
@@ -742,7 +927,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
...
@@ -742,7 +927,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
else
:
else
:
return
model
,
unflatten_dict
(
state
)
return
model
,
unflatten_dict
(
state
)
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
params
=
None
,
push_to_hub
=
False
,
**
kwargs
):
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
params
=
None
,
push_to_hub
=
False
,
max_shard_size
=
"10GB"
,
**
kwargs
):
"""
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`[`~FlaxPreTrainedModel.from_pretrained`]` class method
`[`~FlaxPreTrainedModel.from_pretrained`]` class method
...
@@ -761,6 +948,17 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
...
@@ -761,6 +948,17 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
</Tip>
</Tip>
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
<Tip warning={true}>
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
which will be bigger than `max_shard_size`.
</Tip>
kwargs:
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
"""
...
@@ -788,10 +986,41 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
...
@@ -788,10 +986,41 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# save model
# save model
output_model_file
=
os
.
path
.
join
(
save_directory
,
FLAX_WEIGHTS_NAME
)
output_model_file
=
os
.
path
.
join
(
save_directory
,
FLAX_WEIGHTS_NAME
)
with
open
(
output_model_file
,
"wb"
)
as
f
:
params
=
params
if
params
is
not
None
else
self
.
params
shards
,
index
=
flax_shard_checkpoint
(
params
if
params
is
not
None
else
self
.
params
,
max_shard_size
)
model_bytes
=
to_bytes
(
params
)
# Clean the folder from a previous save
f
.
write
(
model_bytes
)
for
filename
in
os
.
listdir
(
save_directory
):
full_filename
=
os
.
path
.
join
(
save_directory
,
filename
)
if
(
filename
.
startswith
(
FLAX_WEIGHTS_NAME
[:
-
4
])
and
os
.
path
.
isfile
(
full_filename
)
and
filename
not
in
shards
.
keys
()
):
os
.
remove
(
full_filename
)
if
index
is
None
:
with
open
(
output_model_file
,
"wb"
)
as
f
:
params
=
params
if
params
is
not
None
else
self
.
params
model_bytes
=
to_bytes
(
params
)
f
.
write
(
model_bytes
)
else
:
save_index_file
=
os
.
path
.
join
(
save_directory
,
FLAX_WEIGHTS_INDEX_NAME
)
# Save the index as well
with
open
(
save_index_file
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
content
=
json
.
dumps
(
index
,
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
f
.
write
(
content
)
logger
.
info
(
f
"The model is bigger than the maximum size per checkpoint (
{
max_shard_size
}
) and is going to be "
f
"split in
{
len
(
shards
)
}
checkpoint shards. You can find where each parameters has been saved in the "
f
"index located at
{
save_index_file
}
."
)
for
shard_file
,
shard
in
shards
.
items
():
# the shard item are unflattened, to save them we need to flatten them again
with
open
(
os
.
path
.
join
(
save_directory
,
shard_file
),
mode
=
"wb"
)
as
f
:
params
=
unflatten_dict
(
shard
,
sep
=
"/"
)
shard_bytes
=
to_bytes
(
params
)
f
.
write
(
shard_bytes
)
logger
.
info
(
f
"Model weights saved in
{
output_model_file
}
"
)
logger
.
info
(
f
"Model weights saved in
{
output_model_file
}
"
)
...
...
src/transformers/utils/__init__.py
View file @
16c6eb7c
...
@@ -151,6 +151,7 @@ TF2_WEIGHTS_NAME = "tf_model.h5"
...
@@ -151,6 +151,7 @@ TF2_WEIGHTS_NAME = "tf_model.h5"
TF2_WEIGHTS_INDEX_NAME
=
"tf_model.h5.index.json"
TF2_WEIGHTS_INDEX_NAME
=
"tf_model.h5.index.json"
TF_WEIGHTS_NAME
=
"model.ckpt"
TF_WEIGHTS_NAME
=
"model.ckpt"
FLAX_WEIGHTS_NAME
=
"flax_model.msgpack"
FLAX_WEIGHTS_NAME
=
"flax_model.msgpack"
FLAX_WEIGHTS_INDEX_NAME
=
"flax_model.msgpack.index.json"
CONFIG_NAME
=
"config.json"
CONFIG_NAME
=
"config.json"
FEATURE_EXTRACTOR_NAME
=
"preprocessor_config.json"
FEATURE_EXTRACTOR_NAME
=
"preprocessor_config.json"
MODEL_CARD_NAME
=
"modelcard.json"
MODEL_CARD_NAME
=
"modelcard.json"
...
...
src/transformers/utils/hub.py
View file @
16c6eb7c
...
@@ -937,7 +937,7 @@ class PushToHubMixin:
...
@@ -937,7 +937,7 @@ class PushToHubMixin:
use_auth_token
=
use_auth_token
,
use_auth_token
=
use_auth_token
,
)
)
# Save the files in the cloned repo
# Save the files in the cloned repo
self
.
save_pretrained
(
repo_path_or_name
,
max_shard_size
=
max_shard_size
)
if
hasattr
(
self
,
"history"
)
and
hasattr
(
self
,
"create_model_card"
):
if
hasattr
(
self
,
"history"
)
and
hasattr
(
self
,
"create_model_card"
):
self
.
save_pretrained
(
repo_path_or_name
,
max_shard_size
=
max_shard_size
)
self
.
save_pretrained
(
repo_path_or_name
,
max_shard_size
=
max_shard_size
)
# This is a Keras model and we might be able to fish out its History and make a model card out of it
# This is a Keras model and we might be able to fish out its History and make a model card out of it
...
@@ -947,9 +947,7 @@ class PushToHubMixin:
...
@@ -947,9 +947,7 @@ class PushToHubMixin:
}
}
base_model_card_args
.
update
(
model_card_kwargs
)
base_model_card_args
.
update
(
model_card_kwargs
)
self
.
create_model_card
(
**
base_model_card_args
)
self
.
create_model_card
(
**
base_model_card_args
)
else
:
# FLAX does not support sharding yet, will come in next PR
self
.
save_pretrained
(
repo_path_or_name
)
# Commit and push!
# Commit and push!
url
=
self
.
_push_to_hub
(
repo
,
commit_message
=
commit_message
)
url
=
self
.
_push_to_hub
(
repo
,
commit_message
=
commit_message
)
...
@@ -1090,7 +1088,6 @@ def convert_file_size_to_int(size: Union[int, str]):
...
@@ -1090,7 +1088,6 @@ def convert_file_size_to_int(size: Union[int, str]):
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
Example:
Example:
```py
```py
>>> convert_file_size_to_int("1MiB")
>>> convert_file_size_to_int("1MiB")
1048576
1048576
...
...
tests/test_modeling_flax_common.py
View file @
16c6eb7c
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
import
copy
import
copy
import
inspect
import
inspect
import
json
import
random
import
random
import
tempfile
import
tempfile
import
unittest
import
unittest
...
@@ -45,6 +46,7 @@ if is_flax_available():
...
@@ -45,6 +46,7 @@ if is_flax_available():
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
flax.core.frozen_dict
import
FrozenDict
,
freeze
,
unfreeze
from
flax.core.frozen_dict
import
FrozenDict
,
freeze
,
unfreeze
from
flax.serialization
import
from_bytes
from
flax.traverse_util
import
flatten_dict
,
unflatten_dict
from
flax.traverse_util
import
flatten_dict
,
unflatten_dict
from
transformers
import
(
from
transformers
import
(
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
,
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
,
...
@@ -58,6 +60,7 @@ if is_flax_available():
...
@@ -58,6 +60,7 @@ if is_flax_available():
convert_pytorch_state_dict_to_flax
,
convert_pytorch_state_dict_to_flax
,
load_flax_weights_in_pytorch_model
,
load_flax_weights_in_pytorch_model
,
)
)
from
transformers.modeling_flax_utils
import
FLAX_WEIGHTS_INDEX_NAME
,
FLAX_WEIGHTS_NAME
os
.
environ
[
"XLA_PYTHON_CLIENT_MEM_FRACTION"
]
=
"0.12"
# assumed parallelism: 8
os
.
environ
[
"XLA_PYTHON_CLIENT_MEM_FRACTION"
]
=
"0.12"
# assumed parallelism: 8
...
@@ -1043,6 +1046,59 @@ class FlaxModelTesterMixin:
...
@@ -1043,6 +1046,59 @@ class FlaxModelTesterMixin:
# Check if all required parmas are loaded
# Check if all required parmas are loaded
_assert_all_params_initialised
(
model
,
params
)
_assert_all_params_initialised
(
model
,
params
)
def
test_checkpoint_sharding_from_hub
(
self
):
model
=
FlaxBertModel
.
from_pretrained
(
"ArthurZ/flax-tiny-random-bert-sharded"
)
# the model above is the same as the model below, just a sharded version.
ref_model
=
FlaxBertModel
.
from_pretrained
(
"hf-internal-testing/tiny-bert-flax-only"
)
for
p1
,
p2
in
zip
(
flatten_dict
(
model
.
params
).
values
(),
flatten_dict
(
ref_model
.
params
).
values
()):
assert
np
.
allclose
(
np
.
array
(
p1
),
np
.
array
(
p2
))
def
test_checkpoint_sharding_local
(
self
):
model
=
FlaxBertModel
.
from_pretrained
(
"hf-internal-testing/tiny-bert-flax-only"
)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
for
max_size
in
[
"150kB"
,
"150kiB"
,
"200kB"
,
"200kiB"
]:
model
.
save_pretrained
(
tmp_dir
,
max_shard_size
=
max_size
)
# Get each shard file and its size
shard_to_size
=
{}
for
shard
in
os
.
listdir
(
tmp_dir
):
if
shard
.
endswith
(
".msgpack"
):
shard_file
=
os
.
path
.
join
(
tmp_dir
,
shard
)
shard_to_size
[
shard_file
]
=
os
.
path
.
getsize
(
shard_file
)
index_file
=
os
.
path
.
join
(
tmp_dir
,
FLAX_WEIGHTS_INDEX_NAME
)
# Check there is an index but no regular weight file
self
.
assertTrue
(
os
.
path
.
isfile
(
index_file
))
self
.
assertFalse
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmp_dir
,
FLAX_WEIGHTS_NAME
)))
# Check a file is bigger than max_size only when it has a single weight
for
shard_file
,
size
in
shard_to_size
.
items
():
if
max_size
.
endswith
(
"kiB"
):
max_size_int
=
int
(
max_size
[:
-
3
])
*
2
**
10
else
:
max_size_int
=
int
(
max_size
[:
-
2
])
*
10
**
3
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
# the size asked for (since we count parameters)
if
size
>=
max_size_int
+
50000
:
with
open
(
shard_file
,
"rb"
)
as
state_f
:
state_file
=
from_bytes
(
FlaxBertModel
,
state_f
.
read
())
self
.
assertEqual
(
len
(
state_file
),
1
)
# Check the index and the shard files found match
with
open
(
index_file
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
index
=
json
.
loads
(
f
.
read
())
all_shards
=
set
(
index
[
"weight_map"
].
values
())
shards_found
=
set
(
f
for
f
in
os
.
listdir
(
tmp_dir
)
if
f
.
endswith
(
".msgpack"
))
self
.
assertSetEqual
(
all_shards
,
shards_found
)
# Finally, check the model can be reloaded
new_model
=
FlaxBertModel
.
from_pretrained
(
tmp_dir
)
for
p1
,
p2
in
zip
(
flatten_dict
(
model
.
params
).
values
(),
flatten_dict
(
new_model
.
params
).
values
()):
self
.
assertTrue
(
np
.
allclose
(
np
.
array
(
p1
),
np
.
array
(
p2
)))
@
require_flax
@
require_flax
@
is_staging_test
@
is_staging_test
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment