Unverified Commit 5b5e71dc authored by Quentin Meeus's avatar Quentin Meeus Committed by GitHub
Browse files

add dataloader prefetch factor in training args and trainer (#28498)



* add dataloader prefetch factor in training args and trainer

* remove trailing spaces

* prevent dataloader_num_workers == 0 and dataloader_prefetch_factor != None

dataloader_prefetch_factor works only when data is loaded in a different process as the main one. This commit adds the necessary checks to avoid having prefetch_factor set when there is no such process.

* Remove whitespaces in empty line

* Update src/transformers/training_args.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/training_args.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/training_args.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/training_args.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 582d104b
......@@ -806,6 +806,7 @@ class Trainer:
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
......@@ -863,6 +864,7 @@ class Trainer:
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
......@@ -895,6 +897,7 @@ class Trainer:
if not isinstance(test_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(test_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
# We use the same batch_size as for eval.
return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))
......
......@@ -532,6 +532,9 @@ class TrainingArguments:
If True, the data loader will not shut down the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will
increase RAM usage. Will default to `False`.
dataloader_prefetch_factor (`int`, *optional*):
Number of batches loaded in advance by each worker.
2 means there will be a total of 2 * num_workers batches prefetched across all workers.
skip_memory_metrics (`bool`, *optional*, defaults to `True`):
Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows
down the training and evaluation speed.
......@@ -989,7 +992,16 @@ class TrainingArguments:
)
},
)
dataloader_prefetch_factor: int = field(
default=None,
metadata={
"help": (
"Number of batches loaded in advance by each worker. "
"2 means there will be a total of 2 * num_workers batches prefetched across all workers. "
"Default is unset"
)
},
)
past_index: int = field(
default=-1,
metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."},
......@@ -1737,6 +1749,12 @@ class TrainingArguments:
if self.use_cpu:
self.dataloader_pin_memory = False
if self.dataloader_num_workers == 0 and self.dataloader_prefetch_factor is not None:
raise ValueError(
"--dataloader_prefetch_factor can only be set when data is loaded in a different process, i.e."
" when --dataloader_num_workers > 1."
)
if self.push_to_hub_token is not None:
warnings.warn(
"`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
......@@ -2634,6 +2652,7 @@ class TrainingArguments:
num_workers: int = 0,
pin_memory: bool = True,
persistent_workers: bool = False,
prefetch_factor: Optional[int] = None,
auto_find_batch_size: bool = False,
ignore_data_skip: bool = False,
sampler_seed: Optional[int] = None,
......@@ -2654,6 +2673,9 @@ class TrainingArguments:
If True, the data loader will not shut down the worker processes after a dataset has been consumed
once. This allows to maintain the workers Dataset instances alive. Can potentially speed up training,
but will increase RAM usage. Will default to `False`.
prefetch_factor (`int`, *optional*):
Number of batches loaded in advance by each worker.
2 means there will be a total of 2 * num_workers batches prefetched across all workers.
auto_find_batch_size (`bool`, *optional*, defaults to `False`)
Whether to find a batch size that will fit into memory automatically through exponential decay,
avoiding CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
......@@ -2684,6 +2706,7 @@ class TrainingArguments:
self.dataloader_num_workers = num_workers
self.dataloader_pin_memory = pin_memory
self.dataloader_persistent_workers = persistent_workers
self.dataloader_prefetch_factor = prefetch_factor
self.auto_find_batch_size = auto_find_batch_size
self.ignore_data_skip = ignore_data_skip
self.data_seed = sampler_seed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment