Unverified Commit 9dc8fe1b authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Default to msgpack for safetensors (#27460)



* Default to msgpack for safetensors

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 210e38d8
...@@ -50,7 +50,7 @@ def load_pytorch_checkpoint_in_flax_state_dict( ...@@ -50,7 +50,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
"""Load pytorch checkpoints in a flax model""" """Load pytorch checkpoints in a flax model"""
try: try:
import torch # noqa: F401 import torch # noqa: F401
except ImportError: except (ImportError, ModuleNotFoundError):
logger.error( logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see" "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
...@@ -150,7 +150,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): ...@@ -150,7 +150,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision # numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
try: try:
import torch # noqa: F401 import torch # noqa: F401
except ImportError: except (ImportError, ModuleNotFoundError):
logger.error( logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see" "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
...@@ -349,7 +349,7 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): ...@@ -349,7 +349,7 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
try: try:
import torch # noqa: F401 import torch # noqa: F401
except ImportError: except (ImportError, ModuleNotFoundError):
logger.error( logger.error(
"Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see" "Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
......
...@@ -721,7 +721,14 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -721,7 +721,14 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
if is_safetensors_available() and os.path.isfile( if os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)):
# Load from a sharded Flax checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
is_sharded = True
elif is_safetensors_available() and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
): ):
# Load from a safetensors checkpoint # Load from a safetensors checkpoint
...@@ -735,13 +742,6 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -735,13 +742,6 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# Load from a sharded pytorch checkpoint # Load from a sharded pytorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME) archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
is_sharded = True is_sharded = True
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)):
# Load from a sharded Flax checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
is_sharded = True
# At this stage we don't have a weight file so we will raise an error. # At this stage we don't have a weight file so we will raise an error.
elif is_safetensors_available() and os.path.isfile( elif is_safetensors_available() and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
...@@ -770,8 +770,6 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -770,8 +770,6 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
else: else:
if from_pt: if from_pt:
filename = WEIGHTS_NAME filename = WEIGHTS_NAME
elif is_safetensors_available():
filename = SAFE_WEIGHTS_NAME
else: else:
filename = FLAX_WEIGHTS_NAME filename = FLAX_WEIGHTS_NAME
...@@ -792,22 +790,14 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -792,22 +790,14 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
} }
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME:
# Did not find the safetensors file, let's fallback to Flax.
# No support for sharded safetensors yet, so we'll raise an error if that's all we find.
filename = FLAX_WEIGHTS_NAME
resolved_archive_file = cached_file(
pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **cached_file_kwargs
)
if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case. # Maybe the checkpoint is sharded, we try to grab the index name in this case.
if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
resolved_archive_file = cached_file( resolved_archive_file = cached_file(
pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs
) )
if resolved_archive_file is not None: if resolved_archive_file is not None:
is_sharded = True is_sharded = True
# Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case. # Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case.
if resolved_archive_file is None and from_pt: if resolved_archive_file is None and from_pt:
resolved_archive_file = cached_file( resolved_archive_file = cached_file(
...@@ -815,6 +805,17 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -815,6 +805,17 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
) )
if resolved_archive_file is not None: if resolved_archive_file is not None:
is_sharded = True is_sharded = True
# If we still haven't found anything, look for `safetensors`.
if resolved_archive_file is None:
# No support for sharded safetensors yet, so we'll raise an error if that's all we find.
filename = SAFE_WEIGHTS_NAME
resolved_archive_file = cached_file(
pretrained_model_name_or_path, SAFE_WEIGHTS_NAME, **cached_file_kwargs
)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None: if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Torch model file. We try those to give a helpful error # Otherwise, maybe there is a TF or Torch model file. We try those to give a helpful error
# message. # message.
......
...@@ -19,8 +19,16 @@ import numpy as np ...@@ -19,8 +19,16 @@ import numpy as np
from huggingface_hub import HfFolder, delete_repo, snapshot_download from huggingface_hub import HfFolder, delete_repo, snapshot_download
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import BertConfig, BertModel, is_flax_available from transformers import BertConfig, BertModel, is_flax_available, is_torch_available
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_flax, require_safetensors, require_torch from transformers.testing_utils import (
TOKEN,
USER,
is_pt_flax_cross_test,
is_staging_test,
require_flax,
require_safetensors,
require_torch,
)
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME
...@@ -202,6 +210,7 @@ class FlaxModelUtilsTest(unittest.TestCase): ...@@ -202,6 +210,7 @@ class FlaxModelUtilsTest(unittest.TestCase):
@require_flax @require_flax
@require_torch @require_torch
@is_pt_flax_cross_test
def test_safetensors_save_and_load_pt_to_flax(self): def test_safetensors_save_and_load_pt_to_flax(self):
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True) model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
...@@ -218,21 +227,114 @@ class FlaxModelUtilsTest(unittest.TestCase): ...@@ -218,21 +227,114 @@ class FlaxModelUtilsTest(unittest.TestCase):
@require_safetensors @require_safetensors
def test_safetensors_load_from_hub(self): def test_safetensors_load_from_hub(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub
"""
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
# Can load from the Flax-formatted checkpoint # Can load from the Flax-formatted checkpoint
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-only") safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-only")
self.assertTrue(check_models_equal(flax_model, safetensors_model)) self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors
def test_safetensors_load_from_local(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-flax-only", cache_dir=tmp)
flax_model = FlaxBertModel.from_pretrained(location)
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-flax-safetensors-only", cache_dir=tmp)
safetensors_model = FlaxBertModel.from_pretrained(location)
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_torch @require_torch
@require_safetensors @require_safetensors
def test_safetensors_load_from_hub_flax_and_pt(self): @is_pt_flax_cross_test
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") def test_safetensors_load_from_hub_from_safetensors_pt(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format.
"""
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-msgpack")
# Can load from the PyTorch-formatted checkpoint
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_torch
@require_safetensors
@is_pt_flax_cross_test
def test_safetensors_load_from_local_from_safetensors_pt(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format.
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-msgpack", cache_dir=tmp)
flax_model = FlaxBertModel.from_pretrained(location)
# Can load from the PyTorch-formatted checkpoint # Can load from the PyTorch-formatted checkpoint
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only", from_pt=True) with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
safetensors_model = FlaxBertModel.from_pretrained(location)
self.assertTrue(check_models_equal(flax_model, safetensors_model)) self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors
def test_safetensors_load_from_hub_from_safetensors_pt_without_torch_installed(self):
"""
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
saved in the "pt" format if torch isn't installed.
"""
if is_torch_available():
# This test verifies that a correct error message is shown when loading from a pt safetensors
# PyTorch shouldn't be installed for this to work correctly.
return
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
with self.assertRaises(ModuleNotFoundError):
_ = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
@require_safetensors
def test_safetensors_load_from_local_from_safetensors_pt_without_torch_installed(self):
"""
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
saved in the "pt" format if torch isn't installed.
"""
if is_torch_available():
# This test verifies that a correct error message is shown when loading from a pt safetensors
# PyTorch shouldn't be installed for this to work correctly.
return
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
with self.assertRaises(ModuleNotFoundError):
_ = FlaxBertModel.from_pretrained(location)
@require_safetensors
def test_safetensors_load_from_hub_msgpack_before_safetensors(self):
"""
This test checks that we'll first download msgpack weights before safetensors
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
"""
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-msgpack")
@require_safetensors
def test_safetensors_load_from_local_msgpack_before_safetensors(self):
"""
This test checks that we'll first download msgpack weights before safetensors
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors-msgpack", cache_dir=tmp)
FlaxBertModel.from_pretrained(location)
@require_safetensors @require_safetensors
def test_safetensors_flax_from_flax(self): def test_safetensors_flax_from_flax(self):
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
......
...@@ -535,6 +535,71 @@ class TFModelUtilsTest(unittest.TestCase): ...@@ -535,6 +535,71 @@ class TFModelUtilsTest(unittest.TestCase):
# This should discard the safetensors weights in favor of the .h5 sharded weights # This should discard the safetensors weights in favor of the .h5 sharded weights
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded") TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded")
@require_safetensors
def test_safetensors_load_from_local(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-tf-only", cache_dir=tmp)
tf_model = TFBertModel.from_pretrained(location)
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-tf-safetensors-only", cache_dir=tmp)
safetensors_model = TFBertModel.from_pretrained(location)
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_load_from_hub_from_safetensors_pt(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format.
"""
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-h5")
# Can load from the PyTorch-formatted checkpoint
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_load_from_local_from_safetensors_pt(self):
"""
This test checks that we can load safetensors from a local checkpoint that only has those
saved in the "pt" format.
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-h5", cache_dir=tmp)
tf_model = TFBertModel.from_pretrained(location)
# Can load from the PyTorch-formatted checkpoint
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
safetensors_model = TFBertModel.from_pretrained(location)
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_load_from_hub_h5_before_safetensors(self):
"""
This test checks that we'll first download h5 weights before safetensors
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
"""
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-msgpack")
@require_safetensors
def test_safetensors_load_from_local_h5_before_safetensors(self):
"""
This test checks that we'll first download h5 weights before safetensors
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors-msgpack", cache_dir=tmp)
TFBertModel.from_pretrained(location)
@require_tf @require_tf
@is_staging_test @is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment