Unverified Commit 95037a16 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

[Trainer] Add a progress bar for batches skipped (#11324)

parent 95ffbe16
...@@ -29,6 +29,8 @@ from logging import StreamHandler ...@@ -29,6 +29,8 @@ from logging import StreamHandler
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from tqdm.auto import tqdm
# Integrations must be imported before ML frameworks: # Integrations must be imported before ML frameworks:
from .integrations import ( # isort: split from .integrations import ( # isort: split
...@@ -1097,6 +1099,7 @@ class Trainer: ...@@ -1097,6 +1099,7 @@ class Trainer:
start_time = time.time() start_time = time.time()
epochs_trained = 0 epochs_trained = 0
steps_trained_in_current_epoch = 0 steps_trained_in_current_epoch = 0
steps_trained_progress_bar = None
# Check if continuing training from a checkpoint # Check if continuing training from a checkpoint
if resume_from_checkpoint is not None and os.path.isfile( if resume_from_checkpoint is not None and os.path.isfile(
...@@ -1116,8 +1119,12 @@ class Trainer: ...@@ -1116,8 +1119,12 @@ class Trainer:
if not args.ignore_data_skip: if not args.ignore_data_skip:
logger.info( logger.info(
f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
"batches in the first epoch." "batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` "
"flag to your launch command, but you will resume the training on data already seen by your model."
) )
if self.is_local_process_zero() and not args.disable_tqdm:
steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
steps_trained_progress_bar.set_description("Skipping the first batches")
# Update the references # Update the references
self.callback_handler.model = self.model self.callback_handler.model = self.model
...@@ -1176,7 +1183,12 @@ class Trainer: ...@@ -1176,7 +1183,12 @@ class Trainer:
# Skip past any already trained steps if resuming training # Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0: if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1 steps_trained_in_current_epoch -= 1
if steps_trained_progress_bar is not None:
steps_trained_progress_bar.update(1)
continue continue
elif steps_trained_progress_bar is not None:
steps_trained_progress_bar.close()
steps_trained_progress_bar = None
if step % args.gradient_accumulation_steps == 0: if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control) self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment