"tests/models/ctrl/test_modeling_ctrl.py" did not exist on "aa925a52fad9d6b98dac4c1b27f881bef7e88dad"
Unverified Commit 269078a7 authored by Ilya's avatar Ilya Committed by GitHub
Browse files

Add `persistent_workers` parameter to `TrainingArguments` (#27189)



added param
Co-authored-by: default avatarIlya Fedorov <ilyaf@nvidia.com>
parent a2b1e1df
...@@ -793,6 +793,7 @@ class Trainer: ...@@ -793,6 +793,7 @@ class Trainer:
"collate_fn": data_collator, "collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers, "num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory, "pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
} }
if not isinstance(train_dataset, torch.utils.data.IterableDataset): if not isinstance(train_dataset, torch.utils.data.IterableDataset):
...@@ -850,6 +851,7 @@ class Trainer: ...@@ -850,6 +851,7 @@ class Trainer:
"collate_fn": data_collator, "collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers, "num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory, "pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
} }
if not isinstance(eval_dataset, torch.utils.data.IterableDataset): if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
...@@ -881,6 +883,7 @@ class Trainer: ...@@ -881,6 +883,7 @@ class Trainer:
"collate_fn": data_collator, "collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers, "num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory, "pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
} }
if not isinstance(test_dataset, torch.utils.data.IterableDataset): if not isinstance(test_dataset, torch.utils.data.IterableDataset):
......
...@@ -529,6 +529,10 @@ class TrainingArguments: ...@@ -529,6 +529,10 @@ class TrainingArguments:
`DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise. `DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
dataloader_pin_memory (`bool`, *optional*, defaults to `True`): dataloader_pin_memory (`bool`, *optional*, defaults to `True`):
Whether you want to pin memory in data loaders or not. Will default to `True`. Whether you want to pin memory in data loaders or not. Will default to `True`.
dataloader_persistent_workers (`bool`, *optional*, defaults to `False`):
If True, the data loader will not shut down the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will
increase RAM usage. Will default to `False`.
skip_memory_metrics (`bool`, *optional*, defaults to `True`): skip_memory_metrics (`bool`, *optional*, defaults to `True`):
Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows
down the training and evaluation speed. down the training and evaluation speed.
...@@ -1136,6 +1140,12 @@ class TrainingArguments: ...@@ -1136,6 +1140,12 @@ class TrainingArguments:
dataloader_pin_memory: bool = field( dataloader_pin_memory: bool = field(
default=True, metadata={"help": "Whether or not to pin memory for DataLoader."} default=True, metadata={"help": "Whether or not to pin memory for DataLoader."}
) )
dataloader_persistent_workers: bool = field(
default=False,
metadata={
"help": "If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will increase RAM usage."
},
)
skip_memory_metrics: bool = field( skip_memory_metrics: bool = field(
default=True, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."} default=True, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."}
) )
...@@ -2643,6 +2653,7 @@ class TrainingArguments: ...@@ -2643,6 +2653,7 @@ class TrainingArguments:
drop_last: bool = False, drop_last: bool = False,
num_workers: int = 0, num_workers: int = 0,
pin_memory: bool = True, pin_memory: bool = True,
persistent_workers: bool = False,
auto_find_batch_size: bool = False, auto_find_batch_size: bool = False,
ignore_data_skip: bool = False, ignore_data_skip: bool = False,
sampler_seed: Optional[int] = None, sampler_seed: Optional[int] = None,
...@@ -2659,6 +2670,10 @@ class TrainingArguments: ...@@ -2659,6 +2670,10 @@ class TrainingArguments:
the main process. the main process.
pin_memory (`bool`, *optional*, defaults to `True`): pin_memory (`bool`, *optional*, defaults to `True`):
Whether you want to pin memory in data loaders or not. Will default to `True`. Whether you want to pin memory in data loaders or not. Will default to `True`.
persistent_workers (`bool`, *optional*, defaults to `False`):
If True, the data loader will not shut down the worker processes after a dataset has been consumed
once. This allows to maintain the workers Dataset instances alive. Can potentially speed up training,
but will increase RAM usage. Will default to `False`.
auto_find_batch_size (`bool`, *optional*, defaults to `False`) auto_find_batch_size (`bool`, *optional*, defaults to `False`)
Whether to find a batch size that will fit into memory automatically through exponential decay, Whether to find a batch size that will fit into memory automatically through exponential decay,
avoiding CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`) avoiding CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
...@@ -2688,6 +2703,7 @@ class TrainingArguments: ...@@ -2688,6 +2703,7 @@ class TrainingArguments:
self.dataloader_drop_last = drop_last self.dataloader_drop_last = drop_last
self.dataloader_num_workers = num_workers self.dataloader_num_workers = num_workers
self.dataloader_pin_memory = pin_memory self.dataloader_pin_memory = pin_memory
self.dataloader_persistent_workers = persistent_workers
self.auto_find_batch_size = auto_find_batch_size self.auto_find_batch_size = auto_find_batch_size
self.ignore_data_skip = ignore_data_skip self.ignore_data_skip = ignore_data_skip
self.data_seed = sampler_seed self.data_seed = sampler_seed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment