#!/usr/bin/python # -*- coding: UTF-8 -*- # Created by: algohunt # Microsoft Research & Peking University # lilingzhi@pku.edu.cn # Copyright (c) 2019 import torch from torch.nn import init import torch.nn.functional as F from torch import nn from math import sqrt def init_linear(linear): init.xavier_normal(linear.weight) linear.bias.data.zero_() def init_conv(conv, glu=True): init.kaiming_normal(conv.weight) if conv.bias is not None: conv.bias.data.zero_() class EqualLR: def __init__(self, name): self.name = name def compute_weight(self, module): weight = getattr(module, self.name + '_orig') fan_in = weight.data.size(1) * weight.data[0][0].numel() return weight * sqrt(2 / fan_in) @staticmethod def apply(module, name): fn = EqualLR(name) weight = getattr(module, name) del module._parameters[name] module.register_parameter(name + '_orig', nn.Parameter(weight.data)) module.register_forward_pre_hook(fn) return fn def __call__(self, module, input): weight = self.compute_weight(module) setattr(module, self.name, weight) def equal_lr(module, name='weight'): EqualLR.apply(module, name) return module class Blur(nn.Module): def __init__(self): super().__init__() weight = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32) weight = weight.view(1, 1, 3, 3) weight = weight / weight.sum() self.register_buffer('weight', weight) def forward(self, input): return F.conv2d( input, self.weight.repeat(input.shape[1], 1, 1, 1), padding=1, groups=input.shape[1], ) class EqualConv2d(nn.Module): def __init__(self, *args, **kwargs): super().__init__() conv = nn.Conv2d(*args, **kwargs) conv.weight.data.normal_() conv.bias.data.zero_() self.conv = equal_lr(conv) def forward(self, input): return self.conv(input) class EqualLinear(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() linear = nn.Linear(in_dim, out_dim) linear.weight.data.normal_() linear.bias.data.zero_() self.linear = equal_lr(linear) def forward(self, input): return self.linear(input) class AdaptiveInstanceNorm(nn.Module): def __init__(self, in_channel, style_dim): super().__init__() self.norm = nn.InstanceNorm2d(in_channel) self.style = EqualLinear(style_dim, in_channel * 2) self.style.linear.bias.data[:in_channel] = 1 self.style.linear.bias.data[in_channel:] = 0 def forward(self, input, style): style = self.style(style).unsqueeze(2).unsqueeze(3) gamma, beta = style.chunk(2, 1) out = self.norm(input) out = gamma * out + beta return out class NoiseInjection(nn.Module): def __init__(self, channel): super().__init__() self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1)) def forward(self, image, noise): return image + self.weight * noise class ConstantInput(nn.Module): def __init__(self, channel, size=4): super().__init__() self.input = nn.Parameter(torch.randn(1, channel, size, size)) def forward(self, input): batch = input.shape[0] out = self.input.repeat(batch, 1, 1, 1) return out