sampler.py 1.08 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
3
4
5
6
7
8
9
10
11
12
13
import math

import numpy as np


class DistributedSampler:
    def __init__(self, dataset, num_replicas: int, rank: int) -> None:
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank

        if len(self.dataset) % self.num_replicas != 0:
            self.num_samples = math.ceil(
14
                (len(self.dataset) - self.num_replicas) / self.num_replicas  # type: ignore[arg-type]
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
15
16
17
18
19
20
21
            )
        else:
            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)

        self.total_size = self.num_samples * self.num_replicas

        indices = list(range(len(self.dataset)))
22
        indices = indices[: self.total_size]
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
23
24
        assert len(indices) == self.total_size
        # subsample
25
        indices = indices[self.rank : self.total_size : self.num_replicas]
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
26
27
28
29
30
31
        assert len(indices) == self.num_samples
        self.indices = indices

    def sample(self, batch_size: int) -> list:
        sampled_indices = np.random.choice(self.indices, batch_size, replace=False)
        return [self.dataset[idx] for idx in sampled_indices]