Unverified Commit 8d580779 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Skip batches fast with accelerate (#21390)

* Skip batches fast with Accelerate

* remove debug statement

* Hack seed reload at the right time

* Reorganize RNG sync

* Fix accelerate version comp
parent 77db257e
...@@ -138,6 +138,7 @@ from .utils import ( ...@@ -138,6 +138,7 @@ from .utils import (
can_return_loss, can_return_loss,
find_labels, find_labels,
get_full_repo_name, get_full_repo_name,
is_accelerate_available,
is_apex_available, is_apex_available,
is_datasets_available, is_datasets_available,
is_in_notebook, is_in_notebook,
...@@ -193,6 +194,14 @@ else: ...@@ -193,6 +194,14 @@ else:
IS_SAGEMAKER_MP_POST_1_10 = False IS_SAGEMAKER_MP_POST_1_10 = False
skip_first_batches = None
if is_accelerate_available():
from accelerate import __version__ as accelerate_version
if version.parse(accelerate_version) >= version.parse("0.16"):
from accelerate import skip_first_batches
if TYPE_CHECKING: if TYPE_CHECKING:
import optuna import optuna
...@@ -1691,12 +1700,20 @@ class Trainer: ...@@ -1691,12 +1700,20 @@ class Trainer:
logger.info(f" Continuing training from epoch {epochs_trained}") logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}") logger.info(f" Continuing training from global step {self.state.global_step}")
if not args.ignore_data_skip: if not args.ignore_data_skip:
logger.info( if skip_first_batches is None:
f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " logger.info(
"batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` " f" Will skip the first {epochs_trained} epochs then the first"
"flag to your launch command, but you will resume the training on data already seen by your model." f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time,"
) " you can install the latest version of Accelerate with `pip install -U accelerate`.You can"
if self.is_local_process_zero() and not args.disable_tqdm: " also add the `--ignore_data_skip` flag to your launch command, but you will resume the"
" training on data already seen by your model."
)
else:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first"
f" {steps_trained_in_current_epoch} batches in the first epoch."
)
if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None:
steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
steps_trained_progress_bar.set_description("Skipping the first batches") steps_trained_progress_bar.set_description("Skipping the first batches")
...@@ -1772,8 +1789,17 @@ class Trainer: ...@@ -1772,8 +1789,17 @@ class Trainer:
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint) self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
if skip_first_batches is not None and steps_trained_in_current_epoch > 0:
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
steps_trained_in_current_epoch = 0
rng_to_sync = True
step = -1 step = -1
for step, inputs in enumerate(epoch_iterator): for step, inputs in enumerate(epoch_iterator):
if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
# Skip past any already trained steps if resuming training # Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0: if steps_trained_in_current_epoch > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment