Commit 76b9024b authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3145 failed with stages
in 0 seconds
import os
import time
import torch
from collections import OrderedDict
from copy import deepcopy
from torch.nn.parallel import DataParallel, DistributedDataParallel
from basicsr.models import lr_scheduler as lr_scheduler
from basicsr.utils import get_root_logger
from basicsr.utils.dist_util import master_only
class BaseModel():
"""Base model."""
def __init__(self, opt):
self.opt = opt
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
self.is_train = opt['is_train']
self.schedulers = []
self.optimizers = []
def feed_data(self, data):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
pass
def save(self, epoch, current_iter):
"""Save networks and training state."""
pass
def validation(self, dataloader, current_iter, tb_logger, save_img=False):
"""Validation function.
Args:
dataloader (torch.utils.data.DataLoader): Validation dataloader.
current_iter (int): Current iteration.
tb_logger (tensorboard logger): Tensorboard logger.
save_img (bool): Whether to save images. Default: False.
"""
if self.opt['dist']:
self.dist_validation(dataloader, current_iter, tb_logger, save_img)
else:
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
def _initialize_best_metric_results(self, dataset_name):
"""Initialize the best metric results dict for recording the best metric value and iteration."""
if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results:
return
elif not hasattr(self, 'best_metric_results'):
self.best_metric_results = dict()
# add a dataset record
record = dict()
for metric, content in self.opt['val']['metrics'].items():
better = content.get('better', 'higher')
init_val = float('-inf') if better == 'higher' else float('inf')
record[metric] = dict(better=better, val=init_val, iter=-1)
self.best_metric_results[dataset_name] = record
def _update_best_metric_result(self, dataset_name, metric, val, current_iter):
if self.best_metric_results[dataset_name][metric]['better'] == 'higher':
if val >= self.best_metric_results[dataset_name][metric]['val']:
self.best_metric_results[dataset_name][metric]['val'] = val
self.best_metric_results[dataset_name][metric]['iter'] = current_iter
else:
if val <= self.best_metric_results[dataset_name][metric]['val']:
self.best_metric_results[dataset_name][metric]['val'] = val
self.best_metric_results[dataset_name][metric]['iter'] = current_iter
def model_ema(self, decay=0.999):
net_g = self.get_bare_model(self.net_g)
net_g_params = dict(net_g.named_parameters())
net_g_ema_params = dict(self.net_g_ema.named_parameters())
for k in net_g_ema_params.keys():
net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay)
def get_current_log(self):
return self.log_dict
def model_to_device(self, net):
"""Model to device. It also warps models with DistributedDataParallel
or DataParallel.
Args:
net (nn.Module)
"""
net = net.to(self.device)
if self.opt['dist']:
find_unused_parameters = self.opt.get('find_unused_parameters', False)
net = DistributedDataParallel(
net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
elif self.opt['num_gpu'] > 1:
net = DataParallel(net)
return net
def get_optimizer(self, optim_type, params, lr, **kwargs):
if optim_type == 'Adam':
optimizer = torch.optim.Adam(params, lr, **kwargs)
elif optim_type == 'AdamW':
optimizer = torch.optim.AdamW(params, lr, **kwargs)
elif optim_type == 'Adamax':
optimizer = torch.optim.Adamax(params, lr, **kwargs)
elif optim_type == 'SGD':
optimizer = torch.optim.SGD(params, lr, **kwargs)
elif optim_type == 'ASGD':
optimizer = torch.optim.ASGD(params, lr, **kwargs)
elif optim_type == 'RMSprop':
optimizer = torch.optim.RMSprop(params, lr, **kwargs)
elif optim_type == 'Rprop':
optimizer = torch.optim.Rprop(params, lr, **kwargs)
else:
raise NotImplementedError(f'optimizer {optim_type} is not supported yet.')
return optimizer
def setup_schedulers(self):
"""Set up schedulers."""
train_opt = self.opt['train']
scheduler_type = train_opt['scheduler'].pop('type')
if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
for optimizer in self.optimizers:
self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler']))
elif scheduler_type == 'CosineAnnealingRestartLR':
for optimizer in self.optimizers:
self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler']))
else:
raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.')
def get_bare_model(self, net):
"""Get bare model, especially under wrapping with
DistributedDataParallel or DataParallel.
"""
if isinstance(net, (DataParallel, DistributedDataParallel)):
net = net.module
return net
@master_only
def print_network(self, net):
"""Print the str and parameter number of a network.
Args:
net (nn.Module)
"""
if isinstance(net, (DataParallel, DistributedDataParallel)):
net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}'
else:
net_cls_str = f'{net.__class__.__name__}'
net = self.get_bare_model(net)
net_str = str(net)
net_params = sum(map(lambda x: x.numel(), net.parameters()))
logger = get_root_logger()
logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}')
logger.info(net_str)
def _set_lr(self, lr_groups_l):
"""Set learning rate for warm-up.
Args:
lr_groups_l (list): List for lr_groups, each for an optimizer.
"""
for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
for param_group, lr in zip(optimizer.param_groups, lr_groups):
param_group['lr'] = lr
def _get_init_lr(self):
"""Get the initial lr, which is set by the scheduler.
"""
init_lr_groups_l = []
for optimizer in self.optimizers:
init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
return init_lr_groups_l
def update_learning_rate(self, current_iter, warmup_iter=-1):
"""Update learning rate.
Args:
current_iter (int): Current iteration.
warmup_iter (int): Warm-up iter numbers. -1 for no warm-up.
Default: -1.
"""
if current_iter > 1:
for scheduler in self.schedulers:
scheduler.step()
# set up warm-up learning rate
if current_iter < warmup_iter:
# get initial lr for each group
init_lr_g_l = self._get_init_lr()
# modify warming-up learning rates
# currently only support linearly warm up
warm_up_lr_l = []
for init_lr_g in init_lr_g_l:
warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g])
# set learning rate
self._set_lr(warm_up_lr_l)
def get_current_learning_rate(self):
return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
@master_only
def save_network(self, net, net_label, current_iter, param_key='params'):
"""Save networks.
Args:
net (nn.Module | list[nn.Module]): Network(s) to be saved.
net_label (str): Network label.
current_iter (int): Current iter number.
param_key (str | list[str]): The parameter key(s) to save network.
Default: 'params'.
"""
if current_iter == -1:
current_iter = 'latest'
save_filename = f'{net_label}_{current_iter}.pth'
save_path = os.path.join(self.opt['path']['models'], save_filename)
net = net if isinstance(net, list) else [net]
param_key = param_key if isinstance(param_key, list) else [param_key]
assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.'
save_dict = {}
for net_, param_key_ in zip(net, param_key):
net_ = self.get_bare_model(net_)
state_dict = net_.state_dict()
for key, param in state_dict.items():
if key.startswith('module.'): # remove unnecessary 'module.'
key = key[7:]
state_dict[key] = param.cpu()
save_dict[param_key_] = state_dict
# avoid occasional writing errors
retry = 3
while retry > 0:
try:
torch.save(save_dict, save_path)
except Exception as e:
logger = get_root_logger()
logger.warning(f'Save model error: {e}, remaining retry times: {retry - 1}')
time.sleep(1)
else:
break
finally:
retry -= 1
if retry == 0:
logger.warning(f'Still cannot save {save_path}. Just ignore it.')
# raise IOError(f'Cannot save {save_path}.')
def _print_different_keys_loading(self, crt_net, load_net, strict=True):
"""Print keys with different name or different size when loading models.
1. Print keys with different names.
2. If strict=False, print the same key but with different tensor size.
It also ignore these keys with different sizes (not load).
Args:
crt_net (torch model): Current network.
load_net (dict): Loaded network.
strict (bool): Whether strictly loaded. Default: True.
"""
crt_net = self.get_bare_model(crt_net)
crt_net = crt_net.state_dict()
crt_net_keys = set(crt_net.keys())
load_net_keys = set(load_net.keys())
logger = get_root_logger()
if crt_net_keys != load_net_keys:
logger.warning('Current net - loaded net:')
for v in sorted(list(crt_net_keys - load_net_keys)):
logger.warning(f' {v}')
logger.warning('Loaded net - current net:')
for v in sorted(list(load_net_keys - crt_net_keys)):
logger.warning(f' {v}')
# check the size for the same keys
if not strict:
common_keys = crt_net_keys & load_net_keys
for k in common_keys:
if crt_net[k].size() != load_net[k].size():
logger.warning(f'Size different, ignore [{k}]: crt_net: '
f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
load_net[k + '.ignore'] = load_net.pop(k)
def load_network(self, net, load_path, strict=True, param_key='params'):
"""Load network.
Args:
load_path (str): The path of networks to be loaded.
net (nn.Module): Network.
strict (bool): Whether strictly loaded.
param_key (str): The parameter key of loaded network. If set to
None, use the root 'path'.
Default: 'params'.
"""
logger = get_root_logger()
net = self.get_bare_model(net)
load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
if param_key is not None:
if param_key not in load_net and 'params' in load_net:
param_key = 'params'
logger.info('Loading: params_ema does not exist, use params.')
load_net = load_net[param_key]
logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].')
# remove unnecessary 'module.'
for k, v in deepcopy(load_net).items():
if k.startswith('module.'):
load_net[k[7:]] = v
load_net.pop(k)
self._print_different_keys_loading(net, load_net, strict)
net.load_state_dict(load_net, strict=strict)
@master_only
def save_training_state(self, epoch, current_iter):
"""Save training states during training, which will be used for
resuming.
Args:
epoch (int): Current epoch.
current_iter (int): Current iteration.
"""
if current_iter != -1:
state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []}
for o in self.optimizers:
state['optimizers'].append(o.state_dict())
for s in self.schedulers:
state['schedulers'].append(s.state_dict())
save_filename = f'{current_iter}.state'
save_path = os.path.join(self.opt['path']['training_states'], save_filename)
# avoid occasional writing errors
retry = 3
while retry > 0:
try:
torch.save(state, save_path)
except Exception as e:
logger = get_root_logger()
logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}')
time.sleep(1)
else:
break
finally:
retry -= 1
if retry == 0:
logger.warning(f'Still cannot save {save_path}. Just ignore it.')
# raise IOError(f'Cannot save {save_path}.')
def resume_training(self, resume_state):
"""Reload the optimizers and schedulers for resumed training.
Args:
resume_state (dict): Resume state.
"""
resume_optimizers = resume_state['optimizers']
resume_schedulers = resume_state['schedulers']
assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
for i, o in enumerate(resume_optimizers):
self.optimizers[i].load_state_dict(o)
for i, s in enumerate(resume_schedulers):
self.schedulers[i].load_state_dict(s)
def reduce_loss_dict(self, loss_dict):
"""reduce loss dict.
In distributed training, it averages the losses among different GPUs .
Args:
loss_dict (OrderedDict): Loss dict.
"""
with torch.no_grad():
if self.opt['dist']:
keys = []
losses = []
for name, value in loss_dict.items():
keys.append(name)
losses.append(value)
losses = torch.stack(losses, 0)
torch.distributed.reduce(losses, dst=0)
if self.opt['rank'] == 0:
losses /= self.opt['world_size']
loss_dict = {key: loss for key, loss in zip(keys, losses)}
log_dict = OrderedDict()
for name, value in loss_dict.items():
log_dict[name] = value.mean().item()
return log_dict
from basicsr.utils import get_root_logger
from basicsr.utils.registry import MODEL_REGISTRY
from .video_base_model import VideoBaseModel
@MODEL_REGISTRY.register()
class EDVRModel(VideoBaseModel):
"""EDVR Model.
Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. # noqa: E501
"""
def __init__(self, opt):
super(EDVRModel, self).__init__(opt)
if self.is_train:
self.train_tsa_iter = opt['train'].get('tsa_iter')
def setup_optimizers(self):
train_opt = self.opt['train']
dcn_lr_mul = train_opt.get('dcn_lr_mul', 1)
logger = get_root_logger()
logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.')
if dcn_lr_mul == 1:
optim_params = self.net_g.parameters()
else: # separate dcn params and normal params for different lr
normal_params = []
dcn_params = []
for name, param in self.net_g.named_parameters():
if 'dcn' in name:
dcn_params.append(param)
else:
normal_params.append(param)
optim_params = [
{ # add normal params first
'params': normal_params,
'lr': train_opt['optim_g']['lr']
},
{
'params': dcn_params,
'lr': train_opt['optim_g']['lr'] * dcn_lr_mul
},
]
optim_type = train_opt['optim_g'].pop('type')
self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
self.optimizers.append(self.optimizer_g)
def optimize_parameters(self, current_iter):
if self.train_tsa_iter:
if current_iter == 1:
logger = get_root_logger()
logger.info(f'Only train TSA module for {self.train_tsa_iter} iters.')
for name, param in self.net_g.named_parameters():
if 'fusion' not in name:
param.requires_grad = False
elif current_iter == self.train_tsa_iter:
logger = get_root_logger()
logger.warning('Train all the parameters.')
for param in self.net_g.parameters():
param.requires_grad = True
super(EDVRModel, self).optimize_parameters(current_iter)
import torch
from collections import OrderedDict
from basicsr.utils.registry import MODEL_REGISTRY
from .srgan_model import SRGANModel
@MODEL_REGISTRY.register()
class ESRGANModel(SRGANModel):
"""ESRGAN model for single image super-resolution."""
def optimize_parameters(self, current_iter):
# optimize net_g
for p in self.net_d.parameters():
p.requires_grad = False
self.optimizer_g.zero_grad()
self.output = self.net_g(self.lq)
l_g_total = 0
loss_dict = OrderedDict()
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
# pixel loss
if self.cri_pix:
l_g_pix = self.cri_pix(self.output, self.gt)
l_g_total += l_g_pix
loss_dict['l_g_pix'] = l_g_pix
# perceptual loss
if self.cri_perceptual:
l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
if l_g_percep is not None:
l_g_total += l_g_percep
loss_dict['l_g_percep'] = l_g_percep
if l_g_style is not None:
l_g_total += l_g_style
loss_dict['l_g_style'] = l_g_style
# gan loss (relativistic gan)
real_d_pred = self.net_d(self.gt).detach()
fake_g_pred = self.net_d(self.output)
l_g_real = self.cri_gan(real_d_pred - torch.mean(fake_g_pred), False, is_disc=False)
l_g_fake = self.cri_gan(fake_g_pred - torch.mean(real_d_pred), True, is_disc=False)
l_g_gan = (l_g_real + l_g_fake) / 2
l_g_total += l_g_gan
loss_dict['l_g_gan'] = l_g_gan
l_g_total.backward()
self.optimizer_g.step()
# optimize net_d
for p in self.net_d.parameters():
p.requires_grad = True
self.optimizer_d.zero_grad()
# gan loss (relativistic gan)
# In order to avoid the error in distributed training:
# "Error detected in CudnnBatchNormBackward: RuntimeError: one of
# the variables needed for gradient computation has been modified by
# an inplace operation",
# we separate the backwards for real and fake, and also detach the
# tensor for calculating mean.
# real
fake_d_pred = self.net_d(self.output).detach()
real_d_pred = self.net_d(self.gt)
l_d_real = self.cri_gan(real_d_pred - torch.mean(fake_d_pred), True, is_disc=True) * 0.5
l_d_real.backward()
# fake
fake_d_pred = self.net_d(self.output.detach())
l_d_fake = self.cri_gan(fake_d_pred - torch.mean(real_d_pred.detach()), False, is_disc=True) * 0.5
l_d_fake.backward()
self.optimizer_d.step()
loss_dict['l_d_real'] = l_d_real
loss_dict['l_d_fake'] = l_d_fake
loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
self.log_dict = self.reduce_loss_dict(loss_dict)
if self.ema_decay > 0:
self.model_ema(decay=self.ema_decay)
import torch
from collections import OrderedDict
from os import path as osp
from tqdm import tqdm
from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.metrics import calculate_metric
from basicsr.utils import imwrite, tensor2img
from basicsr.utils.registry import MODEL_REGISTRY
from .sr_model import SRModel
@MODEL_REGISTRY.register()
class HiFaceGANModel(SRModel):
"""HiFaceGAN model for generic-purpose face restoration.
No prior modeling required, works for any degradations.
Currently doesn't support EMA for inference.
"""
def init_training_settings(self):
train_opt = self.opt['train']
self.ema_decay = train_opt.get('ema_decay', 0)
if self.ema_decay > 0:
raise (NotImplementedError('HiFaceGAN does not support EMA now. Pass'))
self.net_g.train()
self.net_d = build_network(self.opt['network_d'])
self.net_d = self.model_to_device(self.net_d)
self.print_network(self.net_d)
# define losses
# HiFaceGAN does not use pixel loss by default
if train_opt.get('pixel_opt'):
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
else:
self.cri_pix = None
if train_opt.get('perceptual_opt'):
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
else:
self.cri_perceptual = None
if train_opt.get('feature_matching_opt'):
self.cri_feat = build_loss(train_opt['feature_matching_opt']).to(self.device)
else:
self.cri_feat = None
if self.cri_pix is None and self.cri_perceptual is None:
raise ValueError('Both pixel and perceptual losses are None.')
if train_opt.get('gan_opt'):
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
self.net_d_iters = train_opt.get('net_d_iters', 1)
self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
# set up optimizers and schedulers
self.setup_optimizers()
self.setup_schedulers()
def setup_optimizers(self):
train_opt = self.opt['train']
# optimizer g
optim_type = train_opt['optim_g'].pop('type')
self.optimizer_g = self.get_optimizer(optim_type, self.net_g.parameters(), **train_opt['optim_g'])
self.optimizers.append(self.optimizer_g)
# optimizer d
optim_type = train_opt['optim_d'].pop('type')
self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
self.optimizers.append(self.optimizer_d)
def discriminate(self, input_lq, output, ground_truth):
"""
This is a conditional (on the input) discriminator
In Batch Normalization, the fake and real images are
recommended to be in the same batch to avoid disparate
statistics in fake and real images.
So both fake and real images are fed to D all at once.
"""
h, w = output.shape[-2:]
if output.shape[-2:] != input_lq.shape[-2:]:
lq = torch.nn.functional.interpolate(input_lq, (h, w))
real = torch.nn.functional.interpolate(ground_truth, (h, w))
fake_concat = torch.cat([lq, output], dim=1)
real_concat = torch.cat([lq, real], dim=1)
else:
fake_concat = torch.cat([input_lq, output], dim=1)
real_concat = torch.cat([input_lq, ground_truth], dim=1)
fake_and_real = torch.cat([fake_concat, real_concat], dim=0)
discriminator_out = self.net_d(fake_and_real)
pred_fake, pred_real = self._divide_pred(discriminator_out)
return pred_fake, pred_real
@staticmethod
def _divide_pred(pred):
"""
Take the prediction of fake and real images from the combined batch.
The prediction contains the intermediate outputs of multiscale GAN,
so it's usually a list
"""
if type(pred) == list:
fake = []
real = []
for p in pred:
fake.append([tensor[:tensor.size(0) // 2] for tensor in p])
real.append([tensor[tensor.size(0) // 2:] for tensor in p])
else:
fake = pred[:pred.size(0) // 2]
real = pred[pred.size(0) // 2:]
return fake, real
def optimize_parameters(self, current_iter):
# optimize net_g
for p in self.net_d.parameters():
p.requires_grad = False
self.optimizer_g.zero_grad()
self.output = self.net_g(self.lq)
l_g_total = 0
loss_dict = OrderedDict()
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
# pixel loss
if self.cri_pix:
l_g_pix = self.cri_pix(self.output, self.gt)
l_g_total += l_g_pix
loss_dict['l_g_pix'] = l_g_pix
# perceptual loss
if self.cri_perceptual:
l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
if l_g_percep is not None:
l_g_total += l_g_percep
loss_dict['l_g_percep'] = l_g_percep
if l_g_style is not None:
l_g_total += l_g_style
loss_dict['l_g_style'] = l_g_style
# Requires real prediction for feature matching loss
pred_fake, pred_real = self.discriminate(self.lq, self.output, self.gt)
l_g_gan = self.cri_gan(pred_fake, True, is_disc=False)
l_g_total += l_g_gan
loss_dict['l_g_gan'] = l_g_gan
# feature matching loss
if self.cri_feat:
l_g_feat = self.cri_feat(pred_fake, pred_real)
l_g_total += l_g_feat
loss_dict['l_g_feat'] = l_g_feat
l_g_total.backward()
self.optimizer_g.step()
# optimize net_d
for p in self.net_d.parameters():
p.requires_grad = True
self.optimizer_d.zero_grad()
# TODO: Benchmark test between HiFaceGAN and SRGAN implementation:
# SRGAN use the same fake output for discriminator update
# while HiFaceGAN regenerate a new output using updated net_g
# This should not make too much difference though. Stick to SRGAN now.
# -------------------------------------------------------------------
# ---------- Below are original HiFaceGAN code snippet --------------
# -------------------------------------------------------------------
# with torch.no_grad():
# fake_image = self.net_g(self.lq)
# fake_image = fake_image.detach()
# fake_image.requires_grad_()
# pred_fake, pred_real = self.discriminate(self.lq, fake_image, self.gt)
# real
pred_fake, pred_real = self.discriminate(self.lq, self.output.detach(), self.gt)
l_d_real = self.cri_gan(pred_real, True, is_disc=True)
loss_dict['l_d_real'] = l_d_real
# fake
l_d_fake = self.cri_gan(pred_fake, False, is_disc=True)
loss_dict['l_d_fake'] = l_d_fake
l_d_total = (l_d_real + l_d_fake) / 2
l_d_total.backward()
self.optimizer_d.step()
self.log_dict = self.reduce_loss_dict(loss_dict)
if self.ema_decay > 0:
print('HiFaceGAN does not support EMA now. pass')
def validation(self, dataloader, current_iter, tb_logger, save_img=False):
"""
Warning: HiFaceGAN requires train() mode even for validation
For more info, see https://github.com/Lotayou/Face-Renovation/issues/31
Args:
dataloader (torch.utils.data.DataLoader): Validation dataloader.
current_iter (int): Current iteration.
tb_logger (tensorboard logger): Tensorboard logger.
save_img (bool): Whether to save images. Default: False.
"""
if self.opt['network_g']['type'] in ('HiFaceGAN', 'SPADEGenerator'):
self.net_g.train()
if self.opt['dist']:
self.dist_validation(dataloader, current_iter, tb_logger, save_img)
else:
print('In HiFaceGANModel: The new metrics package is under development.' +
'Using super method now (Only PSNR & SSIM are supported)')
super().nondist_validation(dataloader, current_iter, tb_logger, save_img)
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
"""
TODO: Validation using updated metric system
The metrics are now evaluated after all images have been tested
This allows batch processing, and also allows evaluation of
distributional metrics, such as:
@ Frechet Inception Distance: FID
@ Maximum Mean Discrepancy: MMD
Warning:
Need careful batch management for different inference settings.
"""
dataset_name = dataloader.dataset.opt['name']
with_metrics = self.opt['val'].get('metrics') is not None
if with_metrics:
self.metric_results = dict() # {metric: 0 for metric in self.opt['val']['metrics'].keys()}
sr_tensors = []
gt_tensors = []
pbar = tqdm(total=len(dataloader), unit='image')
for val_data in dataloader:
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
self.feed_data(val_data)
self.test()
visuals = self.get_current_visuals() # detached cpu tensor, non-squeeze
sr_tensors.append(visuals['result'])
if 'gt' in visuals:
gt_tensors.append(visuals['gt'])
del self.gt
# tentative for out of GPU memory
del self.lq
del self.output
torch.cuda.empty_cache()
if save_img:
if self.opt['is_train']:
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
f'{img_name}_{current_iter}.png')
else:
if self.opt['val']['suffix']:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["val"]["suffix"]}.png')
else:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["name"]}.png')
imwrite(tensor2img(visuals['result']), save_img_path)
pbar.update(1)
pbar.set_description(f'Test {img_name}')
pbar.close()
if with_metrics:
sr_pack = torch.cat(sr_tensors, dim=0)
gt_pack = torch.cat(gt_tensors, dim=0)
# calculate metrics
for name, opt_ in self.opt['val']['metrics'].items():
# The new metric caller automatically returns mean value
# FIXME: ERROR: calculate_metric only supports two arguments. Now the codes cannot be successfully run
self.metric_results[name] = calculate_metric(dict(sr_pack=sr_pack, gt_pack=gt_pack), opt_)
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
def save(self, epoch, current_iter):
if hasattr(self, 'net_g_ema'):
print('HiFaceGAN does not support EMA now. Fallback to normal mode.')
self.save_network(self.net_g, 'net_g', current_iter)
self.save_network(self.net_d, 'net_d', current_iter)
self.save_training_state(epoch, current_iter)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment