Unverified Commit f780557a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Safetensors] Add explicit flag to from pretrained (#22083)



* [Safetensors] Add explicit  flag to from pretrained

* add test

* remove @

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 3a35937e
......@@ -2086,6 +2086,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
if trust_remote_code is True:
logger.warning(
......@@ -2222,14 +2223,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
):
# Load from a Flax checkpoint in priority if from_flax
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
elif is_safetensors_available() and os.path.isfile(
elif use_safetensors is not False and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant))
):
# Load from a safetensors checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
)
elif is_safetensors_available() and os.path.isfile(
elif use_safetensors is not False and os.path.isfile(
os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
)
......@@ -2295,7 +2296,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
filename = TF2_WEIGHTS_NAME
elif from_flax:
filename = FLAX_WEIGHTS_NAME
elif is_safetensors_available():
elif use_safetensors is not False:
filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
else:
filename = _add_variant(WEIGHTS_NAME, variant)
......@@ -2328,6 +2329,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
if resolved_archive_file is not None:
is_sharded = True
elif use_safetensors:
raise EnvironmentError(
f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} and thus cannot be loaded with `safetensors`. Please make sure that the model has been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
)
else:
# This repo has no safetensors file of any kind, we switch to PyTorch.
filename = _add_variant(WEIGHTS_NAME, variant)
......
......@@ -15,6 +15,7 @@
import copy
import gc
import glob
import inspect
import json
import os
......@@ -119,6 +120,7 @@ if is_torch_available():
AutoTokenizer,
BertConfig,
BertModel,
CLIPTextModel,
PreTrainedModel,
T5Config,
T5ForConditionalGeneration,
......@@ -3327,6 +3329,49 @@ class ModelUtilsTest(TestCasePlus):
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", config=config
)
@require_safetensors
def test_use_safetensors(self):
# test nice error message if no safetensor files available
with self.assertRaises(OSError) as env_error:
AutoModel.from_pretrained("hf-internal-testing/tiny-random-RobertaModel", use_safetensors=True)
self.assertTrue(
"model.safetensors or model.safetensors.index.json and thus cannot be loaded with `safetensors`"
in str(env_error.exception)
)
# test that error if only safetensors is available
with self.assertRaises(OSError) as env_error:
BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors", use_safetensors=False)
self.assertTrue("does not appear to have a file named pytorch_model.bin" in str(env_error.exception))
# test that only safetensors if both available and use_safetensors=False
with tempfile.TemporaryDirectory() as tmp_dir:
CLIPTextModel.from_pretrained(
"hf-internal-testing/diffusers-stable-diffusion-tiny-all",
subfolder="text_encoder",
use_safetensors=False,
cache_dir=tmp_dir,
)
all_downloaded_files = glob.glob(os.path.join(tmp_dir, "*", "snapshots", "*", "*", "*"))
self.assertTrue(any(f.endswith("bin") for f in all_downloaded_files))
self.assertFalse(any(f.endswith("safetensors") for f in all_downloaded_files))
# test that no safetensors if both available and use_safetensors=True
with tempfile.TemporaryDirectory() as tmp_dir:
CLIPTextModel.from_pretrained(
"hf-internal-testing/diffusers-stable-diffusion-tiny-all",
subfolder="text_encoder",
use_safetensors=True,
cache_dir=tmp_dir,
)
all_downloaded_files = glob.glob(os.path.join(tmp_dir, "*", "snapshots", "*", "*", "*"))
self.assertTrue(any(f.endswith("safetensors") for f in all_downloaded_files))
self.assertFalse(any(f.endswith("bin") for f in all_downloaded_files))
@require_safetensors
def test_safetensors_save_and_load(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment