Commit 25008cf0 authored by dongchy920's avatar dongchy920
Browse files

test

parents
Pipeline #1767 canceled with stages
import math
import torch
from torch import autograd
from torch.nn import functional as F
import numpy as np
from distributed import reduce_sum
from op import upfirdn2d
class AdaptiveAugment:
def __init__(self, ada_aug_target, ada_aug_len, update_every, device):
self.ada_aug_target = ada_aug_target
self.ada_aug_len = ada_aug_len
self.update_every = update_every
self.ada_update = 0
self.ada_aug_buf = torch.tensor([0.0, 0.0], device=device)
self.r_t_stat = 0
self.ada_aug_p = 0
@torch.no_grad()
def tune(self, real_pred):
self.ada_aug_buf += torch.tensor(
(torch.sign(real_pred).sum().item(), real_pred.shape[0]),
device=real_pred.device,
)
self.ada_update += 1
if self.ada_update % self.update_every == 0:
self.ada_aug_buf = reduce_sum(self.ada_aug_buf)
pred_signs, n_pred = self.ada_aug_buf.tolist()
self.r_t_stat = pred_signs / n_pred
if self.r_t_stat > self.ada_aug_target:
sign = 1
else:
sign = -1
self.ada_aug_p += sign * n_pred / self.ada_aug_len
self.ada_aug_p = min(1, max(0, self.ada_aug_p))
self.ada_aug_buf.mul_(0)
self.ada_update = 0
return self.ada_aug_p
SYM6 = (
0.015404109327027373,
0.0034907120842174702,
-0.11799011114819057,
-0.048311742585633,
0.4910559419267466,
0.787641141030194,
0.3379294217276218,
-0.07263752278646252,
-0.021060292512300564,
0.04472490177066578,
0.0017677118642428036,
-0.007800708325034148,
)
def translate_mat(t_x, t_y, device="cpu"):
batch = t_x.shape[0]
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
translate = torch.stack((t_x, t_y), 1)
mat[:, :2, 2] = translate
return mat
def rotate_mat(theta, device="cpu"):
batch = theta.shape[0]
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
sin_t = torch.sin(theta)
cos_t = torch.cos(theta)
rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2)
mat[:, :2, :2] = rot
return mat
def scale_mat(s_x, s_y, device="cpu"):
batch = s_x.shape[0]
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
mat[:, 0, 0] = s_x
mat[:, 1, 1] = s_y
return mat
def translate3d_mat(t_x, t_y, t_z):
batch = t_x.shape[0]
mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
translate = torch.stack((t_x, t_y, t_z), 1)
mat[:, :3, 3] = translate
return mat
def rotate3d_mat(axis, theta):
batch = theta.shape[0]
u_x, u_y, u_z = axis
eye = torch.eye(3).unsqueeze(0)
cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0)
outer = torch.tensor(axis)
outer = (outer.unsqueeze(1) * outer).unsqueeze(0)
sin_t = torch.sin(theta).view(-1, 1, 1)
cos_t = torch.cos(theta).view(-1, 1, 1)
rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer
eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
eye_4[:, :3, :3] = rot
return eye_4
def scale3d_mat(s_x, s_y, s_z):
batch = s_x.shape[0]
mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
mat[:, 0, 0] = s_x
mat[:, 1, 1] = s_y
mat[:, 2, 2] = s_z
return mat
def luma_flip_mat(axis, i):
batch = i.shape[0]
eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
axis = torch.tensor(axis + (0,))
flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1)
return eye - flip
def saturation_mat(axis, i):
batch = i.shape[0]
eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
axis = torch.tensor(axis + (0,))
axis = torch.ger(axis, axis)
saturate = axis + (eye - axis) * i.view(-1, 1, 1)
return saturate
def lognormal_sample(size, mean=0, std=1, device="cpu"):
return torch.empty(size, device=device).log_normal_(mean=mean, std=std)
def category_sample(size, categories, device="cpu"):
category = torch.tensor(categories, device=device)
sample = torch.randint(high=len(categories), size=(size,), device=device)
return category[sample]
def uniform_sample(size, low, high, device="cpu"):
return torch.empty(size, device=device).uniform_(low, high)
def normal_sample(size, mean=0, std=1, device="cpu"):
return torch.empty(size, device=device).normal_(mean, std)
def bernoulli_sample(size, p, device="cpu"):
return torch.empty(size, device=device).bernoulli_(p)
def random_mat_apply(p, transform, prev, eye, device="cpu"):
size = transform.shape[0]
select = bernoulli_sample(size, p, device=device).view(size, 1, 1)
select_transform = select * transform + (1 - select) * eye
return select_transform @ prev
def sample_affine(p, size, height, width, device="cpu"):
G = torch.eye(3, device=device).unsqueeze(0).repeat(size, 1, 1)
eye = G
# flip
param = category_sample(size, (0, 1))
Gc = scale_mat(1 - 2.0 * param, torch.ones(size), device=device)
G = random_mat_apply(p, Gc, G, eye, device=device)
# print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n')
# 90 rotate
param = category_sample(size, (0, 3))
Gc = rotate_mat(-math.pi / 2 * param, device=device)
G = random_mat_apply(p, Gc, G, eye, device=device)
# print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n')
# integer translate
param = uniform_sample((2, size), -0.125, 0.125)
param_height = torch.round(param[0] * height)
param_width = torch.round(param[1] * width)
Gc = translate_mat(param_width, param_height, device=device)
G = random_mat_apply(p, Gc, G, eye, device=device)
# print('integer translate', G, translate_mat(param_width, param_height), sep='\n')
# isotropic scale
param = lognormal_sample(size, std=0.2 * math.log(2))
Gc = scale_mat(param, param, device=device)
G = random_mat_apply(p, Gc, G, eye, device=device)
# print('isotropic scale', G, scale_mat(param, param), sep='\n')
p_rot = 1 - math.sqrt(1 - p)
# pre-rotate
param = uniform_sample(size, -math.pi, math.pi)
Gc = rotate_mat(-param, device=device)
G = random_mat_apply(p_rot, Gc, G, eye, device=device)
# print('pre-rotate', G, rotate_mat(-param), sep='\n')
# anisotropic scale
param = lognormal_sample(size, std=0.2 * math.log(2))
Gc = scale_mat(param, 1 / param, device=device)
G = random_mat_apply(p, Gc, G, eye, device=device)
# print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n')
# post-rotate
param = uniform_sample(size, -math.pi, math.pi)
Gc = rotate_mat(-param, device=device)
G = random_mat_apply(p_rot, Gc, G, eye, device=device)
# print('post-rotate', G, rotate_mat(-param), sep='\n')
# fractional translate
param = normal_sample((2, size), std=0.125)
Gc = translate_mat(param[1] * width, param[0] * height, device=device)
G = random_mat_apply(p, Gc, G, eye, device=device)
# print('fractional translate', G, translate_mat(param, param), sep='\n')
return G
def sample_color(p, size):
C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1)
eye = C
axis_val = 1 / math.sqrt(3)
axis = (axis_val, axis_val, axis_val)
# brightness
param = normal_sample(size, std=0.2)
Cc = translate3d_mat(param, param, param)
C = random_mat_apply(p, Cc, C, eye)
# contrast
param = lognormal_sample(size, std=0.5 * math.log(2))
Cc = scale3d_mat(param, param, param)
C = random_mat_apply(p, Cc, C, eye)
# luma flip
param = category_sample(size, (0, 1))
Cc = luma_flip_mat(axis, param)
C = random_mat_apply(p, Cc, C, eye)
# hue rotation
param = uniform_sample(size, -math.pi, math.pi)
Cc = rotate3d_mat(axis, param)
C = random_mat_apply(p, Cc, C, eye)
# saturation
param = lognormal_sample(size, std=1 * math.log(2))
Cc = saturation_mat(axis, param)
C = random_mat_apply(p, Cc, C, eye)
return C
def make_grid(shape, x0, x1, y0, y1, device):
n, c, h, w = shape
grid = torch.empty(n, h, w, 3, device=device)
grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device)
grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1)
grid[:, :, :, 2] = 1
return grid
def affine_grid(grid, mat):
n, h, w, _ = grid.shape
return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2)
def get_padding(G, height, width, kernel_size):
device = G.device
cx = (width - 1) / 2
cy = (height - 1) / 2
cp = torch.tensor(
[(-cx, -cy, 1), (cx, -cy, 1), (cx, cy, 1), (-cx, cy, 1)], device=device
)
cp = G @ cp.T
pad_k = kernel_size // 4
pad = cp[:, :2, :].permute(1, 0, 2).flatten(1)
pad = torch.cat((-pad, pad)).max(1).values
pad = pad + torch.tensor([pad_k * 2 - cx, pad_k * 2 - cy] * 2, device=device)
pad = pad.max(torch.tensor([0, 0] * 2, device=device))
pad = pad.min(torch.tensor([width - 1, height - 1] * 2, device=device))
pad_x1, pad_y1, pad_x2, pad_y2 = pad.ceil().to(torch.int32)
return pad_x1, pad_x2, pad_y1, pad_y2
def try_sample_affine_and_pad(img, p, kernel_size, G=None):
batch, _, height, width = img.shape
G_try = G
if G is None:
G_try = torch.inverse(sample_affine(p, batch, height, width))
pad_x1, pad_x2, pad_y1, pad_y2 = get_padding(G_try, height, width, kernel_size)
img_pad = F.pad(img, (pad_x1, pad_x2, pad_y1, pad_y2), mode="reflect")
return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2)
class GridSampleForward(autograd.Function):
@staticmethod
def forward(ctx, input, grid):
out = F.grid_sample(
input, grid, mode="bilinear", padding_mode="zeros", align_corners=False
)
ctx.save_for_backward(input, grid)
return out
@staticmethod
def backward(ctx, grad_output):
input, grid = ctx.saved_tensors
grad_input, grad_grid = GridSampleBackward.apply(grad_output, input, grid)
return grad_input, grad_grid
class GridSampleBackward(autograd.Function):
@staticmethod
def forward(ctx, grad_output, input, grid):
op = torch._C._jit_get_operation("aten::grid_sampler_2d_backward")
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
ctx.save_for_backward(grid)
return grad_input, grad_grid
@staticmethod
def backward(ctx, grad_grad_input, grad_grad_grid):
(grid,) = ctx.saved_tensors
grad_grad_output = None
if ctx.needs_input_grad[0]:
grad_grad_output = GridSampleForward.apply(grad_grad_input, grid)
return grad_grad_output, None, None
grid_sample = GridSampleForward.apply
def scale_mat_single(s_x, s_y):
return torch.tensor(((s_x, 0, 0), (0, s_y, 0), (0, 0, 1)), dtype=torch.float32)
def translate_mat_single(t_x, t_y):
return torch.tensor(((1, 0, t_x), (0, 1, t_y), (0, 0, 1)), dtype=torch.float32)
def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6):
kernel = antialiasing_kernel
len_k = len(kernel)
kernel = torch.as_tensor(kernel).to(img)
# kernel = torch.ger(kernel, kernel).to(img)
kernel_flip = torch.flip(kernel, (0,))
img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad(
img, p, len_k, G
)
G_inv = (
translate_mat_single((pad_x1 - pad_x2).item() / 2, (pad_y1 - pad_y2).item() / 2)
@ G
)
up_pad = (
(len_k + 2 - 1) // 2,
(len_k - 2) // 2,
(len_k + 2 - 1) // 2,
(len_k - 2) // 2,
)
img_2x = upfirdn2d(img_pad, kernel.unsqueeze(0), up=(2, 1), pad=(*up_pad[:2], 0, 0))
img_2x = upfirdn2d(img_2x, kernel.unsqueeze(1), up=(1, 2), pad=(0, 0, *up_pad[2:]))
G_inv = scale_mat_single(2, 2) @ G_inv @ scale_mat_single(1 / 2, 1 / 2)
G_inv = translate_mat_single(-0.5, -0.5) @ G_inv @ translate_mat_single(0.5, 0.5)
batch_size, channel, height, width = img.shape
pad_k = len_k // 4
shape = (batch_size, channel, (height + pad_k * 2) * 2, (width + pad_k * 2) * 2)
G_inv = (
scale_mat_single(2 / img_2x.shape[3], 2 / img_2x.shape[2])
@ G_inv
@ scale_mat_single(1 / (2 / shape[3]), 1 / (2 / shape[2]))
)
grid = F.affine_grid(G_inv[:, :2, :].to(img_2x), shape, align_corners=False)
img_affine = grid_sample(img_2x, grid)
d_p = -pad_k * 2
down_pad = (
d_p + (len_k - 2 + 1) // 2,
d_p + (len_k - 2) // 2,
d_p + (len_k - 2 + 1) // 2,
d_p + (len_k - 2) // 2,
)
img_down = upfirdn2d(
img_affine, kernel_flip.unsqueeze(0), down=(2, 1), pad=(*down_pad[:2], 0, 0)
)
img_down = upfirdn2d(
img_down, kernel_flip.unsqueeze(1), down=(1, 2), pad=(0, 0, *down_pad[2:])
)
return img_down, G
def apply_color(img, mat):
batch = img.shape[0]
img = img.permute(0, 2, 3, 1)
mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3)
mat_add = mat[:, :3, 3].view(batch, 1, 1, 3)
img = img @ mat_mul + mat_add
img = img.permute(0, 3, 1, 2)
return img
def random_apply_color(img, p, C=None):
if C is None:
C = sample_color(p, img.shape[0])
img = apply_color(img, C.to(img))
return img, C
def augment(img, p, transform_matrix=(None, None)):
img, G = random_apply_affine(img, p, transform_matrix[0])
img, C = random_apply_color(img, p, transform_matrix[1])
return img, (G, C)
from .fused_act import FusedLeakyReLU, fused_leaky_relu
from .upfirdn2d import upfirdn2d
import contextlib
import warnings
import torch
from torch import autograd
from torch.nn import functional as F
enabled = True
weight_gradients_disabled = False
@contextlib.contextmanager
def no_weight_gradients():
global weight_gradients_disabled
old = weight_gradients_disabled
weight_gradients_disabled = True
yield
weight_gradients_disabled = old
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if could_use_op(input):
return conv2d_gradfix(
transpose=False,
weight_shape=weight.shape,
stride=stride,
padding=padding,
output_padding=0,
dilation=dilation,
groups=groups,
).apply(input, weight, bias)
return F.conv2d(
input=input,
weight=weight,
bias=bias,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
def conv_transpose2d(
input,
weight,
bias=None,
stride=1,
padding=0,
output_padding=0,
groups=1,
dilation=1,
):
if could_use_op(input):
return conv2d_gradfix(
transpose=True,
weight_shape=weight.shape,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
).apply(input, weight, bias)
return F.conv_transpose2d(
input=input,
weight=weight,
bias=bias,
stride=stride,
padding=padding,
output_padding=output_padding,
dilation=dilation,
groups=groups,
)
def could_use_op(input):
if (not enabled) or (not torch.backends.cudnn.enabled):
return False
if input.device.type != "cuda":
return False
if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
return True
warnings.warn(
f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
)
return False
def ensure_tuple(xs, ndim):
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
return xs
conv2d_gradfix_cache = dict()
def conv2d_gradfix(
transpose, weight_shape, stride, padding, output_padding, dilation, groups
):
ndim = 2
weight_shape = tuple(weight_shape)
stride = ensure_tuple(stride, ndim)
padding = ensure_tuple(padding, ndim)
output_padding = ensure_tuple(output_padding, ndim)
dilation = ensure_tuple(dilation, ndim)
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
if key in conv2d_gradfix_cache:
return conv2d_gradfix_cache[key]
common_kwargs = dict(
stride=stride, padding=padding, dilation=dilation, groups=groups
)
def calc_output_padding(input_shape, output_shape):
if transpose:
return [0, 0]
return [
input_shape[i + 2]
- (output_shape[i + 2] - 1) * stride[i]
- (1 - 2 * padding[i])
- dilation[i] * (weight_shape[i + 2] - 1)
for i in range(ndim)
]
class Conv2d(autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias):
if not transpose:
out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
else:
out = F.conv_transpose2d(
input=input,
weight=weight,
bias=bias,
output_padding=output_padding,
**common_kwargs,
)
ctx.save_for_backward(input, weight)
return out
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input, grad_weight, grad_bias = None, None, None
if ctx.needs_input_grad[0]:
p = calc_output_padding(
input_shape=input.shape, output_shape=grad_output.shape
)
grad_input = conv2d_gradfix(
transpose=(not transpose),
weight_shape=weight_shape,
output_padding=p,
**common_kwargs,
).apply(grad_output, weight, None)
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
grad_weight = Conv2dGradWeight.apply(grad_output, input)
if ctx.needs_input_grad[2]:
grad_bias = grad_output.sum((0, 2, 3))
return grad_input, grad_weight, grad_bias
class Conv2dGradWeight(autograd.Function):
@staticmethod
def forward(ctx, grad_output, input):
op = torch._C._jit_get_operation(
"aten::cudnn_convolution_backward_weight"
if not transpose
else "aten::cudnn_convolution_transpose_backward_weight"
)
flags = [
torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic,
torch.backends.cudnn.allow_tf32,
]
grad_weight = op(
weight_shape,
grad_output,
input,
padding,
stride,
dilation,
groups,
*flags,
)
ctx.save_for_backward(grad_output, input)
return grad_weight
@staticmethod
def backward(ctx, grad_grad_weight):
grad_output, input = ctx.saved_tensors
grad_grad_output, grad_grad_input = None, None
if ctx.needs_input_grad[0]:
grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
if ctx.needs_input_grad[1]:
p = calc_output_padding(
input_shape=input.shape, output_shape=grad_output.shape
)
grad_grad_input = conv2d_gradfix(
transpose=(not transpose),
weight_shape=weight_shape,
output_padding=p,
**common_kwargs,
).apply(grad_output, grad_grad_weight, None)
return grad_grad_output, grad_grad_input
conv2d_gradfix_cache[key] = Conv2d
return Conv2d
import os
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load
module_path = os.path.dirname(__file__)
fused = load(
"fused",
sources=[
os.path.join(module_path, "fused_bias_act.cpp"),
os.path.join(module_path, "fused_bias_act_kernel.cu"),
],
)
class FusedLeakyReLUFunctionBackward(Function):
@staticmethod
def forward(ctx, grad_output, out, bias, negative_slope, scale):
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
empty = grad_output.new_empty(0)
grad_input = fused.fused_bias_act(
grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
)
dim = [0]
if grad_input.ndim > 2:
dim += list(range(2, grad_input.ndim))
if bias:
grad_bias = grad_input.sum(dim).detach()
else:
grad_bias = empty
return grad_input, grad_bias
@staticmethod
def backward(ctx, gradgrad_input, gradgrad_bias):
out, = ctx.saved_tensors
gradgrad_out = fused.fused_bias_act(
gradgrad_input.contiguous(),
gradgrad_bias,
out,
3,
1,
ctx.negative_slope,
ctx.scale,
)
return gradgrad_out, None, None, None, None
class FusedLeakyReLUFunction(Function):
@staticmethod
def forward(ctx, input, bias, negative_slope, scale):
empty = input.new_empty(0)
ctx.bias = bias is not None
if bias is None:
bias = empty
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
return out
@staticmethod
def backward(ctx, grad_output):
out, = ctx.saved_tensors
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
)
if not ctx.bias:
grad_bias = None
return grad_input, grad_bias, None, None
class FusedLeakyReLU(nn.Module):
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
super().__init__()
if bias:
self.bias = nn.Parameter(torch.zeros(channel))
else:
self.bias = None
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
if input.device.type == "cpu":
if bias is not None:
rest_dim = [1] * (input.ndim - bias.ndim - 1)
return (
F.leaky_relu(
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
)
* scale
)
else:
return F.leaky_relu(input, negative_slope=0.2) * scale
else:
return FusedLeakyReLUFunction.apply(
input.contiguous(), bias, negative_slope, scale
)
#include <ATen/ATen.h>
#include <torch/extension.h>
torch::Tensor fused_bias_act_op(const torch::Tensor &input,
const torch::Tensor &bias,
const torch::Tensor &refer, int act, int grad,
float alpha, float scale);
#define CHECK_CUDA(x) \
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
torch::Tensor fused_bias_act(const torch::Tensor &input,
const torch::Tensor &bias,
const torch::Tensor &refer, int act, int grad,
float alpha, float scale) {
CHECK_INPUT(input);
CHECK_INPUT(bias);
at::DeviceGuard guard(input.device());
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
}
\ No newline at end of file
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <torch/types.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
template <typename scalar_t>
static __global__ void
fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
const scalar_t *p_ref, int act, int grad, scalar_t alpha,
scalar_t scale, int loop_x, int size_x, int step_b,
int size_b, int use_bias, int use_ref) {
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
scalar_t zero = 0.0;
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
loop_idx++, xi += blockDim.x) {
scalar_t x = p_x[xi];
if (use_bias) {
x += p_b[(xi / step_b) % size_b];
}
scalar_t ref = use_ref ? p_ref[xi] : zero;
scalar_t y;
switch (act * 10 + grad) {
default:
case 10:
y = x;
break;
case 11:
y = x;
break;
case 12:
y = 0.0;
break;
case 30:
y = (x > 0.0) ? x : x * alpha;
break;
case 31:
y = (ref > 0.0) ? x : x * alpha;
break;
case 32:
y = 0.0;
break;
}
out[xi] = y * scale;
}
}
torch::Tensor fused_bias_act_op(const torch::Tensor &input,
const torch::Tensor &bias,
const torch::Tensor &refer, int act, int grad,
float alpha, float scale) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto x = input.contiguous();
auto b = bias.contiguous();
auto ref = refer.contiguous();
int use_bias = b.numel() ? 1 : 0;
int use_ref = ref.numel() ? 1 : 0;
int size_x = x.numel();
int size_b = b.numel();
int step_b = 1;
for (int i = 1 + 1; i < x.dim(); i++) {
step_b *= x.size(i);
}
int loop_x = 4;
int block_size = 4 * 32;
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
auto y = torch::empty_like(x);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x.scalar_type(), "fused_bias_act_kernel", [&] {
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
});
return y;
}
\ No newline at end of file
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <torch/types.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/hip/HIPApplyUtils.cuh>
#include <ATen/hip/HIPContext.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
template <typename scalar_t>
static __global__ void
fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
const scalar_t *p_ref, int act, int grad, scalar_t alpha,
scalar_t scale, int loop_x, int size_x, int step_b,
int size_b, int use_bias, int use_ref) {
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
scalar_t zero = 0.0;
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
loop_idx++, xi += blockDim.x) {
scalar_t x = p_x[xi];
if (use_bias) {
x += p_b[(xi / step_b) % size_b];
}
scalar_t ref = use_ref ? p_ref[xi] : zero;
scalar_t y;
switch (act * 10 + grad) {
default:
case 10:
y = x;
break;
case 11:
y = x;
break;
case 12:
y = 0.0;
break;
case 30:
y = (x > 0.0) ? x : x * alpha;
break;
case 31:
y = (ref > 0.0) ? x : x * alpha;
break;
case 32:
y = 0.0;
break;
}
out[xi] = y * scale;
}
}
torch::Tensor fused_bias_act_op(const torch::Tensor &input,
const torch::Tensor &bias,
const torch::Tensor &refer, int act, int grad,
float alpha, float scale) {
int curDevice = -1;
hipGetDevice(&curDevice);
hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
auto x = input.contiguous();
auto b = bias.contiguous();
auto ref = refer.contiguous();
int use_bias = b.numel() ? 1 : 0;
int use_ref = ref.numel() ? 1 : 0;
int size_x = x.numel();
int size_b = b.numel();
int step_b = 1;
for (int i = 1 + 1; i < x.dim(); i++) {
step_b *= x.size(i);
}
int loop_x = 4;
int block_size = 4 * 32;
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
auto y = torch::empty_like(x);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x.scalar_type(), "fused_bias_act_kernel", [&] {
hipLaunchKernelGGL(( fused_bias_act_kernel<scalar_t>), dim3(grid_size), dim3(block_size), 0, stream,
y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
});
return y;
}
\ No newline at end of file
#include <ATen/ATen.h>
#include <torch/extension.h>
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
const torch::Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1);
#define CHECK_CUDA(x) \
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
int up_x, int up_y, int down_x, int down_y, int pad_x0,
int pad_x1, int pad_y0, int pad_y1) {
CHECK_INPUT(input);
CHECK_INPUT(kernel);
at::DeviceGuard guard(input.device());
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
pad_y0, pad_y1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
}
\ No newline at end of file
from collections import abc
import os
import torch
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load
module_path = os.path.dirname(__file__)
upfirdn2d_op = load(
"upfirdn2d",
sources=[
os.path.join(module_path, "upfirdn2d.cpp"),
os.path.join(module_path, "upfirdn2d_kernel.cu"),
],
)
class UpFirDn2dBackward(Function):
@staticmethod
def forward(
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
):
up_x, up_y = up
down_x, down_y = down
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
grad_input = upfirdn2d_op.upfirdn2d(
grad_output,
grad_kernel,
down_x,
down_y,
up_x,
up_y,
g_pad_x0,
g_pad_x1,
g_pad_y0,
g_pad_y1,
)
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
ctx.save_for_backward(kernel)
pad_x0, pad_x1, pad_y0, pad_y1 = pad
ctx.up_x = up_x
ctx.up_y = up_y
ctx.down_x = down_x
ctx.down_y = down_y
ctx.pad_x0 = pad_x0
ctx.pad_x1 = pad_x1
ctx.pad_y0 = pad_y0
ctx.pad_y1 = pad_y1
ctx.in_size = in_size
ctx.out_size = out_size
return grad_input
@staticmethod
def backward(ctx, gradgrad_input):
kernel, = ctx.saved_tensors
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
gradgrad_out = upfirdn2d_op.upfirdn2d(
gradgrad_input,
kernel,
ctx.up_x,
ctx.up_y,
ctx.down_x,
ctx.down_y,
ctx.pad_x0,
ctx.pad_x1,
ctx.pad_y0,
ctx.pad_y1,
)
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
gradgrad_out = gradgrad_out.view(
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
)
return gradgrad_out, None, None, None, None, None, None, None, None
class UpFirDn2d(Function):
@staticmethod
def forward(ctx, input, kernel, up, down, pad):
up_x, up_y = up
down_x, down_y = down
pad_x0, pad_x1, pad_y0, pad_y1 = pad
kernel_h, kernel_w = kernel.shape
batch, channel, in_h, in_w = input.shape
ctx.in_size = input.shape
input = input.reshape(-1, in_h, in_w, 1)
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
ctx.out_size = (out_h, out_w)
ctx.up = (up_x, up_y)
ctx.down = (down_x, down_y)
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
g_pad_x0 = kernel_w - pad_x0 - 1
g_pad_y0 = kernel_h - pad_y0 - 1
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
out = upfirdn2d_op.upfirdn2d(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
)
# out = out.view(major, out_h, out_w, minor)
out = out.view(-1, channel, out_h, out_w)
return out
@staticmethod
def backward(ctx, grad_output):
kernel, grad_kernel = ctx.saved_tensors
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = UpFirDn2dBackward.apply(
grad_output,
kernel,
grad_kernel,
ctx.up,
ctx.down,
ctx.pad,
ctx.g_pad,
ctx.in_size,
ctx.out_size,
)
return grad_input, None, None, None, None
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
if not isinstance(up, abc.Iterable):
up = (up, up)
if not isinstance(down, abc.Iterable):
down = (down, down)
if len(pad) == 2:
pad = (pad[0], pad[1], pad[0], pad[1])
if input.device.type == "cpu":
out = upfirdn2d_native(input, kernel, *up, *down, *pad)
else:
out = UpFirDn2d.apply(input, kernel, up, down, pad)
return out
def upfirdn2d_native(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
)
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape(
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
)
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
return out.view(-1, channel, out_h, out_w)
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <torch/types.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
int c = a / b;
if (c * b > a) {
c--;
}
return c;
}
struct UpFirDn2DKernelParams {
int up_x;
int up_y;
int down_x;
int down_y;
int pad_x0;
int pad_x1;
int pad_y0;
int pad_y1;
int major_dim;
int in_h;
int in_w;
int minor_dim;
int kernel_h;
int kernel_w;
int out_h;
int out_w;
int loop_major;
int loop_x;
};
template <typename scalar_t>
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
int out_y = minor_idx / p.minor_dim;
minor_idx -= out_y * p.minor_dim;
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
int major_idx_base = blockIdx.z * p.loop_major;
if (out_x_base >= p.out_w || out_y >= p.out_h ||
major_idx_base >= p.major_dim) {
return;
}
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major && major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, out_x = out_x_base;
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
const scalar_t *x_p =
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
minor_idx];
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
int x_px = p.minor_dim;
int k_px = -p.up_x;
int x_py = p.in_w * p.minor_dim;
int k_py = -p.up_y * p.kernel_w;
scalar_t v = 0.0f;
for (int y = 0; y < h; y++) {
for (int x = 0; x < w; x++) {
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
x_p += x_px;
k_p += k_px;
}
x_p += x_py - w * x_px;
k_p += k_py - w * k_px;
}
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
__shared__ volatile float sk[kernel_h][kernel_w];
__shared__ volatile float sx[tile_in_h][tile_in_w];
int minor_idx = blockIdx.x;
int tile_out_y = minor_idx / p.minor_dim;
minor_idx -= tile_out_y * p.minor_dim;
tile_out_y *= tile_out_h;
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
int major_idx_base = blockIdx.z * p.loop_major;
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
major_idx_base >= p.major_dim) {
return;
}
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
tap_idx += blockDim.x) {
int ky = tap_idx / kernel_w;
int kx = tap_idx - ky * kernel_w;
scalar_t v = 0.0;
if (kx < p.kernel_w & ky < p.kernel_h) {
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
}
sk[ky][kx] = v;
}
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major & major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, tile_out_x = tile_out_x_base;
loop_x < p.loop_x & tile_out_x < p.out_w;
loop_x++, tile_out_x += tile_out_w) {
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
int tile_in_x = floor_div(tile_mid_x, up_x);
int tile_in_y = floor_div(tile_mid_y, up_y);
__syncthreads();
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
in_idx += blockDim.x) {
int rel_in_y = in_idx / tile_in_w;
int rel_in_x = in_idx - rel_in_y * tile_in_w;
int in_x = rel_in_x + tile_in_x;
int in_y = rel_in_y + tile_in_y;
scalar_t v = 0.0;
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
p.minor_dim +
minor_idx];
}
sx[rel_in_y][rel_in_x] = v;
}
__syncthreads();
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
out_idx += blockDim.x) {
int rel_out_y = out_idx / tile_out_w;
int rel_out_x = out_idx - rel_out_y * tile_out_w;
int out_x = rel_out_x + tile_out_x;
int out_y = rel_out_y + tile_out_y;
int mid_x = tile_mid_x + rel_out_x * down_x;
int mid_y = tile_mid_y + rel_out_y * down_y;
int in_x = floor_div(mid_x, up_x);
int in_y = floor_div(mid_y, up_y);
int rel_in_x = in_x - tile_in_x;
int rel_in_y = in_y - tile_in_y;
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
scalar_t v = 0.0;
#pragma unroll
for (int y = 0; y < kernel_h / up_y; y++)
#pragma unroll
for (int x = 0; x < kernel_w / up_x; x++)
v += sx[rel_in_y + y][rel_in_x + x] *
sk[kernel_y + y * up_y][kernel_x + x * up_x];
if (out_x < p.out_w & out_y < p.out_h) {
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
}
}
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
const torch::Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
UpFirDn2DKernelParams p;
auto x = input.contiguous();
auto k = kernel.contiguous();
p.major_dim = x.size(0);
p.in_h = x.size(1);
p.in_w = x.size(2);
p.minor_dim = x.size(3);
p.kernel_h = k.size(0);
p.kernel_w = k.size(1);
p.up_x = up_x;
p.up_y = up_y;
p.down_x = down_x;
p.down_y = down_y;
p.pad_x0 = pad_x0;
p.pad_x1 = pad_x1;
p.pad_y0 = pad_y0;
p.pad_y1 = pad_y1;
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
p.down_y;
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
p.down_x;
auto out =
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
int mode = -1;
int tile_out_h = -1;
int tile_out_w = -1;
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 1;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 3 && p.kernel_w <= 3) {
mode = 2;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 3;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 4;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 5;
tile_out_h = 8;
tile_out_w = 32;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 6;
tile_out_h = 8;
tile_out_w = 32;
}
dim3 block_size;
dim3 grid_size;
if (tile_out_h > 0 && tile_out_w > 0) {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 1;
block_size = dim3(32 * 8, 1, 1);
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
(p.major_dim - 1) / p.loop_major + 1);
} else {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 4;
block_size = dim3(4, 32, 1);
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
(p.major_dim - 1) / p.loop_major + 1);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
switch (mode) {
case 1:
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 2:
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 3:
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 4:
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 5:
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 6:
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
default:
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
}
});
return out;
}
\ No newline at end of file
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <torch/types.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/hip/HIPApplyUtils.cuh>
#include <ATen/hip/HIPContext.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
int c = a / b;
if (c * b > a) {
c--;
}
return c;
}
struct UpFirDn2DKernelParams {
int up_x;
int up_y;
int down_x;
int down_y;
int pad_x0;
int pad_x1;
int pad_y0;
int pad_y1;
int major_dim;
int in_h;
int in_w;
int minor_dim;
int kernel_h;
int kernel_w;
int out_h;
int out_w;
int loop_major;
int loop_x;
};
template <typename scalar_t>
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
int out_y = minor_idx / p.minor_dim;
minor_idx -= out_y * p.minor_dim;
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
int major_idx_base = blockIdx.z * p.loop_major;
if (out_x_base >= p.out_w || out_y >= p.out_h ||
major_idx_base >= p.major_dim) {
return;
}
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major && major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, out_x = out_x_base;
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
const scalar_t *x_p =
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
minor_idx];
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
int x_px = p.minor_dim;
int k_px = -p.up_x;
int x_py = p.in_w * p.minor_dim;
int k_py = -p.up_y * p.kernel_w;
scalar_t v = 0.0f;
for (int y = 0; y < h; y++) {
for (int x = 0; x < w; x++) {
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
x_p += x_px;
k_p += k_px;
}
x_p += x_py - w * x_px;
k_p += k_py - w * k_px;
}
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
__shared__ volatile float sk[kernel_h][kernel_w];
__shared__ volatile float sx[tile_in_h][tile_in_w];
int minor_idx = blockIdx.x;
int tile_out_y = minor_idx / p.minor_dim;
minor_idx -= tile_out_y * p.minor_dim;
tile_out_y *= tile_out_h;
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
int major_idx_base = blockIdx.z * p.loop_major;
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
major_idx_base >= p.major_dim) {
return;
}
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
tap_idx += blockDim.x) {
int ky = tap_idx / kernel_w;
int kx = tap_idx - ky * kernel_w;
scalar_t v = 0.0;
if (kx < p.kernel_w & ky < p.kernel_h) {
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
}
sk[ky][kx] = v;
}
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major & major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, tile_out_x = tile_out_x_base;
loop_x < p.loop_x & tile_out_x < p.out_w;
loop_x++, tile_out_x += tile_out_w) {
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
int tile_in_x = floor_div(tile_mid_x, up_x);
int tile_in_y = floor_div(tile_mid_y, up_y);
__syncthreads();
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
in_idx += blockDim.x) {
int rel_in_y = in_idx / tile_in_w;
int rel_in_x = in_idx - rel_in_y * tile_in_w;
int in_x = rel_in_x + tile_in_x;
int in_y = rel_in_y + tile_in_y;
scalar_t v = 0.0;
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
p.minor_dim +
minor_idx];
}
sx[rel_in_y][rel_in_x] = v;
}
__syncthreads();
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
out_idx += blockDim.x) {
int rel_out_y = out_idx / tile_out_w;
int rel_out_x = out_idx - rel_out_y * tile_out_w;
int out_x = rel_out_x + tile_out_x;
int out_y = rel_out_y + tile_out_y;
int mid_x = tile_mid_x + rel_out_x * down_x;
int mid_y = tile_mid_y + rel_out_y * down_y;
int in_x = floor_div(mid_x, up_x);
int in_y = floor_div(mid_y, up_y);
int rel_in_x = in_x - tile_in_x;
int rel_in_y = in_y - tile_in_y;
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
scalar_t v = 0.0;
#pragma unroll
for (int y = 0; y < kernel_h / up_y; y++)
#pragma unroll
for (int x = 0; x < kernel_w / up_x; x++)
v += sx[rel_in_y + y][rel_in_x + x] *
sk[kernel_y + y * up_y][kernel_x + x * up_x];
if (out_x < p.out_w & out_y < p.out_h) {
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
}
}
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
const torch::Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1) {
int curDevice = -1;
hipGetDevice(&curDevice);
hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
UpFirDn2DKernelParams p;
auto x = input.contiguous();
auto k = kernel.contiguous();
p.major_dim = x.size(0);
p.in_h = x.size(1);
p.in_w = x.size(2);
p.minor_dim = x.size(3);
p.kernel_h = k.size(0);
p.kernel_w = k.size(1);
p.up_x = up_x;
p.up_y = up_y;
p.down_x = down_x;
p.down_y = down_y;
p.pad_x0 = pad_x0;
p.pad_x1 = pad_x1;
p.pad_y0 = pad_y0;
p.pad_y1 = pad_y1;
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
p.down_y;
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
p.down_x;
auto out =
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
int mode = -1;
int tile_out_h = -1;
int tile_out_w = -1;
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 1;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 3 && p.kernel_w <= 3) {
mode = 2;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 3;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 4;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 5;
tile_out_h = 8;
tile_out_w = 32;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 6;
tile_out_h = 8;
tile_out_w = 32;
}
dim3 block_size;
dim3 grid_size;
if (tile_out_h > 0 && tile_out_w > 0) {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 1;
block_size = dim3(32 * 8, 1, 1);
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
(p.major_dim - 1) / p.loop_major + 1);
} else {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 4;
block_size = dim3(4, 32, 1);
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
(p.major_dim - 1) / p.loop_major + 1);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
switch (mode) {
case 1:
hipLaunchKernelGGL(( upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>)
, dim3(grid_size), dim3(block_size), 0, stream, out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 2:
hipLaunchKernelGGL(( upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>)
, dim3(grid_size), dim3(block_size), 0, stream, out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 3:
hipLaunchKernelGGL(( upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>)
, dim3(grid_size), dim3(block_size), 0, stream, out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 4:
hipLaunchKernelGGL(( upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>)
, dim3(grid_size), dim3(block_size), 0, stream, out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 5:
hipLaunchKernelGGL(( upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>)
, dim3(grid_size), dim3(block_size), 0, stream, out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 6:
hipLaunchKernelGGL(( upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>)
, dim3(grid_size), dim3(block_size), 0, stream, out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
default:
hipLaunchKernelGGL(( upfirdn2d_kernel_large<scalar_t>), dim3(grid_size), dim3(block_size), 0, stream,
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
}
});
return out;
}
\ No newline at end of file
import argparse
import torch
from torch.nn import functional as F
import numpy as np
from tqdm import tqdm
import lpips
from model import Generator
def normalize(x):
return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True))
def slerp(a, b, t):
a = normalize(a)
b = normalize(b)
d = (a * b).sum(-1, keepdim=True)
p = t * torch.acos(d)
c = normalize(b - d * a)
d = a * torch.cos(p) + c * torch.sin(p)
return normalize(d)
def lerp(a, b, t):
return a + (b - a) * t
if __name__ == "__main__":
device = "cuda"
parser = argparse.ArgumentParser(description="Perceptual Path Length calculator")
parser.add_argument(
"--space", choices=["z", "w"], help="space that PPL calculated with"
)
parser.add_argument(
"--batch", type=int, default=64, help="batch size for the models"
)
parser.add_argument(
"--n_sample",
type=int,
default=5000,
help="number of the samples for calculating PPL",
)
parser.add_argument(
"--size", type=int, default=256, help="output image sizes of the generator"
)
parser.add_argument(
"--eps", type=float, default=1e-4, help="epsilon for numerical stability"
)
parser.add_argument(
"--crop", action="store_true", help="apply center crop to the images"
)
parser.add_argument(
"--sampling",
default="end",
choices=["end", "full"],
help="set endpoint sampling method",
)
parser.add_argument(
"ckpt", metavar="CHECKPOINT", help="path to the model checkpoints"
)
args = parser.parse_args()
latent_dim = 512
ckpt = torch.load(args.ckpt)
g = Generator(args.size, latent_dim, 8).to(device)
g.load_state_dict(ckpt["g_ema"])
g.eval()
percept = lpips.PerceptualLoss(
model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
)
distances = []
n_batch = args.n_sample // args.batch
resid = args.n_sample - (n_batch * args.batch)
batch_sizes = [args.batch] * n_batch + [resid]
with torch.no_grad():
for batch in tqdm(batch_sizes):
noise = g.make_noise()
inputs = torch.randn([batch * 2, latent_dim], device=device)
if args.sampling == "full":
lerp_t = torch.rand(batch, device=device)
else:
lerp_t = torch.zeros(batch, device=device)
if args.space == "w":
latent = g.get_latent(inputs)
latent_t0, latent_t1 = latent[::2], latent[1::2]
latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None])
latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps)
latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape)
image, _ = g([latent_e], input_is_latent=True, noise=noise)
if args.crop:
c = image.shape[2] // 8
image = image[:, :, c * 3 : c * 7, c * 2 : c * 6]
factor = image.shape[2] // 256
if factor > 1:
image = F.interpolate(
image, size=(256, 256), mode="bilinear", align_corners=False
)
dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / (
args.eps ** 2
)
distances.append(dist.to("cpu").numpy())
distances = np.concatenate(distances, 0)
lo = np.percentile(distances, 1, interpolation="lower")
hi = np.percentile(distances, 99, interpolation="higher")
filtered_dist = np.extract(
np.logical_and(lo <= distances, distances <= hi), distances
)
print("ppl:", filtered_dist.mean())
import argparse
from io import BytesIO
import multiprocessing
from functools import partial
from PIL import Image
import lmdb
from tqdm import tqdm
from torchvision import datasets
from torchvision.transforms import functional as trans_fn
def resize_and_convert(img, size, resample, quality=100):
img = trans_fn.resize(img, size, resample)
img = trans_fn.center_crop(img, size)
buffer = BytesIO()
img.save(buffer, format="jpeg", quality=quality)
val = buffer.getvalue()
return val
def resize_multiple(
img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100
):
imgs = []
for size in sizes:
imgs.append(resize_and_convert(img, size, resample, quality))
return imgs
def resize_worker(img_file, sizes, resample):
i, file = img_file
img = Image.open(file)
img = img.convert("RGB")
out = resize_multiple(img, sizes=sizes, resample=resample)
return i, out
def prepare(
env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
):
resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
files = sorted(dataset.imgs, key=lambda x: x[0])
files = [(i, file) for i, (file, label) in enumerate(files)]
total = 0
with multiprocessing.Pool(n_worker) as pool:
for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
for size, img in zip(sizes, imgs):
key = f"{size}-{str(i).zfill(5)}".encode("utf-8")
with env.begin(write=True) as txn:
txn.put(key, img)
total += 1
with env.begin(write=True) as txn:
txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Preprocess images for model training")
parser.add_argument("--out", type=str, help="filename of the result lmdb dataset")
parser.add_argument(
"--size",
type=str,
default="128,256,512,1024",
help="resolutions of images for the dataset",
)
parser.add_argument(
"--n_worker",
type=int,
default=8,
help="number of workers for preparing dataset",
)
parser.add_argument(
"--resample",
type=str,
default="lanczos",
help="resampling methods for resizing images",
)
parser.add_argument("path", type=str, help="path to the image dataset")
args = parser.parse_args()
resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
resample = resample_map[args.resample]
sizes = [int(s.strip()) for s in args.size.split(",")]
print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
imgset = datasets.ImageFolder(args.path)
with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)
import argparse
import math
import os
import torch
from torch import optim
from torch.nn import functional as F
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import lpips
from model import Generator
def noise_regularize(noises):
loss = 0
for noise in noises:
size = noise.shape[2]
while True:
loss = (
loss
+ (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
+ (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
)
if size <= 8:
break
noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
noise = noise.mean([3, 5])
size //= 2
return loss
def noise_normalize_(noises):
for noise in noises:
mean = noise.mean()
std = noise.std()
noise.data.add_(-mean).div_(std)
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
lr_ramp = min(1, (1 - t) / rampdown)
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
lr_ramp = lr_ramp * min(1, t / rampup)
return initial_lr * lr_ramp
def latent_noise(latent, strength):
noise = torch.randn_like(latent) * strength
return latent + noise
def make_image(tensor):
return (
tensor.detach()
.clamp_(min=-1, max=1)
.add(1)
.div_(2)
.mul(255)
.type(torch.uint8)
.permute(0, 2, 3, 1)
.to("cpu")
.numpy()
)
if __name__ == "__main__":
device = "cuda"
parser = argparse.ArgumentParser(
description="Image projector to the generator latent spaces"
)
parser.add_argument(
"--ckpt", type=str, required=True, help="path to the model checkpoint"
)
parser.add_argument(
"--size", type=int, default=256, help="output image sizes of the generator"
)
parser.add_argument(
"--lr_rampup",
type=float,
default=0.05,
help="duration of the learning rate warmup",
)
parser.add_argument(
"--lr_rampdown",
type=float,
default=0.25,
help="duration of the learning rate decay",
)
parser.add_argument("--lr", type=float, default=0.1, help="learning rate")
parser.add_argument(
"--noise", type=float, default=0.05, help="strength of the noise level"
)
parser.add_argument(
"--noise_ramp",
type=float,
default=0.75,
help="duration of the noise level decay",
)
parser.add_argument("--step", type=int, default=1000, help="optimize iterations")
parser.add_argument(
"--noise_regularize",
type=float,
default=1e5,
help="weight of the noise regularization",
)
parser.add_argument("--mse", type=float, default=0, help="weight of the mse loss")
parser.add_argument(
"--w_plus",
action="store_true",
help="allow to use distinct latent codes to each layers",
)
parser.add_argument(
"files", metavar="FILES", nargs="+", help="path to image files to be projected"
)
args = parser.parse_args()
n_mean_latent = 10000
resize = min(args.size, 256)
transform = transforms.Compose(
[
transforms.Resize(resize),
transforms.CenterCrop(resize),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
imgs = []
for imgfile in args.files:
img = transform(Image.open(imgfile).convert("RGB"))
imgs.append(img)
imgs = torch.stack(imgs, 0).to(device)
g_ema = Generator(args.size, 512, 8)
g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
g_ema.eval()
g_ema = g_ema.to(device)
with torch.no_grad():
noise_sample = torch.randn(n_mean_latent, 512, device=device)
latent_out = g_ema.style(noise_sample)
latent_mean = latent_out.mean(0)
latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5
percept = lpips.PerceptualLoss(
model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
)
noises_single = g_ema.make_noise()
noises = []
for noise in noises_single:
noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())
latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1)
if args.w_plus:
latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
latent_in.requires_grad = True
for noise in noises:
noise.requires_grad = True
optimizer = optim.Adam([latent_in] + noises, lr=args.lr)
pbar = tqdm(range(args.step))
latent_path = []
for i in pbar:
t = i / args.step
lr = get_lr(t, args.lr)
optimizer.param_groups[0]["lr"] = lr
noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2
latent_n = latent_noise(latent_in, noise_strength.item())
img_gen, _ = g_ema([latent_n], input_is_latent=True, noise=noises)
batch, channel, height, width = img_gen.shape
if height > 256:
factor = height // 256
img_gen = img_gen.reshape(
batch, channel, height // factor, factor, width // factor, factor
)
img_gen = img_gen.mean([3, 5])
p_loss = percept(img_gen, imgs).sum()
n_loss = noise_regularize(noises)
mse_loss = F.mse_loss(img_gen, imgs)
loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
noise_normalize_(noises)
if (i + 1) % 100 == 0:
latent_path.append(latent_in.detach().clone())
pbar.set_description(
(
f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};"
f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}"
)
)
img_gen, _ = g_ema([latent_path[-1]], input_is_latent=True, noise=noises)
filename = os.path.splitext(os.path.basename(args.files[0]))[0] + ".pt"
img_ar = make_image(img_gen)
result_file = {}
for i, input_name in enumerate(args.files):
noise_single = []
for noise in noises:
noise_single.append(noise[i : i + 1])
result_file[input_name] = {
"img": img_gen[i],
"latent": latent_in[i],
"noise": noise_single,
}
img_name = os.path.splitext(os.path.basename(input_name))[0] + "-project.png"
pil_img = Image.fromarray(img_ar[i])
pil_img.save(img_name)
torch.save(result_file, filename)
lmdb
\ No newline at end of file
import math
import random
import functools
import operator
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
from model import (
ModulatedConv2d,
StyledConv,
ConstantInput,
PixelNorm,
Upsample,
Downsample,
Blur,
EqualLinear,
ConvLayer,
)
def get_haar_wavelet(in_channels):
haar_wav_l = 1 / (2 ** 0.5) * torch.ones(1, 2)
haar_wav_h = 1 / (2 ** 0.5) * torch.ones(1, 2)
haar_wav_h[0, 0] = -1 * haar_wav_h[0, 0]
haar_wav_ll = haar_wav_l.T * haar_wav_l
haar_wav_lh = haar_wav_h.T * haar_wav_l
haar_wav_hl = haar_wav_l.T * haar_wav_h
haar_wav_hh = haar_wav_h.T * haar_wav_h
return haar_wav_ll, haar_wav_lh, haar_wav_hl, haar_wav_hh
def dwt_init(x):
x01 = x[:, :, 0::2, :] / 2
x02 = x[:, :, 1::2, :] / 2
x1 = x01[:, :, :, 0::2]
x2 = x02[:, :, :, 0::2]
x3 = x01[:, :, :, 1::2]
x4 = x02[:, :, :, 1::2]
x_LL = x1 + x2 + x3 + x4
x_HL = -x1 - x2 + x3 + x4
x_LH = -x1 + x2 - x3 + x4
x_HH = x1 - x2 - x3 + x4
return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
def iwt_init(x):
r = 2
in_batch, in_channel, in_height, in_width = x.size()
# print([in_batch, in_channel, in_height, in_width])
out_batch, out_channel, out_height, out_width = (
in_batch,
int(in_channel / (r ** 2)),
r * in_height,
r * in_width,
)
x1 = x[:, 0:out_channel, :, :] / 2
x2 = x[:, out_channel : out_channel * 2, :, :] / 2
x3 = x[:, out_channel * 2 : out_channel * 3, :, :] / 2
x4 = x[:, out_channel * 3 : out_channel * 4, :, :] / 2
h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()
h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
return h
class HaarTransform(nn.Module):
def __init__(self, in_channels):
super().__init__()
ll, lh, hl, hh = get_haar_wavelet(in_channels)
self.register_buffer("ll", ll)
self.register_buffer("lh", lh)
self.register_buffer("hl", hl)
self.register_buffer("hh", hh)
def forward(self, input):
ll = upfirdn2d(input, self.ll, down=2)
lh = upfirdn2d(input, self.lh, down=2)
hl = upfirdn2d(input, self.hl, down=2)
hh = upfirdn2d(input, self.hh, down=2)
return torch.cat((ll, lh, hl, hh), 1)
class InverseHaarTransform(nn.Module):
def __init__(self, in_channels):
super().__init__()
ll, lh, hl, hh = get_haar_wavelet(in_channels)
self.register_buffer("ll", ll)
self.register_buffer("lh", -lh)
self.register_buffer("hl", -hl)
self.register_buffer("hh", hh)
def forward(self, input):
ll, lh, hl, hh = input.chunk(4, 1)
ll = upfirdn2d(ll, self.ll, up=2, pad=(1, 0, 1, 0))
lh = upfirdn2d(lh, self.lh, up=2, pad=(1, 0, 1, 0))
hl = upfirdn2d(hl, self.hl, up=2, pad=(1, 0, 1, 0))
hh = upfirdn2d(hh, self.hh, up=2, pad=(1, 0, 1, 0))
return ll + lh + hl + hh
class ToRGB(nn.Module):
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
super().__init__()
if upsample:
self.iwt = InverseHaarTransform(3)
self.upsample = Upsample(blur_kernel)
self.dwt = HaarTransform(3)
self.conv = ModulatedConv2d(in_channel, 3 * 4, 1, style_dim, demodulate=False)
self.bias = nn.Parameter(torch.zeros(1, 3 * 4, 1, 1))
def forward(self, input, style, skip=None):
out = self.conv(input, style)
out = out + self.bias
if skip is not None:
skip = self.iwt(skip)
skip = self.upsample(skip)
skip = self.dwt(skip)
out = out + skip
return out
class Generator(nn.Module):
def __init__(
self,
size,
style_dim,
n_mlp,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
):
super().__init__()
self.size = size
self.style_dim = style_dim
layers = [PixelNorm()]
for i in range(n_mlp):
layers.append(
EqualLinear(
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
)
)
self.style = nn.Sequential(*layers)
self.channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
self.input = ConstantInput(self.channels[4])
self.conv1 = StyledConv(
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
)
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
self.log_size = int(math.log(size, 2)) - 1
self.num_layers = (self.log_size - 2) * 2 + 1
self.convs = nn.ModuleList()
self.upsamples = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
self.noises = nn.Module()
in_channel = self.channels[4]
for layer_idx in range(self.num_layers):
res = (layer_idx + 5) // 2
shape = [1, 1, 2 ** res, 2 ** res]
self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
for i in range(3, self.log_size + 1):
out_channel = self.channels[2 ** i]
self.convs.append(
StyledConv(
in_channel,
out_channel,
3,
style_dim,
upsample=True,
blur_kernel=blur_kernel,
)
)
self.convs.append(
StyledConv(
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
)
)
self.to_rgbs.append(ToRGB(out_channel, style_dim))
in_channel = out_channel
self.iwt = InverseHaarTransform(3)
self.n_latent = self.log_size * 2 - 2
def make_noise(self):
device = self.input.input.device
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
return noises
def mean_latent(self, n_latent):
latent_in = torch.randn(
n_latent, self.style_dim, device=self.input.input.device
)
latent = self.style(latent_in).mean(0, keepdim=True)
return latent
def get_latent(self, input):
return self.style(input)
def forward(
self,
styles,
return_latents=False,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
noise=None,
randomize_noise=True,
):
if not input_is_latent:
styles = [self.style(s) for s in styles]
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers
else:
noise = [
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
]
if truncation < 1:
style_t = []
for style in styles:
style_t.append(
truncation_latent + truncation * (style - truncation_latent)
)
styles = style_t
if len(styles) < 2:
inject_index = self.n_latent
if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else:
latent = styles[0]
else:
if inject_index is None:
inject_index = random.randint(1, self.n_latent - 1)
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
latent = torch.cat([latent, latent2], 1)
out = self.input(latent)
out = self.conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
i += 2
image = self.iwt(skip)
if return_latents:
return image, latent
else:
return image, None
class ConvBlock(nn.Module):
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
super().__init__()
self.conv1 = ConvLayer(in_channel, in_channel, 3)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)
return out
class FromRGB(nn.Module):
def __init__(self, out_channel, downsample=True, blur_kernel=[1, 3, 3, 1]):
super().__init__()
self.downsample = downsample
if downsample:
self.iwt = InverseHaarTransform(3)
self.downsample = Downsample(blur_kernel)
self.dwt = HaarTransform(3)
self.conv = ConvLayer(3 * 4, out_channel, 3)
def forward(self, input, skip=None):
if self.downsample:
input = self.iwt(input)
input = self.downsample(input)
input = self.dwt(input)
out = self.conv(input)
if skip is not None:
out = out + skip
return input, out
class Discriminator(nn.Module):
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
super().__init__()
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
self.dwt = HaarTransform(3)
self.from_rgbs = nn.ModuleList()
self.convs = nn.ModuleList()
log_size = int(math.log(size, 2)) - 1
in_channel = channels[size]
for i in range(log_size, 2, -1):
out_channel = channels[2 ** (i - 1)]
self.from_rgbs.append(FromRGB(in_channel, downsample=i != log_size))
self.convs.append(ConvBlock(in_channel, out_channel, blur_kernel))
in_channel = out_channel
self.from_rgbs.append(FromRGB(channels[4]))
self.stddev_group = 4
self.stddev_feat = 1
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
self.final_linear = nn.Sequential(
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
EqualLinear(channels[4], 1),
)
def forward(self, input):
input = self.dwt(input)
out = None
for from_rgb, conv in zip(self.from_rgbs, self.convs):
input, out = from_rgb(input, out)
out = conv(out)
_, out = self.from_rgbs[-1](input, out)
batch, channel, height, width = out.shape
group = min(batch, self.stddev_group)
stddev = out.view(
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
)
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
stddev = stddev.repeat(group, 1, height, width)
out = torch.cat([out, stddev], 1)
out = self.final_conv(out)
out = out.view(batch, -1)
out = self.final_linear(out)
return out
import argparse
import math
import random
import os
import numpy as np
import torch
from torch import nn, autograd, optim
from torch.nn import functional as F
from torch.utils import data
import torch.distributed as dist
from torchvision import transforms, utils
from tqdm import tqdm
try:
import wandb
except ImportError:
wandb = None
from dataset import MultiResolutionDataset
from distributed import (
get_rank,
synchronize,
reduce_loss_dict,
reduce_sum,
get_world_size,
)
from op import conv2d_gradfix
from non_leaking import augment, AdaptiveAugment
def data_sampler(dataset, shuffle, distributed):
if distributed:
return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
return data.RandomSampler(dataset)
else:
return data.SequentialSampler(dataset)
def requires_grad(model, flag=True):
for p in model.parameters():
p.requires_grad = flag
def accumulate(model1, model2, decay=0.999):
par1 = dict(model1.named_parameters())
par2 = dict(model2.named_parameters())
for k in par1.keys():
par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
def sample_data(loader):
while True:
for batch in loader:
yield batch
def d_logistic_loss(real_pred, fake_pred):
real_loss = F.softplus(-real_pred)
fake_loss = F.softplus(fake_pred)
return real_loss.mean() + fake_loss.mean()
def d_r1_loss(real_pred, real_img):
with conv2d_gradfix.no_weight_gradients():
grad_real, = autograd.grad(
outputs=real_pred.sum(), inputs=real_img, create_graph=True
)
grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty
def g_nonsaturating_loss(fake_pred):
loss = F.softplus(-fake_pred).mean()
return loss
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
noise = torch.randn_like(fake_img) / math.sqrt(
fake_img.shape[2] * fake_img.shape[3]
)
grad, = autograd.grad(
outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True
)
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
path_penalty = (path_lengths - path_mean).pow(2).mean()
return path_penalty, path_mean.detach(), path_lengths
def make_noise(batch, latent_dim, n_noise, device):
if n_noise == 1:
return torch.randn(batch, latent_dim, device=device)
noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)
return noises
def mixing_noise(batch, latent_dim, prob, device):
if prob > 0 and random.random() < prob:
return make_noise(batch, latent_dim, 2, device)
else:
return [make_noise(batch, latent_dim, 1, device)]
def set_grad_none(model, targets):
for n, p in model.named_parameters():
if n in targets:
p.grad = None
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device):
loader = sample_data(loader)
pbar = range(args.iter)
if get_rank() == 0:
pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)
mean_path_length = 0
d_loss_val = 0
r1_loss = torch.tensor(0.0, device=device)
g_loss_val = 0
path_loss = torch.tensor(0.0, device=device)
path_lengths = torch.tensor(0.0, device=device)
mean_path_length_avg = 0
loss_dict = {}
if args.distributed:
g_module = generator.module
d_module = discriminator.module
else:
g_module = generator
d_module = discriminator
accum = 0.5 ** (32 / (10 * 1000))
ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
r_t_stat = 0
if args.augment and args.augment_p == 0:
ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8, device)
sample_z = torch.randn(args.n_sample, args.latent, device=device)
for idx in pbar:
i = idx + args.start_iter
if i > args.iter:
print("Done!")
break
real_img = next(loader)
real_img = real_img.to(device)
requires_grad(generator, False)
requires_grad(discriminator, True)
noise = mixing_noise(args.batch, args.latent, args.mixing, device)
fake_img, _ = generator(noise)
if args.augment:
real_img_aug, _ = augment(real_img, ada_aug_p)
fake_img, _ = augment(fake_img, ada_aug_p)
else:
real_img_aug = real_img
fake_pred = discriminator(fake_img)
real_pred = discriminator(real_img_aug)
d_loss = d_logistic_loss(real_pred, fake_pred)
loss_dict["d"] = d_loss
loss_dict["real_score"] = real_pred.mean()
loss_dict["fake_score"] = fake_pred.mean()
discriminator.zero_grad()
d_loss.backward()
d_optim.step()
if args.augment and args.augment_p == 0:
ada_aug_p = ada_augment.tune(real_pred)
r_t_stat = ada_augment.r_t_stat
d_regularize = i % args.d_reg_every == 0
if d_regularize:
real_img.requires_grad = True
if args.augment:
real_img_aug, _ = augment(real_img, ada_aug_p)
else:
real_img_aug = real_img
real_pred = discriminator(real_img_aug)
r1_loss = d_r1_loss(real_pred, real_img)
discriminator.zero_grad()
(args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()
d_optim.step()
loss_dict["r1"] = r1_loss
requires_grad(generator, True)
requires_grad(discriminator, False)
noise = mixing_noise(args.batch, args.latent, args.mixing, device)
fake_img, _ = generator(noise)
if args.augment:
fake_img, _ = augment(fake_img, ada_aug_p)
fake_pred = discriminator(fake_img)
g_loss = g_nonsaturating_loss(fake_pred)
loss_dict["g"] = g_loss
generator.zero_grad()
g_loss.backward()
g_optim.step()
g_regularize = i % args.g_reg_every == 0
if g_regularize:
path_batch_size = max(1, args.batch // args.path_batch_shrink)
noise = mixing_noise(path_batch_size, args.latent, args.mixing, device)
fake_img, latents = generator(noise, return_latents=True)
path_loss, mean_path_length, path_lengths = g_path_regularize(
fake_img, latents, mean_path_length
)
generator.zero_grad()
weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
if args.path_batch_shrink:
weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
weighted_path_loss.backward()
g_optim.step()
mean_path_length_avg = (
reduce_sum(mean_path_length).item() / get_world_size()
)
loss_dict["path"] = path_loss
loss_dict["path_length"] = path_lengths.mean()
accumulate(g_ema, g_module, accum)
loss_reduced = reduce_loss_dict(loss_dict)
d_loss_val = loss_reduced["d"].mean().item()
g_loss_val = loss_reduced["g"].mean().item()
r1_val = loss_reduced["r1"].mean().item()
path_loss_val = loss_reduced["path"].mean().item()
real_score_val = loss_reduced["real_score"].mean().item()
fake_score_val = loss_reduced["fake_score"].mean().item()
path_length_val = loss_reduced["path_length"].mean().item()
if get_rank() == 0:
pbar.set_description(
(
f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
f"augment: {ada_aug_p:.4f}"
)
)
if wandb and args.wandb:
wandb.log(
{
"Generator": g_loss_val,
"Discriminator": d_loss_val,
"Augment": ada_aug_p,
"Rt": r_t_stat,
"R1": r1_val,
"Path Length Regularization": path_loss_val,
"Mean Path Length": mean_path_length,
"Real Score": real_score_val,
"Fake Score": fake_score_val,
"Path Length": path_length_val,
}
)
if i % 100 == 0:
with torch.no_grad():
g_ema.eval()
sample, _ = g_ema([sample_z])
utils.save_image(
sample,
f"sample/{str(i).zfill(6)}.png",
nrow=int(args.n_sample ** 0.5),
normalize=True,
value_range=(-1, 1),
)
if i % 10000 == 0:
torch.save(
{
"g": g_module.state_dict(),
"d": d_module.state_dict(),
"g_ema": g_ema.state_dict(),
"g_optim": g_optim.state_dict(),
"d_optim": d_optim.state_dict(),
"args": args,
"ada_aug_p": ada_aug_p,
},
f"checkpoint/{str(i).zfill(6)}.pt",
)
if __name__ == "__main__":
device = "cuda"
parser = argparse.ArgumentParser(description="StyleGAN2 trainer")
parser.add_argument("path", type=str, help="path to the lmdb dataset")
parser.add_argument('--arch', type=str, default='stylegan2', help='model architectures (stylegan2 | swagan)')
parser.add_argument(
"--iter", type=int, default=800000, help="total training iterations"
)
parser.add_argument(
"--batch", type=int, default=16, help="batch sizes for each gpus"
)
parser.add_argument(
"--n_sample",
type=int,
default=4,
help="number of the samples generated during training",
)
parser.add_argument(
"--size", type=int, default=256, help="image sizes for the model"
)
parser.add_argument(
"--r1", type=float, default=10, help="weight of the r1 regularization"
)
parser.add_argument(
"--path_regularize",
type=float,
default=2,
help="weight of the path length regularization",
)
parser.add_argument(
"--path_batch_shrink",
type=int,
default=2,
help="batch size reducing factor for the path length regularization (reduce memory consumption)",
)
parser.add_argument(
"--d_reg_every",
type=int,
default=16,
help="interval of the applying r1 regularization",
)
parser.add_argument(
"--g_reg_every",
type=int,
default=4,
help="interval of the applying path length regularization",
)
parser.add_argument(
"--mixing", type=float, default=0.9, help="probability of latent code mixing"
)
parser.add_argument(
"--ckpt",
type=str,
default=None,
help="path to the checkpoints to resume training",
)
parser.add_argument("--lr", type=float, default=0.002, help="learning rate")
parser.add_argument(
"--channel_multiplier",
type=int,
default=2,
help="channel multiplier factor for the model. config-f = 2, else = 1",
)
parser.add_argument(
"--wandb", action="store_true", help="use weights and biases logging"
)
parser.add_argument(
"--local-rank", type=int, default=0, help="local rank for distributed training"
)
parser.add_argument(
"--augment", action="store_true", help="apply non leaking augmentation"
)
parser.add_argument(
"--augment_p",
type=float,
default=0,
help="probability of applying augmentation. 0 = use adaptive augmentation",
)
parser.add_argument(
"--ada_target",
type=float,
default=0.6,
help="target augmentation probability for adaptive augmentation",
)
parser.add_argument(
"--ada_length",
type=int,
default=500 * 1000,
help="target duraing to reach augmentation probability for adaptive augmentation",
)
parser.add_argument(
"--ada_every",
type=int,
default=256,
help="probability update interval of the adaptive augmentation",
)
args = parser.parse_args()
n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
args.distributed = n_gpu > 1
if args.distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
synchronize()
args.latent = 512
args.n_mlp = 8
args.start_iter = 0
if args.arch == 'stylegan2':
from model import Generator, Discriminator
elif args.arch == 'swagan':
from swagan import Generator, Discriminator
generator = Generator(
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
).to(device)
discriminator = Discriminator(
args.size, channel_multiplier=args.channel_multiplier
).to(device)
g_ema = Generator(
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
).to(device)
g_ema.eval()
accumulate(g_ema, generator, 0)
g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)
g_optim = optim.Adam(
generator.parameters(),
lr=args.lr * g_reg_ratio,
betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
)
d_optim = optim.Adam(
discriminator.parameters(),
lr=args.lr * d_reg_ratio,
betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
)
if args.ckpt is not None:
print("load model:", args.ckpt)
ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)
try:
ckpt_name = os.path.basename(args.ckpt)
args.start_iter = int(os.path.splitext(ckpt_name)[0])
except ValueError:
pass
generator.load_state_dict(ckpt["g"])
discriminator.load_state_dict(ckpt["d"])
g_ema.load_state_dict(ckpt["g_ema"])
g_optim.load_state_dict(ckpt["g_optim"])
d_optim.load_state_dict(ckpt["d_optim"])
if args.distributed:
generator = nn.parallel.DistributedDataParallel(
generator,
device_ids=[args.local_rank],
output_device=args.local_rank,
broadcast_buffers=False,
)
discriminator = nn.parallel.DistributedDataParallel(
discriminator,
device_ids=[args.local_rank],
output_device=args.local_rank,
broadcast_buffers=False,
)
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
]
)
dataset = MultiResolutionDataset(args.path, transform, args.size)
loader = data.DataLoader(
dataset,
batch_size=args.batch,
sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed),
drop_last=True,
)
if get_rank() == 0 and wandb is not None and args.wandb:
wandb.init(project="stylegan 2")
train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment