Unverified Commit 900c88c7 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Bugfix on GroupedBatchSampler for corner case where there are not enough...

Bugfix on GroupedBatchSampler for corner case where there are not enough examples in a category to form a batch (#1677)
parent 1d229b77
import bisect
from collections import defaultdict
import copy
from itertools import repeat, chain
import math
import numpy as np
import torch
......@@ -12,6 +14,12 @@ import torchvision
from PIL import Image
def _repeat_to_at_least(iterable, n):
repeat_times = math.ceil(n / len(iterable))
repeated = chain.from_iterable(repeat(iterable, repeat_times))
return list(repeated)
class GroupedBatchSampler(BatchSampler):
"""
Wraps another sampler to yield a mini-batch of indices.
......@@ -63,8 +71,8 @@ class GroupedBatchSampler(BatchSampler):
for group_id, _ in sorted(buffer_per_group.items(),
key=lambda x: len(x[1]), reverse=True):
remaining = self.batch_size - len(buffer_per_group[group_id])
buffer_per_group[group_id].extend(
samples_per_group[group_id][:remaining])
samples_from_group_id = _repeat_to_at_least(samples_per_group[group_id], remaining)
buffer_per_group[group_id].extend(samples_from_group_id[:remaining])
assert len(buffer_per_group[group_id]) == self.batch_size
yield buffer_per_group[group_id]
num_remaining -= 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment