Commit eddcdf08 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix indexing in TokenBlockDataset

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/719

Differential Revision: D15258483

Pulled By: myleott

fbshipit-source-id: dd00daa6f1c87264c1196a77dfffc8c876ebde7f
parent 0cb45bcb
...@@ -70,7 +70,7 @@ class TokenBlockDataset(FairseqDataset): ...@@ -70,7 +70,7 @@ class TokenBlockDataset(FairseqDataset):
if not torch.is_tensor(sizes): if not torch.is_tensor(sizes):
sizes = torch.tensor(sizes) sizes = torch.tensor(sizes)
cumsum = torch.cumsum(sizes, dim=0) cumsum = torch.cumsum(sizes, dim=0)
self.slice_indices[0, 1] = sizes[0] self.slice_indices[0] = [0, sizes[0]]
self.slice_indices[1:] = cumsum.unfold(0, 2, 1) self.slice_indices[1:] = cumsum.unfold(0, 2, 1)
else: else:
raise ValueError('Invalid break_mode: ' + break_mode) raise ValueError('Invalid break_mode: ' + break_mode)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment