Unverified Commit 900c88c7 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Bugfix on GroupedBatchSampler for corner case where there are not enough...

Bugfix on GroupedBatchSampler for corner case where there are not enough examples in a category to form a batch (#1677)
parent 1d229b77
import bisect import bisect
from collections import defaultdict from collections import defaultdict
import copy import copy
from itertools import repeat, chain
import math
import numpy as np import numpy as np
import torch import torch
...@@ -12,6 +14,12 @@ import torchvision ...@@ -12,6 +14,12 @@ import torchvision
from PIL import Image from PIL import Image
def _repeat_to_at_least(iterable, n):
repeat_times = math.ceil(n / len(iterable))
repeated = chain.from_iterable(repeat(iterable, repeat_times))
return list(repeated)
class GroupedBatchSampler(BatchSampler): class GroupedBatchSampler(BatchSampler):
""" """
Wraps another sampler to yield a mini-batch of indices. Wraps another sampler to yield a mini-batch of indices.
...@@ -63,8 +71,8 @@ class GroupedBatchSampler(BatchSampler): ...@@ -63,8 +71,8 @@ class GroupedBatchSampler(BatchSampler):
for group_id, _ in sorted(buffer_per_group.items(), for group_id, _ in sorted(buffer_per_group.items(),
key=lambda x: len(x[1]), reverse=True): key=lambda x: len(x[1]), reverse=True):
remaining = self.batch_size - len(buffer_per_group[group_id]) remaining = self.batch_size - len(buffer_per_group[group_id])
buffer_per_group[group_id].extend( samples_from_group_id = _repeat_to_at_least(samples_per_group[group_id], remaining)
samples_per_group[group_id][:remaining]) buffer_per_group[group_id].extend(samples_from_group_id[:remaining])
assert len(buffer_per_group[group_id]) == self.batch_size assert len(buffer_per_group[group_id]) == self.batch_size
yield buffer_per_group[group_id] yield buffer_per_group[group_id]
num_remaining -= 1 num_remaining -= 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment