Unverified Commit da857beb authored by hlky's avatar hlky Committed by GitHub
Browse files

Revert `save_model` in ModelMixin save_pretrained and use safe_serialization=False in test (#11196)

parent 52b460fe
......@@ -714,10 +714,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
try:
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
except RuntimeError:
safetensors.torch.save_model(model_to_save, filepath, metadata={"format": "pt"})
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
torch.save(shard, filepath)
......
......@@ -2293,7 +2293,7 @@ class PipelineTesterMixin:
specified_key = next(iter(components.keys()))
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe.save_pretrained(tmpdirname, safe_serialization=False)
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
loaded_pipe = self.pipeline_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment