"...asr/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "e502b10c83e17ad22f68b50365862ff8bef2dec8"
Unverified Commit 5b5e71dc authored by Quentin Meeus's avatar Quentin Meeus Committed by GitHub
Browse files

add dataloader prefetch factor in training args and trainer (#28498)



* add dataloader prefetch factor in training args and trainer

* remove trailing spaces

* prevent dataloader_num_workers == 0 and dataloader_prefetch_factor != None

dataloader_prefetch_factor works only when data is loaded in a different process as the main one. This commit adds the necessary checks to avoid having prefetch_factor set when there is no such process.

* Remove whitespaces in empty line

* Update src/transformers/training_args.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/training_args.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/training_args.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/training_args.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 582d104b
...@@ -806,6 +806,7 @@ class Trainer: ...@@ -806,6 +806,7 @@ class Trainer:
dataloader_params["sampler"] = self._get_train_sampler() dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
...@@ -863,6 +864,7 @@ class Trainer: ...@@ -863,6 +864,7 @@ class Trainer:
if not isinstance(eval_dataset, torch.utils.data.IterableDataset): if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
...@@ -895,6 +897,7 @@ class Trainer: ...@@ -895,6 +897,7 @@ class Trainer:
if not isinstance(test_dataset, torch.utils.data.IterableDataset): if not isinstance(test_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(test_dataset) dataloader_params["sampler"] = self._get_eval_sampler(test_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
# We use the same batch_size as for eval. # We use the same batch_size as for eval.
return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params)) return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))
......
...@@ -532,6 +532,9 @@ class TrainingArguments: ...@@ -532,6 +532,9 @@ class TrainingArguments:
If True, the data loader will not shut down the worker processes after a dataset has been consumed once. If True, the data loader will not shut down the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will
increase RAM usage. Will default to `False`. increase RAM usage. Will default to `False`.
dataloader_prefetch_factor (`int`, *optional*):
Number of batches loaded in advance by each worker.
2 means there will be a total of 2 * num_workers batches prefetched across all workers.
skip_memory_metrics (`bool`, *optional*, defaults to `True`): skip_memory_metrics (`bool`, *optional*, defaults to `True`):
Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows
down the training and evaluation speed. down the training and evaluation speed.
...@@ -989,7 +992,16 @@ class TrainingArguments: ...@@ -989,7 +992,16 @@ class TrainingArguments:
) )
}, },
) )
dataloader_prefetch_factor: int = field(
default=None,
metadata={
"help": (
"Number of batches loaded in advance by each worker. "
"2 means there will be a total of 2 * num_workers batches prefetched across all workers. "
"Default is unset"
)
},
)
past_index: int = field( past_index: int = field(
default=-1, default=-1,
metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."}, metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."},
...@@ -1737,6 +1749,12 @@ class TrainingArguments: ...@@ -1737,6 +1749,12 @@ class TrainingArguments:
if self.use_cpu: if self.use_cpu:
self.dataloader_pin_memory = False self.dataloader_pin_memory = False
if self.dataloader_num_workers == 0 and self.dataloader_prefetch_factor is not None:
raise ValueError(
"--dataloader_prefetch_factor can only be set when data is loaded in a different process, i.e."
" when --dataloader_num_workers > 1."
)
if self.push_to_hub_token is not None: if self.push_to_hub_token is not None:
warnings.warn( warnings.warn(
"`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use " "`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
...@@ -2634,6 +2652,7 @@ class TrainingArguments: ...@@ -2634,6 +2652,7 @@ class TrainingArguments:
num_workers: int = 0, num_workers: int = 0,
pin_memory: bool = True, pin_memory: bool = True,
persistent_workers: bool = False, persistent_workers: bool = False,
prefetch_factor: Optional[int] = None,
auto_find_batch_size: bool = False, auto_find_batch_size: bool = False,
ignore_data_skip: bool = False, ignore_data_skip: bool = False,
sampler_seed: Optional[int] = None, sampler_seed: Optional[int] = None,
...@@ -2654,6 +2673,9 @@ class TrainingArguments: ...@@ -2654,6 +2673,9 @@ class TrainingArguments:
If True, the data loader will not shut down the worker processes after a dataset has been consumed If True, the data loader will not shut down the worker processes after a dataset has been consumed
once. This allows to maintain the workers Dataset instances alive. Can potentially speed up training, once. This allows to maintain the workers Dataset instances alive. Can potentially speed up training,
but will increase RAM usage. Will default to `False`. but will increase RAM usage. Will default to `False`.
prefetch_factor (`int`, *optional*):
Number of batches loaded in advance by each worker.
2 means there will be a total of 2 * num_workers batches prefetched across all workers.
auto_find_batch_size (`bool`, *optional*, defaults to `False`) auto_find_batch_size (`bool`, *optional*, defaults to `False`)
Whether to find a batch size that will fit into memory automatically through exponential decay, Whether to find a batch size that will fit into memory automatically through exponential decay,
avoiding CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`) avoiding CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
...@@ -2684,6 +2706,7 @@ class TrainingArguments: ...@@ -2684,6 +2706,7 @@ class TrainingArguments:
self.dataloader_num_workers = num_workers self.dataloader_num_workers = num_workers
self.dataloader_pin_memory = pin_memory self.dataloader_pin_memory = pin_memory
self.dataloader_persistent_workers = persistent_workers self.dataloader_persistent_workers = persistent_workers
self.dataloader_prefetch_factor = prefetch_factor
self.auto_find_batch_size = auto_find_batch_size self.auto_find_batch_size = auto_find_batch_size
self.ignore_data_skip = ignore_data_skip self.ignore_data_skip = ignore_data_skip
self.data_seed = sampler_seed self.data_seed = sampler_seed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment