Commit a7d0bd0e authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

fix token block rotation

parent 0e9e7f7b
......@@ -76,14 +76,18 @@ class TokenBlockDataset(torch.utils.data.Dataset):
def __getitem__(self, index):
s, e = self.slice_indices[index]
item = torch.LongTensor(self.tokens[s:e])
if self.include_targets:
if e == self.total_size:
return item[:-1], item[1:]
# target is the sentence, for source, rotate item one token to the left (would start with eos)
if s == 0:
source = np.concatenate([self.tokens[-1:], self.tokens[0:e - 1]])
else:
return item, torch.LongTensor(self.tokens[s + 1:e + 1])
else:
return item
source = self.tokens[s - 1:e - 1]
return torch.LongTensor(source), item
return item
def __len__(self):
return len(self.slice_indices)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment