"tests/utils/test_modeling_tf_utils.py" did not exist on "3060899be51fe1a96b12de97376f2e2b8315bc4c"
Unverified Commit cf7bed98 authored by David Xue's avatar David Xue Committed by GitHub
Browse files

Add safetensors to model not found error msg for default use_safetensors value (#30602)

* add safetensors to model not found error for default use_safetensors=None case

* format code w/ ruff

* fix assert true typo
parent 884e3b1c
......@@ -3270,8 +3270,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
else:
raise EnvironmentError(
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME},"
f" {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory"
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)},"
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory"
f" {pretrained_model_name_or_path}."
)
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
......@@ -3417,8 +3417,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
f" {FLAX_WEIGHTS_NAME}."
f" {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)},"
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
......
......@@ -1001,6 +1001,26 @@ class ModelUtilsTest(TestCasePlus):
self.assertTrue(any(f.endswith("safetensors") for f in all_downloaded_files))
self.assertFalse(any(f.endswith("bin") for f in all_downloaded_files))
# test no model file found when use_safetensors=None (default when safetensors package available)
with self.assertRaises(OSError) as missing_model_file_error:
BertModel.from_pretrained("hf-internal-testing/config-no-model")
self.assertTrue(
"does not appear to have a file named pytorch_model.bin, model.safetensors,"
in str(missing_model_file_error.exception)
)
with self.assertRaises(OSError) as missing_model_file_error:
with tempfile.TemporaryDirectory() as tmp_dir:
with open(os.path.join(tmp_dir, "config.json"), "w") as f:
f.write("{}")
f.close()
BertModel.from_pretrained(tmp_dir)
self.assertTrue(
"Error no file named pytorch_model.bin, model.safetensors" in str(missing_model_file_error.exception)
)
@require_safetensors
def test_safetensors_save_and_load(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment