"test/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "6e0af71353a88b1d6e378b4d97a0b7a608600169"
Unverified Commit b0167632 authored by Cola's avatar Cola Committed by GitHub
Browse files

Shuffle train subset for summarization example (#3909)

* Shuffle train subset

* Cleaner shuffle
parent c53cc018
......@@ -102,13 +102,13 @@ class SummarizationTrainer(BaseTransformer):
return self.test_end(outputs)
def get_dataloader(self, type_path: str, batch_size: int) -> DataLoader:
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=shuffle)
return dataloader
def train_dataloader(self) -> DataLoader:
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size)
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
t_total = (
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu)))
// self.hparams.gradient_accumulation_steps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment