Unverified Commit 0549000c authored by jeffhataws's avatar jeffhataws Committed by GitHub
Browse files

Use save_safetensor to disable safe serialization for XLA (#28669)

* Use save_safetensor to disable safe serialization for XLA

https://github.com/huggingface/transformers/issues/28438

* Style fixup
parent c5c69096
...@@ -2910,13 +2910,19 @@ class Trainer: ...@@ -2910,13 +2910,19 @@ class Trainer:
is_main_process=self.args.should_save, is_main_process=self.args.should_save,
state_dict=model.state_dict(), state_dict=model.state_dict(),
save_function=xm.save, save_function=xm.save,
safe_serialization=self.args.save_safetensors,
) )
else: else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = model.state_dict() state_dict = model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else: else:
model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save) model.save_pretrained(
output_dir,
is_main_process=self.args.should_save,
save_function=xm.save,
safe_serialization=self.args.save_safetensors,
)
if self.tokenizer is not None and self.args.should_save: if self.tokenizer is not None and self.args.should_save:
self.tokenizer.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment