Commit 25008cf0 authored by dongchy920's avatar dongchy920
Browse files

test

parents
Pipeline #1767 canceled with stages
import math
import torch
from torch import autograd
from torch.nn import functional as F
import numpy as np
from distributed import reduce_sum
from op import upfirdn2d
class AdaptiveAugment:
def __init__(self, ada_aug_target, ada_aug_len, update_every, device):
self.ada_aug_target = ada_aug_target
self.ada_aug_len = ada_aug_len
self.update_every = update_every
self.ada_update = 0
self.ada_aug_buf = torch.tensor([0.0, 0.0], device=device)
self.r_t_stat = 0
self.ada_aug_p = 0
@torch.no_grad()
def tune(self, real_pred):
self.ada_aug_buf += torch.tensor(
(torch.sign(real_pred).sum().item(), real_pred.shape[0]),
device=real_pred.device,
)
self.ada_update += 1
if self.ada_update % self.update_every == 0:
self.ada_aug_buf = reduce_sum(self.ada_aug_buf)
pred_signs, n_pred = self.ada_aug_buf.tolist()
self.r_t_stat = pred_signs / n_pred
if self.r_t_stat > self.ada_aug_target:
sign = 1
else:
sign = -1
self.ada_aug_p += sign * n_pred / self.ada_aug_len
self.ada_aug_p = min(1, max(0, self.ada_aug_p))
self.ada_aug_buf.mul_(0)
self.ada_update = 0
return self.ada_aug_p
SYM6 = (
0.015404109327027373,
0.0034907120842174702,
-0.11799011114819057,
-0.048311742585633,
0.4910559419267466,
0.787641141030194,
0.3379294217276218,
-0.07263752278646252,
-0.021060292512300564,
0.04472490177066578,
0.0017677118642428036,
-0.007800708325034148,
)
def translate_mat(t_x, t_y, device="cpu"):
batch = t_x.shape[0]
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
translate = torch.stack((t_x, t_y), 1)
mat[:, :2, 2] = translate
return mat
def rotate_mat(theta, device="cpu"):
batch = theta.shape[0]
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
sin_t = torch.sin(theta)
cos_t = torch.cos(theta)
rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2)
mat[:, :2, :2] = rot
return mat
def scale_mat(s_x, s_y, device="cpu"):
batch = s_x.shape[0]
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
mat[:, 0, 0] = s_x
mat[:, 1, 1] = s_y
return mat
def translate3d_mat(t_x, t_y, t_z):
batch = t_x.shape[0]
mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
translate = torch.stack((t_x, t_y, t_z), 1)
mat[:, :3, 3] = translate
return mat
def rotate3d_mat(axis, theta):
batch = theta.shape[0]
u_x, u_y, u_z = axis
eye = torch.eye(3).unsqueeze(0)
cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0)
outer = torch.tensor(axis)
outer = (outer.unsqueeze(1) * outer).unsqueeze(0)
sin_t = torch.sin(theta).view(-1, 1, 1)
cos_t = torch.cos(theta).view(-1, 1, 1)
rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer
eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
eye_4[:, :3, :3] = rot
return eye_4
def scale3d_mat(s_x, s_y, s_z):
batch = s_x.shape[0]
mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
mat[:, 0, 0] = s_x
mat[:, 1, 1] = s_y
mat[:, 2, 2] = s_z
return mat
def luma_flip_mat(axis, i):
batch = i.shape[0]
eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
axis = torch.tensor(axis + (0,))
flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1)
return eye - flip
def saturation_mat(axis, i):
batch = i.shape[0]
eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
axis = torch.tensor(axis + (0,))
axis = torch.ger(axis, axis)
saturate = axis + (eye - axis) * i.view(-1, 1, 1)
return saturate
def lognormal_sample(size, mean=0, std=1, device="cpu"):
return torch.empty(size, device=device).log_normal_(mean=mean, std=std)
def category_sample(size, categories, device="cpu"):
category = torch.tensor(categories, device=device)
sample = torch.randint(high=len(categories), size=(size,), device=device)
return category[sample]
def uniform_sample(size, low, high, device="cpu"):
return torch.empty(size, device=device).uniform_(low, high)
def normal_sample(size, mean=0, std=1, device="cpu"):
return torch.empty(size, device=device).normal_(mean, std)
def bernoulli_sample(size, p, device="cpu"):
return torch.empty(size, device=device).bernoulli_(p)
def random_mat_apply(p, transform, prev, eye, device="cpu"):
size = transform.shape[0]
select = bernoulli_sample(size, p, device=device).view(size, 1, 1)
select_transform = select * transform + (1 - select) * eye
return select_transform @ prev
def sample_affine(p, size, height, width, device="cpu"):
G = torch.eye(3, device=device).unsqueeze(0).repeat(size, 1, 1)
eye = G
# flip
param = category_sample(size, (0, 1))
Gc = scale_mat(1 - 2.0 * param, torch.ones(size), device=device)
G = random_mat_apply(p, Gc, G, eye, device=device)
# print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n')
# 90 rotate
param = category_sample(size, (0, 3))
Gc = rotate_mat(-math.pi / 2 * param, device=device)
G = random_mat_apply(p, Gc, G, eye, device=device)
# print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n')
# integer translate
param = uniform_sample((2, size), -0.125, 0.125)
param_height = torch.round(param[0] * height)
param_width = torch.round(param[1] * width)
Gc = translate_mat(param_width, param_height, device=device)
G = random_mat_apply(p, Gc, G, eye, device=device)
# print('integer translate', G, translate_mat(param_width, param_height), sep='\n')
# isotropic scale
param = lognormal_sample(size, std=0.2 * math.log(2))
Gc = scale_mat(param, param, device=device)
G = random_mat_apply(p, Gc, G, eye, device=device)
# print('isotropic scale', G, scale_mat(param, param), sep='\n')
p_rot = 1 - math.sqrt(1 - p)
# pre-rotate
param = uniform_sample(size, -math.pi, math.pi)
Gc = rotate_mat(-param, device=device)
G = random_mat_apply(p_rot, Gc, G, eye, device=device)
# print('pre-rotate', G, rotate_mat(-param), sep='\n')
# anisotropic scale
param = lognormal_sample(size, std=0.2 * math.log(2))
Gc = scale_mat(param, 1 / param, device=device)
G = random_mat_apply(p, Gc, G, eye, device=device)
# print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n')
# post-rotate
param = uniform_sample(size, -math.pi, math.pi)
Gc = rotate_mat(-param, device=device)
G = random_mat_apply(p_rot, Gc, G, eye, device=device)
# print('post-rotate', G, rotate_mat(-param), sep='\n')
# fractional translate
param = normal_sample((2, size), std=0.125)
Gc = translate_mat(param[1] * width, param[0] * height, device=device)
G = random_mat_apply(p, Gc, G, eye, device=device)
# print('fractional translate', G, translate_mat(param, param), sep='\n')
return G
def sample_color(p, size):
C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1)
eye = C
axis_val = 1 / math.sqrt(3)
axis = (axis_val, axis_val, axis_val)
# brightness
param = normal_sample(size, std=0.2)
Cc = translate3d_mat(param, param, param)
C = random_mat_apply(p, Cc, C, eye)
# contrast
param = lognormal_sample(size, std=0.5 * math.log(2))
Cc = scale3d_mat(param, param, param)
C = random_mat_apply(p, Cc, C, eye)
# luma flip
param = category_sample(size, (0, 1))
Cc = luma_flip_mat(axis, param)
C = random_mat_apply(p, Cc, C, eye)
# hue rotation
param = uniform_sample(size, -math.pi, math.pi)
Cc = rotate3d_mat(axis, param)
C = random_mat_apply(p, Cc, C, eye)
# saturation
param = lognormal_sample(size, std=1 * math.log(2))
Cc = saturation_mat(axis, param)
C = random_mat_apply(p, Cc, C, eye)
return C
def make_grid(shape, x0, x1, y0, y1, device):
n, c, h, w = shape
grid = torch.empty(n, h, w, 3, device=device)
grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device)
grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1)
grid[:, :, :, 2] = 1
return grid
def affine_grid(grid, mat):
n, h, w, _ = grid.shape
return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2)
def get_padding(G, height, width, kernel_size):
device = G.device
cx = (width - 1) / 2
cy = (height - 1) / 2
cp = torch.tensor(
[(-cx, -cy, 1), (cx, -cy, 1), (cx, cy, 1), (-cx, cy, 1)], device=device
)
cp = G @ cp.T
pad_k = kernel_size // 4
pad = cp[:, :2, :].permute(1, 0, 2).flatten(1)
pad = torch.cat((-pad, pad)).max(1).values
pad = pad + torch.tensor([pad_k * 2 - cx, pad_k * 2 - cy] * 2, device=device)
pad = pad.max(torch.tensor([0, 0] * 2, device=device))
pad = pad.min(torch.tensor([width - 1, height - 1] * 2, device=device))
pad_x1, pad_y1, pad_x2, pad_y2 = pad.ceil().to(torch.int32)
return pad_x1, pad_x2, pad_y1, pad_y2
def try_sample_affine_and_pad(img, p, kernel_size, G=None):
batch, _, height, width = img.shape
G_try = G
if G is None:
G_try = torch.inverse(sample_affine(p, batch, height, width))
pad_x1, pad_x2, pad_y1, pad_y2 = get_padding(G_try, height, width, kernel_size)
img_pad = F.pad(img, (pad_x1, pad_x2, pad_y1, pad_y2), mode="reflect")
return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2)
class GridSampleForward(autograd.Function):
@staticmethod
def forward(ctx, input, grid):
out = F.grid_sample(
input, grid, mode="bilinear", padding_mode="zeros", align_corners=False
)
ctx.save_for_backward(input, grid)
return out
@staticmethod
def backward(ctx, grad_output):
input, grid = ctx.saved_tensors
grad_input, grad_grid = GridSampleBackward.apply(grad_output, input, grid)
return grad_input, grad_grid
class GridSampleBackward(autograd.Function):
@staticmethod
def forward(ctx, grad_output, input, grid):
op = torch._C._jit_get_operation("aten::grid_sampler_2d_backward")
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
ctx.save_for_backward(grid)
return grad_input, grad_grid
@staticmethod
def backward(ctx, grad_grad_input, grad_grad_grid):
(grid,) = ctx.saved_tensors
grad_grad_output = None
if ctx.needs_input_grad[0]:
grad_grad_output = GridSampleForward.apply(grad_grad_input, grid)
return grad_grad_output, None, None
grid_sample = GridSampleForward.apply
def scale_mat_single(s_x, s_y):
return torch.tensor(((s_x, 0, 0), (0, s_y, 0), (0, 0, 1)), dtype=torch.float32)
def translate_mat_single(t_x, t_y):
return torch.tensor(((1, 0, t_x), (0, 1, t_y), (0, 0, 1)), dtype=torch.float32)
def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6):
kernel = antialiasing_kernel
len_k = len(kernel)
kernel = torch.as_tensor(kernel).to(img)
# kernel = torch.ger(kernel, kernel).to(img)
kernel_flip = torch.flip(kernel, (0,))
img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad(
img, p, len_k, G
)
G_inv = (
translate_mat_single((pad_x1 - pad_x2).item() / 2, (pad_y1 - pad_y2).item() / 2)
@ G
)
up_pad = (
(len_k + 2 - 1) // 2,
(len_k - 2) // 2,
(len_k + 2 - 1) // 2,
(len_k - 2) // 2,
)
img_2x = upfirdn2d(img_pad, kernel.unsqueeze(0), up=(2, 1), pad=(*up_pad[:2], 0, 0))
img_2x = upfirdn2d(img_2x, kernel.unsqueeze(1), up=(1, 2), pad=(0, 0, *up_pad[2:]))
G_inv = scale_mat_single(2, 2) @ G_inv @ scale_mat_single(1 / 2, 1 / 2)
G_inv = translate_mat_single(-0.5, -0.5) @ G_inv @ translate_mat_single(0.5, 0.5)
batch_size, channel, height, width = img.shape
pad_k = len_k // 4
shape = (batch_size, channel, (height + pad_k * 2) * 2, (width + pad_k * 2) * 2)
G_inv = (
scale_mat_single(2 / img_2x.shape[3], 2 / img_2x.shape[2])
@ G_inv
@ scale_mat_single(1 / (2 / shape[3]), 1 / (2 / shape[2]))
)
grid = F.affine_grid(G_inv[:, :2, :].to(img_2x), shape, align_corners=False)
img_affine = grid_sample(img_2x, grid)
d_p = -pad_k * 2
down_pad = (
d_p + (len_k - 2 + 1) // 2,
d_p + (len_k - 2) // 2,
d_p + (len_k - 2 + 1) // 2,
d_p + (len_k - 2) // 2,
)
img_down = upfirdn2d(
img_affine, kernel_flip.unsqueeze(0), down=(2, 1), pad=(*down_pad[:2], 0, 0)
)
img_down = upfirdn2d(
img_down, kernel_flip.unsqueeze(1), down=(1, 2), pad=(0, 0, *down_pad[2:])
)
return img_down, G
def apply_color(img, mat):
batch = img.shape[0]
img = img.permute(0, 2, 3, 1)
mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3)
mat_add = mat[:, :3, 3].view(batch, 1, 1, 3)
img = img @ mat_mul + mat_add
img = img.permute(0, 3, 1, 2)
return img
def random_apply_color(img, p, C=None):
if C is None:
C = sample_color(p, img.shape[0])
img = apply_color(img, C.to(img))
return img, C
def augment(img, p, transform_matrix=(None, None)):
img, G = random_apply_affine(img, p, transform_matrix[0])
img, C = random_apply_color(img, p, transform_matrix[1])
return img, (G, C)
from .fused_act import FusedLeakyReLU, fused_leaky_relu
from .upfirdn2d import upfirdn2d
import contextlib
import warnings
import torch
from torch import autograd
from torch.nn import functional as F
enabled = True
weight_gradients_disabled = False
@contextlib.contextmanager
def no_weight_gradients():
global weight_gradients_disabled
old = weight_gradients_disabled
weight_gradients_disabled = True
yield
weight_gradients_disabled = old
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if could_use_op(input):
return conv2d_gradfix(
transpose=False,
weight_shape=weight.shape,
stride=stride,
padding=padding,
output_padding=0,
dilation=dilation,
groups=groups,
).apply(input, weight, bias)
return F.conv2d(
input=input,
weight=weight,
bias=bias,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
def conv_transpose2d(
input,
weight,
bias=None,
stride=1,
padding=0,
output_padding=0,
groups=1,
dilation=1,
):
if could_use_op(input):
return conv2d_gradfix(
transpose=True,
weight_shape=weight.shape,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
).apply(input, weight, bias)
return F.conv_transpose2d(
input=input,
weight=weight,
bias=bias,
stride=stride,
padding=padding,
output_padding=output_padding,
dilation=dilation,
groups=groups,
)
def could_use_op(input):
if (not enabled) or (not torch.backends.cudnn.enabled):
return False
if input.device.type != "cuda":
return False
if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
return True
warnings.warn(
f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
)
return False
def ensure_tuple(xs, ndim):
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
return xs
conv2d_gradfix_cache = dict()
def conv2d_gradfix(
transpose, weight_shape, stride, padding, output_padding, dilation, groups
):
ndim = 2
weight_shape = tuple(weight_shape)
stride = ensure_tuple(stride, ndim)
padding = ensure_tuple(padding, ndim)
output_padding = ensure_tuple(output_padding, ndim)
dilation = ensure_tuple(dilation, ndim)
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
if key in conv2d_gradfix_cache:
return conv2d_gradfix_cache[key]
common_kwargs = dict(
stride=stride, padding=padding, dilation=dilation, groups=groups
)
def calc_output_padding(input_shape, output_shape):
if transpose:
return [0, 0]
return [
input_shape[i + 2]
- (output_shape[i + 2] - 1) * stride[i]
- (1 - 2 * padding[i])
- dilation[i] * (weight_shape[i + 2] - 1)
for i in range(ndim)
]
class Conv2d(autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias):
if not transpose:
out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
else:
out = F.conv_transpose2d(
input=input,
weight=weight,
bias=bias,
output_padding=output_padding,
**common_kwargs,
)
ctx.save_for_backward(input, weight)
return out
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input, grad_weight, grad_bias = None, None, None
if ctx.needs_input_grad[0]:
p = calc_output_padding(
input_shape=input.shape, output_shape=grad_output.shape
)
grad_input = conv2d_gradfix(
transpose=(not transpose),
weight_shape=weight_shape,
output_padding=p,
**common_kwargs,
).apply(grad_output, weight, None)
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
grad_weight = Conv2dGradWeight.apply(grad_output, input)
if ctx.needs_input_grad[2]:
grad_bias = grad_output.sum((0, 2, 3))
return grad_input, grad_weight, grad_bias
class Conv2dGradWeight(autograd.Function):
@staticmethod
def forward(ctx, grad_output, input):
op = torch._C._jit_get_operation(
"aten::cudnn_convolution_backward_weight"
if not transpose
else "aten::cudnn_convolution_transpose_backward_weight"
)
flags = [
torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic,
torch.backends.cudnn.allow_tf32,
]
grad_weight = op(
weight_shape,
grad_output,
input,
padding,
stride,
dilation,
groups,
*flags,
)
ctx.save_for_backward(grad_output, input)
return grad_weight
@staticmethod
def backward(ctx, grad_grad_weight):
grad_output, input = ctx.saved_tensors
grad_grad_output, grad_grad_input = None, None
if ctx.needs_input_grad[0]:
grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
if ctx.needs_input_grad[1]:
p = calc_output_padding(
input_shape=input.shape, output_shape=grad_output.shape
)
grad_grad_input = conv2d_gradfix(
transpose=(not transpose),
weight_shape=weight_shape,
output_padding=p,
**common_kwargs,
).apply(grad_output, grad_grad_weight, None)
return grad_grad_output, grad_grad_input
conv2d_gradfix_cache[key] = Conv2d
return Conv2d
import os
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load
module_path = os.path.dirname(__file__)
fused = load(
"fused",
sources=[
os.path.join(module_path, "fused_bias_act.cpp"),
os.path.join(module_path, "fused_bias_act_kernel.cu"),
],
)
class FusedLeakyReLUFunctionBackward(Function):
@staticmethod
def forward(ctx, grad_output, out, bias, negative_slope, scale):
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
empty = grad_output.new_empty(0)
grad_input = fused.fused_bias_act(
grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
)
dim = [0]
if grad_input.ndim > 2:
dim += list(range(2, grad_input.ndim))
if bias:
grad_bias = grad_input.sum(dim).detach()
else:
grad_bias = empty
return grad_input, grad_bias
@staticmethod
def backward(ctx, gradgrad_input, gradgrad_bias):
out, = ctx.saved_tensors
gradgrad_out = fused.fused_bias_act(
gradgrad_input.contiguous(),
gradgrad_bias,
out,
3,
1,
ctx.negative_slope,
ctx.scale,
)
return gradgrad_out, None, None, None, None
class FusedLeakyReLUFunction(Function):
@staticmethod
def forward(ctx, input, bias, negative_slope, scale):
empty = input.new_empty(0)
ctx.bias = bias is not None
if bias is None:
bias = empty
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
return out
@staticmethod
def backward(ctx, grad_output):
out, = ctx.saved_tensors
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
)
if not ctx.bias:
grad_bias = None
return grad_input, grad_bias, None, None
class FusedLeakyReLU(nn.Module):
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
super().__init__()
if bias:
self.bias = nn.Parameter(torch.zeros(channel))
else:
self.bias = None
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
if input.device.type == "cpu":
if bias is not None:
rest_dim = [1] * (input.ndim - bias.ndim - 1)
return (
F.leaky_relu(
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
)
* scale
)
else:
return F.leaky_relu(input, negative_slope=0.2) * scale
else:
return FusedLeakyReLUFunction.apply(
input.contiguous(), bias, negative_slope, scale
)
#include <ATen/ATen.h>
#include <torch/extension.h>
torch::Tensor fused_bias_act_op(const torch::Tensor &input,
const torch::Tensor &bias,
const torch::Tensor &refer, int act, int grad,
float alpha, float scale);
#define CHECK_CUDA(x) \
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
torch::Tensor fused_bias_act(const torch::Tensor &input,
const torch::Tensor &bias,
const torch::Tensor &refer, int act, int grad,
float alpha, float scale) {
CHECK_INPUT(input);
CHECK_INPUT(bias);
at::DeviceGuard guard(input.device());
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
}
\ No newline at end of file
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <torch/types.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
template <typename scalar_t>
static __global__ void
fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
const scalar_t *p_ref, int act, int grad, scalar_t alpha,
scalar_t scale, int loop_x, int size_x, int step_b,
int size_b, int use_bias, int use_ref) {
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
scalar_t zero = 0.0;
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
loop_idx++, xi += blockDim.x) {
scalar_t x = p_x[xi];
if (use_bias) {
x += p_b[(xi / step_b) % size_b];
}
scalar_t ref = use_ref ? p_ref[xi] : zero;
scalar_t y;
switch (act * 10 + grad) {
default:
case 10:
y = x;
break;
case 11:
y = x;
break;
case 12:
y = 0.0;
break;
case 30:
y = (x > 0.0) ? x : x * alpha;
break;
case 31:
y = (ref > 0.0) ? x : x * alpha;
break;
case 32:
y = 0.0;
break;
}
out[xi] = y * scale;
}
}
torch::Tensor fused_bias_act_op(const torch::Tensor &input,
const torch::Tensor &bias,
const torch::Tensor &refer, int act, int grad,
float alpha, float scale) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto x = input.contiguous();
auto b = bias.contiguous();
auto ref = refer.contiguous();
int use_bias = b.numel() ? 1 : 0;
int use_ref = ref.numel() ? 1 : 0;
int size_x = x.numel();
int size_b = b.numel();
int step_b = 1;
for (int i = 1 + 1; i < x.dim(); i++) {
step_b *= x.size(i);
}
int loop_x = 4;
int block_size = 4 * 32;
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
auto y = torch::empty_like(x);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x.scalar_type(), "fused_bias_act_kernel", [&] {
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
});
return y;
}
\ No newline at end of file
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <torch/types.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/hip/HIPApplyUtils.cuh>
#include <ATen/hip/HIPContext.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
template <typename scalar_t>
static __global__ void
fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
const scalar_t *p_ref, int act, int grad, scalar_t alpha,
scalar_t scale, int loop_x, int size_x, int step_b,
int size_b, int use_bias, int use_ref) {
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
scalar_t zero = 0.0;
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
loop_idx++, xi += blockDim.x) {
scalar_t x = p_x[xi];
if (use_bias) {
x += p_b[(xi / step_b) % size_b];
}
scalar_t ref = use_ref ? p_ref[xi] : zero;
scalar_t y;
switch (act * 10 + grad) {
default:
case 10:
y = x;
break;
case 11:
y = x;
break;
case 12:
y = 0.0;
break;
case 30:
y = (x > 0.0) ? x : x * alpha;
break;
case 31:
y = (ref > 0.0) ? x : x * alpha;
break;
case 32:
y = 0.0;
break;
}
out[xi] = y * scale;
}
}
torch::Tensor fused_bias_act_op(const torch::Tensor &input,
const torch::Tensor &bias,
const torch::Tensor &refer, int act, int grad,
float alpha, float scale) {
int curDevice = -1;
hipGetDevice(&curDevice);
hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
auto x = input.contiguous();
auto b = bias.contiguous();
auto ref = refer.contiguous();
int use_bias = b.numel() ? 1 : 0;
int use_ref = ref.numel() ? 1 : 0;
int size_x = x.numel();
int size_b = b.numel();
int step_b = 1;
for (int i = 1 + 1; i < x.dim(); i++) {
step_b *= x.size(i);
}
int loop_x = 4;
int block_size = 4 * 32;
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
auto y = torch::empty_like(x);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x.scalar_type(), "fused_bias_act_kernel", [&] {
hipLaunchKernelGGL(( fused_bias_act_kernel<scalar_t>), dim3(grid_size), dim3(block_size), 0, stream,
y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
});
return y;
}
\ No newline at end of file
#include <ATen/ATen.h>
#include <torch/extension.h>
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
const torch::Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1);
#define CHECK_CUDA(x) \
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
int up_x, int up_y, int down_x, int down_y, int pad_x0,
int pad_x1, int pad_y0, int pad_y1) {
CHECK_INPUT(input);
CHECK_INPUT(kernel);
at::DeviceGuard guard(input.device());
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
pad_y0, pad_y1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
}
\ No newline at end of file
from collections import abc
import os
import torch
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load
module_path = os.path.dirname(__file__)
upfirdn2d_op = load(
"upfirdn2d",
sources=[
os.path.join(module_path, "upfirdn2d.cpp"),
os.path.join(module_path, "upfirdn2d_kernel.cu"),
],
)
class UpFirDn2dBackward(Function):
@staticmethod
def forward(
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
):
up_x, up_y = up
down_x, down_y = down
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
grad_input = upfirdn2d_op.upfirdn2d(
grad_output,
grad_kernel,
down_x,
down_y,
up_x,
up_y,
g_pad_x0,
g_pad_x1,
g_pad_y0,
g_pad_y1,
)
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
ctx.save_for_backward(kernel)
pad_x0, pad_x1, pad_y0, pad_y1 = pad
ctx.up_x = up_x
ctx.up_y = up_y
ctx.down_x = down_x
ctx.down_y = down_y
ctx.pad_x0 = pad_x0
ctx.pad_x1 = pad_x1
ctx.pad_y0 = pad_y0
ctx.pad_y1 = pad_y1
ctx.in_size = in_size
ctx.out_size = out_size
return grad_input
@staticmethod
def backward(ctx, gradgrad_input):
kernel, = ctx.saved_tensors
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
gradgrad_out = upfirdn2d_op.upfirdn2d(
gradgrad_input,
kernel,
ctx.up_x,
ctx.up_y,
ctx.down_x,
ctx.down_y,
ctx.pad_x0,
ctx.pad_x1,
ctx.pad_y0,
ctx.pad_y1,
)
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
gradgrad_out = gradgrad_out.view(
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
)
return gradgrad_out, None, None, None, None, None, None, None, None
class UpFirDn2d(Function):
@staticmethod
def forward(ctx, input, kernel, up, down, pad):
up_x, up_y = up
down_x, down_y = down
pad_x0, pad_x1, pad_y0, pad_y1 = pad
kernel_h, kernel_w = kernel.shape
batch, channel, in_h, in_w = input.shape
ctx.in_size = input.shape
input = input.reshape(-1, in_h, in_w, 1)
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
ctx.out_size = (out_h, out_w)
ctx.up = (up_x, up_y)
ctx.down = (down_x, down_y)
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
g_pad_x0 = kernel_w - pad_x0 - 1
g_pad_y0 = kernel_h - pad_y0 - 1
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
out = upfirdn2d_op.upfirdn2d(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
)
# out = out.view(major, out_h, out_w, minor)
out = out.view(-1, channel, out_h, out_w)
return out
@staticmethod
def backward(ctx, grad_output):
kernel, grad_kernel = ctx.saved_tensors
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = UpFirDn2dBackward.apply(
grad_output,
kernel,
grad_kernel,
ctx.up,
ctx.down,
ctx.pad,
ctx.g_pad,
ctx.in_size,
ctx.out_size,
)
return grad_input, None, None, None, None
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
if not isinstance(up, abc.Iterable):
up = (up, up)
if not isinstance(down, abc.Iterable):
down = (down, down)
if len(pad) == 2:
pad = (pad[0], pad[1], pad[0], pad[1])
if input.device.type == "cpu":
out = upfirdn2d_native(input, kernel, *up, *down, *pad)
else:
out = UpFirDn2d.apply(input, kernel, up, down, pad)
return out
def upfirdn2d_native(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
)
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape(
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
)
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
return out.view(-1, channel, out_h, out_w)
This diff is collapsed.
This diff is collapsed.
import argparse
import torch
from torch.nn import functional as F
import numpy as np
from tqdm import tqdm
import lpips
from model import Generator
def normalize(x):
return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True))
def slerp(a, b, t):
a = normalize(a)
b = normalize(b)
d = (a * b).sum(-1, keepdim=True)
p = t * torch.acos(d)
c = normalize(b - d * a)
d = a * torch.cos(p) + c * torch.sin(p)
return normalize(d)
def lerp(a, b, t):
return a + (b - a) * t
if __name__ == "__main__":
device = "cuda"
parser = argparse.ArgumentParser(description="Perceptual Path Length calculator")
parser.add_argument(
"--space", choices=["z", "w"], help="space that PPL calculated with"
)
parser.add_argument(
"--batch", type=int, default=64, help="batch size for the models"
)
parser.add_argument(
"--n_sample",
type=int,
default=5000,
help="number of the samples for calculating PPL",
)
parser.add_argument(
"--size", type=int, default=256, help="output image sizes of the generator"
)
parser.add_argument(
"--eps", type=float, default=1e-4, help="epsilon for numerical stability"
)
parser.add_argument(
"--crop", action="store_true", help="apply center crop to the images"
)
parser.add_argument(
"--sampling",
default="end",
choices=["end", "full"],
help="set endpoint sampling method",
)
parser.add_argument(
"ckpt", metavar="CHECKPOINT", help="path to the model checkpoints"
)
args = parser.parse_args()
latent_dim = 512
ckpt = torch.load(args.ckpt)
g = Generator(args.size, latent_dim, 8).to(device)
g.load_state_dict(ckpt["g_ema"])
g.eval()
percept = lpips.PerceptualLoss(
model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
)
distances = []
n_batch = args.n_sample // args.batch
resid = args.n_sample - (n_batch * args.batch)
batch_sizes = [args.batch] * n_batch + [resid]
with torch.no_grad():
for batch in tqdm(batch_sizes):
noise = g.make_noise()
inputs = torch.randn([batch * 2, latent_dim], device=device)
if args.sampling == "full":
lerp_t = torch.rand(batch, device=device)
else:
lerp_t = torch.zeros(batch, device=device)
if args.space == "w":
latent = g.get_latent(inputs)
latent_t0, latent_t1 = latent[::2], latent[1::2]
latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None])
latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps)
latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape)
image, _ = g([latent_e], input_is_latent=True, noise=noise)
if args.crop:
c = image.shape[2] // 8
image = image[:, :, c * 3 : c * 7, c * 2 : c * 6]
factor = image.shape[2] // 256
if factor > 1:
image = F.interpolate(
image, size=(256, 256), mode="bilinear", align_corners=False
)
dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / (
args.eps ** 2
)
distances.append(dist.to("cpu").numpy())
distances = np.concatenate(distances, 0)
lo = np.percentile(distances, 1, interpolation="lower")
hi = np.percentile(distances, 99, interpolation="higher")
filtered_dist = np.extract(
np.logical_and(lo <= distances, distances <= hi), distances
)
print("ppl:", filtered_dist.mean())
import argparse
from io import BytesIO
import multiprocessing
from functools import partial
from PIL import Image
import lmdb
from tqdm import tqdm
from torchvision import datasets
from torchvision.transforms import functional as trans_fn
def resize_and_convert(img, size, resample, quality=100):
img = trans_fn.resize(img, size, resample)
img = trans_fn.center_crop(img, size)
buffer = BytesIO()
img.save(buffer, format="jpeg", quality=quality)
val = buffer.getvalue()
return val
def resize_multiple(
img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100
):
imgs = []
for size in sizes:
imgs.append(resize_and_convert(img, size, resample, quality))
return imgs
def resize_worker(img_file, sizes, resample):
i, file = img_file
img = Image.open(file)
img = img.convert("RGB")
out = resize_multiple(img, sizes=sizes, resample=resample)
return i, out
def prepare(
env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
):
resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
files = sorted(dataset.imgs, key=lambda x: x[0])
files = [(i, file) for i, (file, label) in enumerate(files)]
total = 0
with multiprocessing.Pool(n_worker) as pool:
for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
for size, img in zip(sizes, imgs):
key = f"{size}-{str(i).zfill(5)}".encode("utf-8")
with env.begin(write=True) as txn:
txn.put(key, img)
total += 1
with env.begin(write=True) as txn:
txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Preprocess images for model training")
parser.add_argument("--out", type=str, help="filename of the result lmdb dataset")
parser.add_argument(
"--size",
type=str,
default="128,256,512,1024",
help="resolutions of images for the dataset",
)
parser.add_argument(
"--n_worker",
type=int,
default=8,
help="number of workers for preparing dataset",
)
parser.add_argument(
"--resample",
type=str,
default="lanczos",
help="resampling methods for resizing images",
)
parser.add_argument("path", type=str, help="path to the image dataset")
args = parser.parse_args()
resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
resample = resample_map[args.resample]
sizes = [int(s.strip()) for s in args.size.split(",")]
print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
imgset = datasets.ImageFolder(args.path)
with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)
This diff is collapsed.
lmdb
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment