"docs/source/en/task_summary.mdx" did not exist on "0118c4f6a85b3c3a454933e9bc3f35e95f5384ca"
Unverified Commit 8d580779 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Skip batches fast with accelerate (#21390)

* Skip batches fast with Accelerate

* remove debug statement

* Hack seed reload at the right time

* Reorganize RNG sync

* Fix accelerate version comp
parent 77db257e
......@@ -138,6 +138,7 @@ from .utils import (
can_return_loss,
find_labels,
get_full_repo_name,
is_accelerate_available,
is_apex_available,
is_datasets_available,
is_in_notebook,
......@@ -193,6 +194,14 @@ else:
IS_SAGEMAKER_MP_POST_1_10 = False
skip_first_batches = None
if is_accelerate_available():
from accelerate import __version__ as accelerate_version
if version.parse(accelerate_version) >= version.parse("0.16"):
from accelerate import skip_first_batches
if TYPE_CHECKING:
import optuna
......@@ -1691,12 +1700,20 @@ class Trainer:
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}")
if not args.ignore_data_skip:
if skip_first_batches is None:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
"batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` "
"flag to your launch command, but you will resume the training on data already seen by your model."
f" Will skip the first {epochs_trained} epochs then the first"
f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time,"
" you can install the latest version of Accelerate with `pip install -U accelerate`.You can"
" also add the `--ignore_data_skip` flag to your launch command, but you will resume the"
" training on data already seen by your model."
)
if self.is_local_process_zero() and not args.disable_tqdm:
else:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first"
f" {steps_trained_in_current_epoch} batches in the first epoch."
)
if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None:
steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
steps_trained_progress_bar.set_description("Skipping the first batches")
......@@ -1772,8 +1789,17 @@ class Trainer:
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
if skip_first_batches is not None and steps_trained_in_current_epoch > 0:
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
steps_trained_in_current_epoch = 0
rng_to_sync = True
step = -1
for step, inputs in enumerate(epoch_iterator):
if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment