Commit 6296de82 authored by Myle Ott's avatar Myle Ott
Browse files

Add --upsample-primary

parent 5852d3a0
...@@ -41,6 +41,8 @@ class TranslationTask(FairseqTask): ...@@ -41,6 +41,8 @@ class TranslationTask(FairseqTask):
help='max number of tokens in the source sequence') help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence') help='max number of tokens in the target sequence')
parser.add_argument('--upsample-primary', default=1, type=int,
help='amount to upsample primary dataset')
def __init__(self, args, src_dict, tgt_dict): def __init__(self, args, src_dict, tgt_dict):
super().__init__(args) super().__init__(args)
...@@ -120,12 +122,14 @@ class TranslationTask(FairseqTask): ...@@ -120,12 +122,14 @@ class TranslationTask(FairseqTask):
src_sizes = src_dataset.sizes src_sizes = src_dataset.sizes
tgt_sizes = tgt_dataset.sizes tgt_sizes = tgt_dataset.sizes
else: else:
if self.args.upsample_primary > 1:
src_datasets.extend([src_datasets[0]] * (self.args.upsample_primary - 1))
tgt_datasets.extend([tgt_datasets[0]] * (self.args.upsample_primary - 1))
src_dataset = ConcatDataset(src_datasets) src_dataset = ConcatDataset(src_datasets)
tgt_dataset = ConcatDataset(tgt_datasets) tgt_dataset = ConcatDataset(tgt_datasets)
src_sizes = np.concatenate([ds.sizes for ds in src_datasets]) src_sizes = np.concatenate([ds.sizes for ds in src_datasets])
tgt_sizes = np.concatenate([ds.sizes for ds in tgt_datasets]) tgt_sizes = np.concatenate([ds.sizes for ds in tgt_datasets])
self.datasets[split] = LanguagePairDataset( self.datasets[split] = LanguagePairDataset(
src_dataset, src_sizes, self.src_dict, src_dataset, src_sizes, self.src_dict,
tgt_dataset, tgt_sizes, self.tgt_dict, tgt_dataset, tgt_sizes, self.tgt_dict,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment