"examples/research_projects/bertabs/modeling_bertabs.py" did not exist on "c203509d5b0f002a7833382a03ffe7802aa14e91"
Unverified Commit 817a676b authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Don't default to other weights file when use_safetensors=True (#31874)

* Don't default to other weights file when use_safetensors=True

* Add tests

* Update tests/utils/test_modeling_utils.py

* Add clarifying comments to tests

* Update tests/utils/test_modeling_utils.py

* Update tests/utils/test_modeling_utils.py
parent 74d0eb3f
......@@ -3395,14 +3395,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
)
is_sharded = True
elif os.path.isfile(
elif not use_safetensors and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
):
# Load from a PyTorch checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
)
elif os.path.isfile(
elif not use_safetensors and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
):
# Load from a sharded PyTorch checkpoint
......@@ -3411,15 +3411,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
is_sharded = True
# At this stage we don't have a weight file so we will raise an error.
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)):
elif not use_safetensors and (
os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index"))
or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME))
):
raise EnvironmentError(
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use"
" `from_tf=True` to load this model from those weights."
)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
elif not use_safetensors and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
):
raise EnvironmentError(
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`"
......
......@@ -815,6 +815,72 @@ class ModelUtilsTest(TestCasePlus):
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
def test_checkpoint_loading_only_safetensors_available(self):
# Test that the loading behaviour is as expected when only safetensor checkpoints are available
# - We can load the model with use_safetensors=True
# - We can load the model without specifying use_safetensors i.e. we search for the available checkpoint,
# preferring safetensors
# - We cannot load the model with use_safetensors=False
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, max_shard_size="50kB", safe_serialization=True)
weights_index_name = ".".join(SAFE_WEIGHTS_INDEX_NAME.split(".")[:-1] + ["json"])
weights_index_file = os.path.join(tmp_dir, weights_index_name)
self.assertTrue(os.path.isfile(weights_index_file))
for i in range(1, 5):
weights_name = f"model-0000{i}-of-00005" + ".safetensors"
weights_name_file = os.path.join(tmp_dir, weights_name)
self.assertTrue(os.path.isfile(weights_name_file))
# Setting use_safetensors=False should raise an error as the checkpoint was saved with safetensors=True
with self.assertRaises(OSError):
_ = BertModel.from_pretrained(tmp_dir, use_safetensors=False)
# We can load the model with use_safetensors=True
new_model = BertModel.from_pretrained(tmp_dir, use_safetensors=True)
# We can load the model without specifying use_safetensors
new_model = BertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
def test_checkpoint_loading_only_pytorch_bin_available(self):
# Test that the loading behaviour is as expected when only pytorch checkpoints are available
# - We can load the model with use_safetensors=False
# - We can load the model without specifying use_safetensors i.e. we search for the available checkpoint,
# preferring safetensors but falling back to pytorch
# - We cannot load the model with use_safetensors=True
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, max_shard_size="50kB", safe_serialization=False)
weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["json"])
weights_index_file = os.path.join(tmp_dir, weights_index_name)
self.assertTrue(os.path.isfile(weights_index_file))
for i in range(1, 5):
weights_name = WEIGHTS_NAME.split(".")[0].split("_")[0] + f"_model-0000{i}-of-00005" + ".bin"
weights_name_file = os.path.join(tmp_dir, weights_name)
self.assertTrue(os.path.isfile(weights_name_file))
# Setting use_safetensors=True should raise an error as the checkpoint was saved with safetensors=False
with self.assertRaises(OSError):
_ = BertModel.from_pretrained(tmp_dir, use_safetensors=True)
# We can load the model with use_safetensors=False
new_model = BertModel.from_pretrained(tmp_dir, use_safetensors=False)
# We can load the model without specifying use_safetensors
new_model = BertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
def test_checkpoint_variant_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaises(EnvironmentError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment