# Copyright (c) OpenMMLab. All rights reserved. from torch.utils.data import DataLoader, Dataset from mmaction.datasets.samplers import (ClassSpecificDistributedSampler, DistributedSampler) class MyDataset(Dataset): def __init__(self, class_prob={i: 1 for i in range(10)}): super().__init__() self.class_prob = class_prob self.video_infos = [ dict(data=idx, label=idx % 10) for idx in range(100) ] def __len__(self): return len(self.video_infos) def __getitem__(self, idx): return self.video_infos[idx] def test_distributed_sampler(): dataset = MyDataset() sampler = DistributedSampler(dataset, num_replicas=1, rank=0) data_loader = DataLoader(dataset, batch_size=4, sampler=sampler) batches = [] for _, data in enumerate(data_loader): batches.append(data) assert len(batches) == 25 assert sum([len(x['data']) for x in batches]) == 100 sampler = DistributedSampler(dataset, num_replicas=4, rank=2) data_loader = DataLoader(dataset, batch_size=4, sampler=sampler) batches = [] for i, data in enumerate(data_loader): batches.append(data) assert len(batches) == 7 assert sum([len(x['data']) for x in batches]) == 25 sampler = DistributedSampler(dataset, num_replicas=6, rank=3) data_loader = DataLoader(dataset, batch_size=4, sampler=sampler) batches = [] for i, data in enumerate(data_loader): batches.append(data) assert len(batches) == 5 assert sum([len(x['data']) for x in batches]) == 17 def test_class_specific_distributed_sampler(): class_prob = dict(zip(list(range(10)), [1] * 5 + [3] * 5)) dataset = MyDataset(class_prob=class_prob) sampler = ClassSpecificDistributedSampler( dataset, num_replicas=1, rank=0, dynamic_length=True) data_loader = DataLoader(dataset, batch_size=4, sampler=sampler) batches = [] for _, data in enumerate(data_loader): batches.append(data) assert len(batches) == 50 assert sum([len(x['data']) for x in batches]) == 200 sampler = ClassSpecificDistributedSampler( dataset, num_replicas=1, rank=0, dynamic_length=False) data_loader = DataLoader(dataset, batch_size=4, sampler=sampler) batches = [] for i, data in enumerate(data_loader): batches.append(data) assert len(batches) == 25 assert sum([len(x['data']) for x in batches]) == 100 sampler = ClassSpecificDistributedSampler( dataset, num_replicas=6, rank=2, dynamic_length=True) data_loader = DataLoader(dataset, batch_size=4, sampler=sampler) batches = [] for i, data in enumerate(data_loader): batches.append(data) assert len(batches) == 9 assert sum([len(x['data']) for x in batches]) == 34 sampler = ClassSpecificDistributedSampler( dataset, num_replicas=6, rank=2, dynamic_length=False) data_loader = DataLoader(dataset, batch_size=4, sampler=sampler) batches = [] for i, data in enumerate(data_loader): batches.append(data) assert len(batches) == 5 assert sum([len(x['data']) for x in batches]) == 17