Unverified Commit cb2bc69b authored by Guo Qipeng's avatar Guo Qipeng Committed by GitHub
Browse files

Fix a bug in the data loading of GraphWriter (#2343)

In the bucket sampler, the length array should be recovered after random picking samples.
parent 501b2b75
......@@ -227,6 +227,8 @@ class BucketSampler(torch.utils.data.Sampler):
random.shuffle(datas)
idxs = sum(datas, [])
batch = []
lens = torch.Tensor([len(x) for x in self.data_source])
for idx in idxs:
batch.append(idx)
mlen = max([0]+[lens[x] for x in batch])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment