"git@developer.sourcefind.cn:OpenDAS/torchani.git" did not exist on "01bb50cef7f2b321e2258f97278763eae91d6b6e"
Unverified Commit 4cacf5a1 authored by Hu Ye's avatar Hu Ye Committed by GitHub
Browse files

support random seed for RA sampler (#5053)



* support random seed used to shuffle the sampler

* fix bug for RA samper
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 849d02bc
...@@ -15,7 +15,7 @@ class RASampler(torch.utils.data.Sampler): ...@@ -15,7 +15,7 @@ class RASampler(torch.utils.data.Sampler):
https://github.com/facebookresearch/deit/blob/main/samplers.py https://github.com/facebookresearch/deit/blob/main/samplers.py
""" """
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0):
if num_replicas is None: if num_replicas is None:
if not dist.is_available(): if not dist.is_available():
raise RuntimeError("Requires distributed package to be available!") raise RuntimeError("Requires distributed package to be available!")
...@@ -32,11 +32,12 @@ class RASampler(torch.utils.data.Sampler): ...@@ -32,11 +32,12 @@ class RASampler(torch.utils.data.Sampler):
self.total_size = self.num_samples * self.num_replicas self.total_size = self.num_samples * self.num_replicas
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
self.shuffle = shuffle self.shuffle = shuffle
self.seed = seed
def __iter__(self): def __iter__(self):
# Deterministically shuffle based on epoch # Deterministically shuffle based on epoch
g = torch.Generator() g = torch.Generator()
g.manual_seed(self.epoch) g.manual_seed(self.seed + self.epoch)
if self.shuffle: if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist() indices = torch.randperm(len(self.dataset), generator=g).tolist()
else: else:
......
...@@ -9,7 +9,7 @@ import torch.utils.data ...@@ -9,7 +9,7 @@ import torch.utils.data
import torchvision import torchvision
import transforms import transforms
import utils import utils
from references.classification.sampler import RASampler from sampler import RASampler
from torch import nn from torch import nn
from torch.utils.data.dataloader import default_collate from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment