"git@developer.sourcefind.cn:change/sglang.git" did not exist on "21514ff5bdad2ae598bf3070d3a583aa4fa35ae7"
Commit a7d0bd0e authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

fix token block rotation

parent 0e9e7f7b
...@@ -76,13 +76,17 @@ class TokenBlockDataset(torch.utils.data.Dataset): ...@@ -76,13 +76,17 @@ class TokenBlockDataset(torch.utils.data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
s, e = self.slice_indices[index] s, e = self.slice_indices[index]
item = torch.LongTensor(self.tokens[s:e]) item = torch.LongTensor(self.tokens[s:e])
if self.include_targets: if self.include_targets:
if e == self.total_size: # target is the sentence, for source, rotate item one token to the left (would start with eos)
return item[:-1], item[1:] if s == 0:
else: source = np.concatenate([self.tokens[-1:], self.tokens[0:e - 1]])
return item, torch.LongTensor(self.tokens[s + 1:e + 1])
else: else:
source = self.tokens[s - 1:e - 1]
return torch.LongTensor(source), item
return item return item
def __len__(self): def __len__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment