# Sourced from https://github.com/myungsub/CAIN/blob/master/loss.py, who sourced from https://github.com/thstkdgus35/EDSR-PyTorch/tree/master/src/loss # Added Huber loss in addition. import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models import pytorch_msssim class MeanShift(nn.Conv2d): def __init__(self, rgb_mean, rgb_std, sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) self.weight.data.div_(std.view(3, 1, 1, 1)) self.bias.data = sign * torch.Tensor(rgb_mean) self.bias.data.div_(std) self.requires_grad = False class HuberLoss(nn.Module): def __init__(self , delta=1): super().__init__() self.delta = delta def forward(self , sr , hr): l1 = torch.abs(sr - hr) mask = l1= 0: loss_d = (d_fake - d_real).mean() if self.gan_type.find('GP') >= 0: epsilon = torch.rand_like(fake).view(-1, 1, 1, 1) hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) hat.requires_grad = True d_hat = self.discriminator(hat) gradients = torch.autograd.grad( outputs=d_hat.sum(), inputs=hat, retain_graph=True, create_graph=True, only_inputs=True )[0] gradients = gradients.view(gradients.size(0), -1) gradient_norm = gradients.norm(2, dim=1) gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() loss_d += gradient_penalty # Discriminator update self.loss += loss_d.item() if self.training: loss_d.backward() self.optimizer.step() if self.gan_type == 'WGAN': for p in self.discriminator.parameters(): p.data.clamp_(-1, 1) self.loss /= self.gan_k d_fake_for_g = self.discriminator(fake) if self.gan_type == 'GAN': loss_g = F.binary_cross_entropy_with_logits( d_fake_for_g, label_real ) elif self.gan_type.find('WGAN') >= 0: loss_g = -d_fake_for_g.mean() # Generator loss return loss_g def state_dict(self, *args, **kwargs): state_discriminator = self.discriminator.state_dict(*args, **kwargs) state_optimizer = self.optimizer.state_dict() return dict(**state_discriminator, **state_optimizer) # Some references # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py # OR # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py # Wrapper of loss functions class Loss(nn.modules.loss._Loss): def __init__(self, args): super(Loss, self).__init__() print('Preparing loss function:') self.loss = [] self.loss_module = nn.ModuleList() for loss in args.loss.split('+'): weight, loss_type = loss.split('*') if loss_type == 'MSE': loss_function = nn.MSELoss() elif loss_type == 'Huber': loss_function = HuberLoss(delta=.5) elif loss_type == 'L1': loss_function = nn.L1Loss() elif loss_type.find('VGG') >= 0: loss_function = VGG(loss_type[3:]) elif loss_type == 'SSIM': loss_function = pytorch_msssim.SSIM(val_range=1.) elif loss_type.find('GAN') >= 0: loss_function = Adversarial(args, loss_type) self.loss.append({ 'type': loss_type, 'weight': float(weight), 'function': loss_function} ) if loss_type.find('GAN') >= 0 >= 0: self.loss.append({'type': 'DIS', 'weight': 1, 'function': None}) if len(self.loss) > 1: self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) for l in self.loss: if l['function'] is not None: print('{:.3f} * {}'.format(l['weight'], l['type'])) self.loss_module.append(l['function']) device = torch.device('cuda' if args.cuda else 'cpu') self.loss_module.to(device) #if args.precision == 'half': self.loss_module.half() if args.cuda:# and args.n_GPUs > 1: self.loss_module = nn.DataParallel(self.loss_module) def forward(self, sr, hr, fake_imgs=None): loss = 0 losses = {} for i, l in enumerate(self.loss): if l['function'] is not None: if l['type'] == 'GAN': if fake_imgs is None: fake_imgs = [None, None, None] _loss = l['function'](sr, hr, fake_imgs[0], fake_imgs[1], fake_imgs[2]) else: _loss = l['function'](sr, hr) effective_loss = l['weight'] * _loss losses[l['type']] = effective_loss loss += effective_loss elif l['type'] == 'DIS': losses[l['type']] = self.loss[i - 1]['function'].loss return loss, losses