Unverified Commit 5ba8ac54 authored by Baole Ai's avatar Baole Ai Committed by GitHub
Browse files

Fix _save_tpu: use _maybe_convert_to_cpu instead of to cpu. (#31264)

* Fix _save_tpu: use _maybe_convert_to_cpu instead of to cpu.

* fix lint
parent 14ff5dd9
...@@ -3407,8 +3407,6 @@ class Trainer: ...@@ -3407,8 +3407,6 @@ class Trainer:
logger.info(f"Saving model checkpoint to {output_dir}") logger.info(f"Saving model checkpoint to {output_dir}")
model = self.model model = self.model
xm.mark_step() xm.mark_step()
if self.args.save_safetensors:
model.to("cpu")
if xm.is_master_ordinal(): if xm.is_master_ordinal():
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
...@@ -3423,13 +3421,13 @@ class Trainer: ...@@ -3423,13 +3421,13 @@ class Trainer:
self.accelerator.unwrap_model(model).save_pretrained( self.accelerator.unwrap_model(model).save_pretrained(
output_dir, output_dir,
is_main_process=self.args.should_save, is_main_process=self.args.should_save,
state_dict=model.state_dict(), state_dict=xm._maybe_convert_to_cpu(model.state_dict()),
save_function=xm.save, save_function=xm.save,
safe_serialization=self.args.save_safetensors, safe_serialization=self.args.save_safetensors,
) )
else: else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = model.state_dict() state_dict = xm._maybe_convert_to_cpu(model.state_dict())
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else: else:
model.save_pretrained( model.save_pretrained(
...@@ -3437,15 +3435,11 @@ class Trainer: ...@@ -3437,15 +3435,11 @@ class Trainer:
is_main_process=self.args.should_save, is_main_process=self.args.should_save,
save_function=xm.save, save_function=xm.save,
safe_serialization=self.args.save_safetensors, safe_serialization=self.args.save_safetensors,
state_dict=xm._maybe_convert_to_cpu(model.state_dict()),
) )
if self.tokenizer is not None and self.args.should_save: if self.tokenizer is not None and self.args.should_save:
self.tokenizer.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir)
# We moved the model from TPU -> CPU for saving the weights.
# Now we should move it back to subsequent compute still works.
if self.args.save_safetensors:
model.to(self.args.device)
def _save(self, output_dir: Optional[str] = None, state_dict=None): def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that. # If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir output_dir = output_dir if output_dir is not None else self.args.output_dir
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment