Unverified Commit 17099ebd authored by Chady Kamar's avatar Chady Kamar Committed by GitHub
Browse files

Add num workers cli arg (#7322)

* Add dataloader_num_workers to TrainingArguments

This argument is meant to be used to set the
number of workers for the PyTorch DataLoader.

* Pass num_workers argument on DataLoader init
parent 25b0463d
...@@ -352,6 +352,7 @@ class Trainer: ...@@ -352,6 +352,7 @@ class Trainer:
sampler=train_sampler, sampler=train_sampler,
collate_fn=self.data_collator, collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last, drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
) )
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]: def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
...@@ -391,6 +392,7 @@ class Trainer: ...@@ -391,6 +392,7 @@ class Trainer:
batch_size=self.args.eval_batch_size, batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator, collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last, drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
) )
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
......
...@@ -122,6 +122,8 @@ class TrainingArguments: ...@@ -122,6 +122,8 @@ class TrainingArguments:
eval_steps (:obj:`int`, `optional`): eval_steps (:obj:`int`, `optional`):
Number of update steps between two evaluations if :obj:`evaluation_strategy="steps"`. Will default to the Number of update steps between two evaluations if :obj:`evaluation_strategy="steps"`. Will default to the
same value as :obj:`logging_steps` if not set. same value as :obj:`logging_steps` if not set.
dataloader_num_workers (:obj:`int`, `optional`, defaults to 0):
Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the main process.
past_index (:obj:`int`, `optional`, defaults to -1): past_index (:obj:`int`, `optional`, defaults to -1):
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
make use of the past hidden states for their predictions. If this argument is set to a positive int, the make use of the past hidden states for their predictions. If this argument is set to a positive int, the
...@@ -259,6 +261,10 @@ class TrainingArguments: ...@@ -259,6 +261,10 @@ class TrainingArguments:
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
) )
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."}) eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
dataloader_num_workers: int = field(
default=0,
metadata={"help": "Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the main process."}
)
past_index: int = field( past_index: int = field(
default=-1, default=-1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment