Unverified Commit 2882ef12 authored by HanZhipeng's avatar HanZhipeng Committed by GitHub
Browse files

bug fixes

parent 8a3e232c
......@@ -36,7 +36,7 @@ def collate(batch, samples_per_gpu=1):
assert isinstance(batch[i].data, torch.Tensor)
# TODO: handle tensors other than 3d
assert batch[i].dim() == 3
c, h, w = batch[0].size()
c, h, w = batch[i].size()
for sample in batch[i:i + samples_per_gpu]:
assert c == sample.size(0)
h = max(h, sample.size(1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment