Unverified Commit 2eb7bb15 authored by Rahul Huilgol's avatar Rahul Huilgol Committed by GitHub
Browse files

Updates in Trainer to support new features in SM Model Parallel library (#15877)



* Create optimizer after model creation for SMP

* update dp_rank to rdp_rank for opt_state_dict

* update world_size and process_index for smp

* Address comments

* Lint fix
Co-authored-by: default avatarCavdar <dcavdar@a07817b12d7e.ant.amazon.com>
parent 05c237ea
......@@ -1233,7 +1233,9 @@ class Trainer:
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
delay_optimizer_creation = (
self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE or is_sagemaker_mp_enabled()
)
if args.deepspeed:
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
......@@ -1690,8 +1692,8 @@ class Trainer:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled():
if smp.dp_rank() == 0:
# Consolidate the state dict on all processed of dp_rank 0
if smp.rdp_rank() == 0:
# Consolidate the state dict on all processed of rdp_rank 0
opt_state_dict = self.optimizer.state_dict()
# Save it and the scheduler on the main process
if self.args.should_save:
......
......@@ -1126,7 +1126,7 @@ class TrainingArguments:
if is_torch_tpu_available():
return xm.xrt_world_size()
elif is_sagemaker_mp_enabled():
return smp.dp_size()
return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size()
elif is_sagemaker_dp_enabled():
return sm_dist.get_world_size()
elif self.local_rank != -1:
......@@ -1142,7 +1142,7 @@ class TrainingArguments:
if is_torch_tpu_available():
return xm.get_ordinal()
elif is_sagemaker_mp_enabled():
return smp.dp_rank()
return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank()
elif is_sagemaker_dp_enabled():
return sm_dist.get_rank()
elif self.local_rank != -1:
......@@ -1244,12 +1244,14 @@ class TrainingArguments:
"""
if is_torch_available() and self.world_size > 1:
main_process_desc = "main process"
if local:
is_main_process = self.local_process_index == 0
main_process_desc = "main local process"
elif is_sagemaker_mp_enabled():
is_main_process = smp.rank() == 0
else:
is_main_process = self.process_index == 0
main_process_desc = "main process"
try:
if not is_main_process:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment