Commit d494485f authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

fix raw text for language modeling

parent 7358296b
...@@ -47,7 +47,7 @@ class TokenBlockDataset(torch.utils.data.Dataset): ...@@ -47,7 +47,7 @@ class TokenBlockDataset(torch.utils.data.Dataset):
self.slice_indices = [block_at(i) for i in range(length)] self.slice_indices = [block_at(i) for i in range(length)]
elif break_mode == 'complete': elif break_mode == 'complete':
assert sizes is not None and sum(sizes) == len(tokens) assert sizes is not None and sum(sizes) == len(tokens), '{} != {}'.format(sum(sizes), len(tokens))
tok_idx = 0 tok_idx = 0
sz_idx = 0 sz_idx = 0
curr_size = 0 curr_size = 0
...@@ -62,7 +62,7 @@ class TokenBlockDataset(torch.utils.data.Dataset): ...@@ -62,7 +62,7 @@ class TokenBlockDataset(torch.utils.data.Dataset):
if curr_size > 0: if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size)) self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos': elif break_mode == 'eos':
assert sizes is not None and sum(sizes) == len(tokens) assert sizes is not None and sum(sizes) == len(tokens), '{} != {}'.format(sum(sizes), len(tokens))
curr = 0 curr = 0
for sz in sizes: for sz in sizes:
# skip samples with just 1 example (which would be just the eos token) # skip samples with just 1 example (which would be just the eos token)
......
...@@ -48,7 +48,7 @@ class LanguageModelingTask(FairseqTask): ...@@ -48,7 +48,7 @@ class LanguageModelingTask(FairseqTask):
path = os.path.join(self.args.data, split) path = os.path.join(self.args.data, split)
if self.args.raw_text and IndexedRawTextDataset.exists(path): if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary) ds = IndexedRawTextDataset(path, self.dictionary)
tokens = ds.tokens_list tokens = [t for l in ds.tokens_list for t in l]
elif not self.args.raw_text and IndexedInMemoryDataset.exists(path): elif not self.args.raw_text and IndexedInMemoryDataset.exists(path):
ds = IndexedInMemoryDataset(path, fix_lua_indexing=True) ds = IndexedInMemoryDataset(path, fix_lua_indexing=True)
tokens = ds.buffer tokens = ds.buffer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment