"git@developer.sourcefind.cn:OpenDAS/fast_rnnt.git" did not exist on "2ec4b08e0d968150636aeb066686ba7de721bd22"
Unverified Commit bc109ae5 authored by abhishek thakur's avatar abhishek thakur Committed by GitHub
Browse files

pin_memory -> dataloader_pin_memory (#9874)

parent 80e4184f
...@@ -485,7 +485,7 @@ class Trainer: ...@@ -485,7 +485,7 @@ class Trainer:
collate_fn=self.data_collator, collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last, drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers, num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.pin_memory, pin_memory=self.args.dataloader_pin_memory,
) )
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]: def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
...@@ -523,7 +523,7 @@ class Trainer: ...@@ -523,7 +523,7 @@ class Trainer:
collate_fn=self.data_collator, collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last, drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers, num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.pin_memory, pin_memory=self.args.dataloader_pin_memory,
) )
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
...@@ -550,7 +550,7 @@ class Trainer: ...@@ -550,7 +550,7 @@ class Trainer:
batch_size=self.args.eval_batch_size, batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator, collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last, drop_last=self.args.dataloader_drop_last,
pin_memory=self.args.pin_memory, pin_memory=self.args.dataloader_pin_memory,
) )
def create_optimizer_and_scheduler(self, num_training_steps: int): def create_optimizer_and_scheduler(self, num_training_steps: int):
......
...@@ -244,7 +244,7 @@ class TrainingArguments: ...@@ -244,7 +244,7 @@ class TrainingArguments:
When using distributed training, the value of the flag :obj:`find_unused_parameters` passed to When using distributed training, the value of the flag :obj:`find_unused_parameters` passed to
:obj:`DistributedDataParallel`. Will default to :obj:`False` if gradient checkpointing is used, :obj:`True` :obj:`DistributedDataParallel`. Will default to :obj:`False` if gradient checkpointing is used, :obj:`True`
otherwise. otherwise.
pin_memory (:obj:`bool`, `optional`, defaults to :obj:`True`)): dataloader_pin_memory (:obj:`bool`, `optional`, defaults to :obj:`True`)):
Whether you want to pin memory in data loaders or not. Will default to :obj:`True`. Whether you want to pin memory in data loaders or not. Will default to :obj:`True`.
""" """
...@@ -438,7 +438,9 @@ class TrainingArguments: ...@@ -438,7 +438,9 @@ class TrainingArguments:
"`DistributedDataParallel`." "`DistributedDataParallel`."
}, },
) )
pin_memory: bool = field(default=True, metadata={"help": "Whether or not to pin memory for data loaders."}) dataloader_pin_memory: bool = field(
default=True, metadata={"help": "Whether or not to pin memory for DataLoader."}
)
_n_gpu: int = field(init=False, repr=False, default=-1) _n_gpu: int = field(init=False, repr=False, default=-1)
def __post_init__(self): def __post_init__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment