Unverified Commit 7fc686ef authored by Mansi Mane's avatar Mansi Mane Committed by GitHub
Browse files

Sagemaker Model Parallel tensoboard writing fix (#10403)

* Added tb fix

* Removed local rank condition

* Updated reference to args
parent 83d2d55c
...@@ -71,11 +71,21 @@ if is_smdistributed_available(): ...@@ -71,11 +71,21 @@ if is_smdistributed_available():
class SageMakerTrainer(Trainer): class SageMakerTrainer(Trainer):
def __init__(self, args=None, **kwargs): def __init__(self, args=None, **kwargs):
self.is_model_parallel_enabled = is_smdistributed_available() and args.mp_parameters != ""
super().__init__(args=args, **kwargs) super().__init__(args=args, **kwargs)
self.is_model_parallel_enabled = is_smdistributed_available() and self.args.mp_parameters != ""
if self.is_model_parallel_enabled and self.args.gradient_accumulation_steps != 1: if self.is_model_parallel_enabled and self.args.gradient_accumulation_steps != 1:
raise ValueError("Gradient accumulation is not supported when model parallel is enabled.") raise ValueError("Gradient accumulation is not supported when model parallel is enabled.")
def is_world_process_zero(self) -> bool:
"""
Whether or not this process is the global main process (when training in a distributed fashion on several
machines, this is only going to be :obj:`True` for one process).
"""
if self.is_model_parallel_enabled:
return smp.rank() == 0 and smp.local_rank() == 0 and smp.mp_rank() == 0 and smp.dp_rank() == 0
else:
return super.is_world_process_zero()
def _get_train_sampler(self): def _get_train_sampler(self):
if self.is_model_parallel_enabled: if self.is_model_parallel_enabled:
if self.args.group_by_length: if self.args.group_by_length:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment