Unverified Commit b20147a3 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Variant] Make sure variant files are not incorrectly deleted (#21562)

* [Variant] Make sure variant files are not incorrectly deleted

* Apply suggestions from code review

* fix
parent 51c3f42d
......@@ -1718,11 +1718,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
# in distributed settings to avoid race conditions.
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
reg = re.compile("(.*?)-\d{5}-of-\d{5}")
if (
filename.startswith(weights_no_suffix)
and os.path.isfile(full_filename)
and filename not in shards.keys()
and is_main_process
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename)
......
......@@ -3119,6 +3119,27 @@ class ModelUtilsTest(TestCasePlus):
)
self.assertIsNotNone(model)
def test_checkpoint_variant_save_load(self):
with tempfile.TemporaryDirectory() as tmp_dir:
model = BertModel.from_pretrained(
"hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2"
)
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
model.save_pretrained(tmp_dir, variant="v2")
# saving will create a variant checkpoint
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
model.save_pretrained(tmp_dir)
# saving shouldn't delete variant checkpoints
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
# there should be a normal checkpoint
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
self.assertIsNotNone(model)
@require_accelerate
def test_from_pretrained_low_cpu_mem_usage_functional(self):
# test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment