"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "aea7b23b57ea0bbe15c42114d35b3c7a4d79d7eb"
Unverified Commit 269078a7 authored by Ilya's avatar Ilya Committed by GitHub
Browse files

Add `persistent_workers` parameter to `TrainingArguments` (#27189)



added param
Co-authored-by: default avatarIlya Fedorov <ilyaf@nvidia.com>
parent a2b1e1df
......@@ -793,6 +793,7 @@ class Trainer:
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
......@@ -850,6 +851,7 @@ class Trainer:
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
......@@ -881,6 +883,7 @@ class Trainer:
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(test_dataset, torch.utils.data.IterableDataset):
......
......@@ -529,6 +529,10 @@ class TrainingArguments:
`DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
dataloader_pin_memory (`bool`, *optional*, defaults to `True`):
Whether you want to pin memory in data loaders or not. Will default to `True`.
dataloader_persistent_workers (`bool`, *optional*, defaults to `False`):
If True, the data loader will not shut down the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will
increase RAM usage. Will default to `False`.
skip_memory_metrics (`bool`, *optional*, defaults to `True`):
Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows
down the training and evaluation speed.
......@@ -1136,6 +1140,12 @@ class TrainingArguments:
dataloader_pin_memory: bool = field(
default=True, metadata={"help": "Whether or not to pin memory for DataLoader."}
)
dataloader_persistent_workers: bool = field(
default=False,
metadata={
"help": "If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will increase RAM usage."
},
)
skip_memory_metrics: bool = field(
default=True, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."}
)
......@@ -2643,6 +2653,7 @@ class TrainingArguments:
drop_last: bool = False,
num_workers: int = 0,
pin_memory: bool = True,
persistent_workers: bool = False,
auto_find_batch_size: bool = False,
ignore_data_skip: bool = False,
sampler_seed: Optional[int] = None,
......@@ -2659,6 +2670,10 @@ class TrainingArguments:
the main process.
pin_memory (`bool`, *optional*, defaults to `True`):
Whether you want to pin memory in data loaders or not. Will default to `True`.
persistent_workers (`bool`, *optional*, defaults to `False`):
If True, the data loader will not shut down the worker processes after a dataset has been consumed
once. This allows to maintain the workers Dataset instances alive. Can potentially speed up training,
but will increase RAM usage. Will default to `False`.
auto_find_batch_size (`bool`, *optional*, defaults to `False`)
Whether to find a batch size that will fit into memory automatically through exponential decay,
avoiding CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
......@@ -2688,6 +2703,7 @@ class TrainingArguments:
self.dataloader_drop_last = drop_last
self.dataloader_num_workers = num_workers
self.dataloader_pin_memory = pin_memory
self.dataloader_persistent_workers = persistent_workers
self.auto_find_batch_size = auto_find_batch_size
self.ignore_data_skip = ignore_data_skip
self.data_seed = sampler_seed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment