"test/vscode:/vscode.git/clone" did not exist on "95479e6767b3d254fbc9b030bae28d4863ea48eb"
Unverified Commit 29a2b142 authored by Siddartha Naidu's avatar Siddartha Naidu Committed by GitHub
Browse files

Change progress logging to once across all nodes (#28373)

parent 2382706a
...@@ -489,17 +489,17 @@ class ProgressCallback(TrainerCallback): ...@@ -489,17 +489,17 @@ class ProgressCallback(TrainerCallback):
self.prediction_bar = None self.prediction_bar = None
def on_train_begin(self, args, state, control, **kwargs): def on_train_begin(self, args, state, control, **kwargs):
if state.is_local_process_zero: if state.is_world_process_zero:
self.training_bar = tqdm(total=state.max_steps, dynamic_ncols=True) self.training_bar = tqdm(total=state.max_steps, dynamic_ncols=True)
self.current_step = 0 self.current_step = 0
def on_step_end(self, args, state, control, **kwargs): def on_step_end(self, args, state, control, **kwargs):
if state.is_local_process_zero: if state.is_world_process_zero:
self.training_bar.update(state.global_step - self.current_step) self.training_bar.update(state.global_step - self.current_step)
self.current_step = state.global_step self.current_step = state.global_step
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if state.is_local_process_zero and has_length(eval_dataloader): if state.is_world_process_zero and has_length(eval_dataloader):
if self.prediction_bar is None: if self.prediction_bar is None:
self.prediction_bar = tqdm( self.prediction_bar = tqdm(
total=len(eval_dataloader), leave=self.training_bar is None, dynamic_ncols=True total=len(eval_dataloader), leave=self.training_bar is None, dynamic_ncols=True
...@@ -507,24 +507,24 @@ class ProgressCallback(TrainerCallback): ...@@ -507,24 +507,24 @@ class ProgressCallback(TrainerCallback):
self.prediction_bar.update(1) self.prediction_bar.update(1)
def on_evaluate(self, args, state, control, **kwargs): def on_evaluate(self, args, state, control, **kwargs):
if state.is_local_process_zero: if state.is_world_process_zero:
if self.prediction_bar is not None: if self.prediction_bar is not None:
self.prediction_bar.close() self.prediction_bar.close()
self.prediction_bar = None self.prediction_bar = None
def on_predict(self, args, state, control, **kwargs): def on_predict(self, args, state, control, **kwargs):
if state.is_local_process_zero: if state.is_world_process_zero:
if self.prediction_bar is not None: if self.prediction_bar is not None:
self.prediction_bar.close() self.prediction_bar.close()
self.prediction_bar = None self.prediction_bar = None
def on_log(self, args, state, control, logs=None, **kwargs): def on_log(self, args, state, control, logs=None, **kwargs):
if state.is_local_process_zero and self.training_bar is not None: if state.is_world_process_zero and self.training_bar is not None:
_ = logs.pop("total_flos", None) _ = logs.pop("total_flos", None)
self.training_bar.write(str(logs)) self.training_bar.write(str(logs))
def on_train_end(self, args, state, control, **kwargs): def on_train_end(self, args, state, control, **kwargs):
if state.is_local_process_zero: if state.is_world_process_zero:
self.training_bar.close() self.training_bar.close()
self.training_bar = None self.training_bar = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment