Commit 76b9024b authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3145 failed with stages
in 0 seconds
import math
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import flow_warp
class BasicModule(nn.Module):
"""Basic Module for SpyNet.
"""
def __init__(self):
super(BasicModule, self).__init__()
self.basic_module = nn.Sequential(
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
def forward(self, tensor_input):
return self.basic_module(tensor_input)
@ARCH_REGISTRY.register()
class SpyNet(nn.Module):
"""SpyNet architecture.
Args:
load_path (str): path for pretrained SpyNet. Default: None.
"""
def __init__(self, load_path=None):
super(SpyNet, self).__init__()
self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)])
if load_path:
self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def preprocess(self, tensor_input):
tensor_output = (tensor_input - self.mean) / self.std
return tensor_output
def process(self, ref, supp):
flow = []
ref = [self.preprocess(ref)]
supp = [self.preprocess(supp)]
for level in range(5):
ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
flow = ref[0].new_zeros(
[ref[0].size(0), 2,
int(math.floor(ref[0].size(2) / 2.0)),
int(math.floor(ref[0].size(3) / 2.0))])
for level in range(len(ref)):
upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
if upsampled_flow.size(2) != ref[level].size(2):
upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate')
if upsampled_flow.size(3) != ref[level].size(3):
upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate')
flow = self.basic_module[level](torch.cat([
ref[level],
flow_warp(
supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'),
upsampled_flow
], 1)) + upsampled_flow
return flow
def forward(self, ref, supp):
assert ref.size() == supp.size()
h, w = ref.size(2), ref.size(3)
w_floor = math.floor(math.ceil(w / 32.0) * 32.0)
h_floor = math.floor(math.ceil(h / 32.0) * 32.0)
ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
flow = F.interpolate(input=self.process(ref, supp), size=(h, w), mode='bilinear', align_corners=False)
flow[:, 0, :, :] *= float(w) / float(w_floor)
flow[:, 1, :, :] *= float(h) / float(h_floor)
return flow
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import ResidualBlockNoBN, default_init_weights, make_layer
@ARCH_REGISTRY.register()
class MSRResNet(nn.Module):
"""Modified SRResNet.
A compacted version modified from SRResNet in
"Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"
It uses residual blocks without BN, similar to EDSR.
Currently, it supports x2, x3 and x4 upsampling scale factor.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_out_ch (int): Channel number of outputs. Default: 3.
num_feat (int): Channel number of intermediate features. Default: 64.
num_block (int): Block number in the body network. Default: 16.
upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4.
"""
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, upscale=4):
super(MSRResNet, self).__init__()
self.upscale = upscale
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat)
# upsampling
if self.upscale in [2, 3]:
self.upconv1 = nn.Conv2d(num_feat, num_feat * self.upscale * self.upscale, 3, 1, 1)
self.pixel_shuffle = nn.PixelShuffle(self.upscale)
elif self.upscale == 4:
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
self.pixel_shuffle = nn.PixelShuffle(2)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
# initialization
default_init_weights([self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1)
if self.upscale == 4:
default_init_weights(self.upconv2, 0.1)
def forward(self, x):
feat = self.lrelu(self.conv_first(x))
out = self.body(feat)
if self.upscale == 4:
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
elif self.upscale in [2, 3]:
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
out = self.conv_last(self.lrelu(self.conv_hr(out)))
base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
out += base
return out
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
@ARCH_REGISTRY.register(suffix='basicsr')
class SRVGGNetCompact(nn.Module):
"""A compact VGG-style network structure for super-resolution.
It is a compact network structure, which performs upsampling in the last layer and no convolution is
conducted on the HR feature space.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_out_ch (int): Channel number of outputs. Default: 3.
num_feat (int): Channel number of intermediate features. Default: 64.
num_conv (int): Number of convolution layers in the body network. Default: 16.
upscale (int): Upsampling factor. Default: 4.
act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
"""
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
super(SRVGGNetCompact, self).__init__()
self.num_in_ch = num_in_ch
self.num_out_ch = num_out_ch
self.num_feat = num_feat
self.num_conv = num_conv
self.upscale = upscale
self.act_type = act_type
self.body = nn.ModuleList()
# the first conv
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
# the first activation
if act_type == 'relu':
activation = nn.ReLU(inplace=True)
elif act_type == 'prelu':
activation = nn.PReLU(num_parameters=num_feat)
elif act_type == 'leakyrelu':
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.body.append(activation)
# the body structure
for _ in range(num_conv):
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
# activation
if act_type == 'relu':
activation = nn.ReLU(inplace=True)
elif act_type == 'prelu':
activation = nn.PReLU(num_parameters=num_feat)
elif act_type == 'leakyrelu':
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.body.append(activation)
# the last conv
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
# upsample
self.upsampler = nn.PixelShuffle(upscale)
def forward(self, x):
out = x
for i in range(0, len(self.body)):
out = self.body[i](out)
out = self.upsampler(out)
# add the nearest upsampled image, so that the network learns the residual
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
out += base
return out
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import flow_warp
class BasicModule(nn.Module):
"""Basic module of SPyNet.
Note that unlike the architecture in spynet_arch.py, the basic module
here contains batch normalization.
"""
def __init__(self):
super(BasicModule, self).__init__()
self.basic_module = nn.Sequential(
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3, bias=False),
nn.BatchNorm2d(64), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3, bias=False),
nn.BatchNorm2d(16), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
def forward(self, tensor_input):
"""
Args:
tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
8 channels contain:
[reference image (3), neighbor image (3), initial flow (2)].
Returns:
Tensor: Estimated flow with shape (b, 2, h, w)
"""
return self.basic_module(tensor_input)
class SPyNetTOF(nn.Module):
"""SPyNet architecture for TOF.
Note that this implementation is specifically for TOFlow. Please use :file:`spynet_arch.py` for general use.
They differ in the following aspects:
1. The basic modules here contain BatchNorm.
2. Normalization and denormalization are not done here, as they are done in TOFlow.
``Paper: Optical Flow Estimation using a Spatial Pyramid Network``
Reference: https://github.com/Coldog2333/pytoflow
Args:
load_path (str): Path for pretrained SPyNet. Default: None.
"""
def __init__(self, load_path=None):
super(SPyNetTOF, self).__init__()
self.basic_module = nn.ModuleList([BasicModule() for _ in range(4)])
if load_path:
self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
def forward(self, ref, supp):
"""
Args:
ref (Tensor): Reference image with shape of (b, 3, h, w).
supp: The supporting image to be warped: (b, 3, h, w).
Returns:
Tensor: Estimated optical flow: (b, 2, h, w).
"""
num_batches, _, h, w = ref.size()
ref = [ref]
supp = [supp]
# generate downsampled frames
for _ in range(3):
ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
# flow computation
flow = ref[0].new_zeros(num_batches, 2, h // 16, w // 16)
for i in range(4):
flow_up = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
flow = flow_up + self.basic_module[i](
torch.cat([ref[i], flow_warp(supp[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1))
return flow
@ARCH_REGISTRY.register()
class TOFlow(nn.Module):
"""PyTorch implementation of TOFlow.
In TOFlow, the LR frames are pre-upsampled and have the same size with the GT frames.
``Paper: Video Enhancement with Task-Oriented Flow``
Reference: https://github.com/anchen1011/toflow
Reference: https://github.com/Coldog2333/pytoflow
Args:
adapt_official_weights (bool): Whether to adapt the weights translated
from the official implementation. Set to false if you want to
train from scratch. Default: False
"""
def __init__(self, adapt_official_weights=False):
super(TOFlow, self).__init__()
self.adapt_official_weights = adapt_official_weights
self.ref_idx = 0 if adapt_official_weights else 3
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
# flow estimation module
self.spynet = SPyNetTOF()
# reconstruction module
self.conv_1 = nn.Conv2d(3 * 7, 64, 9, 1, 4)
self.conv_2 = nn.Conv2d(64, 64, 9, 1, 4)
self.conv_3 = nn.Conv2d(64, 64, 1)
self.conv_4 = nn.Conv2d(64, 3, 1)
# activation function
self.relu = nn.ReLU(inplace=True)
def normalize(self, img):
return (img - self.mean) / self.std
def denormalize(self, img):
return img * self.std + self.mean
def forward(self, lrs):
"""
Args:
lrs: Input lr frames: (b, 7, 3, h, w).
Returns:
Tensor: SR frame: (b, 3, h, w).
"""
# In the official implementation, the 0-th frame is the reference frame
if self.adapt_official_weights:
lrs = lrs[:, [3, 0, 1, 2, 4, 5, 6], :, :, :]
num_batches, num_lrs, _, h, w = lrs.size()
lrs = self.normalize(lrs.view(-1, 3, h, w))
lrs = lrs.view(num_batches, num_lrs, 3, h, w)
lr_ref = lrs[:, self.ref_idx, :, :, :]
lr_aligned = []
for i in range(7): # 7 frames
if i == self.ref_idx:
lr_aligned.append(lr_ref)
else:
lr_supp = lrs[:, i, :, :, :]
flow = self.spynet(lr_ref, lr_supp)
lr_aligned.append(flow_warp(lr_supp, flow.permute(0, 2, 3, 1)))
# reconstruction
hr = torch.stack(lr_aligned, dim=1)
hr = hr.view(num_batches, -1, h, w)
hr = self.relu(self.conv_1(hr))
hr = self.relu(self.conv_2(hr))
hr = self.relu(self.conv_3(hr))
hr = self.conv_4(hr) + lr_ref
return self.denormalize(hr)
import os
import torch
from collections import OrderedDict
from torch import nn as nn
from torchvision.models import vgg as vgg
from basicsr.utils.registry import ARCH_REGISTRY
VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
NAMES = {
'vgg11': [
'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
'pool5'
],
'vgg13': [
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
],
'vgg16': [
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
'pool5'
],
'vgg19': [
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
]
}
def insert_bn(names):
"""Insert bn layer after each conv.
Args:
names (list): The list of layer names.
Returns:
list: The list of layer names with bn layers.
"""
names_bn = []
for name in names:
names_bn.append(name)
if 'conv' in name:
position = name.replace('conv', '')
names_bn.append('bn' + position)
return names_bn
@ARCH_REGISTRY.register()
class VGGFeatureExtractor(nn.Module):
"""VGG network for feature extraction.
In this implementation, we allow users to choose whether use normalization
in the input feature and the type of vgg network. Note that the pretrained
path must fit the vgg type.
Args:
layer_name_list (list[str]): Forward function returns the corresponding
features according to the layer_name_list.
Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image. Importantly,
the input feature must in the range [0, 1]. Default: True.
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
Default: False.
requires_grad (bool): If true, the parameters of VGG network will be
optimized. Default: False.
remove_pooling (bool): If true, the max pooling operations in VGG net
will be removed. Default: False.
pooling_stride (int): The stride of max pooling operation. Default: 2.
"""
def __init__(self,
layer_name_list,
vgg_type='vgg19',
use_input_norm=True,
range_norm=False,
requires_grad=False,
remove_pooling=False,
pooling_stride=2):
super(VGGFeatureExtractor, self).__init__()
self.layer_name_list = layer_name_list
self.use_input_norm = use_input_norm
self.range_norm = range_norm
self.names = NAMES[vgg_type.replace('_bn', '')]
if 'bn' in vgg_type:
self.names = insert_bn(self.names)
# only borrow layers that will be used to avoid unused params
max_idx = 0
for v in layer_name_list:
idx = self.names.index(v)
if idx > max_idx:
max_idx = idx
if os.path.exists(VGG_PRETRAIN_PATH):
vgg_net = getattr(vgg, vgg_type)(pretrained=False)
state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
vgg_net.load_state_dict(state_dict)
else:
vgg_net = getattr(vgg, vgg_type)(pretrained=True)
features = vgg_net.features[:max_idx + 1]
modified_net = OrderedDict()
for k, v in zip(self.names, features):
if 'pool' in k:
# if remove_pooling is true, pooling operation will be removed
if remove_pooling:
continue
else:
# in some cases, we may want to change the default stride
modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
else:
modified_net[k] = v
self.vgg_net = nn.Sequential(modified_net)
if not requires_grad:
self.vgg_net.eval()
for param in self.parameters():
param.requires_grad = False
else:
self.vgg_net.train()
for param in self.parameters():
param.requires_grad = True
if self.use_input_norm:
# the mean is for image with range [0, 1]
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
# the std is for image with range [0, 1]
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.range_norm:
x = (x + 1) / 2
if self.use_input_norm:
x = (x - self.mean) / self.std
output = {}
for key, layer in self.vgg_net._modules.items():
x = layer(x)
if key in self.layer_name_list:
output[key] = x.clone()
return output
import importlib
import numpy as np
import random
import torch
import torch.utils.data
from copy import deepcopy
from functools import partial
from os import path as osp
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.dist_util import get_dist_info
from basicsr.utils.registry import DATASET_REGISTRY
__all__ = ['build_dataset', 'build_dataloader']
# automatically scan and import dataset modules for registry
# scan all the files under the data folder with '_dataset' in file names
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
# import all the dataset modules
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
def build_dataset(dataset_opt):
"""Build dataset from options.
Args:
dataset_opt (dict): Configuration for dataset. It must contain:
name (str): Dataset name.
type (str): Dataset type.
"""
dataset_opt = deepcopy(dataset_opt)
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
logger = get_root_logger()
logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
return dataset
def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
"""Build dataloader.
Args:
dataset (torch.utils.data.Dataset): Dataset.
dataset_opt (dict): Dataset options. It contains the following keys:
phase (str): 'train' or 'val'.
num_worker_per_gpu (int): Number of workers for each GPU.
batch_size_per_gpu (int): Training batch size for each GPU.
num_gpu (int): Number of GPUs. Used only in the train phase.
Default: 1.
dist (bool): Whether in distributed training. Used only in the train
phase. Default: False.
sampler (torch.utils.data.sampler): Data sampler. Default: None.
seed (int | None): Seed. Default: None
"""
phase = dataset_opt['phase']
rank, _ = get_dist_info()
if phase == 'train':
if dist: # distributed training
batch_size = dataset_opt['batch_size_per_gpu']
num_workers = dataset_opt['num_worker_per_gpu']
else: # non-distributed training
multiplier = 1 if num_gpu == 0 else num_gpu
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
dataloader_args = dict(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
sampler=sampler,
drop_last=True)
if sampler is None:
dataloader_args['shuffle'] = True
dataloader_args['worker_init_fn'] = partial(
worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
elif phase in ['val', 'test']: # validation
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
else:
raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
prefetch_mode = dataset_opt.get('prefetch_mode')
if prefetch_mode == 'cpu': # CPUPrefetcher
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
logger = get_root_logger()
logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
else:
# prefetch_mode=None: Normal dataloader
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
return torch.utils.data.DataLoader(**dataloader_args)
def worker_init_fn(worker_id, num_workers, rank, seed):
# Set the worker seed to num_workers * rank + worker_id + seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment