Commit 67ee6d1f authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

remove right-to-left lm support

parent d2e2a1d4
......@@ -29,13 +29,12 @@ class TokenBlockDataset(torch.utils.data.Dataset):
include_targets: return next tokens as targets
"""
def __init__(self, tokens, sizes, block_size, break_mode=None, include_targets=False, reverse=False):
def __init__(self, tokens, sizes, block_size, break_mode=None, include_targets=False):
super().__init__()
self.tokens = tokens
self.total_size = len(tokens)
self.include_targets = include_targets
self.reverse = reverse
self.slice_indices = []
if break_mode is None or break_mode == 'none':
......@@ -78,19 +77,9 @@ class TokenBlockDataset(torch.utils.data.Dataset):
def __getitem__(self, index):
s, e = self.slice_indices[index]
if self.reverse:
item = torch.LongTensor(np.flip(self.tokens[s:e], 0).copy())
else:
item = torch.LongTensor(self.tokens[s:e])
if self.include_targets:
if self.reverse:
if s == 0:
target = np.concatenate([self.tokens[-1:], item.numpy()[1:]])
else:
target = np.concatenate([self.tokens[s - 1:s], item.numpy()[:-1]])
return item, torch.LongTensor(target)
# target is the sentence, for source, rotate item one token to the left (would start with eos)
if s == 0:
source = np.concatenate([self.tokens[-1:], self.tokens[0:e - 1]])
......
......@@ -36,13 +36,9 @@ class LanguageModelingTask(FairseqTask):
help='max number of tokens per sample for LM dataset')
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
parser.add_argument('--right-to-left', default=False, action='store_true',
help='if set, trains a language model right-to-left (instead of left-to-right)')
def __init__(self, args, dictionary):
super().__init__(args)
args.right_to_left = getattr(args, 'right_to_left', False)
self.dictionary = dictionary
@classmethod
......@@ -75,7 +71,7 @@ class LanguageModelingTask(FairseqTask):
loaded_datasets.append(
TokenBlockDataset(
tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode,
include_targets=True, reverse=self.args.right_to_left,
include_targets=True
))
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment