"...resnet50_tensorflow.git" did not exist on "c5763f442973d4852d34d9999f1cf9cf11499dea"
Unverified Commit e238e3d5 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[seq2seq] Don't copy self.source in sortishsampler (#5818)

parent 2e4624b4
......@@ -144,16 +144,9 @@ class SummarizationDataset(Dataset):
batch = {"input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y}
return batch
@property
def src_lens(self): # Can delete?
return lmap(len, self.source)
@property
def tgt_lens(self):
return lmap(len, self.target)
def make_sortish_sampler(self, batch_size):
return SortishSampler(self.source, batch_size)
lens = [x["input_ids"].ne(self.pad_token_id).sum() for x in self.source]
return SortishSampler(lens, batch_size)
class SortishSampler(Sampler):
......@@ -163,7 +156,7 @@ class SortishSampler(Sampler):
self.data, self.bs = data, batch_size
def key(self, i):
return len(self.data[i])
return self.data[i]
def __len__(self) -> int:
return len(self.data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment