"driver/src/driver.cu" did not exist on "a414e3fdf83272cbb965c5846f677007c5391d6a"
Commit 25008cf0 authored by dongchy920's avatar dongchy920
Browse files

test

parents
Pipeline #1767 canceled with stages
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.2-py3.10
RUN source /opt/dtk/env.sh
\ No newline at end of file
import argparse
import pickle
import torch
from torch import nn
import numpy as np
from scipy import linalg
from tqdm import tqdm
from model import Generator
from calc_inception import load_patched_inception_v3
@torch.no_grad()
def extract_feature_from_samples(
generator, inception, truncation, truncation_latent, batch_size, n_sample, device
):
n_batch = n_sample // batch_size
resid = n_sample - (n_batch * batch_size)
batch_sizes = [batch_size] * n_batch + [resid]
features = []
for batch in tqdm(batch_sizes):
latent = torch.randn(batch, 512, device=device)
img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent)
feat = inception(img)[0].view(img.shape[0], -1)
features.append(feat.to("cpu"))
features = torch.cat(features, 0)
return features
def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6):
cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False)
if not np.isfinite(cov_sqrt).all():
print("product of cov matrices is singular")
offset = np.eye(sample_cov.shape[0]) * eps
cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset))
if np.iscomplexobj(cov_sqrt):
if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
m = np.max(np.abs(cov_sqrt.imag))
raise ValueError(f"Imaginary component {m}")
cov_sqrt = cov_sqrt.real
mean_diff = sample_mean - real_mean
mean_norm = mean_diff @ mean_diff
trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt)
fid = mean_norm + trace
return fid
if __name__ == "__main__":
device = "cuda"
parser = argparse.ArgumentParser(description="Calculate FID scores")
parser.add_argument("--truncation", type=float, default=1, help="truncation factor")
parser.add_argument(
"--truncation_mean",
type=int,
default=4096,
help="number of samples to calculate mean for truncation",
)
parser.add_argument(
"--batch", type=int, default=64, help="batch size for the generator"
)
parser.add_argument(
"--n_sample",
type=int,
default=50000,
help="number of the samples for calculating FID",
)
parser.add_argument(
"--size", type=int, default=256, help="image sizes for generator"
)
parser.add_argument(
"--inception",
type=str,
default=None,
required=True,
help="path to precomputed inception embedding",
)
parser.add_argument(
"ckpt", metavar="CHECKPOINT", help="path to generator checkpoint"
)
args = parser.parse_args()
ckpt = torch.load(args.ckpt)
g = Generator(args.size, 512, 8).to(device)
g.load_state_dict(ckpt["g_ema"])
g = nn.DataParallel(g)
g.eval()
if args.truncation < 1:
with torch.no_grad():
mean_latent = g.mean_latent(args.truncation_mean)
else:
mean_latent = None
inception = nn.DataParallel(load_patched_inception_v3()).to(device)
inception.eval()
features = extract_feature_from_samples(
g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device
).numpy()
print(f"extracted {features.shape[0]} features")
sample_mean = np.mean(features, 0)
sample_cov = np.cov(features, rowvar=False)
with open(args.inception, "rb") as f:
embeds = pickle.load(f)
real_mean = embeds["mean"]
real_cov = embeds["cov"]
fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov)
print("fid:", fid)
import argparse
import torch
from torchvision import utils
from model import Generator
from tqdm import tqdm
def generate(args, g_ema, device, mean_latent):
with torch.no_grad():
g_ema.eval()
for i in tqdm(range(args.pics)):
sample_z = torch.randn(args.sample, args.latent, device=device)
sample, _ = g_ema(
[sample_z], truncation=args.truncation, truncation_latent=mean_latent
)
utils.save_image(
sample,
f"sample/{str(i).zfill(6)}.png",
nrow=1,
normalize=True,
value_range=(-1, 1),
)
if __name__ == "__main__":
device = "cuda"
parser = argparse.ArgumentParser(description="Generate samples from the generator")
parser.add_argument(
"--size", type=int, default=1024, help="output image size of the generator"
)
parser.add_argument(
"--sample",
type=int,
default=1,
help="number of samples to be generated for each image",
)
parser.add_argument(
"--pics", type=int, default=20, help="number of images to be generated"
)
parser.add_argument("--truncation", type=float, default=1, help="truncation ratio")
parser.add_argument(
"--truncation_mean",
type=int,
default=4096,
help="number of vectors to calculate mean for the truncation",
)
parser.add_argument(
"--ckpt",
type=str,
default="stylegan2-ffhq-config-f.pt",
help="path to the model checkpoint",
)
parser.add_argument(
"--channel_multiplier",
type=int,
default=2,
help="channel multiplier of the generator. config-f = 2, else = 1",
)
args = parser.parse_args()
args.latent = 512
args.n_mlp = 8
g_ema = Generator(
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
).to(device)
checkpoint = torch.load(args.ckpt)
g_ema.load_state_dict(checkpoint["g_ema"])
if args.truncation < 1:
with torch.no_grad():
mean_latent = g_ema.mean_latent(args.truncation_mean)
else:
mean_latent = None
generate(args, g_ema, device, mean_latent)
icon.png

58.2 KB

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
try:
from torchvision.models.utils import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
# Inception weights ported to Pytorch from
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
class InceptionV3(nn.Module):
"""Pretrained InceptionV3 network returning feature maps"""
# Index of default block of inception to return,
# corresponds to output of final average pooling
DEFAULT_BLOCK_INDEX = 3
# Maps feature dimensionality to their output blocks indices
BLOCK_INDEX_BY_DIM = {
64: 0, # First max pooling features
192: 1, # Second max pooling featurs
768: 2, # Pre-aux classifier features
2048: 3 # Final average pooling features
}
def __init__(self,
output_blocks=[DEFAULT_BLOCK_INDEX],
resize_input=True,
normalize_input=True,
requires_grad=False,
use_fid_inception=True):
"""Build pretrained InceptionV3
Parameters
----------
output_blocks : list of int
Indices of blocks to return features of. Possible values are:
- 0: corresponds to output of first max pooling
- 1: corresponds to output of second max pooling
- 2: corresponds to output which is fed to aux classifier
- 3: corresponds to output of final average pooling
resize_input : bool
If true, bilinearly resizes input to width and height 299 before
feeding input to model. As the network without fully connected
layers is fully convolutional, it should be able to handle inputs
of arbitrary size, so resizing might not be strictly needed
normalize_input : bool
If true, scales the input from range (0, 1) to the range the
pretrained Inception network expects, namely (-1, 1)
requires_grad : bool
If true, parameters of the model require gradients. Possibly useful
for finetuning the network
use_fid_inception : bool
If true, uses the pretrained Inception model used in Tensorflow's
FID implementation. If false, uses the pretrained Inception model
available in torchvision. The FID Inception model has different
weights and a slightly different structure from torchvision's
Inception model. If you want to compute FID scores, you are
strongly advised to set this parameter to true to get comparable
results.
"""
super(InceptionV3, self).__init__()
self.resize_input = resize_input
self.normalize_input = normalize_input
self.output_blocks = sorted(output_blocks)
self.last_needed_block = max(output_blocks)
assert self.last_needed_block <= 3, \
'Last possible output block index is 3'
self.blocks = nn.ModuleList()
if use_fid_inception:
inception = fid_inception_v3()
else:
inception = models.inception_v3(pretrained=True)
# Block 0: input to maxpool1
block0 = [
inception.Conv2d_1a_3x3,
inception.Conv2d_2a_3x3,
inception.Conv2d_2b_3x3,
nn.MaxPool2d(kernel_size=3, stride=2)
]
self.blocks.append(nn.Sequential(*block0))
# Block 1: maxpool1 to maxpool2
if self.last_needed_block >= 1:
block1 = [
inception.Conv2d_3b_1x1,
inception.Conv2d_4a_3x3,
nn.MaxPool2d(kernel_size=3, stride=2)
]
self.blocks.append(nn.Sequential(*block1))
# Block 2: maxpool2 to aux classifier
if self.last_needed_block >= 2:
block2 = [
inception.Mixed_5b,
inception.Mixed_5c,
inception.Mixed_5d,
inception.Mixed_6a,
inception.Mixed_6b,
inception.Mixed_6c,
inception.Mixed_6d,
inception.Mixed_6e,
]
self.blocks.append(nn.Sequential(*block2))
# Block 3: aux classifier to final avgpool
if self.last_needed_block >= 3:
block3 = [
inception.Mixed_7a,
inception.Mixed_7b,
inception.Mixed_7c,
nn.AdaptiveAvgPool2d(output_size=(1, 1))
]
self.blocks.append(nn.Sequential(*block3))
for param in self.parameters():
param.requires_grad = requires_grad
def forward(self, inp):
"""Get Inception feature maps
Parameters
----------
inp : torch.autograd.Variable
Input tensor of shape Bx3xHxW. Values are expected to be in
range (0, 1)
Returns
-------
List of torch.autograd.Variable, corresponding to the selected output
block, sorted ascending by index
"""
outp = []
x = inp
if self.resize_input:
x = F.interpolate(x,
size=(299, 299),
mode='bilinear',
align_corners=False)
if self.normalize_input:
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
for idx, block in enumerate(self.blocks):
x = block(x)
if idx in self.output_blocks:
outp.append(x)
if idx == self.last_needed_block:
break
return outp
def fid_inception_v3():
"""Build pretrained Inception model for FID computation
The Inception model for FID computation uses a different set of weights
and has a slightly different structure than torchvision's Inception.
This method first constructs torchvision's Inception and then patches the
necessary parts that are different in the FID Inception model.
"""
inception = models.inception_v3(num_classes=1008,
aux_logits=False,
pretrained=False)
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
inception.Mixed_7b = FIDInceptionE_1(1280)
inception.Mixed_7c = FIDInceptionE_2(2048)
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
inception.load_state_dict(state_dict)
return inception
class FIDInceptionA(models.inception.InceptionA):
"""InceptionA block patched for FID computation"""
def __init__(self, in_channels, pool_features):
super(FIDInceptionA, self).__init__(in_channels, pool_features)
def forward(self, x):
branch1x1 = self.branch1x1(x)
branch5x5 = self.branch5x5_1(x)
branch5x5 = self.branch5x5_2(branch5x5)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
# Patch: Tensorflow's average pool does not use the padded zero's in
# its average calculation
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
count_include_pad=False)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
return torch.cat(outputs, 1)
class FIDInceptionC(models.inception.InceptionC):
"""InceptionC block patched for FID computation"""
def __init__(self, in_channels, channels_7x7):
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
def forward(self, x):
branch1x1 = self.branch1x1(x)
branch7x7 = self.branch7x7_1(x)
branch7x7 = self.branch7x7_2(branch7x7)
branch7x7 = self.branch7x7_3(branch7x7)
branch7x7dbl = self.branch7x7dbl_1(x)
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
# Patch: Tensorflow's average pool does not use the padded zero's in
# its average calculation
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
count_include_pad=False)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
return torch.cat(outputs, 1)
class FIDInceptionE_1(models.inception.InceptionE):
"""First InceptionE block patched for FID computation"""
def __init__(self, in_channels):
super(FIDInceptionE_1, self).__init__(in_channels)
def forward(self, x):
branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3_1(x)
branch3x3 = [
self.branch3x3_2a(branch3x3),
self.branch3x3_2b(branch3x3),
]
branch3x3 = torch.cat(branch3x3, 1)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = [
self.branch3x3dbl_3a(branch3x3dbl),
self.branch3x3dbl_3b(branch3x3dbl),
]
branch3x3dbl = torch.cat(branch3x3dbl, 1)
# Patch: Tensorflow's average pool does not use the padded zero's in
# its average calculation
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
count_include_pad=False)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
return torch.cat(outputs, 1)
class FIDInceptionE_2(models.inception.InceptionE):
"""Second InceptionE block patched for FID computation"""
def __init__(self, in_channels):
super(FIDInceptionE_2, self).__init__(in_channels)
def forward(self, x):
branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3_1(x)
branch3x3 = [
self.branch3x3_2a(branch3x3),
self.branch3x3_2b(branch3x3),
]
branch3x3 = torch.cat(branch3x3, 1)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = [
self.branch3x3dbl_3a(branch3x3dbl),
self.branch3x3dbl_3b(branch3x3dbl),
]
branch3x3dbl = torch.cat(branch3x3dbl, 1)
# Patch: The FID Inception model uses max pooling instead of average
# pooling. This is likely an error in this specific Inception
# implementation, as other Inception models use average pooling here
# (which matches the description in the paper).
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
return torch.cat(outputs, 1)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from skimage.measure import compare_ssim
import torch
from torch.autograd import Variable
from lpips import dist_model
class PerceptualLoss(torch.nn.Module):
def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
# def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
super(PerceptualLoss, self).__init__()
print('Setting up Perceptual loss...')
self.use_gpu = use_gpu
self.spatial = spatial
self.gpu_ids = gpu_ids
self.model = dist_model.DistModel()
self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
print('...[%s] initialized'%self.model.name())
print('...Done')
def forward(self, pred, target, normalize=False):
"""
Pred and target are Variables.
If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
If normalize is False, assumes the images are already between [-1,+1]
Inputs pred and target are Nx3xHxW
Output pytorch Variable N long
"""
if normalize:
target = 2 * target - 1
pred = 2 * pred - 1
return self.model.forward(target, pred)
def normalize_tensor(in_feat,eps=1e-10):
norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
return in_feat/(norm_factor+eps)
def l2(p0, p1, range=255.):
return .5*np.mean((p0 / range - p1 / range)**2)
def psnr(p0, p1, peak=255.):
return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
def dssim(p0, p1, range=255.):
return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
def rgb2lab(in_img,mean_cent=False):
from skimage import color
img_lab = color.rgb2lab(in_img)
if(mean_cent):
img_lab[:,:,0] = img_lab[:,:,0]-50
return img_lab
def tensor2np(tensor_obj):
# change dimension of a tensor object into a numpy array
return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
def np2tensor(np_obj):
# change dimenion of np array into tensor array
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
# image tensor to lab tensor
from skimage import color
img = tensor2im(image_tensor)
img_lab = color.rgb2lab(img)
if(mc_only):
img_lab[:,:,0] = img_lab[:,:,0]-50
if(to_norm and not mc_only):
img_lab[:,:,0] = img_lab[:,:,0]-50
img_lab = img_lab/100.
return np2tensor(img_lab)
def tensorlab2tensor(lab_tensor,return_inbnd=False):
from skimage import color
import warnings
warnings.filterwarnings("ignore")
lab = tensor2np(lab_tensor)*100.
lab[:,:,0] = lab[:,:,0]+50
rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
if(return_inbnd):
# convert back to lab, see if we match
lab_back = color.rgb2lab(rgb_back.astype('uint8'))
mask = 1.*np.isclose(lab_back,lab,atol=2.)
mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
return (im2tensor(rgb_back),mask)
else:
return im2tensor(rgb_back)
def rgb2lab(input):
from skimage import color
return color.rgb2lab(input / 255.)
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
return image_numpy.astype(imtype)
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
return torch.Tensor((image / factor - cent)
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
def tensor2vec(vector_tensor):
return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
def voc_ap(rec, prec, use_07_metric=False):
""" ap = voc_ap(rec, prec, [use_07_metric])
Compute VOC AP given precision and recall.
If use_07_metric is true, uses the
VOC 07 11 point method (default:False).
"""
if use_07_metric:
# 11 point metric
ap = 0.
for t in np.arange(0., 1.1, 0.1):
if np.sum(rec >= t) == 0:
p = 0
else:
p = np.max(prec[rec >= t])
ap = ap + p / 11.
else:
# correct AP calculation
# first append sentinel values at the end
mrec = np.concatenate(([0.], rec, [1.]))
mpre = np.concatenate(([0.], prec, [0.]))
# compute the precision envelope
for i in range(mpre.size - 1, 0, -1):
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
# to calculate area under PR curve, look for points
# where X axis (recall) changes value
i = np.where(mrec[1:] != mrec[:-1])[0]
# and sum (\Delta recall) * prec
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
return ap
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
return image_numpy.astype(imtype)
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
return torch.Tensor((image / factor - cent)
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
import os
import numpy as np
import torch
from torch.autograd import Variable
from pdb import set_trace as st
from IPython import embed
class BaseModel():
def __init__(self):
pass;
def name(self):
return 'BaseModel'
def initialize(self, use_gpu=True, gpu_ids=[0]):
self.use_gpu = use_gpu
self.gpu_ids = gpu_ids
def forward(self):
pass
def get_image_paths(self):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
return self.input
def get_current_errors(self):
return {}
def save(self, label):
pass
# helper saving function that can be used by subclasses
def save_network(self, network, path, network_label, epoch_label):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(path, save_filename)
torch.save(network.state_dict(), save_path)
# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
print('Loading network from %s'%save_path)
network.load_state_dict(torch.load(save_path))
def update_learning_rate():
pass
def get_image_paths(self):
return self.image_paths
def save_done(self, flag=False):
np.save(os.path.join(self.save_dir, 'done_flag'),flag)
np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
from __future__ import absolute_import
import sys
import numpy as np
import torch
from torch import nn
import os
from collections import OrderedDict
from torch.autograd import Variable
import itertools
from .base_model import BaseModel
from scipy.ndimage import zoom
import fractions
import functools
import skimage.transform
from tqdm import tqdm
from IPython import embed
from . import networks_basic as networks
import lpips as util
class DistModel(BaseModel):
def name(self):
return self.model_name
def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
use_gpu=True, printNet=False, spatial=False,
is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
'''
INPUTS
model - ['net-lin'] for linearly calibrated network
['net'] for off-the-shelf network
['L2'] for L2 distance in Lab colorspace
['SSIM'] for ssim in RGB colorspace
net - ['squeeze','alex','vgg']
model_path - if None, will look in weights/[NET_NAME].pth
colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
use_gpu - bool - whether or not to use a GPU
printNet - bool - whether or not to print network architecture out
spatial - bool - whether to output an array containing varying distances across spatial dimensions
spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
is_train - bool - [True] for training mode
lr - float - initial learning rate
beta1 - float - initial momentum term for adam
version - 0.1 for latest, 0.0 was original (with a bug)
gpu_ids - int array - [0] by default, gpus to use
'''
BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
self.model = model
self.net = net
self.is_train = is_train
self.spatial = spatial
self.gpu_ids = gpu_ids
self.model_name = '%s [%s]'%(model,net)
if(self.model == 'net-lin'): # pretrained net + linear layer
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
use_dropout=True, spatial=spatial, version=version, lpips=True)
kw = {}
if not use_gpu:
kw['map_location'] = 'cpu'
if(model_path is None):
import inspect
model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net)))
if(not is_train):
print('Loading model from: %s'%model_path)
self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
elif(self.model=='net'): # pretrained network
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
elif(self.model in ['L2','l2']):
self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
self.model_name = 'L2'
elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace)
self.model_name = 'SSIM'
else:
raise ValueError("Model [%s] not recognized." % self.model)
self.parameters = list(self.net.parameters())
if self.is_train: # training mode
# extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
self.rankLoss = networks.BCERankingLoss()
self.parameters += list(self.rankLoss.net.parameters())
self.lr = lr
self.old_lr = lr
self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
else: # test mode
self.net.eval()
if(use_gpu):
self.net.to(gpu_ids[0])
self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
if(self.is_train):
self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
if(printNet):
print('---------- Networks initialized -------------')
networks.print_network(self.net)
print('-----------------------------------------------')
def forward(self, in0, in1, retPerLayer=False):
''' Function computes the distance between image patches in0 and in1
INPUTS
in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
OUTPUT
computed distances between in0 and in1
'''
return self.net.forward(in0, in1, retPerLayer=retPerLayer)
# ***** TRAINING FUNCTIONS *****
def optimize_parameters(self):
self.forward_train()
self.optimizer_net.zero_grad()
self.backward_train()
self.optimizer_net.step()
self.clamp_weights()
def clamp_weights(self):
for module in self.net.modules():
if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
module.weight.data = torch.clamp(module.weight.data,min=0)
def set_input(self, data):
self.input_ref = data['ref']
self.input_p0 = data['p0']
self.input_p1 = data['p1']
self.input_judge = data['judge']
if(self.use_gpu):
self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
self.var_ref = Variable(self.input_ref,requires_grad=True)
self.var_p0 = Variable(self.input_p0,requires_grad=True)
self.var_p1 = Variable(self.input_p1,requires_grad=True)
def forward_train(self): # run forward pass
# print(self.net.module.scaling_layer.shift)
# print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
self.d0 = self.forward(self.var_ref, self.var_p0)
self.d1 = self.forward(self.var_ref, self.var_p1)
self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)
self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
return self.loss_total
def backward_train(self):
torch.mean(self.loss_total).backward()
def compute_accuracy(self,d0,d1,judge):
''' d0, d1 are Variables, judge is a Tensor '''
d1_lt_d0 = (d1<d0).cpu().data.numpy().flatten()
judge_per = judge.cpu().numpy().flatten()
return d1_lt_d0*judge_per + (1-d1_lt_d0)*(1-judge_per)
def get_current_errors(self):
retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
('acc_r', self.acc_r)])
for key in retDict.keys():
retDict[key] = np.mean(retDict[key])
return retDict
def get_current_visuals(self):
zoom_factor = 256/self.var_ref.data.size()[2]
ref_img = util.tensor2im(self.var_ref.data)
p0_img = util.tensor2im(self.var_p0.data)
p1_img = util.tensor2im(self.var_p1.data)
ref_img_vis = zoom(ref_img,[zoom_factor, zoom_factor, 1],order=0)
p0_img_vis = zoom(p0_img,[zoom_factor, zoom_factor, 1],order=0)
p1_img_vis = zoom(p1_img,[zoom_factor, zoom_factor, 1],order=0)
return OrderedDict([('ref', ref_img_vis),
('p0', p0_img_vis),
('p1', p1_img_vis)])
def save(self, path, label):
if(self.use_gpu):
self.save_network(self.net.module, path, '', label)
else:
self.save_network(self.net, path, '', label)
self.save_network(self.rankLoss.net, path, 'rank', label)
def update_learning_rate(self,nepoch_decay):
lrd = self.lr / nepoch_decay
lr = self.old_lr - lrd
for param_group in self.optimizer_net.param_groups:
param_group['lr'] = lr
print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr))
self.old_lr = lr
def score_2afc_dataset(data_loader, func, name=''):
''' Function computes Two Alternative Forced Choice (2AFC) score using
distance function 'func' in dataset 'data_loader'
INPUTS
data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
func - callable distance function - calling d=func(in0,in1) should take 2
pytorch tensors with shape Nx3xXxY, and return numpy array of length N
OUTPUTS
[0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
[1] - dictionary with following elements
d0s,d1s - N arrays containing distances between reference patch to perturbed patches
gts - N array in [0,1], preferred patch selected by human evaluators
(closer to "0" for left patch p0, "1" for right patch p1,
"0.6" means 60pct people preferred right patch, 40pct preferred left)
scores - N array in [0,1], corresponding to what percentage function agreed with humans
CONSTS
N - number of test triplets in data_loader
'''
d0s = []
d1s = []
gts = []
for data in tqdm(data_loader.load_data(), desc=name):
d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist()
d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist()
gts+=data['judge'].cpu().numpy().flatten().tolist()
d0s = np.array(d0s)
d1s = np.array(d1s)
gts = np.array(gts)
scores = (d0s<d1s)*(1.-gts) + (d1s<d0s)*gts + (d1s==d0s)*.5
return(np.mean(scores), dict(d0s=d0s,d1s=d1s,gts=gts,scores=scores))
def score_jnd_dataset(data_loader, func, name=''):
''' Function computes JND score using distance function 'func' in dataset 'data_loader'
INPUTS
data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
func - callable distance function - calling d=func(in0,in1) should take 2
pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
OUTPUTS
[0] - JND score in [0,1], mAP score (area under precision-recall curve)
[1] - dictionary with following elements
ds - N array containing distances between two patches shown to human evaluator
sames - N array containing fraction of people who thought the two patches were identical
CONSTS
N - number of test triplets in data_loader
'''
ds = []
gts = []
for data in tqdm(data_loader.load_data(), desc=name):
ds+=func(data['p0'],data['p1']).data.cpu().numpy().tolist()
gts+=data['same'].cpu().numpy().flatten().tolist()
sames = np.array(gts)
ds = np.array(ds)
sorted_inds = np.argsort(ds)
ds_sorted = ds[sorted_inds]
sames_sorted = sames[sorted_inds]
TPs = np.cumsum(sames_sorted)
FPs = np.cumsum(1-sames_sorted)
FNs = np.sum(sames_sorted)-TPs
precs = TPs/(TPs+FPs)
recs = TPs/(TPs+FNs)
score = util.voc_ap(recs,precs)
return(score, dict(ds=ds,sames=sames))
from __future__ import absolute_import
import sys
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
import numpy as np
from pdb import set_trace as st
from skimage import color
from IPython import embed
from . import pretrained_networks as pn
import lpips as util
def spatial_average(in_tens, keepdim=True):
return in_tens.mean([2,3],keepdim=keepdim)
def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
in_H = in_tens.shape[2]
scale_factor = 1.*out_H/in_H
return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
# Learned perceptual metric
class PNetLin(nn.Module):
def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
super(PNetLin, self).__init__()
self.pnet_type = pnet_type
self.pnet_tune = pnet_tune
self.pnet_rand = pnet_rand
self.spatial = spatial
self.lpips = lpips
self.version = version
self.scaling_layer = ScalingLayer()
if(self.pnet_type in ['vgg','vgg16']):
net_type = pn.vgg16
self.chns = [64,128,256,512,512]
elif(self.pnet_type=='alex'):
net_type = pn.alexnet
self.chns = [64,192,384,256,256]
elif(self.pnet_type=='squeeze'):
net_type = pn.squeezenet
self.chns = [64,128,256,384,384,512,512]
self.L = len(self.chns)
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
if(lpips):
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
self.lins+=[self.lin5,self.lin6]
def forward(self, in0, in1, retPerLayer=False):
# v0.0 - original release had a bug, where input was not scaled
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}
for kk in range(self.L):
feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk]-feats1[kk])**2
if(self.lpips):
if(self.spatial):
res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
else:
res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
else:
if(self.spatial):
res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
else:
res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
val = res[0]
for l in range(1,self.L):
val += res[l]
if(retPerLayer):
return (val, res)
else:
return val
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
''' A single linear layer which does a 1x1 conv '''
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = [nn.Dropout(),] if(use_dropout) else []
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
self.model = nn.Sequential(*layers)
class Dist2LogitLayer(nn.Module):
''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
def __init__(self, chn_mid=32, use_sigmoid=True):
super(Dist2LogitLayer, self).__init__()
layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
layers += [nn.LeakyReLU(0.2,True),]
layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
layers += [nn.LeakyReLU(0.2,True),]
layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
if(use_sigmoid):
layers += [nn.Sigmoid(),]
self.model = nn.Sequential(*layers)
def forward(self,d0,d1,eps=0.1):
return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
class BCERankingLoss(nn.Module):
def __init__(self, chn_mid=32):
super(BCERankingLoss, self).__init__()
self.net = Dist2LogitLayer(chn_mid=chn_mid)
# self.parameters = list(self.net.parameters())
self.loss = torch.nn.BCELoss()
def forward(self, d0, d1, judge):
per = (judge+1.)/2.
self.logit = self.net.forward(d0,d1)
return self.loss(self.logit, per)
# L2, DSSIM metrics
class FakeNet(nn.Module):
def __init__(self, use_gpu=True, colorspace='Lab'):
super(FakeNet, self).__init__()
self.use_gpu = use_gpu
self.colorspace=colorspace
class L2(FakeNet):
def forward(self, in0, in1, retPerLayer=None):
assert(in0.size()[0]==1) # currently only supports batchSize 1
if(self.colorspace=='RGB'):
(N,C,X,Y) = in0.size()
value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
return value
elif(self.colorspace=='Lab'):
value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
ret_var = Variable( torch.Tensor((value,) ) )
if(self.use_gpu):
ret_var = ret_var.cuda()
return ret_var
class DSSIM(FakeNet):
def forward(self, in0, in1, retPerLayer=None):
assert(in0.size()[0]==1) # currently only supports batchSize 1
if(self.colorspace=='RGB'):
value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
elif(self.colorspace=='Lab'):
value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
ret_var = Variable( torch.Tensor((value,) ) )
if(self.use_gpu):
ret_var = ret_var.cuda()
return ret_var
def print_network(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print('Network',net)
print('Total number of parameters: %d' % num_params)
from collections import namedtuple
import torch
from torchvision import models as tv
from IPython import embed
class squeezenet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(squeezenet, self).__init__()
pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.slice6 = torch.nn.Sequential()
self.slice7 = torch.nn.Sequential()
self.N_slices = 7
for x in range(2):
self.slice1.add_module(str(x), pretrained_features[x])
for x in range(2,5):
self.slice2.add_module(str(x), pretrained_features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), pretrained_features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), pretrained_features[x])
for x in range(10, 11):
self.slice5.add_module(str(x), pretrained_features[x])
for x in range(11, 12):
self.slice6.add_module(str(x), pretrained_features[x])
for x in range(12, 13):
self.slice7.add_module(str(x), pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
h = self.slice6(h)
h_relu6 = h
h = self.slice7(h)
h_relu7 = h
vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
return out
class alexnet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(alexnet, self).__init__()
alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(2):
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
for x in range(2, 5):
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
for x in range(10, 12):
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
return out
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
class resnet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True, num=18):
super(resnet, self).__init__()
if(num==18):
self.net = tv.resnet18(pretrained=pretrained)
elif(num==34):
self.net = tv.resnet34(pretrained=pretrained)
elif(num==50):
self.net = tv.resnet50(pretrained=pretrained)
elif(num==101):
self.net = tv.resnet101(pretrained=pretrained)
elif(num==152):
self.net = tv.resnet152(pretrained=pretrained)
self.N_slices = 5
self.conv1 = self.net.conv1
self.bn1 = self.net.bn1
self.relu = self.net.relu
self.maxpool = self.net.maxpool
self.layer1 = self.net.layer1
self.layer2 = self.net.layer2
self.layer3 = self.net.layer3
self.layer4 = self.net.layer4
def forward(self, X):
h = self.conv1(X)
h = self.bn1(h)
h = self.relu(h)
h_relu1 = h
h = self.maxpool(h)
h = self.layer1(h)
h_conv2 = h
h = self.layer2(h)
h_conv3 = h
h = self.layer3(h)
h_conv4 = h
h = self.layer4(h)
h_conv5 = h
outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
return out
import math
import random
import functools
import operator
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
class PixelNorm(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
def make_kernel(k):
k = torch.tensor(k, dtype=torch.float32)
if k.ndim == 1:
k = k[None, :] * k[:, None]
k /= k.sum()
return k
class Upsample(nn.Module):
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = make_kernel(kernel) * (factor ** 2)
self.register_buffer("kernel", kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, input):
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
return out
class Downsample(nn.Module):
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = make_kernel(kernel)
self.register_buffer("kernel", kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, input):
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
return out
class Blur(nn.Module):
def __init__(self, kernel, pad, upsample_factor=1):
super().__init__()
kernel = make_kernel(kernel)
if upsample_factor > 1:
kernel = kernel * (upsample_factor ** 2)
self.register_buffer("kernel", kernel)
self.pad = pad
def forward(self, input):
out = upfirdn2d(input, self.kernel, pad=self.pad)
return out
class EqualConv2d(nn.Module):
def __init__(
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
):
super().__init__()
self.weight = nn.Parameter(
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
)
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
self.stride = stride
self.padding = padding
if bias:
self.bias = nn.Parameter(torch.zeros(out_channel))
else:
self.bias = None
def forward(self, input):
out = conv2d_gradfix.conv2d(
input,
self.weight * self.scale,
bias=self.bias,
stride=self.stride,
padding=self.padding,
)
return out
def __repr__(self):
return (
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
)
class EqualLinear(nn.Module):
def __init__(
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
else:
self.bias = None
self.activation = activation
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
if self.activation:
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)
else:
out = F.linear(
input, self.weight * self.scale, bias=self.bias * self.lr_mul
)
return out
def __repr__(self):
return (
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
)
class ModulatedConv2d(nn.Module):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
style_dim,
demodulate=True,
upsample=False,
downsample=False,
blur_kernel=[1, 3, 3, 1],
fused=True,
):
super().__init__()
self.eps = 1e-8
self.kernel_size = kernel_size
self.in_channel = in_channel
self.out_channel = out_channel
self.upsample = upsample
self.downsample = downsample
if upsample:
factor = 2
p = (len(blur_kernel) - factor) - (kernel_size - 1)
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2 + 1
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
fan_in = in_channel * kernel_size ** 2
self.scale = 1 / math.sqrt(fan_in)
self.padding = kernel_size // 2
self.weight = nn.Parameter(
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
)
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
self.demodulate = demodulate
self.fused = fused
def __repr__(self):
return (
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
f"upsample={self.upsample}, downsample={self.downsample})"
)
def forward(self, input, style):
batch, in_channel, height, width = input.shape
if not self.fused:
weight = self.scale * self.weight.squeeze(0)
style = self.modulation(style)
if self.demodulate:
w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
input = input * style.reshape(batch, in_channel, 1, 1)
if self.upsample:
weight = weight.transpose(0, 1)
out = conv2d_gradfix.conv_transpose2d(
input, weight, padding=0, stride=2
)
out = self.blur(out)
elif self.downsample:
input = self.blur(input)
out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
else:
out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
if self.demodulate:
out = out * dcoefs.view(batch, -1, 1, 1)
return out
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
weight = self.scale * self.weight * style
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
weight = weight.view(
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
)
if self.upsample:
input = input.view(1, batch * in_channel, height, width)
weight = weight.view(
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
)
weight = weight.transpose(1, 2).reshape(
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
)
out = conv2d_gradfix.conv_transpose2d(
input, weight, padding=0, stride=2, groups=batch
)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
out = self.blur(out)
elif self.downsample:
input = self.blur(input)
_, _, height, width = input.shape
input = input.view(1, batch * in_channel, height, width)
out = conv2d_gradfix.conv2d(
input, weight, padding=0, stride=2, groups=batch
)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
else:
input = input.view(1, batch * in_channel, height, width)
out = conv2d_gradfix.conv2d(
input, weight, padding=self.padding, groups=batch
)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
return out
class NoiseInjection(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1))
def forward(self, image, noise=None):
if noise is None:
batch, _, height, width = image.shape
noise = image.new_empty(batch, 1, height, width).normal_()
return image + self.weight * noise
class ConstantInput(nn.Module):
def __init__(self, channel, size=4):
super().__init__()
self.input = nn.Parameter(torch.randn(1, channel, size, size))
def forward(self, input):
batch = input.shape[0]
out = self.input.repeat(batch, 1, 1, 1)
return out
class StyledConv(nn.Module):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
style_dim,
upsample=False,
blur_kernel=[1, 3, 3, 1],
demodulate=True,
):
super().__init__()
self.conv = ModulatedConv2d(
in_channel,
out_channel,
kernel_size,
style_dim,
upsample=upsample,
blur_kernel=blur_kernel,
demodulate=demodulate,
)
self.noise = NoiseInjection()
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
# self.activate = ScaledLeakyReLU(0.2)
self.activate = FusedLeakyReLU(out_channel)
def forward(self, input, style, noise=None):
out = self.conv(input, style)
out = self.noise(out, noise=noise)
# out = out + self.bias
out = self.activate(out)
return out
class ToRGB(nn.Module):
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
super().__init__()
if upsample:
self.upsample = Upsample(blur_kernel)
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
def forward(self, input, style, skip=None):
out = self.conv(input, style)
out = out + self.bias
if skip is not None:
skip = self.upsample(skip)
out = out + skip
return out
class Generator(nn.Module):
def __init__(
self,
size,
style_dim,
n_mlp,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
):
super().__init__()
self.size = size
self.style_dim = style_dim
layers = [PixelNorm()]
for i in range(n_mlp):
layers.append(
EqualLinear(
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
)
)
self.style = nn.Sequential(*layers)
self.channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
self.input = ConstantInput(self.channels[4])
self.conv1 = StyledConv(
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
)
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
self.log_size = int(math.log(size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1
self.convs = nn.ModuleList()
self.upsamples = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
self.noises = nn.Module()
in_channel = self.channels[4]
for layer_idx in range(self.num_layers):
res = (layer_idx + 5) // 2
shape = [1, 1, 2 ** res, 2 ** res]
self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
for i in range(3, self.log_size + 1):
out_channel = self.channels[2 ** i]
self.convs.append(
StyledConv(
in_channel,
out_channel,
3,
style_dim,
upsample=True,
blur_kernel=blur_kernel,
)
)
self.convs.append(
StyledConv(
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
)
)
self.to_rgbs.append(ToRGB(out_channel, style_dim))
in_channel = out_channel
self.n_latent = self.log_size * 2 - 2
def make_noise(self):
device = self.input.input.device
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
return noises
def mean_latent(self, n_latent):
latent_in = torch.randn(
n_latent, self.style_dim, device=self.input.input.device
)
latent = self.style(latent_in).mean(0, keepdim=True)
return latent
def get_latent(self, input):
return self.style(input)
def forward(
self,
styles,
return_latents=False,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
noise=None,
randomize_noise=True,
):
if not input_is_latent:
styles = [self.style(s) for s in styles]
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers
else:
noise = [
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
]
if truncation < 1:
style_t = []
for style in styles:
style_t.append(
truncation_latent + truncation * (style - truncation_latent)
)
styles = style_t
if len(styles) < 2:
inject_index = self.n_latent
if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else:
latent = styles[0]
else:
if inject_index is None:
inject_index = random.randint(1, self.n_latent - 1)
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
latent = torch.cat([latent, latent2], 1)
out = self.input(latent)
out = self.conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
class ConvLayer(nn.Sequential):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
downsample=False,
blur_kernel=[1, 3, 3, 1],
bias=True,
activate=True,
):
layers = []
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
stride = 2
self.padding = 0
else:
stride = 1
self.padding = kernel_size // 2
layers.append(
EqualConv2d(
in_channel,
out_channel,
kernel_size,
padding=self.padding,
stride=stride,
bias=bias and not activate,
)
)
if activate:
layers.append(FusedLeakyReLU(out_channel, bias=bias))
super().__init__(*layers)
class ResBlock(nn.Module):
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
super().__init__()
self.conv1 = ConvLayer(in_channel, in_channel, 3)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
self.skip = ConvLayer(
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
)
def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)
skip = self.skip(input)
out = (out + skip) / math.sqrt(2)
return out
class Discriminator(nn.Module):
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
super().__init__()
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
convs = [ConvLayer(3, channels[size], 1)]
log_size = int(math.log(size, 2))
in_channel = channels[size]
for i in range(log_size, 2, -1):
out_channel = channels[2 ** (i - 1)]
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
in_channel = out_channel
self.convs = nn.Sequential(*convs)
self.stddev_group = 4
self.stddev_feat = 1
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
self.final_linear = nn.Sequential(
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
EqualLinear(channels[4], 1),
)
def forward(self, input):
out = self.convs(input)
batch, channel, height, width = out.shape
group = min(batch, self.stddev_group)
stddev = out.view(
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
)
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
stddev = stddev.repeat(group, 1, height, width)
out = torch.cat([out, stddev], 1)
out = self.final_conv(out)
out = out.view(batch, -1)
out = self.final_linear(out)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment