Unverified Commit 31a81109 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add option to save on each training node (#12421)



* Add option to save on each training node

* Apply suggestions from code review
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Address review comments
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
parent 990540b7
......@@ -393,7 +393,7 @@ class Trainer:
# Create clone of distant repo and output directory if needed
if self.args.push_to_hub:
self.init_git_repo()
if self.is_world_process_zero():
if self.args.should_save:
os.makedirs(self.args.output_dir, exist_ok=True)
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
......@@ -899,7 +899,7 @@ class Trainer:
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir)
if self.is_world_process_zero():
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
......@@ -1357,10 +1357,18 @@ class Trainer:
logger.info(
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
)
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
if os.path.exists(best_model_path):
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)
else:
logger.warn(
f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
"on multiple nodes, you should activate `--save_on_each_node`."
)
if self.deepspeed:
self.deepspeed.load_checkpoint(
......@@ -1500,14 +1508,14 @@ class Trainer:
# Consolidate the state dict on all processed of dp_rank 0
opt_state_dict = self.optimizer.state_dict()
# Save it and the scheduler on the main process
if self.is_world_process_zero():
if self.args.should_save:
torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
if self.use_amp:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt"))
elif self.is_world_process_zero() and not self.deepspeed:
elif self.args.should_save and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
......@@ -1533,7 +1541,7 @@ class Trainer:
self.state.best_model_checkpoint = output_dir
# Save the Trainer state
if self.is_world_process_zero():
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
# Save RNG state in non-distributed training
......@@ -1562,7 +1570,7 @@ class Trainer:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))
# Maybe delete some older checkpoints.
if self.is_world_process_zero():
if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
def _load_optimizer_and_scheduler(self, checkpoint):
......@@ -1831,19 +1839,19 @@ class Trainer:
elif is_sagemaker_mp_enabled():
# Calling the state_dict needs to be done on the wrapped model and on all processes.
state_dict = self.model_wrapped.state_dict()
if self.is_world_process_zero():
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
elif (
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
):
state_dict = self.model.state_dict()
if self.is_world_process_zero():
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
elif self.deepspeed:
# this takes care of everything as long as we aren't under zero3
if self.is_world_process_zero():
if self.args.should_save:
self._save(output_dir)
if is_deepspeed_zero3_enabled():
......@@ -1851,7 +1859,7 @@ class Trainer:
# saved, so since under zero3 the file is bogus, simply delete it. The user should
# either user deepspeed checkpoint to resume or to recover full weights use
# zero_to_fp32.py stored in the checkpoint.
if self.is_world_process_zero():
if self.args.should_save:
file = os.path.join(output_dir, WEIGHTS_NAME)
if os.path.isfile(file):
# logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
......@@ -1862,7 +1870,7 @@ class Trainer:
# This must be called on all ranks
self.deepspeed.save_fp16_model(output_dir, WEIGHTS_NAME)
elif self.is_world_process_zero():
elif self.args.should_save:
self._save(output_dir)
def _save_tpu(self, output_dir: Optional[str] = None):
......@@ -1880,7 +1888,7 @@ class Trainer:
if isinstance(unwrap_model(self.model), PreTrainedModel):
unwrap_model(self.model).save_pretrained(
output_dir,
save_config=self.is_world_process_zero(),
save_config=self.args.should_save,
state_dict=self.model.state_dict(),
save_function=xm.save,
)
......@@ -1889,8 +1897,8 @@ class Trainer:
state_dict = self.model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir, save_config=self.is_world_process_zero(), save_function=xm.save)
if self.tokenizer is not None and self.is_world_process_zero():
self.model.save_pretrained(output_dir, save_config=self.args.should_save, save_function=xm.save)
if self.tokenizer is not None and self.args.should_save:
self.tokenizer.save_pretrained(output_dir)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
......@@ -1960,7 +1968,7 @@ class Trainer:
if len(checkpoints_sorted) <= self.args.save_total_limit:
return
# If save_total_limit=1 with load_best_mode_at_end=True, we could end up deleting the last checkpoint, which
# If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
# we don't do to allow resuming.
save_total_limit = self.args.save_total_limit
if (
......@@ -2436,7 +2444,7 @@ class Trainer:
"""
Initializes a git repo in :obj:`self.args.push_to_hub_model_id`.
"""
if not self.is_world_process_zero():
if not self.args.should_save:
return
use_auth_token = True if self.args.push_to_hub_token is None else self.args.push_to_hub_token
repo_url = PushToHubMixin._get_repo_url_from_name(
......@@ -2494,11 +2502,16 @@ class Trainer:
Returns:
The url of the commit of your model in the given repository.
"""
if not self.is_world_process_zero():
if not self.args.should_save:
return
self.create_model_card(model_name=self.args.push_to_hub_model_id, **kwargs)
self.save_model()
# Only push from one node.
if not self.is_world_process_zero():
return
return self.repo.push_to_hub(commit_message=commit_message)
#
......
......@@ -183,6 +183,12 @@ class TrainingArguments:
save_total_limit (:obj:`int`, `optional`):
If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
:obj:`output_dir`.
save_on_each_node (:obj:`bool`, `optional`, defaults to :obj:`False`):
When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on
the main one.
This should not be activated when the different nodes use the same storage as the files will be saved with
the same names for each node.
no_cuda (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to not use CUDA even when it is available or not.
seed (:obj:`int`, `optional`, defaults to 42):
......@@ -456,6 +462,12 @@ class TrainingArguments:
)
},
)
save_on_each_node: bool = field(
default=False,
metadata={
"help": "When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on the main one"
},
)
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
......@@ -937,6 +949,19 @@ class TrainingArguments:
else:
return self.process_index == 0
@property
def should_save(self):
"""
Whether or not the current process should write to disk, e.g., to save models and checkpoints.
"""
if self.save_on_each_node:
return self.local_process_index == 0
else:
if is_sagemaker_mp_enabled():
return smp.rank() == 0
else:
return self.process_index == 0
def get_process_log_level(self):
"""
Returns the log level to be used depending on whether this process is the main process of node 0, main process
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment