"vscode:/vscode.git/clone" did not exist on "576e2823a397942421e1724e79f51a12122ef49e"
Unverified Commit 9dc8fe1b authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Default to msgpack for safetensors (#27460)



* Default to msgpack for safetensors

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 210e38d8
......@@ -50,7 +50,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
"""Load pytorch checkpoints in a flax model"""
try:
import torch # noqa: F401
except ImportError:
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
......@@ -150,7 +150,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
try:
import torch # noqa: F401
except ImportError:
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
......@@ -349,7 +349,7 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
try:
import torch # noqa: F401
except ImportError:
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
......
......@@ -721,7 +721,14 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if is_safetensors_available() and os.path.isfile(
if os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)):
# Load from a sharded Flax checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
is_sharded = True
elif is_safetensors_available() and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
):
# Load from a safetensors checkpoint
......@@ -735,13 +742,6 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# Load from a sharded pytorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
is_sharded = True
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)):
# Load from a sharded Flax checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
is_sharded = True
# At this stage we don't have a weight file so we will raise an error.
elif is_safetensors_available() and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
......@@ -770,8 +770,6 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
else:
if from_pt:
filename = WEIGHTS_NAME
elif is_safetensors_available():
filename = SAFE_WEIGHTS_NAME
else:
filename = FLAX_WEIGHTS_NAME
......@@ -792,22 +790,14 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
}
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME:
# Did not find the safetensors file, let's fallback to Flax.
# No support for sharded safetensors yet, so we'll raise an error if that's all we find.
filename = FLAX_WEIGHTS_NAME
resolved_archive_file = cached_file(
pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **cached_file_kwargs
)
if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
resolved_archive_file = cached_file(
pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs
)
if resolved_archive_file is not None:
is_sharded = True
# Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case.
if resolved_archive_file is None and from_pt:
resolved_archive_file = cached_file(
......@@ -815,6 +805,17 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
)
if resolved_archive_file is not None:
is_sharded = True
# If we still haven't found anything, look for `safetensors`.
if resolved_archive_file is None:
# No support for sharded safetensors yet, so we'll raise an error if that's all we find.
filename = SAFE_WEIGHTS_NAME
resolved_archive_file = cached_file(
pretrained_model_name_or_path, SAFE_WEIGHTS_NAME, **cached_file_kwargs
)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Torch model file. We try those to give a helpful error
# message.
......
......@@ -19,8 +19,16 @@ import numpy as np
from huggingface_hub import HfFolder, delete_repo, snapshot_download
from requests.exceptions import HTTPError
from transformers import BertConfig, BertModel, is_flax_available
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_flax, require_safetensors, require_torch
from transformers import BertConfig, BertModel, is_flax_available, is_torch_available
from transformers.testing_utils import (
TOKEN,
USER,
is_pt_flax_cross_test,
is_staging_test,
require_flax,
require_safetensors,
require_torch,
)
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME
......@@ -202,6 +210,7 @@ class FlaxModelUtilsTest(unittest.TestCase):
@require_flax
@require_torch
@is_pt_flax_cross_test
def test_safetensors_save_and_load_pt_to_flax(self):
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
......@@ -218,21 +227,114 @@ class FlaxModelUtilsTest(unittest.TestCase):
@require_safetensors
def test_safetensors_load_from_hub(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub
"""
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
# Can load from the Flax-formatted checkpoint
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-only")
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors
def test_safetensors_load_from_local(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-flax-only", cache_dir=tmp)
flax_model = FlaxBertModel.from_pretrained(location)
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-flax-safetensors-only", cache_dir=tmp)
safetensors_model = FlaxBertModel.from_pretrained(location)
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_torch
@require_safetensors
def test_safetensors_load_from_hub_flax_and_pt(self):
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
@is_pt_flax_cross_test
def test_safetensors_load_from_hub_from_safetensors_pt(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format.
"""
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-msgpack")
# Can load from the PyTorch-formatted checkpoint
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_torch
@require_safetensors
@is_pt_flax_cross_test
def test_safetensors_load_from_local_from_safetensors_pt(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format.
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-msgpack", cache_dir=tmp)
flax_model = FlaxBertModel.from_pretrained(location)
# Can load from the PyTorch-formatted checkpoint
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only", from_pt=True)
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
safetensors_model = FlaxBertModel.from_pretrained(location)
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors
def test_safetensors_load_from_hub_from_safetensors_pt_without_torch_installed(self):
"""
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
saved in the "pt" format if torch isn't installed.
"""
if is_torch_available():
# This test verifies that a correct error message is shown when loading from a pt safetensors
# PyTorch shouldn't be installed for this to work correctly.
return
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
with self.assertRaises(ModuleNotFoundError):
_ = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
@require_safetensors
def test_safetensors_load_from_local_from_safetensors_pt_without_torch_installed(self):
"""
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
saved in the "pt" format if torch isn't installed.
"""
if is_torch_available():
# This test verifies that a correct error message is shown when loading from a pt safetensors
# PyTorch shouldn't be installed for this to work correctly.
return
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
with self.assertRaises(ModuleNotFoundError):
_ = FlaxBertModel.from_pretrained(location)
@require_safetensors
def test_safetensors_load_from_hub_msgpack_before_safetensors(self):
"""
This test checks that we'll first download msgpack weights before safetensors
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
"""
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-msgpack")
@require_safetensors
def test_safetensors_load_from_local_msgpack_before_safetensors(self):
"""
This test checks that we'll first download msgpack weights before safetensors
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors-msgpack", cache_dir=tmp)
FlaxBertModel.from_pretrained(location)
@require_safetensors
def test_safetensors_flax_from_flax(self):
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
......
......@@ -535,6 +535,71 @@ class TFModelUtilsTest(unittest.TestCase):
# This should discard the safetensors weights in favor of the .h5 sharded weights
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded")
@require_safetensors
def test_safetensors_load_from_local(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-tf-only", cache_dir=tmp)
tf_model = TFBertModel.from_pretrained(location)
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-tf-safetensors-only", cache_dir=tmp)
safetensors_model = TFBertModel.from_pretrained(location)
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_load_from_hub_from_safetensors_pt(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format.
"""
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-h5")
# Can load from the PyTorch-formatted checkpoint
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_load_from_local_from_safetensors_pt(self):
"""
This test checks that we can load safetensors from a local checkpoint that only has those
saved in the "pt" format.
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-h5", cache_dir=tmp)
tf_model = TFBertModel.from_pretrained(location)
# Can load from the PyTorch-formatted checkpoint
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
safetensors_model = TFBertModel.from_pretrained(location)
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_load_from_hub_h5_before_safetensors(self):
"""
This test checks that we'll first download h5 weights before safetensors
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
"""
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-msgpack")
@require_safetensors
def test_safetensors_load_from_local_h5_before_safetensors(self):
"""
This test checks that we'll first download h5 weights before safetensors
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors-msgpack", cache_dir=tmp)
TFBertModel.from_pretrained(location)
@require_tf
@is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment