Commit 34e6a5e8 authored by Louis MARTIN's avatar Louis MARTIN Committed by Facebook Github Bot
Browse files

Fix load_dataset signature (#1281)

Summary:
Fix for https://github.com/pytorch/fairseq/issues/1240
Tested with MaskedLMTask.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/1281

Differential Revision: D18051472

fbshipit-source-id: 0aeff60c71489655f5e621349f780ba9cd8c027a
parent 66d24dc2
...@@ -71,7 +71,7 @@ class MaskedLMTask(FairseqTask): ...@@ -71,7 +71,7 @@ class MaskedLMTask(FairseqTask):
print('| dictionary: {} types'.format(len(dictionary))) print('| dictionary: {} types'.format(len(dictionary)))
return cls(args, dictionary) return cls(args, dictionary)
def load_dataset(self, split, epoch=0, combine=False): def load_dataset(self, split, epoch=0, combine=False, **kwargs):
"""Load a given dataset split. """Load a given dataset split.
Args: Args:
......
...@@ -110,7 +110,7 @@ class MultiLingualMaskedLMTask(FairseqTask): ...@@ -110,7 +110,7 @@ class MultiLingualMaskedLMTask(FairseqTask):
smoothed_prob = smoothed_prob / smoothed_prob.sum() smoothed_prob = smoothed_prob / smoothed_prob.sum()
return smoothed_prob return smoothed_prob
def load_dataset(self, split, epoch=0, combine=False): def load_dataset(self, split, epoch=0, combine=False, **kwargs):
"""Load a given dataset split. """Load a given dataset split.
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment