import torch class Prefetcher(): """ Modified from the data_prefetcher in https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py """ def __init__(self, loader): self.orig_loader = loader self.stream = torch.cuda.Stream() self.next_sample = None def preload(self): try: self.next_sample = next(self.loader) except StopIteration: self.next_sample = None return with torch.cuda.stream(self.stream): for key, value in self.next_sample.items(): if isinstance(value, torch.Tensor): self.next_sample[key] = value.cuda(non_blocking=True) def __next__(self): torch.cuda.current_stream().wait_stream(self.stream) sample = self.next_sample if sample is not None: for key, value in sample.items(): if isinstance(value, torch.Tensor): sample[key].record_stream(torch.cuda.current_stream()) self.preload() else: # throw stop exception if there is no more data to perform as a default dataloader raise StopIteration("No samples in loader. example: `iterator = iter(Prefetcher(loader)); " "data = next(iterator)`") return sample def __iter__(self): self.loader = iter(self.orig_loader) self.preload() return self