"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1144d336b689d1710534b245697e41be7a168075"
Unverified Commit 2eb7bb15 authored by Rahul Huilgol's avatar Rahul Huilgol Committed by GitHub
Browse files

Updates in Trainer to support new features in SM Model Parallel library (#15877)



* Create optimizer after model creation for SMP

* update dp_rank to rdp_rank for opt_state_dict

* update world_size and process_index for smp

* Address comments

* Lint fix
Co-authored-by: default avatarCavdar <dcavdar@a07817b12d7e.ant.amazon.com>
parent 05c237ea
......@@ -1233,7 +1233,9 @@ class Trainer:
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
delay_optimizer_creation = (
self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE or is_sagemaker_mp_enabled()
)
if args.deepspeed:
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
......@@ -1690,8 +1692,8 @@ class Trainer:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled():
if smp.dp_rank() == 0:
# Consolidate the state dict on all processed of dp_rank 0
if smp.rdp_rank() == 0:
# Consolidate the state dict on all processed of rdp_rank 0
opt_state_dict = self.optimizer.state_dict()
# Save it and the scheduler on the main process
if self.args.should_save:
......
......@@ -1126,7 +1126,7 @@ class TrainingArguments:
if is_torch_tpu_available():
return xm.xrt_world_size()
elif is_sagemaker_mp_enabled():
return smp.dp_size()
return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size()
elif is_sagemaker_dp_enabled():
return sm_dist.get_world_size()
elif self.local_rank != -1:
......@@ -1142,7 +1142,7 @@ class TrainingArguments:
if is_torch_tpu_available():
return xm.get_ordinal()
elif is_sagemaker_mp_enabled():
return smp.dp_rank()
return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank()
elif is_sagemaker_dp_enabled():
return sm_dist.get_rank()
elif self.local_rank != -1:
......@@ -1244,12 +1244,14 @@ class TrainingArguments:
"""
if is_torch_available() and self.world_size > 1:
main_process_desc = "main process"
if local:
is_main_process = self.local_process_index == 0
main_process_desc = "main local process"
elif is_sagemaker_mp_enabled():
is_main_process = smp.rank() == 0
else:
is_main_process = self.process_index == 0
main_process_desc = "main process"
try:
if not is_main_process:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment