Unverified Commit 2eb7bb15 authored by Rahul Huilgol's avatar Rahul Huilgol Committed by GitHub
Browse files

Updates in Trainer to support new features in SM Model Parallel library (#15877)



* Create optimizer after model creation for SMP

* update dp_rank to rdp_rank for opt_state_dict

* update world_size and process_index for smp

* Address comments

* Lint fix
Co-authored-by: default avatarCavdar <dcavdar@a07817b12d7e.ant.amazon.com>
parent 05c237ea
...@@ -1233,7 +1233,9 @@ class Trainer: ...@@ -1233,7 +1233,9 @@ class Trainer:
else: else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa debug_overflow = DebugUnderflowOverflow(self.model) # noqa
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE delay_optimizer_creation = (
self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE or is_sagemaker_mp_enabled()
)
if args.deepspeed: if args.deepspeed:
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
...@@ -1690,8 +1692,8 @@ class Trainer: ...@@ -1690,8 +1692,8 @@ class Trainer:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
if smp.dp_rank() == 0: if smp.rdp_rank() == 0:
# Consolidate the state dict on all processed of dp_rank 0 # Consolidate the state dict on all processed of rdp_rank 0
opt_state_dict = self.optimizer.state_dict() opt_state_dict = self.optimizer.state_dict()
# Save it and the scheduler on the main process # Save it and the scheduler on the main process
if self.args.should_save: if self.args.should_save:
......
...@@ -1126,7 +1126,7 @@ class TrainingArguments: ...@@ -1126,7 +1126,7 @@ class TrainingArguments:
if is_torch_tpu_available(): if is_torch_tpu_available():
return xm.xrt_world_size() return xm.xrt_world_size()
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
return smp.dp_size() return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size()
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled():
return sm_dist.get_world_size() return sm_dist.get_world_size()
elif self.local_rank != -1: elif self.local_rank != -1:
...@@ -1142,7 +1142,7 @@ class TrainingArguments: ...@@ -1142,7 +1142,7 @@ class TrainingArguments:
if is_torch_tpu_available(): if is_torch_tpu_available():
return xm.get_ordinal() return xm.get_ordinal()
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
return smp.dp_rank() return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank()
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled():
return sm_dist.get_rank() return sm_dist.get_rank()
elif self.local_rank != -1: elif self.local_rank != -1:
...@@ -1244,12 +1244,14 @@ class TrainingArguments: ...@@ -1244,12 +1244,14 @@ class TrainingArguments:
""" """
if is_torch_available() and self.world_size > 1: if is_torch_available() and self.world_size > 1:
main_process_desc = "main process"
if local: if local:
is_main_process = self.local_process_index == 0 is_main_process = self.local_process_index == 0
main_process_desc = "main local process" main_process_desc = "main local process"
elif is_sagemaker_mp_enabled():
is_main_process = smp.rank() == 0
else: else:
is_main_process = self.process_index == 0 is_main_process = self.process_index == 0
main_process_desc = "main process"
try: try:
if not is_main_process: if not is_main_process:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment