Unverified Commit 2b5d5ead authored by Yeounoh Chung's avatar Yeounoh Chung Committed by GitHub
Browse files

[Hot-Fix][XLA] Re-enable broken _tpu_save for XLATensors (#27799)

* [XLA] Re-enable broken _tpu_save for XLATensors, by explicitly moving to cpu

* linter-fix
parent 1da1302e
......@@ -2834,7 +2834,7 @@ class Trainer:
xm.rendezvous("saving_checkpoint")
if not isinstance(self.model, PreTrainedModel):
if isinstance(unwrap_model(self.model), PreTrainedModel):
unwrap_model(self.model).save_pretrained(
unwrap_model(self.model).to("cpu").save_pretrained(
output_dir,
is_main_process=self.args.should_save,
state_dict=self.model.state_dict(),
......@@ -2842,10 +2842,12 @@ class Trainer:
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
state_dict = self.model.state_dict().to("cpu")
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
self.model.to("cpu").save_pretrained(
output_dir, is_main_process=self.args.should_save, save_function=xm.save
)
if self.tokenizer is not None and self.args.should_save:
self.tokenizer.save_pretrained(output_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment