Commit 355e9d2f authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Francisco Massa
Browse files

extend DistributedSampler to support group_size (#1512)

* extend DistributedSampler to support group_size

* Fix lint
parent b60cb726
......@@ -5,7 +5,11 @@ import torch
import unittest
from torchvision import io
from torchvision.datasets.samplers import RandomClipSampler, UniformClipSampler
from torchvision.datasets.samplers import (
DistributedSampler,
RandomClipSampler,
UniformClipSampler,
)
from torchvision.datasets.video_utils import VideoClips, unfold
from torchvision import get_video_backend
......@@ -83,6 +87,31 @@ class Tester(unittest.TestCase):
indices = torch.tensor(list(iter(sampler)))
self.assertTrue(indices.equal(torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11])))
def test_distributed_sampler_and_uniform_clip_sampler(self):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
clip_sampler = UniformClipSampler(video_clips, 3)
distributed_sampler_rank0 = DistributedSampler(
clip_sampler,
num_replicas=2,
rank=0,
group_size=3,
)
indices = torch.tensor(list(iter(distributed_sampler_rank0)))
self.assertEqual(len(distributed_sampler_rank0), 6)
self.assertTrue(indices.equal(torch.tensor([0, 2, 4, 10, 12, 14])))
distributed_sampler_rank1 = DistributedSampler(
clip_sampler,
num_replicas=2,
rank=1,
group_size=3,
)
indices = torch.tensor(list(iter(distributed_sampler_rank1)))
self.assertEqual(len(distributed_sampler_rank1), 6)
self.assertTrue(indices.equal(torch.tensor([5, 7, 9, 0, 2, 4])))
if __name__ == '__main__':
unittest.main()
......@@ -9,9 +9,32 @@ class DistributedSampler(Sampler):
"""
Extension of DistributedSampler, as discussed in
https://github.com/pytorch/pytorch/issues/23430
Example:
dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
num_replicas: 4
shuffle: False
when group_size = 1
RANK | shard_dataset
=========================
rank_0 | [0, 4, 8, 12]
rank_1 | [1, 5, 9, 13]
rank_2 | [2, 6, 10, 0]
rank_3 | [3, 7, 11, 1]
when group_size = 2
RANK | shard_dataset
=========================
rank_0 | [0, 1, 8, 9]
rank_1 | [2, 3, 10, 11]
rank_2 | [4, 5, 12, 13]
rank_3 | [6, 7, 0, 1]
"""
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False):
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False, group_size=1):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
......@@ -20,11 +43,20 @@ class DistributedSampler(Sampler):
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
assert len(dataset) % group_size == 0, (
"dataset length must be a multiplier of group size"
"dataset length: %d, group size: %d" % (len(dataset), group_size)
)
self.dataset = dataset
self.group_size = group_size
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
dataset_group_length = len(dataset) // group_size
self.num_group_samples = int(
math.ceil(dataset_group_length * 1.0 / self.num_replicas)
)
self.num_samples = self.num_group_samples * group_size
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
......@@ -41,8 +73,14 @@ class DistributedSampler(Sampler):
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
total_group_size = self.total_size // self.group_size
indices = torch.reshape(
torch.LongTensor(indices), (total_group_size, self.group_size)
)
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
indices = indices[self.rank:total_group_size:self.num_replicas, :]
indices = torch.reshape(indices, (-1,)).tolist()
assert len(indices) == self.num_samples
if isinstance(self.dataset, Sampler):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment