Commit fba8bde8 authored by bailuo's avatar bailuo
Browse files

update

parents
Pipeline #1808 failed with stages
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include "filtered_lrelu.cu"
// Template/kernel specializations for sign write mode.
// Full op, 32-bit indexing.
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
// Full op, 64-bit indexing.
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
// Activation/signs only for generic variant. 64-bit indexing.
template void* choose_filtered_lrelu_act_kernel<c10::Half, true, false>(void);
template void* choose_filtered_lrelu_act_kernel<float, true, false>(void);
template void* choose_filtered_lrelu_act_kernel<double, true, false>(void);
// Copy filters to constant memory.
template cudaError_t copy_filters<true, false>(cudaStream_t stream);
\ No newline at end of file
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
import torch
#----------------------------------------------------------------------------
def fma(a, b, c): # => a * b + c
return _FusedMultiplyAdd.apply(a, b, c)
#----------------------------------------------------------------------------
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
@staticmethod
def forward(ctx, a, b, c): # pylint: disable=arguments-differ
out = torch.addcmul(c, a, b)
ctx.save_for_backward(a, b)
ctx.c_shape = c.shape
return out
@staticmethod
def backward(ctx, dout): # pylint: disable=arguments-differ
a, b = ctx.saved_tensors
c_shape = ctx.c_shape
da = None
db = None
dc = None
if ctx.needs_input_grad[0]:
da = _unbroadcast(dout * b, a.shape)
if ctx.needs_input_grad[1]:
db = _unbroadcast(dout * a, b.shape)
if ctx.needs_input_grad[2]:
dc = _unbroadcast(dout, c_shape)
return da, db, dc
#----------------------------------------------------------------------------
def _unbroadcast(x, shape):
extra_dims = x.ndim - len(shape)
assert extra_dims >= 0
dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
if len(dim):
x = x.sum(dim=dim, keepdim=True)
if extra_dims:
x = x.reshape(-1, *x.shape[extra_dims+1:])
assert x.shape == shape
return x
#----------------------------------------------------------------------------
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Custom replacement for `torch.nn.functional.grid_sample` that
supports arbitrarily high order gradients between the input and output.
Only works on 2D images and assumes
`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
import warnings
import torch
# pylint: disable=redefined-builtin
# pylint: disable=arguments-differ
# pylint: disable=protected-access
#----------------------------------------------------------------------------
enabled = False # Enable the custom op by setting this to true.
#----------------------------------------------------------------------------
def grid_sample(input, grid):
if _should_use_custom_op():
return _GridSample2dForward.apply(input, grid)
return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
#----------------------------------------------------------------------------
def _should_use_custom_op():
if not enabled:
return False
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
return True
warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
return False
#----------------------------------------------------------------------------
class _GridSample2dForward(torch.autograd.Function):
@staticmethod
def forward(ctx, input, grid):
assert input.ndim == 4
assert grid.ndim == 4
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
ctx.save_for_backward(input, grid)
return output
@staticmethod
def backward(ctx, grad_output):
input, grid = ctx.saved_tensors
grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
return grad_input, grad_grid
#----------------------------------------------------------------------------
class _GridSample2dBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input, grid):
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
ctx.save_for_backward(grid)
return grad_input, grad_grid
@staticmethod
def backward(ctx, grad2_grad_input, grad2_grad_grid):
_ = grad2_grad_grid # unused
grid, = ctx.saved_tensors
grad2_grad_output = None
grad2_input = None
grad2_grid = None
if ctx.needs_input_grad[0]:
grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
assert not ctx.needs_input_grad[2]
return grad2_grad_output, grad2_input, grad2_grid
#----------------------------------------------------------------------------
// Copyright (c) SenseTime Research. All rights reserved.
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "upfirdn2d.h"
//------------------------------------------------------------------------
static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
{
// Validate arguments.
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
TORCH_CHECK(f.dim() == 2, "f must be rank 2");
TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
// Create output tensor.
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
// Initialize CUDA kernel parameters.
upfirdn2d_kernel_params p;
p.x = x.data_ptr();
p.f = f.data_ptr<float>();
p.y = y.data_ptr();
p.up = make_int2(upx, upy);
p.down = make_int2(downx, downy);
p.pad0 = make_int2(padx0, pady0);
p.flip = (flip) ? 1 : 0;
p.gain = gain;
p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
// Choose CUDA kernel.
upfirdn2d_kernel_spec spec;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
{
spec = choose_upfirdn2d_kernel<scalar_t>(p);
});
// Set looping options.
p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
p.loopMinor = spec.loopMinor;
p.loopX = spec.loopX;
p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
// Compute grid size.
dim3 blockSize, gridSize;
if (spec.tileOutW < 0) // large
{
blockSize = dim3(4, 32, 1);
gridSize = dim3(
((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
(p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
p.launchMajor);
}
else // small
{
blockSize = dim3(256, 1, 1);
gridSize = dim3(
((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
(p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
p.launchMajor);
}
// Launch CUDA kernel.
void* args[] = {&p};
AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
return y;
}
//------------------------------------------------------------------------
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("upfirdn2d", &upfirdn2d);
}
//------------------------------------------------------------------------
// Copyright (c) SenseTime Research. All rights reserved.
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <c10/util/Half.h>
#include "upfirdn2d.h"
//------------------------------------------------------------------------
// Helpers.
template <class T> struct InternalType;
template <> struct InternalType<double> { typedef double scalar_t; };
template <> struct InternalType<float> { typedef float scalar_t; };
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
static __device__ __forceinline__ int floor_div(int a, int b)
{
int t = 1 - a / b;
return (a + t * b) / b - t;
}
//------------------------------------------------------------------------
// Generic CUDA implementation for large filters.
template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
{
typedef typename InternalType<T>::scalar_t scalar_t;
// Calculate thread index.
int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
int outY = minorBase / p.launchMinor;
minorBase -= outY * p.launchMinor;
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
int majorBase = blockIdx.z * p.loopMajor;
if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
return;
// Setup Y receptive field.
int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
if (p.flip)
filterY = p.filterSize.y - 1 - filterY;
// Loop over major, minor, and X.
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
{
int nc = major * p.sizeMinor + minor;
int n = nc / p.inSize.z;
int c = nc - n * p.inSize.z;
for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
{
// Setup X receptive field.
int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
if (p.flip)
filterX = p.filterSize.x - 1 - filterX;
// Initialize pointers.
const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
// Inner loop.
scalar_t v = 0;
for (int y = 0; y < h; y++)
{
for (int x = 0; x < w; x++)
{
v += (scalar_t)(*xp) * (scalar_t)(*fp);
xp += p.inStride.x;
fp += filterStepX;
}
xp += p.inStride.y - w * p.inStride.x;
fp += filterStepY - w * filterStepX;
}
// Store result.
v *= p.gain;
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
}
}
}
//------------------------------------------------------------------------
// Specialized CUDA implementation for small filters.
template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
{
typedef typename InternalType<T>::scalar_t scalar_t;
const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
__shared__ volatile scalar_t sf[filterH][filterW];
__shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
// Calculate tile index.
int minorBase = blockIdx.x;
int tileOutY = minorBase / p.launchMinor;
minorBase -= tileOutY * p.launchMinor;
minorBase *= loopMinor;
tileOutY *= tileOutH;
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
int majorBase = blockIdx.z * p.loopMajor;
if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
return;
// Load filter (flipped).
for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
{
int fy = tapIdx / filterW;
int fx = tapIdx - fy * filterW;
scalar_t v = 0;
if (fx < p.filterSize.x & fy < p.filterSize.y)
{
int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
}
sf[fy][fx] = v;
}
// Loop over major and X.
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
{
int baseNC = major * p.sizeMinor + minorBase;
int n = baseNC / p.inSize.z;
int baseC = baseNC - n * p.inSize.z;
for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
{
// Load input pixels.
int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
int tileInX = floor_div(tileMidX, upx);
int tileInY = floor_div(tileMidY, upy);
__syncthreads();
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
{
int relC = inIdx;
int relInX = relC / loopMinor;
int relInY = relInX / tileInW;
relC -= relInX * loopMinor;
relInX -= relInY * tileInW;
int c = baseC + relC;
int inX = tileInX + relInX;
int inY = tileInY + relInY;
scalar_t v = 0;
if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
sx[relInY][relInX][relC] = v;
}
// Loop over output pixels.
__syncthreads();
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
{
int relC = outIdx;
int relOutX = relC / loopMinor;
int relOutY = relOutX / tileOutW;
relC -= relOutX * loopMinor;
relOutX -= relOutY * tileOutW;
int c = baseC + relC;
int outX = tileOutX + relOutX;
int outY = tileOutY + relOutY;
// Setup receptive field.
int midX = tileMidX + relOutX * downx;
int midY = tileMidY + relOutY * downy;
int inX = floor_div(midX, upx);
int inY = floor_div(midY, upy);
int relInX = inX - tileInX;
int relInY = inY - tileInY;
int filterX = (inX + 1) * upx - midX - 1; // flipped
int filterY = (inY + 1) * upy - midY - 1; // flipped
// Inner loop.
if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
{
scalar_t v = 0;
#pragma unroll
for (int y = 0; y < filterH / upy; y++)
#pragma unroll
for (int x = 0; x < filterW / upx; x++)
v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
v *= p.gain;
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
}
}
}
}
}
//------------------------------------------------------------------------
// CUDA kernel selection.
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
{
int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
{
if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
}
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
{
if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
}
if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
{
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
}
if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
{
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
}
if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
{
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
}
if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
{
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
}
if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
{
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
}
if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
{
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
}
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
{
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
}
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
{
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
}
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
{
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,8,1>, 64,8,1, 1};
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,8,1>, 64,8,1, 1};
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
}
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
{
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,1,8>, 64,1,8, 1};
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,1,8>, 64,1,8, 1};
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
}
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
{
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 32,16,1>, 32,16,1, 1};
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 32,16,1>, 32,16,1, 1};
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
}
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
{
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 1,64,8>, 1,64,8, 1};
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 1,64,8>, 1,64,8, 1};
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
}
return spec;
}
//------------------------------------------------------------------------
// Template specializations.
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
//------------------------------------------------------------------------
// Copyright (c) SenseTime Research. All rights reserved.
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <cuda_runtime.h>
//------------------------------------------------------------------------
// CUDA kernel parameters.
struct upfirdn2d_kernel_params
{
const void* x;
const float* f;
void* y;
int2 up;
int2 down;
int2 pad0;
int flip;
float gain;
int4 inSize; // [width, height, channel, batch]
int4 inStride;
int2 filterSize; // [width, height]
int2 filterStride;
int4 outSize; // [width, height, channel, batch]
int4 outStride;
int sizeMinor;
int sizeMajor;
int loopMinor;
int loopMajor;
int loopX;
int launchMinor;
int launchMajor;
};
//------------------------------------------------------------------------
// CUDA kernel specialization.
struct upfirdn2d_kernel_spec
{
void* kernel;
int tileOutW;
int tileOutH;
int loopMinor;
int loopX;
};
//------------------------------------------------------------------------
// CUDA kernel selection.
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
//------------------------------------------------------------------------
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Custom PyTorch ops for efficient resampling of 2D images."""
import os
import warnings
import numpy as np
import torch
import traceback
from .. import custom_ops
from .. import misc
from . import conv2d_gradfix
#----------------------------------------------------------------------------
_inited = False
_plugin = None
def _init():
global _inited, _plugin
if not _inited:
sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
try:
_plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
except:
warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
return _plugin is not None
def _parse_scaling(scaling):
if isinstance(scaling, int):
scaling = [scaling, scaling]
assert isinstance(scaling, (list, tuple))
assert all(isinstance(x, int) for x in scaling)
sx, sy = scaling
assert sx >= 1 and sy >= 1
return sx, sy
def _parse_padding(padding):
if isinstance(padding, int):
padding = [padding, padding]
assert isinstance(padding, (list, tuple))
assert all(isinstance(x, int) for x in padding)
if len(padding) == 2:
padx, pady = padding
padding = [padx, padx, pady, pady]
padx0, padx1, pady0, pady1 = padding
return padx0, padx1, pady0, pady1
def _get_filter_size(f):
if f is None:
return 1, 1
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
fw = f.shape[-1]
fh = f.shape[0]
with misc.suppress_tracer_warnings():
fw = int(fw)
fh = int(fh)
misc.assert_shape(f, [fh, fw][:f.ndim])
assert fw >= 1 and fh >= 1
return fw, fh
#----------------------------------------------------------------------------
def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
Args:
f: Torch tensor, numpy array, or python list of the shape
`[filter_height, filter_width]` (non-separable),
`[filter_taps]` (separable),
`[]` (impulse), or
`None` (identity).
device: Result device (default: cpu).
normalize: Normalize the filter so that it retains the magnitude
for constant input signal (DC)? (default: True).
flip_filter: Flip the filter? (default: False).
gain: Overall scaling factor for signal magnitude (default: 1).
separable: Return a separable filter? (default: select automatically).
Returns:
Float32 tensor of the shape
`[filter_height, filter_width]` (non-separable) or
`[filter_taps]` (separable).
"""
# Validate.
if f is None:
f = 1
f = torch.as_tensor(f, dtype=torch.float32)
assert f.ndim in [0, 1, 2]
assert f.numel() > 0
if f.ndim == 0:
f = f[np.newaxis]
# Separable?
if separable is None:
separable = (f.ndim == 1 and f.numel() >= 8)
if f.ndim == 1 and not separable:
f = f.ger(f)
assert f.ndim == (1 if separable else 2)
# Apply normalize, flip, gain, and device.
if normalize:
f /= f.sum()
if flip_filter:
f = f.flip(list(range(f.ndim)))
f = f * (gain ** (f.ndim / 2))
f = f.to(device=device)
return f
#----------------------------------------------------------------------------
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
r"""Pad, upsample, filter, and downsample a batch of 2D images.
Performs the following sequence of operations for each channel:
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
2. Pad the image with the specified number of zeros on each side (`padding`).
Negative padding corresponds to cropping the image.
3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
so that the footprint of all output pixels lies within the input image.
4. Downsample the image by keeping every Nth pixel (`down`).
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
The fused op is considerably more efficient than performing the same calculation
using standard PyTorch ops. It supports gradients of arbitrary order.
Args:
x: Float32/float64/float16 input tensor of the shape
`[batch_size, num_channels, in_height, in_width]`.
f: Float32 FIR filter of the shape
`[filter_height, filter_width]` (non-separable),
`[filter_taps]` (separable), or
`None` (identity).
up: Integer upsampling factor. Can be a single int or a list/tuple
`[x, y]` (default: 1).
down: Integer downsampling factor. Can be a single int or a list/tuple
`[x, y]` (default: 1).
padding: Padding with respect to the upsampled image. Can be a single number
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
(default: 0).
flip_filter: False = convolution, True = correlation (default: False).
gain: Overall scaling factor for signal magnitude (default: 1).
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
Returns:
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
"""
assert isinstance(x, torch.Tensor)
assert impl in ['ref', 'cuda']
if impl == 'ref' and x.device.type == 'cuda' and _init():
return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
#----------------------------------------------------------------------------
@misc.profiled_function
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
"""
# Validate arguments.
assert isinstance(x, torch.Tensor) and x.ndim == 4
if f is None:
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
assert f.dtype == torch.float32 and not f.requires_grad
batch_size, num_channels, in_height, in_width = x.shape
upx, upy = _parse_scaling(up)
downx, downy = _parse_scaling(down)
padx0, padx1, pady0, pady1 = _parse_padding(padding)
# Upsample by inserting zeros.
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
# Pad or crop.
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
# Setup filter.
f = f * (gain ** (f.ndim / 2))
f = f.to(x.dtype)
if not flip_filter:
f = f.flip(list(range(f.ndim)))
# Convolve with the filter.
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
if f.ndim == 4:
x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
else:
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
# Downsample by throwing away pixels.
x = x[:, :, ::downy, ::downx]
return x
#----------------------------------------------------------------------------
_upfirdn2d_cuda_cache = dict()
def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
"""Fast CUDA implementation of `upfirdn2d()` using custom ops.
"""
# Parse arguments.
upx, upy = _parse_scaling(up)
downx, downy = _parse_scaling(down)
padx0, padx1, pady0, pady1 = _parse_padding(padding)
# Lookup from cache.
key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
if key in _upfirdn2d_cuda_cache:
return _upfirdn2d_cuda_cache[key]
# Forward op.
class Upfirdn2dCuda(torch.autograd.Function):
@staticmethod
def forward(ctx, x, f): # pylint: disable=arguments-differ
assert isinstance(x, torch.Tensor) and x.ndim == 4
if f is None:
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
y = x
if f.ndim == 2:
y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
else:
y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
ctx.save_for_backward(f)
ctx.x_shape = x.shape
return y
@staticmethod
def backward(ctx, dy): # pylint: disable=arguments-differ
f, = ctx.saved_tensors
_, _, ih, iw = ctx.x_shape
_, _, oh, ow = dy.shape
fw, fh = _get_filter_size(f)
p = [
fw - padx0 - 1,
iw * upx - ow * downx + padx0 - upx + 1,
fh - pady0 - 1,
ih * upy - oh * downy + pady0 - upy + 1,
]
dx = None
df = None
if ctx.needs_input_grad[0]:
dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
assert not ctx.needs_input_grad[1]
return dx, df
# Add to cache.
_upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
return Upfirdn2dCuda
#----------------------------------------------------------------------------
def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
r"""Filter a batch of 2D images using the given 2D FIR filter.
By default, the result is padded so that its shape matches the input.
User-specified padding is applied on top of that, with negative values
indicating cropping. Pixels outside the image are assumed to be zero.
Args:
x: Float32/float64/float16 input tensor of the shape
`[batch_size, num_channels, in_height, in_width]`.
f: Float32 FIR filter of the shape
`[filter_height, filter_width]` (non-separable),
`[filter_taps]` (separable), or
`None` (identity).
padding: Padding with respect to the output. Can be a single number or a
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
(default: 0).
flip_filter: False = convolution, True = correlation (default: False).
gain: Overall scaling factor for signal magnitude (default: 1).
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
Returns:
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
"""
padx0, padx1, pady0, pady1 = _parse_padding(padding)
fw, fh = _get_filter_size(f)
p = [
padx0 + fw // 2,
padx1 + (fw - 1) // 2,
pady0 + fh // 2,
pady1 + (fh - 1) // 2,
]
return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
#----------------------------------------------------------------------------
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
r"""Upsample a batch of 2D images using the given 2D FIR filter.
By default, the result is padded so that its shape is a multiple of the input.
User-specified padding is applied on top of that, with negative values
indicating cropping. Pixels outside the image are assumed to be zero.
Args:
x: Float32/float64/float16 input tensor of the shape
`[batch_size, num_channels, in_height, in_width]`.
f: Float32 FIR filter of the shape
`[filter_height, filter_width]` (non-separable),
`[filter_taps]` (separable), or
`None` (identity).
up: Integer upsampling factor. Can be a single int or a list/tuple
`[x, y]` (default: 1).
padding: Padding with respect to the output. Can be a single number or a
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
(default: 0).
flip_filter: False = convolution, True = correlation (default: False).
gain: Overall scaling factor for signal magnitude (default: 1).
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
Returns:
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
"""
upx, upy = _parse_scaling(up)
padx0, padx1, pady0, pady1 = _parse_padding(padding)
fw, fh = _get_filter_size(f)
p = [
padx0 + (fw + upx - 1) // 2,
padx1 + (fw - upx) // 2,
pady0 + (fh + upy - 1) // 2,
pady1 + (fh - upy) // 2,
]
return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
#----------------------------------------------------------------------------
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
r"""Downsample a batch of 2D images using the given 2D FIR filter.
By default, the result is padded so that its shape is a fraction of the input.
User-specified padding is applied on top of that, with negative values
indicating cropping. Pixels outside the image are assumed to be zero.
Args:
x: Float32/float64/float16 input tensor of the shape
`[batch_size, num_channels, in_height, in_width]`.
f: Float32 FIR filter of the shape
`[filter_height, filter_width]` (non-separable),
`[filter_taps]` (separable), or
`None` (identity).
down: Integer downsampling factor. Can be a single int or a list/tuple
`[x, y]` (default: 1).
padding: Padding with respect to the input. Can be a single number or a
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
(default: 0).
flip_filter: False = convolution, True = correlation (default: False).
gain: Overall scaling factor for signal magnitude (default: 1).
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
Returns:
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
"""
downx, downy = _parse_scaling(down)
padx0, padx1, pady0, pady1 = _parse_padding(padding)
fw, fh = _get_filter_size(f)
p = [
padx0 + (fw - downx + 1) // 2,
padx1 + (fw - downx) // 2,
pady0 + (fh - downy + 1) // 2,
pady1 + (fh - downy) // 2,
]
return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
#----------------------------------------------------------------------------
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Facilities for pickling Python code alongside other data.
The pickled code is automatically imported into a separate Python module
during unpickling. This way, any previously exported pickles will remain
usable even if the original code is no longer available, or if the current
version of the code is not consistent with what was originally pickled."""
import sys
import pickle
import io
import inspect
import copy
import uuid
import types
import dnnlib
#----------------------------------------------------------------------------
_version = 6 # internal version number
_decorators = set() # {decorator_class, ...}
_import_hooks = [] # [hook_function, ...]
_module_to_src_dict = dict() # {module: src, ...}
_src_to_module_dict = dict() # {src: module, ...}
#----------------------------------------------------------------------------
def persistent_class(orig_class):
r"""Class decorator that extends a given class to save its source code
when pickled.
Example:
from torch_utils import persistence
@persistence.persistent_class
class MyNetwork(torch.nn.Module):
def __init__(self, num_inputs, num_outputs):
super().__init__()
self.fc = MyLayer(num_inputs, num_outputs)
...
@persistence.persistent_class
class MyLayer(torch.nn.Module):
...
When pickled, any instance of `MyNetwork` and `MyLayer` will save its
source code alongside other internal state (e.g., parameters, buffers,
and submodules). This way, any previously exported pickle will remain
usable even if the class definitions have been modified or are no
longer available.
The decorator saves the source code of the entire Python module
containing the decorated class. It does *not* save the source code of
any imported modules. Thus, the imported modules must be available
during unpickling, also including `torch_utils.persistence` itself.
It is ok to call functions defined in the same module from the
decorated class. However, if the decorated class depends on other
classes defined in the same module, they must be decorated as well.
This is illustrated in the above example in the case of `MyLayer`.
It is also possible to employ the decorator just-in-time before
calling the constructor. For example:
cls = MyLayer
if want_to_make_it_persistent:
cls = persistence.persistent_class(cls)
layer = cls(num_inputs, num_outputs)
As an additional feature, the decorator also keeps track of the
arguments that were used to construct each instance of the decorated
class. The arguments can be queried via `obj.init_args` and
`obj.init_kwargs`, and they are automatically pickled alongside other
object state. A typical use case is to first unpickle a previous
instance of a persistent class, and then upgrade it to use the latest
version of the source code:
with open('old_pickle.pkl', 'rb') as f:
old_net = pickle.load(f)
new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
misc.copy_params_and_buffers(old_net, new_net, require_all=True)
"""
assert isinstance(orig_class, type)
if is_persistent(orig_class):
return orig_class
assert orig_class.__module__ in sys.modules
orig_module = sys.modules[orig_class.__module__]
orig_module_src = _module_to_src(orig_module)
class Decorator(orig_class):
_orig_module_src = orig_module_src
_orig_class_name = orig_class.__name__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._init_args = copy.deepcopy(args)
self._init_kwargs = copy.deepcopy(kwargs)
assert orig_class.__name__ in orig_module.__dict__
_check_pickleable(self.__reduce__())
@property
def init_args(self):
return copy.deepcopy(self._init_args)
@property
def init_kwargs(self):
return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
def __reduce__(self):
fields = list(super().__reduce__())
fields += [None] * max(3 - len(fields), 0)
if fields[0] is not _reconstruct_persistent_obj:
meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
fields[0] = _reconstruct_persistent_obj # reconstruct func
fields[1] = (meta,) # reconstruct args
fields[2] = None # state dict
return tuple(fields)
Decorator.__name__ = orig_class.__name__
_decorators.add(Decorator)
return Decorator
#----------------------------------------------------------------------------
def is_persistent(obj):
r"""Test whether the given object or class is persistent, i.e.,
whether it will save its source code when pickled.
"""
try:
if obj in _decorators:
return True
except TypeError:
pass
return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
#----------------------------------------------------------------------------
def import_hook(hook):
r"""Register an import hook that is called whenever a persistent object
is being unpickled. A typical use case is to patch the pickled source
code to avoid errors and inconsistencies when the API of some imported
module has changed.
The hook should have the following signature:
hook(meta) -> modified meta
`meta` is an instance of `dnnlib.EasyDict` with the following fields:
type: Type of the persistent object, e.g. `'class'`.
version: Internal version number of `torch_utils.persistence`.
module_src Original source code of the Python module.
class_name: Class name in the original Python module.
state: Internal state of the object.
Example:
@persistence.import_hook
def wreck_my_network(meta):
if meta.class_name == 'MyNetwork':
print('MyNetwork is being imported. I will wreck it!')
meta.module_src = meta.module_src.replace("True", "False")
return meta
"""
assert callable(hook)
_import_hooks.append(hook)
#----------------------------------------------------------------------------
def _reconstruct_persistent_obj(meta):
r"""Hook that is called internally by the `pickle` module to unpickle
a persistent object.
"""
meta = dnnlib.EasyDict(meta)
meta.state = dnnlib.EasyDict(meta.state)
for hook in _import_hooks:
meta = hook(meta)
assert meta is not None
assert meta.version == _version
module = _src_to_module(meta.module_src)
assert meta.type == 'class'
orig_class = module.__dict__[meta.class_name]
decorator_class = persistent_class(orig_class)
obj = decorator_class.__new__(decorator_class)
setstate = getattr(obj, '__setstate__', None)
if callable(setstate):
setstate(meta.state) # pylint: disable=not-callable
else:
obj.__dict__.update(meta.state)
return obj
#----------------------------------------------------------------------------
def _module_to_src(module):
r"""Query the source code of a given Python module.
"""
src = _module_to_src_dict.get(module, None)
if src is None:
src = inspect.getsource(module)
_module_to_src_dict[module] = src
_src_to_module_dict[src] = module
return src
def _src_to_module(src):
r"""Get or create a Python module for the given source code.
"""
module = _src_to_module_dict.get(src, None)
if module is None:
module_name = "_imported_module_" + uuid.uuid4().hex
module = types.ModuleType(module_name)
sys.modules[module_name] = module
_module_to_src_dict[module] = src
_src_to_module_dict[src] = module
exec(src, module.__dict__) # pylint: disable=exec-used
return module
#----------------------------------------------------------------------------
def _check_pickleable(obj):
r"""Check that the given object is pickleable, raising an exception if
it is not. This function is expected to be considerably more efficient
than actually pickling the object.
"""
def recurse(obj):
if isinstance(obj, (list, tuple, set)):
return [recurse(x) for x in obj]
if isinstance(obj, dict):
return [[recurse(x), recurse(y)] for x, y in obj.items()]
if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
return None # Python primitive types are pickleable.
if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
return None # NumPy arrays and PyTorch tensors are pickleable.
if is_persistent(obj):
return None # Persistent objects are pickleable, by virtue of the constructor check.
return obj
with io.BytesIO() as f:
pickle.dump(recurse(obj), f)
#----------------------------------------------------------------------------
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Facilities for reporting and collecting training statistics across
multiple processes and devices. The interface is designed to minimize
synchronization overhead as well as the amount of boilerplate in user
code."""
import re
import numpy as np
import torch
import dnnlib
from . import misc
#----------------------------------------------------------------------------
_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
_counter_dtype = torch.float64 # Data type to use for the internal counters.
_rank = 0 # Rank of the current process.
_sync_device = None # Device to use for multiprocess communication. None = single-process.
_sync_called = False # Has _sync() been called yet?
_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
#----------------------------------------------------------------------------
def init_multiprocessing(rank, sync_device):
r"""Initializes `torch_utils.training_stats` for collecting statistics
across multiple processes.
This function must be called after
`torch.distributed.init_process_group()` and before `Collector.update()`.
The call is not necessary if multi-process collection is not needed.
Args:
rank: Rank of the current process.
sync_device: PyTorch device to use for inter-process
communication, or None to disable multi-process
collection. Typically `torch.device('cuda', rank)`.
"""
global _rank, _sync_device
assert not _sync_called
_rank = rank
_sync_device = sync_device
#----------------------------------------------------------------------------
@misc.profiled_function
def report(name, value):
r"""Broadcasts the given set of scalars to all interested instances of
`Collector`, across device and process boundaries.
This function is expected to be extremely cheap and can be safely
called from anywhere in the training loop, loss function, or inside a
`torch.nn.Module`.
Warning: The current implementation expects the set of unique names to
be consistent across processes. Please make sure that `report()` is
called at least once for each unique name by each process, and in the
same order. If a given process has no scalars to broadcast, it can do
`report(name, [])` (empty list).
Args:
name: Arbitrary string specifying the name of the statistic.
Averages are accumulated separately for each unique name.
value: Arbitrary set of scalars. Can be a list, tuple,
NumPy array, PyTorch tensor, or Python scalar.
Returns:
The same `value` that was passed in.
"""
if name not in _counters:
_counters[name] = dict()
elems = torch.as_tensor(value)
if elems.numel() == 0:
return value
elems = elems.detach().flatten().to(_reduce_dtype)
moments = torch.stack([
torch.ones_like(elems).sum(),
elems.sum(),
elems.square().sum(),
])
assert moments.ndim == 1 and moments.shape[0] == _num_moments
moments = moments.to(_counter_dtype)
device = moments.device
if device not in _counters[name]:
_counters[name][device] = torch.zeros_like(moments)
_counters[name][device].add_(moments)
return value
#----------------------------------------------------------------------------
def report0(name, value):
r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
but ignores any scalars provided by the other processes.
See `report()` for further details.
"""
report(name, value if _rank == 0 else [])
return value
#----------------------------------------------------------------------------
class Collector:
r"""Collects the scalars broadcasted by `report()` and `report0()` and
computes their long-term averages (mean and standard deviation) over
user-defined periods of time.
The averages are first collected into internal counters that are not
directly visible to the user. They are then copied to the user-visible
state as a result of calling `update()` and can then be queried using
`mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
internal counters for the next round, so that the user-visible state
effectively reflects averages collected between the last two calls to
`update()`.
Args:
regex: Regular expression defining which statistics to
collect. The default is to collect everything.
keep_previous: Whether to retain the previous averages if no
scalars were collected on a given round
(default: True).
"""
def __init__(self, regex='.*', keep_previous=True):
self._regex = re.compile(regex)
self._keep_previous = keep_previous
self._cumulative = dict()
self._moments = dict()
self.update()
self._moments.clear()
def names(self):
r"""Returns the names of all statistics broadcasted so far that
match the regular expression specified at construction time.
"""
return [name for name in _counters if self._regex.fullmatch(name)]
def update(self):
r"""Copies current values of the internal counters to the
user-visible state and resets them for the next round.
If `keep_previous=True` was specified at construction time, the
operation is skipped for statistics that have received no scalars
since the last update, retaining their previous averages.
This method performs a number of GPU-to-CPU transfers and one
`torch.distributed.all_reduce()`. It is intended to be called
periodically in the main training loop, typically once every
N training steps.
"""
if not self._keep_previous:
self._moments.clear()
for name, cumulative in _sync(self.names()):
if name not in self._cumulative:
self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
delta = cumulative - self._cumulative[name]
self._cumulative[name].copy_(cumulative)
if float(delta[0]) != 0:
self._moments[name] = delta
def _get_delta(self, name):
r"""Returns the raw moments that were accumulated for the given
statistic between the last two calls to `update()`, or zero if
no scalars were collected.
"""
assert self._regex.fullmatch(name)
if name not in self._moments:
self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
return self._moments[name]
def num(self, name):
r"""Returns the number of scalars that were accumulated for the given
statistic between the last two calls to `update()`, or zero if
no scalars were collected.
"""
delta = self._get_delta(name)
return int(delta[0])
def mean(self, name):
r"""Returns the mean of the scalars that were accumulated for the
given statistic between the last two calls to `update()`, or NaN if
no scalars were collected.
"""
delta = self._get_delta(name)
if int(delta[0]) == 0:
return float('nan')
return float(delta[1] / delta[0])
def std(self, name):
r"""Returns the standard deviation of the scalars that were
accumulated for the given statistic between the last two calls to
`update()`, or NaN if no scalars were collected.
"""
delta = self._get_delta(name)
if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
return float('nan')
if int(delta[0]) == 1:
return float(0)
mean = float(delta[1] / delta[0])
raw_var = float(delta[2] / delta[0])
return np.sqrt(max(raw_var - np.square(mean), 0))
def as_dict(self):
r"""Returns the averages accumulated between the last two calls to
`update()` as an `dnnlib.EasyDict`. The contents are as follows:
dnnlib.EasyDict(
NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
...
)
"""
stats = dnnlib.EasyDict()
for name in self.names():
stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
return stats
def __getitem__(self, name):
r"""Convenience getter.
`collector[name]` is a synonym for `collector.mean(name)`.
"""
return self.mean(name)
#----------------------------------------------------------------------------
def _sync(names):
r"""Synchronize the global cumulative counters across devices and
processes. Called internally by `Collector.update()`.
"""
if len(names) == 0:
return []
global _sync_called
_sync_called = True
# Collect deltas within current rank.
deltas = []
device = _sync_device if _sync_device is not None else torch.device('cpu')
for name in names:
delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
for counter in _counters[name].values():
delta.add_(counter.to(device))
counter.copy_(torch.zeros_like(counter))
deltas.append(delta)
deltas = torch.stack(deltas)
# Sum deltas across ranks.
if _sync_device is not None:
torch.distributed.all_reduce(deltas)
# Update cumulative values.
deltas = deltas.cpu()
for idx, name in enumerate(names):
if name not in _cumulative:
_cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
_cumulative[name].add_(deltas[idx])
# Return name-value pairs.
return [(name, _cumulative[name]) for name in names]
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
# empty
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Augmentation pipeline from the paper
"Training Generative Adversarial Networks with Limited Data".
Matches the original implementation by Karras et al. at
https://github.com/NVlabs/stylegan2-ada/blob/main/training/augment.py"""
import numpy as np
import scipy.signal
import torch
from torch_utils import persistence
from torch_utils import misc
from torch_utils.ops import upfirdn2d
from torch_utils.ops import grid_sample_gradfix
from torch_utils.ops import conv2d_gradfix
#----------------------------------------------------------------------------
# Coefficients of various wavelet decomposition low-pass filters.
wavelets = {
'haar': [0.7071067811865476, 0.7071067811865476],
'db1': [0.7071067811865476, 0.7071067811865476],
'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
}
#----------------------------------------------------------------------------
# Helpers for constructing transformation matrices.
def matrix(*rows, device=None):
assert all(len(row) == len(rows[0]) for row in rows)
elems = [x for row in rows for x in row]
ref = [x for x in elems if isinstance(x, torch.Tensor)]
if len(ref) == 0:
return misc.constant(np.asarray(rows), device=device)
assert device is None or device == ref[0].device
elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
def translate2d(tx, ty, **kwargs):
return matrix(
[1, 0, tx],
[0, 1, ty],
[0, 0, 1],
**kwargs)
def translate3d(tx, ty, tz, **kwargs):
return matrix(
[1, 0, 0, tx],
[0, 1, 0, ty],
[0, 0, 1, tz],
[0, 0, 0, 1],
**kwargs)
def scale2d(sx, sy, **kwargs):
return matrix(
[sx, 0, 0],
[0, sy, 0],
[0, 0, 1],
**kwargs)
def scale3d(sx, sy, sz, **kwargs):
return matrix(
[sx, 0, 0, 0],
[0, sy, 0, 0],
[0, 0, sz, 0],
[0, 0, 0, 1],
**kwargs)
def rotate2d(theta, **kwargs):
return matrix(
[torch.cos(theta), torch.sin(-theta), 0],
[torch.sin(theta), torch.cos(theta), 0],
[0, 0, 1],
**kwargs)
def rotate3d(v, theta, **kwargs):
vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
return matrix(
[vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
[vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
[vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
[0, 0, 0, 1],
**kwargs)
def translate2d_inv(tx, ty, **kwargs):
return translate2d(-tx, -ty, **kwargs)
def scale2d_inv(sx, sy, **kwargs):
return scale2d(1 / sx, 1 / sy, **kwargs)
def rotate2d_inv(theta, **kwargs):
return rotate2d(-theta, **kwargs)
#----------------------------------------------------------------------------
# Versatile image augmentation pipeline from the paper
# "Training Generative Adversarial Networks with Limited Data".
#
# All augmentations are disabled by default; individual augmentations can
# be enabled by setting their probability multipliers to 1.
@persistence.persistent_class
class AugmentPipe(torch.nn.Module):
def __init__(self,
xflip=0, rotate90=0, xint=0, xint_max=0.125,
scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125,
brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1,
imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1,
noise=0, cutout=0, noise_std=0.1, cutout_size=0.5,
):
super().__init__()
self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability.
# Pixel blitting.
self.xflip = float(xflip) # Probability multiplier for x-flip.
self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations.
self.xint = float(xint) # Probability multiplier for integer translation.
self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions.
# General geometric transformations.
self.scale = float(scale) # Probability multiplier for isotropic scaling.
self.rotate = float(rotate) # Probability multiplier for arbitrary rotation.
self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
self.xfrac = float(xfrac) # Probability multiplier for fractional translation.
self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle.
self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions.
# Color transformations.
self.brightness = float(brightness) # Probability multiplier for brightness.
self.contrast = float(contrast) # Probability multiplier for contrast.
self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
self.hue = float(hue) # Probability multiplier for hue rotation.
self.saturation = float(saturation) # Probability multiplier for saturation.
self.brightness_std = float(brightness_std) # Standard deviation of brightness.
self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
# Image-space filtering.
self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering.
self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands.
self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification.
# Image-space corruptions.
self.noise = float(noise) # Probability multiplier for additive RGB noise.
self.cutout = float(cutout) # Probability multiplier for cutout.
self.noise_std = float(noise_std) # Standard deviation of additive RGB noise.
self.cutout_size = float(cutout_size) # Size of the cutout rectangle, relative to image dimensions.
# Setup orthogonal lowpass filter for geometric augmentations.
self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6']))
# Construct filter bank for image-space filtering.
Hz_lo = np.asarray(wavelets['sym2']) # H(z)
Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
for i in range(1, Hz_fbank.shape[0]):
Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32))
def forward(self, images, debug_percentile=None):
assert isinstance(images, torch.Tensor) and images.ndim == 4
batch_size, num_channels, height, width = images.shape
device = images.device
if debug_percentile is not None:
debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device)
# -------------------------------------
# Select parameters for pixel blitting.
# -------------------------------------
# Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
I_3 = torch.eye(3, device=device)
G_inv = I_3
# Apply x-flip with probability (xflip * strength).
if self.xflip > 0:
i = torch.floor(torch.rand([batch_size], device=device) * 2)
i = torch.where(torch.rand([batch_size], device=device) < self.xflip * self.p, i, torch.zeros_like(i))
if debug_percentile is not None:
i = torch.full_like(i, torch.floor(debug_percentile * 2))
G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1)
# Apply 90 degree rotations with probability (rotate90 * strength).
if self.rotate90 > 0:
i = torch.floor(torch.rand([batch_size], device=device) * 4)
i = torch.where(torch.rand([batch_size], device=device) < self.rotate90 * self.p, i, torch.zeros_like(i))
if debug_percentile is not None:
i = torch.full_like(i, torch.floor(debug_percentile * 4))
G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i)
# Apply integer translation with probability (xint * strength).
if self.xint > 0:
t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max
t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t))
if debug_percentile is not None:
t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max)
G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height))
# --------------------------------------------------------
# Select parameters for general geometric transformations.
# --------------------------------------------------------
# Apply isotropic scaling with probability (scale * strength).
if self.scale > 0:
s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std)
s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s))
if debug_percentile is not None:
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std))
G_inv = G_inv @ scale2d_inv(s, s)
# Apply pre-rotation with probability p_rot.
p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p
if self.rotate > 0:
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
if debug_percentile is not None:
theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max)
G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling.
# Apply anisotropic scaling with probability (aniso * strength).
if self.aniso > 0:
s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std)
s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s))
if debug_percentile is not None:
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std))
G_inv = G_inv @ scale2d_inv(s, 1 / s)
# Apply post-rotation with probability p_rot.
if self.rotate > 0:
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
if debug_percentile is not None:
theta = torch.zeros_like(theta)
G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling.
# Apply fractional translation with probability (xfrac * strength).
if self.xfrac > 0:
t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t))
if debug_percentile is not None:
t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height)
# ----------------------------------
# Execute geometric transformations.
# ----------------------------------
# Execute if the transform is not identity.
if G_inv is not I_3:
# Calculate padding.
cx = (width - 1) / 2
cy = (height - 1) / 2
cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
cp = G_inv @ cp.t() # [batch, xyz, idx]
Hz_pad = self.Hz_geom.shape[0] // 4
margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
margin = margin.max(misc.constant([0, 0] * 2, device=device))
margin = margin.min(misc.constant([width-1, height-1] * 2, device=device))
mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
# Pad image and adjust origin.
images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
# Upsample.
images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
# Execute transformation.
shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2]
G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
images = grid_sample_gradfix.grid_sample(images, grid)
# Downsample and crop.
images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True)
# --------------------------------------------
# Select parameters for color transformations.
# --------------------------------------------
# Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
I_4 = torch.eye(4, device=device)
C = I_4
# Apply brightness with probability (brightness * strength).
if self.brightness > 0:
b = torch.randn([batch_size], device=device) * self.brightness_std
b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b))
if debug_percentile is not None:
b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std)
C = translate3d(b, b, b) @ C
# Apply contrast with probability (contrast * strength).
if self.contrast > 0:
c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std)
c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c))
if debug_percentile is not None:
c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std))
C = scale3d(c, c, c) @ C
# Apply luma flip with probability (lumaflip * strength).
v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis.
if self.lumaflip > 0:
i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2)
i = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.lumaflip * self.p, i, torch.zeros_like(i))
if debug_percentile is not None:
i = torch.full_like(i, torch.floor(debug_percentile * 2))
C = (I_4 - 2 * v.ger(v) * i) @ C # Householder reflection.
# Apply hue rotation with probability (hue * strength).
if self.hue > 0 and num_channels > 1:
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max
theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta))
if debug_percentile is not None:
theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max)
C = rotate3d(v, theta) @ C # Rotate around v.
# Apply saturation with probability (saturation * strength).
if self.saturation > 0 and num_channels > 1:
s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std)
s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s))
if debug_percentile is not None:
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std))
C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C
# ------------------------------
# Execute color transformations.
# ------------------------------
# Execute if the transform is not identity.
if C is not I_4:
images = images.reshape([batch_size, num_channels, height * width])
if num_channels == 3:
images = C[:, :3, :3] @ images + C[:, :3, 3:]
elif num_channels == 1:
C = C[:, :3, :].mean(dim=1, keepdims=True)
images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:]
else:
raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
images = images.reshape([batch_size, num_channels, height, width])
# ----------------------
# Image-space filtering.
# ----------------------
if self.imgfilter > 0:
num_bands = self.Hz_fbank.shape[0]
assert len(self.imgfilter_bands) == num_bands
expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f).
# Apply amplification for each band with probability (imgfilter * strength * band_strength).
g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity).
for i, band_strength in enumerate(self.imgfilter_bands):
t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std)
t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i))
if debug_percentile is not None:
t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i)
t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector.
t[:, i] = t_i # Replace i'th element.
t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
g = g * t # Accumulate into global gain.
# Construct combined amplification filter.
Hz_prime = g @ self.Hz_fbank # [batch, tap]
Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
# Apply filter.
p = self.Hz_fbank.shape[1] // 2
images = images.reshape([1, batch_size * num_channels, height, width])
images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
images = images.reshape([batch_size, num_channels, height, width])
# ------------------------
# Image-space corruptions.
# ------------------------
# Apply additive RGB noise with probability (noise * strength).
if self.noise > 0:
sigma = torch.randn([batch_size, 1, 1, 1], device=device).abs() * self.noise_std
sigma = torch.where(torch.rand([batch_size, 1, 1, 1], device=device) < self.noise * self.p, sigma, torch.zeros_like(sigma))
if debug_percentile is not None:
sigma = torch.full_like(sigma, torch.erfinv(debug_percentile) * self.noise_std)
images = images + torch.randn([batch_size, num_channels, height, width], device=device) * sigma
# Apply cutout with probability (cutout * strength).
if self.cutout > 0:
size = torch.full([batch_size, 2, 1, 1, 1], self.cutout_size, device=device)
size = torch.where(torch.rand([batch_size, 1, 1, 1, 1], device=device) < self.cutout * self.p, size, torch.zeros_like(size))
center = torch.rand([batch_size, 2, 1, 1, 1], device=device)
if debug_percentile is not None:
size = torch.full_like(size, self.cutout_size)
center = torch.full_like(center, debug_percentile)
coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1])
coord_y = torch.arange(height, device=device).reshape([1, 1, -1, 1])
mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >= size[:, 0] / 2)
mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >= size[:, 1] / 2)
mask = torch.logical_or(mask_x, mask_y).to(torch.float32)
images = images * mask
return images
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Streaming images and labels from datasets created with dataset_tool.py."""
import os
import numpy as np
import zipfile
import PIL.Image
import json
import torch
import dnnlib
try:
import pyspng
except ImportError:
pyspng = None
#----------------------------------------------------------------------------
class Dataset(torch.utils.data.Dataset):
def __init__(self,
name, # Name of the dataset.
raw_shape, # Shape of the raw image data (NCHW).
max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
use_labels = False, # Enable conditioning labels? False = label dimension is zero.
xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
random_seed = 0, # Random seed to use when applying max_size.
):
self._name = name
self._raw_shape = list(raw_shape)
self._use_labels = use_labels
self._raw_labels = None
self._label_shape = None
# Apply max_size.
self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
if (max_size is not None) and (self._raw_idx.size > max_size):
np.random.RandomState(random_seed).shuffle(self._raw_idx)
self._raw_idx = np.sort(self._raw_idx[:max_size])
# Apply xflip.
self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
if xflip:
self._raw_idx = np.tile(self._raw_idx, 2)
self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
def _get_raw_labels(self):
if self._raw_labels is None:
self._raw_labels = self._load_raw_labels() if self._use_labels else None
if self._raw_labels is None:
self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
assert isinstance(self._raw_labels, np.ndarray)
assert self._raw_labels.shape[0] == self._raw_shape[0]
assert self._raw_labels.dtype in [np.float32, np.int64]
if self._raw_labels.dtype == np.int64:
assert self._raw_labels.ndim == 1
assert np.all(self._raw_labels >= 0)
return self._raw_labels
def close(self): # to be overridden by subclass
pass
def _load_raw_image(self, raw_idx): # to be overridden by subclass
raise NotImplementedError
def _load_raw_labels(self): # to be overridden by subclass
raise NotImplementedError
def __getstate__(self):
return dict(self.__dict__, _raw_labels=None)
def __del__(self):
try:
self.close()
except:
pass
def __len__(self):
return self._raw_idx.size
def __getitem__(self, idx):
image = self._load_raw_image(self._raw_idx[idx])
assert isinstance(image, np.ndarray)
assert list(image.shape) == self.image_shape
assert image.dtype == np.uint8
if self._xflip[idx]:
assert image.ndim == 3 # CHW
image = image[:, :, ::-1]
return image.copy(), self.get_label(idx)
def get_label(self, idx):
label = self._get_raw_labels()[self._raw_idx[idx]]
if label.dtype == np.int64:
onehot = np.zeros(self.label_shape, dtype=np.float32)
onehot[label] = 1
label = onehot
return label.copy()
def get_details(self, idx):
d = dnnlib.EasyDict()
d.raw_idx = int(self._raw_idx[idx])
d.xflip = (int(self._xflip[idx]) != 0)
d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
return d
@property
def name(self):
return self._name
@property
def image_shape(self):
return list(self._raw_shape[1:])
@property
def num_channels(self):
assert len(self.image_shape) == 3 # CHW
return self.image_shape[0]
@property
def resolution(self):
assert len(self.image_shape) == 3 # CHW
assert self.image_shape[1] == self.image_shape[2]
return self.image_shape[1]
@property
def label_shape(self):
if self._label_shape is None:
raw_labels = self._get_raw_labels()
if raw_labels.dtype == np.int64:
self._label_shape = [int(np.max(raw_labels)) + 1]
else:
self._label_shape = raw_labels.shape[1:]
return list(self._label_shape)
@property
def label_dim(self):
assert len(self.label_shape) == 1
return self.label_shape[0]
@property
def has_labels(self):
return any(x != 0 for x in self.label_shape)
@property
def has_onehot_labels(self):
return self._get_raw_labels().dtype == np.int64
#----------------------------------------------------------------------------
class ImageFolderDataset(Dataset):
def __init__(self,
path, # Path to directory or zip.
resolution = None, # Ensure specific resolution, None = highest available.
**super_kwargs, # Additional arguments for the Dataset base class.
):
self._path = path
self._zipfile = None
if os.path.isdir(self._path):
self._type = 'dir'
self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
elif self._file_ext(self._path) == '.zip':
self._type = 'zip'
self._all_fnames = set(self._get_zipfile().namelist())
else:
raise IOError('Path must point to a directory or zip')
PIL.Image.init()
self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
if len(self._image_fnames) == 0:
raise IOError('No image files found in the specified path')
name = os.path.splitext(os.path.basename(self._path))[0]
raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
raise IOError('Image files do not match the specified resolution')
super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
@staticmethod
def _file_ext(fname):
return os.path.splitext(fname)[1].lower()
def _get_zipfile(self):
assert self._type == 'zip'
if self._zipfile is None:
self._zipfile = zipfile.ZipFile(self._path)
return self._zipfile
def _open_file(self, fname):
if self._type == 'dir':
return open(os.path.join(self._path, fname), 'rb')
if self._type == 'zip':
return self._get_zipfile().open(fname, 'r')
return None
def close(self):
try:
if self._zipfile is not None:
self._zipfile.close()
finally:
self._zipfile = None
def __getstate__(self):
return dict(super().__getstate__(), _zipfile=None)
def _load_raw_image(self, raw_idx):
fname = self._image_fnames[raw_idx]
with self._open_file(fname) as f:
if pyspng is not None and self._file_ext(fname) == '.png':
image = pyspng.load(f.read())
else:
image = np.array(PIL.Image.open(f))
if image.ndim == 2:
image = image[:, :, np.newaxis] # HW => HWC
image = image.transpose(2, 0, 1) # HWC => CHW
return image
def _load_raw_labels(self):
fname = 'dataset.json'
if fname not in self._all_fnames:
return None
with self._open_file(fname) as f:
labels = json.load(f)['labels']
if labels is None:
return None
labels = dict(labels)
labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
labels = np.array(labels)
labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
return labels
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Loss functions."""
import numpy as np
import torch
from torch_utils import training_stats
from torch_utils.ops import conv2d_gradfix
from torch_utils.ops import upfirdn2d
#----------------------------------------------------------------------------
class Loss:
def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): # to be overridden by subclass
raise NotImplementedError()
#----------------------------------------------------------------------------
class StyleGAN2Loss(Loss):
def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0):
super().__init__()
self.device = device
self.G = G
self.D = D
self.augment_pipe = augment_pipe
self.r1_gamma = r1_gamma
self.style_mixing_prob = style_mixing_prob
self.pl_weight = pl_weight
self.pl_batch_shrink = pl_batch_shrink
self.pl_decay = pl_decay
self.pl_no_weight_grad = pl_no_weight_grad
self.pl_mean = torch.zeros([], device=device)
self.blur_init_sigma = blur_init_sigma
self.blur_fade_kimg = blur_fade_kimg
def run_G(self, z, c, update_emas=False):
ws = self.G.mapping(z, c, update_emas=update_emas)
if self.style_mixing_prob > 0:
with torch.autograd.profiler.record_function('style_mixing'):
cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]
img = self.G.synthesis(ws, update_emas=update_emas)
return img, ws
def run_D(self, img, c, blur_sigma=0, update_emas=False):
blur_size = np.floor(blur_sigma * 3)
if blur_size > 0:
with torch.autograd.profiler.record_function('blur'):
f = torch.arange(-blur_size, blur_size + 1, device=img.device).div(blur_sigma).square().neg().exp2()
img = upfirdn2d.filter2d(img, f / f.sum())
if self.augment_pipe is not None:
img = self.augment_pipe(img)
logits = self.D(img, c, update_emas=update_emas)
return logits
def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg):
assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
if self.pl_weight == 0:
phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase)
if self.r1_gamma == 0:
phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase)
blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0
# Gmain: Maximize logits for generated images.
if phase in ['Gmain', 'Gboth']:
with torch.autograd.profiler.record_function('Gmain_forward'):
gen_img, _gen_ws = self.run_G(gen_z, gen_c)
gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma)
training_stats.report('Loss/scores/fake', gen_logits)
training_stats.report('Loss/signs/fake', gen_logits.sign())
loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
training_stats.report('Loss/G/loss', loss_Gmain)
with torch.autograd.profiler.record_function('Gmain_backward'):
loss_Gmain.mean().mul(gain).backward()
# Gpl: Apply path length regularization.
if phase in ['Greg', 'Gboth']:
with torch.autograd.profiler.record_function('Gpl_forward'):
batch_size = gen_z.shape[0] // self.pl_batch_shrink
gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size])
pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients(self.pl_no_weight_grad):
pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
self.pl_mean.copy_(pl_mean.detach())
pl_penalty = (pl_lengths - pl_mean).square()
training_stats.report('Loss/pl_penalty', pl_penalty)
loss_Gpl = pl_penalty * self.pl_weight
training_stats.report('Loss/G/reg', loss_Gpl)
with torch.autograd.profiler.record_function('Gpl_backward'):
loss_Gpl.mean().mul(gain).backward()
# Dmain: Minimize logits for generated images.
loss_Dgen = 0
if phase in ['Dmain', 'Dboth']:
with torch.autograd.profiler.record_function('Dgen_forward'):
gen_img, _gen_ws = self.run_G(gen_z, gen_c, update_emas=True)
gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True)
training_stats.report('Loss/scores/fake', gen_logits)
training_stats.report('Loss/signs/fake', gen_logits.sign())
loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
with torch.autograd.profiler.record_function('Dgen_backward'):
loss_Dgen.mean().mul(gain).backward()
# Dmain: Maximize logits for real images.
# Dr1: Apply R1 regularization.
if phase in ['Dmain', 'Dreg', 'Dboth']:
name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1'
with torch.autograd.profiler.record_function(name + '_forward'):
real_img_tmp = real_img.detach().requires_grad_(phase in ['Dreg', 'Dboth'])
real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma)
training_stats.report('Loss/scores/real', real_logits)
training_stats.report('Loss/signs/real', real_logits.sign())
loss_Dreal = 0
if phase in ['Dmain', 'Dboth']:
loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits))
training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)
loss_Dr1 = 0
if phase in ['Dreg', 'Dboth']:
with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
r1_penalty = r1_grads.square().sum([1,2,3])
loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
training_stats.report('Loss/r1_penalty', r1_penalty)
training_stats.report('Loss/D/reg', loss_Dr1)
with torch.autograd.profiler.record_function(name + '_backward'):
(loss_Dreal + loss_Dr1).mean().mul(gain).backward()
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Network architectures from the paper
"Analyzing and Improving the Image Quality of StyleGAN".
Matches the original implementation of configs E-F by Karras et al. at
https://github.com/NVlabs/stylegan2/blob/master/training/networks_stylegan2.py"""
import numpy as np
import torch
import torch.nn.functional as F
from torch_utils import misc
from torch_utils import persistence
from torch_utils.ops import conv2d_resample
from torch_utils.ops import upfirdn2d
from torch_utils.ops import bias_act
from torch_utils.ops import fma
#----------------------------------------------------------------------------
@misc.profiled_function
def normalize_2nd_moment(x, dim=1, eps=1e-8):
return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
#----------------------------------------------------------------------------
@misc.profiled_function
def modulated_conv2d(
x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
styles, # Modulation coefficients of shape [batch_size, in_channels].
noise = None, # Optional noise tensor to add to the output activations.
up = 1, # Integer upsampling factor.
down = 1, # Integer downsampling factor.
padding = 0, # Padding with respect to the upsampled image.
resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
demodulate = True, # Apply weight demodulation?
flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation?
):
batch_size = x.shape[0]
out_channels, in_channels, kh, kw = weight.shape
misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
misc.assert_shape(styles, [batch_size, in_channels]) # [NI]
# Pre-normalize inputs to avoid FP16 overflow.
if x.dtype == torch.float16 and demodulate:
weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk
styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
# Calculate per-sample weights and demodulation coefficients.
w = None
dcoefs = None
if demodulate or fused_modconv:
w = weight.unsqueeze(0) # [NOIkk]
w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
if demodulate:
dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
if demodulate and fused_modconv:
w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
# Execute by scaling the activations before and after the convolution.
if not fused_modconv:
x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight)
if demodulate and noise is not None:
x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
elif demodulate:
x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
elif noise is not None:
x = x.add_(noise.to(x.dtype))
return x
# Execute as one fused op using grouped convolution.
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
batch_size = int(batch_size)
misc.assert_shape(x, [batch_size, in_channels, None, None])
x = x.reshape(1, -1, *x.shape[2:])
w = w.reshape(-1, in_channels, kh, kw)
x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight)
x = x.reshape(batch_size, -1, *x.shape[2:])
if noise is not None:
x = x.add_(noise)
return x
#----------------------------------------------------------------------------
@persistence.persistent_class
class FullyConnectedLayer(torch.nn.Module):
def __init__(self,
in_features, # Number of input features.
out_features, # Number of output features.
bias = True, # Apply additive bias before the activation function?
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
lr_multiplier = 1, # Learning rate multiplier.
bias_init = 0, # Initial value for the additive bias.
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.activation = activation
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
self.weight_gain = lr_multiplier / np.sqrt(in_features)
self.bias_gain = lr_multiplier
def forward(self, x):
w = self.weight.to(x.dtype) * self.weight_gain
b = self.bias
if b is not None:
b = b.to(x.dtype)
if self.bias_gain != 1:
b = b * self.bias_gain
if self.activation == 'linear' and b is not None:
x = torch.addmm(b.unsqueeze(0), x, w.t())
else:
x = x.matmul(w.t())
x = bias_act.bias_act(x, b, act=self.activation)
return x
def extra_repr(self):
return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
#----------------------------------------------------------------------------
@persistence.persistent_class
class Conv2dLayer(torch.nn.Module):
def __init__(self,
in_channels, # Number of input channels.
out_channels, # Number of output channels.
kernel_size, # Width and height of the convolution kernel.
bias = True, # Apply additive bias before the activation function?
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
up = 1, # Integer upsampling factor.
down = 1, # Integer downsampling factor.
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
conv_clamp = None, # Clamp the output to +-X, None = disable clamping.
channels_last = False, # Expect the input to have memory_format=channels_last?
trainable = True, # Update the weights of this layer during training?
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.activation = activation
self.up = up
self.down = down
self.conv_clamp = conv_clamp
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
self.padding = kernel_size // 2
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
self.act_gain = bias_act.activation_funcs[activation].def_gain
memory_format = torch.channels_last if channels_last else torch.contiguous_format
weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
bias = torch.zeros([out_channels]) if bias else None
if trainable:
self.weight = torch.nn.Parameter(weight)
self.bias = torch.nn.Parameter(bias) if bias is not None else None
else:
self.register_buffer('weight', weight)
if bias is not None:
self.register_buffer('bias', bias)
else:
self.bias = None
def forward(self, x, gain=1):
w = self.weight * self.weight_gain
b = self.bias.to(x.dtype) if self.bias is not None else None
flip_weight = (self.up == 1) # slightly faster
x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight)
act_gain = self.act_gain * gain
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
return x
def extra_repr(self):
return ' '.join([
f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, activation={self.activation:s},',
f'up={self.up}, down={self.down}'])
#----------------------------------------------------------------------------
@persistence.persistent_class
class MappingNetwork(torch.nn.Module):
def __init__(self,
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
w_dim, # Intermediate latent (W) dimensionality.
num_ws, # Number of intermediate latents to output, None = do not broadcast.
num_layers = 8, # Number of mapping layers.
embed_features = None, # Label embedding dimensionality, None = same as w_dim.
layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
w_avg_beta = 0.998, # Decay for tracking the moving average of W during training, None = do not track.
):
super().__init__()
self.z_dim = z_dim
self.c_dim = c_dim
self.w_dim = w_dim
self.num_ws = num_ws
self.num_layers = num_layers
self.w_avg_beta = w_avg_beta
if embed_features is None:
embed_features = w_dim
if c_dim == 0:
embed_features = 0
if layer_features is None:
layer_features = w_dim
features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
if c_dim > 0:
self.embed = FullyConnectedLayer(c_dim, embed_features)
for idx in range(num_layers):
in_features = features_list[idx]
out_features = features_list[idx + 1]
layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
setattr(self, f'fc{idx}', layer)
if num_ws is not None and w_avg_beta is not None:
self.register_buffer('w_avg', torch.zeros([w_dim]))
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
# Embed, normalize, and concat inputs.
x = None
with torch.autograd.profiler.record_function('input'):
if self.z_dim > 0:
misc.assert_shape(z, [None, self.z_dim])
x = normalize_2nd_moment(z.to(torch.float32))
if self.c_dim > 0:
misc.assert_shape(c, [None, self.c_dim])
y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
x = torch.cat([x, y], dim=1) if x is not None else y
# Main layers.
for idx in range(self.num_layers):
layer = getattr(self, f'fc{idx}')
x = layer(x)
# Update moving average of W.
if update_emas and self.w_avg_beta is not None:
with torch.autograd.profiler.record_function('update_w_avg'):
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
# Broadcast.
if self.num_ws is not None:
with torch.autograd.profiler.record_function('broadcast'):
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
# Apply truncation.
if truncation_psi != 1:
with torch.autograd.profiler.record_function('truncate'):
assert self.w_avg_beta is not None
if self.num_ws is None or truncation_cutoff is None:
x = self.w_avg.lerp(x, truncation_psi)
else:
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
return x
def extra_repr(self):
return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}'
#----------------------------------------------------------------------------
@persistence.persistent_class
class SynthesisLayer(torch.nn.Module):
def __init__(self,
in_channels, # Number of input channels.
out_channels, # Number of output channels.
w_dim, # Intermediate latent (W) dimensionality.
resolution, # Resolution of this layer.
kernel_size = 3, # Convolution kernel size.
up = 1, # Integer upsampling factor.
use_noise = True, # Enable noise input?
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
channels_last = False, # Use channels_last format for the weights?
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.w_dim = w_dim
self.resolution = resolution
self.up = up
self.use_noise = use_noise
self.activation = activation
self.conv_clamp = conv_clamp
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
self.padding = kernel_size // 2
self.act_gain = bias_act.activation_funcs[activation].def_gain
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
memory_format = torch.channels_last if channels_last else torch.contiguous_format
self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
if use_noise:
self.register_buffer('noise_const', torch.randn([resolution, resolution]))
self.noise_strength = torch.nn.Parameter(torch.zeros([]))
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1):
assert noise_mode in ['random', 'const', 'none']
in_resolution = self.resolution // self.up
misc.assert_shape(x, [None, self.in_channels, in_resolution, in_resolution])
styles = self.affine(w)
noise = None
if self.use_noise and noise_mode == 'random':
noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
if self.use_noise and noise_mode == 'const':
noise = self.noise_const * self.noise_strength
flip_weight = (self.up == 1) # slightly faster
x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv)
act_gain = self.act_gain * gain
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp)
return x
def extra_repr(self):
return ' '.join([
f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d},',
f'resolution={self.resolution:d}, up={self.up}, activation={self.activation:s}'])
#----------------------------------------------------------------------------
@persistence.persistent_class
class ToRGBLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.w_dim = w_dim
self.conv_clamp = conv_clamp
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
memory_format = torch.channels_last if channels_last else torch.contiguous_format
self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
def forward(self, x, w, fused_modconv=True):
styles = self.affine(w) * self.weight_gain
x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)
x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
return x
def extra_repr(self):
return f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d}'
#----------------------------------------------------------------------------
@persistence.persistent_class
class SynthesisBlock(torch.nn.Module):
def __init__(self,
in_channels, # Number of input channels, 0 = first block.
out_channels, # Number of output channels.
w_dim, # Intermediate latent (W) dimensionality.
resolution, # Resolution of this block.
img_channels, # Number of output color channels.
is_last, # Is this the last block?
architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
use_fp16 = False, # Use FP16 for this block?
fp16_channels_last = False, # Use channels-last memory format with FP16?
fused_modconv_default = True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training.
**layer_kwargs, # Arguments for SynthesisLayer.
):
assert architecture in ['orig', 'skip', 'resnet']
super().__init__()
self.in_channels = in_channels
self.w_dim = w_dim
self.resolution = resolution
self.img_channels = img_channels
self.is_last = is_last
self.architecture = architecture
self.use_fp16 = use_fp16
self.channels_last = (use_fp16 and fp16_channels_last)
self.fused_modconv_default = fused_modconv_default
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
self.num_conv = 0
self.num_torgb = 0
if in_channels == 0:
self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
if in_channels != 0:
self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2,
resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
self.num_conv += 1
self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution,
conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
self.num_conv += 1
if is_last or architecture == 'skip':
self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim,
conv_clamp=conv_clamp, channels_last=self.channels_last)
self.num_torgb += 1
if in_channels != 0 and architecture == 'resnet':
self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
resample_filter=resample_filter, channels_last=self.channels_last)
def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs):
_ = update_emas # unused
misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
w_iter = iter(ws.unbind(dim=1))
if ws.device.type != 'cuda':
force_fp32 = True
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
if fused_modconv is None:
fused_modconv = self.fused_modconv_default
if fused_modconv == 'inference_only':
fused_modconv = (not self.training)
# Input.
if self.in_channels == 0:
x = self.const.to(dtype=dtype, memory_format=memory_format)
x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
else:
misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
x = x.to(dtype=dtype, memory_format=memory_format)
# Main layers.
if self.in_channels == 0:
x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
elif self.architecture == 'resnet':
y = self.skip(x, gain=np.sqrt(0.5))
x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
x = y.add_(x)
else:
x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
# ToRGB.
if img is not None:
misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
img = upfirdn2d.upsample2d(img, self.resample_filter)
if self.is_last or self.architecture == 'skip':
y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
img = img.add_(y) if img is not None else y
assert x.dtype == dtype
assert img is None or img.dtype == torch.float32
return x, img
def extra_repr(self):
return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
#----------------------------------------------------------------------------
@persistence.persistent_class
class SynthesisNetwork(torch.nn.Module):
def __init__(self,
w_dim, # Intermediate latent (W) dimensionality.
img_resolution, # Output image resolution.
img_channels, # Number of color channels.
channel_base = 32768, # Overall multiplier for the number of channels.
channel_max = 512, # Maximum number of channels in any layer.
num_fp16_res = 4, # Use FP16 for the N highest resolutions.
**block_kwargs, # Arguments for SynthesisBlock.
):
assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
super().__init__()
self.w_dim = w_dim
self.img_resolution = img_resolution
self.img_resolution_log2 = int(np.log2(img_resolution))
self.img_channels = img_channels
self.num_fp16_res = num_fp16_res
self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)]
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
self.num_ws = 0
for res in self.block_resolutions:
in_channels = channels_dict[res // 2] if res > 4 else 0
out_channels = channels_dict[res]
use_fp16 = (res >= fp16_resolution)
is_last = (res == self.img_resolution)
block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
self.num_ws += block.num_conv
if is_last:
self.num_ws += block.num_torgb
setattr(self, f'b{res}', block)
def forward(self, ws, return_feature=False, **block_kwargs):
block_ws = []
features = []
with torch.autograd.profiler.record_function('split_ws'):
misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
ws = ws.to(torch.float32)
w_idx = 0
for res in self.block_resolutions:
block = getattr(self, f'b{res}')
block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
w_idx += block.num_conv
x = img = None
for res, cur_ws in zip(self.block_resolutions, block_ws):
block = getattr(self, f'b{res}')
x, img = block(x, img, cur_ws, **block_kwargs)
features.append(x)
if return_feature:
return img, features
else:
return img
def extra_repr(self):
return ' '.join([
f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},',
f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',
f'num_fp16_res={self.num_fp16_res:d}'])
#----------------------------------------------------------------------------
@persistence.persistent_class
class Generator(torch.nn.Module):
def __init__(self,
z_dim, # Input latent (Z) dimensionality.
c_dim, # Conditioning label (C) dimensionality.
w_dim, # Intermediate latent (W) dimensionality.
img_resolution, # Output resolution.
img_channels, # Number of output color channels.
mapping_kwargs = {}, # Arguments for MappingNetwork.
synthesis_kwargs = {}, # Arguments for SynthesisNetwork.
resize=None,
# **synthesis_kwargs, # Arguments for SynthesisNetwork.
):
super().__init__()
self.z_dim = z_dim
self.c_dim = c_dim
self.w_dim = w_dim
self.img_resolution = img_resolution
self.img_channels = img_channels
self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
self.num_ws = self.synthesis.num_ws
self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
self.resize = resize
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, input_is_w=False, return_feature=False, **synthesis_kwargs):
if input_is_w:
ws = z
if ws.dim() == 2:
ws = ws.unsqueeze(1).repeat([1, self.mapping.num_ws, 1])
else:
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
img = self.synthesis(ws, update_emas=update_emas, return_feature=return_feature, **synthesis_kwargs)
if self.resize is not None:
img = imresize(img, [self.resize, self.resize])
return img
def imresize(image, size):
dim = image.dim()
if dim == 3:
image = image.unsqueeze(1)
b, _, h, w = image.shape
if size[0] > h:
image = F.interpolate(image, size, mode='bilinear')
elif size[0] < h:
image = F.interpolate(image, size, mode='area')
if dim == 3:
image = image.squeeze(1)
return image
#----------------------------------------------------------------------------
@persistence.persistent_class
class DiscriminatorBlock(torch.nn.Module):
def __init__(self,
in_channels, # Number of input channels, 0 = first block.
tmp_channels, # Number of intermediate channels.
out_channels, # Number of output channels.
resolution, # Resolution of this block.
img_channels, # Number of input color channels.
first_layer_idx, # Index of the first layer.
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
use_fp16 = False, # Use FP16 for this block?
fp16_channels_last = False, # Use channels-last memory format with FP16?
freeze_layers = 0, # Freeze-D: Number of layers to freeze.
):
assert in_channels in [0, tmp_channels]
assert architecture in ['orig', 'skip', 'resnet']
super().__init__()
self.in_channels = in_channels
self.resolution = resolution
self.img_channels = img_channels
self.first_layer_idx = first_layer_idx
self.architecture = architecture
self.use_fp16 = use_fp16
self.channels_last = (use_fp16 and fp16_channels_last)
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
self.num_layers = 0
def trainable_gen():
while True:
layer_idx = self.first_layer_idx + self.num_layers
trainable = (layer_idx >= freeze_layers)
self.num_layers += 1
yield trainable
trainable_iter = trainable_gen()
if in_channels == 0 or architecture == 'skip':
self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation,
trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation,
trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2,
trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last)
if architecture == 'resnet':
self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2,
trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last)
def forward(self, x, img, force_fp32=False):
if (x if x is not None else img).device.type != 'cuda':
force_fp32 = True
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
# Input.
if x is not None:
misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution])
x = x.to(dtype=dtype, memory_format=memory_format)
# FromRGB.
if self.in_channels == 0 or self.architecture == 'skip':
misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
img = img.to(dtype=dtype, memory_format=memory_format)
y = self.fromrgb(img)
x = x + y if x is not None else y
img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None
# Main layers.
if self.architecture == 'resnet':
y = self.skip(x, gain=np.sqrt(0.5))
x = self.conv0(x)
x = self.conv1(x, gain=np.sqrt(0.5))
x = y.add_(x)
else:
x = self.conv0(x)
x = self.conv1(x)
assert x.dtype == dtype
return x, img
def extra_repr(self):
return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
#----------------------------------------------------------------------------
@persistence.persistent_class
class MinibatchStdLayer(torch.nn.Module):
def __init__(self, group_size, num_channels=1):
super().__init__()
self.group_size = group_size
self.num_channels = num_channels
def forward(self, x):
N, C, H, W = x.shape
with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants
G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N
F = self.num_channels
c = C // F
y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels.
y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
return x
def extra_repr(self):
return f'group_size={self.group_size}, num_channels={self.num_channels:d}'
#----------------------------------------------------------------------------
@persistence.persistent_class
class DiscriminatorEpilogue(torch.nn.Module):
def __init__(self,
in_channels, # Number of input channels.
cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
resolution, # Resolution of this block.
img_channels, # Number of input color channels.
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable.
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
):
assert architecture in ['orig', 'skip', 'resnet']
super().__init__()
self.in_channels = in_channels
self.cmap_dim = cmap_dim
self.resolution = resolution
self.img_channels = img_channels
self.architecture = architecture
if architecture == 'skip':
self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation)
self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp)
self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation)
self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim)
def forward(self, x, img, cmap, force_fp32=False):
misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW]
_ = force_fp32 # unused
dtype = torch.float32
memory_format = torch.contiguous_format
# FromRGB.
x = x.to(dtype=dtype, memory_format=memory_format)
if self.architecture == 'skip':
misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
img = img.to(dtype=dtype, memory_format=memory_format)
x = x + self.fromrgb(img)
# Main layers.
if self.mbstd is not None:
x = self.mbstd(x)
x = self.conv(x)
x = self.fc(x.flatten(1))
x = self.out(x)
# Conditioning.
if self.cmap_dim > 0:
misc.assert_shape(cmap, [None, self.cmap_dim])
x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
assert x.dtype == dtype
return x
def extra_repr(self):
return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
#----------------------------------------------------------------------------
@persistence.persistent_class
class Discriminator(torch.nn.Module):
def __init__(self,
c_dim, # Conditioning label (C) dimensionality.
img_resolution, # Input resolution.
img_channels, # Number of input color channels.
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
channel_base = 32768, # Overall multiplier for the number of channels.
channel_max = 512, # Maximum number of channels in any layer.
num_fp16_res = 4, # Use FP16 for the N highest resolutions.
conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
block_kwargs = {}, # Arguments for DiscriminatorBlock.
mapping_kwargs = {}, # Arguments for MappingNetwork.
epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
):
super().__init__()
self.c_dim = c_dim
self.img_resolution = img_resolution
self.img_resolution_log2 = int(np.log2(img_resolution))
self.img_channels = img_channels
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
if cmap_dim is None:
cmap_dim = channels_dict[4]
if c_dim == 0:
cmap_dim = 0
common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
cur_layer_idx = 0
for res in self.block_resolutions:
in_channels = channels_dict[res] if res < img_resolution else 0
tmp_channels = channels_dict[res]
out_channels = channels_dict[res // 2]
use_fp16 = (res >= fp16_resolution)
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
setattr(self, f'b{res}', block)
cur_layer_idx += block.num_layers
if c_dim > 0:
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
def forward(self, img, c, update_emas=False, **block_kwargs):
_ = update_emas # unused
x = None
for res in self.block_resolutions:
block = getattr(self, f'b{res}')
x, img = block(x, img, **block_kwargs)
cmap = None
if self.c_dim > 0:
cmap = self.mapping(None, c)
x = self.b4(x, img, cmap)
return x
def extra_repr(self):
return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Generator architecture from the paper
"Alias-Free Generative Adversarial Networks"."""
import numpy as np
import scipy.signal
import scipy.optimize
import torch
import torch.nn.functional as F
from torch_utils import misc
from torch_utils import persistence
from torch_utils.ops import conv2d_gradfix
from torch_utils.ops import filtered_lrelu
from torch_utils.ops import bias_act
#----------------------------------------------------------------------------
@misc.profiled_function
def modulated_conv2d(
x, # Input tensor: [batch_size, in_channels, in_height, in_width]
w, # Weight tensor: [out_channels, in_channels, kernel_height, kernel_width]
s, # Style tensor: [batch_size, in_channels]
demodulate = True, # Apply weight demodulation?
padding = 0, # Padding: int or [padH, padW]
input_gain = None, # Optional scale factors for the input channels: [], [in_channels], or [batch_size, in_channels]
):
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
batch_size = int(x.shape[0])
out_channels, in_channels, kh, kw = w.shape
misc.assert_shape(w, [out_channels, in_channels, kh, kw]) # [OIkk]
misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
misc.assert_shape(s, [batch_size, in_channels]) # [NI]
# Pre-normalize inputs.
if demodulate:
w = w * w.square().mean([1,2,3], keepdim=True).rsqrt()
s = s * s.square().mean().rsqrt()
# Modulate weights.
w = w.unsqueeze(0) # [NOIkk]
w = w * s.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]
# Demodulate weights.
if demodulate:
dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
w = w * dcoefs.unsqueeze(2).unsqueeze(3).unsqueeze(4) # [NOIkk]
# Apply input scaling.
if input_gain is not None:
input_gain = input_gain.expand(batch_size, in_channels) # [NI]
w = w * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]
# Execute as one fused op using grouped convolution.
x = x.reshape(1, -1, *x.shape[2:])
w = w.reshape(-1, in_channels, kh, kw)
x = conv2d_gradfix.conv2d(input=x, weight=w.to(x.dtype), padding=padding, groups=batch_size)
x = x.reshape(batch_size, -1, *x.shape[2:])
return x
#----------------------------------------------------------------------------
@persistence.persistent_class
class FullyConnectedLayer(torch.nn.Module):
def __init__(self,
in_features, # Number of input features.
out_features, # Number of output features.
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
bias = True, # Apply additive bias before the activation function?
lr_multiplier = 1, # Learning rate multiplier.
weight_init = 1, # Initial standard deviation of the weight tensor.
bias_init = 0, # Initial value of the additive bias.
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.activation = activation
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier))
bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features])
self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None
self.weight_gain = lr_multiplier / np.sqrt(in_features)
self.bias_gain = lr_multiplier
def forward(self, x):
w = self.weight.to(x.dtype) * self.weight_gain
b = self.bias
if b is not None:
b = b.to(x.dtype)
if self.bias_gain != 1:
b = b * self.bias_gain
if self.activation == 'linear' and b is not None:
x = torch.addmm(b.unsqueeze(0), x, w.t())
else:
x = x.matmul(w.t())
x = bias_act.bias_act(x, b, act=self.activation)
return x
def extra_repr(self):
return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
#----------------------------------------------------------------------------
@persistence.persistent_class
class MappingNetwork(torch.nn.Module):
def __init__(self,
z_dim, # Input latent (Z) dimensionality.
c_dim, # Conditioning label (C) dimensionality, 0 = no labels.
w_dim, # Intermediate latent (W) dimensionality.
num_ws, # Number of intermediate latents to output.
num_layers = 2, # Number of mapping layers.
lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
w_avg_beta = 0.998, # Decay for tracking the moving average of W during training.
):
super().__init__()
self.z_dim = z_dim
self.c_dim = c_dim
self.w_dim = w_dim
self.num_ws = num_ws
self.num_layers = num_layers
self.w_avg_beta = w_avg_beta
# Construct layers.
self.embed = FullyConnectedLayer(self.c_dim, self.w_dim) if self.c_dim > 0 else None
features = [self.z_dim + (self.w_dim if self.c_dim > 0 else 0)] + [self.w_dim] * self.num_layers
for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]):
layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier)
setattr(self, f'fc{idx}', layer)
self.register_buffer('w_avg', torch.zeros([w_dim]))
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
misc.assert_shape(z, [None, self.z_dim])
if truncation_cutoff is None:
truncation_cutoff = self.num_ws
# Embed, normalize, and concatenate inputs.
x = z.to(torch.float32)
x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt()
if self.c_dim > 0:
misc.assert_shape(c, [None, self.c_dim])
y = self.embed(c.to(torch.float32))
y = y * (y.square().mean(1, keepdim=True) + 1e-8).rsqrt()
x = torch.cat([x, y], dim=1) if x is not None else y
# Execute layers.
for idx in range(self.num_layers):
x = getattr(self, f'fc{idx}')(x)
# Update moving average of W.
if update_emas:
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
# Broadcast and apply truncation.
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
if truncation_psi != 1:
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
return x
def extra_repr(self):
return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}'
#----------------------------------------------------------------------------
@persistence.persistent_class
class SynthesisInput(torch.nn.Module):
def __init__(self,
w_dim, # Intermediate latent (W) dimensionality.
channels, # Number of output channels.
size, # Output spatial size: int or [width, height].
sampling_rate, # Output sampling rate.
bandwidth, # Output bandwidth.
):
super().__init__()
self.w_dim = w_dim
self.channels = channels
self.size = np.broadcast_to(np.asarray(size), [2])
self.sampling_rate = sampling_rate
self.bandwidth = bandwidth
# Draw random frequencies from uniform 2D disc.
freqs = torch.randn([self.channels, 2])
radii = freqs.square().sum(dim=1, keepdim=True).sqrt()
freqs /= radii * radii.square().exp().pow(0.25)
freqs *= bandwidth
phases = torch.rand([self.channels]) - 0.5
# Setup parameters and buffers.
self.weight = torch.nn.Parameter(torch.randn([self.channels, self.channels]))
self.affine = FullyConnectedLayer(w_dim, 4, weight_init=0, bias_init=[1,0,0,0])
self.register_buffer('transform', torch.eye(3, 3)) # User-specified inverse transform wrt. resulting image.
self.register_buffer('freqs', freqs)
self.register_buffer('phases', phases)
def forward(self, w):
# Introduce batch dimension.
transforms = self.transform.unsqueeze(0) # [batch, row, col]
freqs = self.freqs.unsqueeze(0) # [batch, channel, xy]
phases = self.phases.unsqueeze(0) # [batch, channel]
# Apply learned transformation.
t = self.affine(w) # t = (r_c, r_s, t_x, t_y)
t = t / t[:, :2].norm(dim=1, keepdim=True) # t' = (r'_c, r'_s, t'_x, t'_y)
m_r = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse rotation wrt. resulting image.
m_r[:, 0, 0] = t[:, 0] # r'_c
m_r[:, 0, 1] = -t[:, 1] # r'_s
m_r[:, 1, 0] = t[:, 1] # r'_s
m_r[:, 1, 1] = t[:, 0] # r'_c
m_t = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse translation wrt. resulting image.
m_t[:, 0, 2] = -t[:, 2] # t'_x
m_t[:, 1, 2] = -t[:, 3] # t'_y
transforms = m_r @ m_t @ transforms # First rotate resulting image, then translate, and finally apply user-specified transform.
# Transform frequencies.
phases = phases + (freqs @ transforms[:, :2, 2:]).squeeze(2)
freqs = freqs @ transforms[:, :2, :2]
# Dampen out-of-band frequencies that may occur due to the user-specified transform.
amplitudes = (1 - (freqs.norm(dim=2) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth)).clamp(0, 1)
# Construct sampling grid.
theta = torch.eye(2, 3, device=w.device)
theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate
theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate
grids = torch.nn.functional.affine_grid(theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]], align_corners=False)
# Compute Fourier features.
x = (grids.unsqueeze(3) @ freqs.permute(0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(3) # [batch, height, width, channel]
x = x + phases.unsqueeze(1).unsqueeze(2)
x = torch.sin(x * (np.pi * 2))
x = x * amplitudes.unsqueeze(1).unsqueeze(2)
# Apply trainable mapping.
weight = self.weight / np.sqrt(self.channels)
x = x @ weight.t()
# Ensure correct shape.
x = x.permute(0, 3, 1, 2) # [batch, channel, height, width]
misc.assert_shape(x, [w.shape[0], self.channels, int(self.size[1]), int(self.size[0])])
return x
def extra_repr(self):
return '\n'.join([
f'w_dim={self.w_dim:d}, channels={self.channels:d}, size={list(self.size)},',
f'sampling_rate={self.sampling_rate:g}, bandwidth={self.bandwidth:g}'])
#----------------------------------------------------------------------------
@persistence.persistent_class
class SynthesisLayer(torch.nn.Module):
def __init__(self,
w_dim, # Intermediate latent (W) dimensionality.
is_torgb, # Is this the final ToRGB layer?
is_critically_sampled, # Does this layer use critical sampling?
use_fp16, # Does this layer use FP16?
# Input & output specifications.
in_channels, # Number of input channels.
out_channels, # Number of output channels.
in_size, # Input spatial size: int or [width, height].
out_size, # Output spatial size: int or [width, height].
in_sampling_rate, # Input sampling rate (s).
out_sampling_rate, # Output sampling rate (s).
in_cutoff, # Input cutoff frequency (f_c).
out_cutoff, # Output cutoff frequency (f_c).
in_half_width, # Input transition band half-width (f_h).
out_half_width, # Output Transition band half-width (f_h).
# Hyperparameters.
conv_kernel = 3, # Convolution kernel size. Ignored for final the ToRGB layer.
filter_size = 6, # Low-pass filter size relative to the lower resolution when up/downsampling.
lrelu_upsampling = 2, # Relative sampling rate for leaky ReLU. Ignored for final the ToRGB layer.
use_radial_filters = False, # Use radially symmetric downsampling filter? Ignored for critically sampled layers.
conv_clamp = 256, # Clamp the output to [-X, +X], None = disable clamping.
magnitude_ema_beta = 0.999, # Decay rate for the moving average of input magnitudes.
):
super().__init__()
self.w_dim = w_dim
self.is_torgb = is_torgb
self.is_critically_sampled = is_critically_sampled
self.use_fp16 = use_fp16
self.in_channels = in_channels
self.out_channels = out_channels
self.in_size = np.broadcast_to(np.asarray(in_size), [2])
self.out_size = np.broadcast_to(np.asarray(out_size), [2])
self.in_sampling_rate = in_sampling_rate
self.out_sampling_rate = out_sampling_rate
self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling)
self.in_cutoff = in_cutoff
self.out_cutoff = out_cutoff
self.in_half_width = in_half_width
self.out_half_width = out_half_width
self.conv_kernel = 1 if is_torgb else conv_kernel
self.conv_clamp = conv_clamp
self.magnitude_ema_beta = magnitude_ema_beta
# Setup parameters and buffers.
self.affine = FullyConnectedLayer(self.w_dim, self.in_channels, bias_init=1)
self.weight = torch.nn.Parameter(torch.randn([self.out_channels, self.in_channels, self.conv_kernel, self.conv_kernel]))
self.bias = torch.nn.Parameter(torch.zeros([self.out_channels]))
self.register_buffer('magnitude_ema', torch.ones([]))
# Design upsampling filter.
self.up_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate))
assert self.in_sampling_rate * self.up_factor == self.tmp_sampling_rate
self.up_taps = filter_size * self.up_factor if self.up_factor > 1 and not self.is_torgb else 1
self.register_buffer('up_filter', self.design_lowpass_filter(
numtaps=self.up_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate))
# Design downsampling filter.
self.down_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate))
assert self.out_sampling_rate * self.down_factor == self.tmp_sampling_rate
self.down_taps = filter_size * self.down_factor if self.down_factor > 1 and not self.is_torgb else 1
self.down_radial = use_radial_filters and not self.is_critically_sampled
self.register_buffer('down_filter', self.design_lowpass_filter(
numtaps=self.down_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.down_radial))
# Compute padding.
pad_total = (self.out_size - 1) * self.down_factor + 1 # Desired output size before downsampling.
pad_total -= (self.in_size + self.conv_kernel - 1) * self.up_factor # Input size after upsampling.
pad_total += self.up_taps + self.down_taps - 2 # Size reduction caused by the filters.
pad_lo = (pad_total + self.up_factor) // 2 # Shift sample locations according to the symmetric interpretation (Appendix C.3).
pad_hi = pad_total - pad_lo
self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])]
def forward(self, x, w, noise_mode='random', force_fp32=False, update_emas=False):
assert noise_mode in ['random', 'const', 'none'] # unused
misc.assert_shape(x, [None, self.in_channels, int(self.in_size[1]), int(self.in_size[0])])
misc.assert_shape(w, [x.shape[0], self.w_dim])
# Track input magnitude.
if update_emas:
with torch.autograd.profiler.record_function('update_magnitude_ema'):
magnitude_cur = x.detach().to(torch.float32).square().mean()
self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema, self.magnitude_ema_beta))
input_gain = self.magnitude_ema.rsqrt()
# Execute affine layer.
styles = self.affine(w)
if self.is_torgb:
weight_gain = 1 / np.sqrt(self.in_channels * (self.conv_kernel ** 2))
styles = styles * weight_gain
# Execute modulated conv2d.
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
x = modulated_conv2d(x=x.to(dtype), w=self.weight, s=styles,
padding=self.conv_kernel-1, demodulate=(not self.is_torgb), input_gain=input_gain)
# Execute bias, filtered leaky ReLU, and clamping.
gain = 1 if self.is_torgb else np.sqrt(2)
slope = 1 if self.is_torgb else 0.2
x = filtered_lrelu.filtered_lrelu(x=x, fu=self.up_filter, fd=self.down_filter, b=self.bias.to(x.dtype),
up=self.up_factor, down=self.down_factor, padding=self.padding, gain=gain, slope=slope, clamp=self.conv_clamp)
# Ensure correct shape and dtype.
misc.assert_shape(x, [None, self.out_channels, int(self.out_size[1]), int(self.out_size[0])])
assert x.dtype == dtype
return x
@staticmethod
def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False):
assert numtaps >= 1
# Identity filter.
if numtaps == 1:
return None
# Separable Kaiser low-pass filter.
if not radial:
f = scipy.signal.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs)
return torch.as_tensor(f, dtype=torch.float32)
# Radially symmetric jinc-based filter.
x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs
r = np.hypot(*np.meshgrid(x, x))
f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r)
beta = scipy.signal.kaiser_beta(scipy.signal.kaiser_atten(numtaps, width / (fs / 2)))
w = np.kaiser(numtaps, beta)
f *= np.outer(w, w)
f /= np.sum(f)
return torch.as_tensor(f, dtype=torch.float32)
def extra_repr(self):
return '\n'.join([
f'w_dim={self.w_dim:d}, is_torgb={self.is_torgb},',
f'is_critically_sampled={self.is_critically_sampled}, use_fp16={self.use_fp16},',
f'in_sampling_rate={self.in_sampling_rate:g}, out_sampling_rate={self.out_sampling_rate:g},',
f'in_cutoff={self.in_cutoff:g}, out_cutoff={self.out_cutoff:g},',
f'in_half_width={self.in_half_width:g}, out_half_width={self.out_half_width:g},',
f'in_size={list(self.in_size)}, out_size={list(self.out_size)},',
f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}'])
#----------------------------------------------------------------------------
@persistence.persistent_class
class SynthesisNetwork(torch.nn.Module):
def __init__(self,
w_dim, # Intermediate latent (W) dimensionality.
img_resolution, # Output image resolution.
img_channels, # Number of color channels.
channel_base = 32768, # Overall multiplier for the number of channels.
channel_max = 512, # Maximum number of channels in any layer.
num_layers = 14, # Total number of layers, excluding Fourier features and ToRGB.
num_critical = 2, # Number of critically sampled layers at the end.
first_cutoff = 2, # Cutoff frequency of the first layer (f_{c,0}).
first_stopband = 2**2.1, # Minimum stopband of the first layer (f_{t,0}).
last_stopband_rel = 2**0.3, # Minimum stopband of the last layer, expressed relative to the cutoff.
margin_size = 10, # Number of additional pixels outside the image.
output_scale = 0.25, # Scale factor for the output image.
num_fp16_res = 4, # Use FP16 for the N highest resolutions.
**layer_kwargs, # Arguments for SynthesisLayer.
):
super().__init__()
self.w_dim = w_dim
self.num_ws = num_layers + 2
self.img_resolution = img_resolution
self.img_channels = img_channels
self.num_layers = num_layers
self.num_critical = num_critical
self.margin_size = margin_size
self.output_scale = output_scale
self.num_fp16_res = num_fp16_res
# Geometric progression of layer cutoffs and min. stopbands.
last_cutoff = self.img_resolution / 2 # f_{c,N}
last_stopband = last_cutoff * last_stopband_rel # f_{t,N}
exponents = np.minimum(np.arange(self.num_layers + 1) / (self.num_layers - self.num_critical), 1)
cutoffs = first_cutoff * (last_cutoff / first_cutoff) ** exponents # f_c[i]
stopbands = first_stopband * (last_stopband / first_stopband) ** exponents # f_t[i]
# Compute remaining layer parameters.
sampling_rates = np.exp2(np.ceil(np.log2(np.minimum(stopbands * 2, self.img_resolution)))) # s[i]
half_widths = np.maximum(stopbands, sampling_rates / 2) - cutoffs # f_h[i]
sizes = sampling_rates + self.margin_size * 2
sizes[-2:] = self.img_resolution
channels = np.rint(np.minimum((channel_base / 2) / cutoffs, channel_max))
channels[-1] = self.img_channels
# Construct layers.
self.input = SynthesisInput(
w_dim=self.w_dim, channels=int(channels[0]), size=int(sizes[0]),
sampling_rate=sampling_rates[0], bandwidth=cutoffs[0])
self.layer_names = []
for idx in range(self.num_layers + 1):
prev = max(idx - 1, 0)
is_torgb = (idx == self.num_layers)
is_critically_sampled = (idx >= self.num_layers - self.num_critical)
use_fp16 = (sampling_rates[idx] * (2 ** self.num_fp16_res) > self.img_resolution)
layer = SynthesisLayer(
w_dim=self.w_dim, is_torgb=is_torgb, is_critically_sampled=is_critically_sampled, use_fp16=use_fp16,
in_channels=int(channels[prev]), out_channels= int(channels[idx]),
in_size=int(sizes[prev]), out_size=int(sizes[idx]),
in_sampling_rate=int(sampling_rates[prev]), out_sampling_rate=int(sampling_rates[idx]),
in_cutoff=cutoffs[prev], out_cutoff=cutoffs[idx],
in_half_width=half_widths[prev], out_half_width=half_widths[idx],
**layer_kwargs)
name = f'L{idx}_{layer.out_size[0]}_{layer.out_channels}'
setattr(self, name, layer)
self.layer_names.append(name)
def forward(self, ws, **layer_kwargs):
misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
ws = ws.to(torch.float32).unbind(dim=1)
# Execute layers.
x = self.input(ws[0])
for name, w in zip(self.layer_names, ws[1:]):
x = getattr(self, name)(x, w, **layer_kwargs)
if self.output_scale != 1:
x = x * self.output_scale
# Ensure correct shape and dtype.
misc.assert_shape(x, [None, self.img_channels, self.img_resolution, self.img_resolution])
x = x.to(torch.float32)
return x
def extra_repr(self):
return '\n'.join([
f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},',
f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',
f'num_layers={self.num_layers:d}, num_critical={self.num_critical:d},',
f'margin_size={self.margin_size:d}, num_fp16_res={self.num_fp16_res:d}'])
#----------------------------------------------------------------------------
@persistence.persistent_class
class Generator(torch.nn.Module):
def __init__(self,
z_dim, # Input latent (Z) dimensionality.
c_dim, # Conditioning label (C) dimensionality.
w_dim, # Intermediate latent (W) dimensionality.
img_resolution, # Output resolution.
img_channels, # Number of output color channels.
mapping_kwargs = {}, # Arguments for MappingNetwork.
resize=None,
**synthesis_kwargs, # Arguments for SynthesisNetwork.
):
super().__init__()
self.z_dim = z_dim
self.c_dim = c_dim
self.w_dim = w_dim
self.img_resolution = img_resolution
self.img_channels = img_channels
self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
self.num_ws = self.synthesis.num_ws
self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
self.resize = resize
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, input_is_w=False, **synthesis_kwargs):
if input_is_w:
ws = z
if ws.dim() == 2:
ws = ws.unsqueeze(1).repeat([1, self.mapping.num_ws, 1])
else:
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
if self.resize is not None:
img = imresize(img, [self.resize, self.resize])
return img
#----------------------------------------------------------------------------
def imresize(image, size):
dim = image.dim()
if dim == 3:
image = image.unsqueeze(1)
b, _, h, w = image.shape
if size[0] > h:
image = F.interpolate(image, size, mode='bilinear')
elif size[0] < h:
image = F.interpolate(image, size, mode='area')
if dim == 3:
image = image.squeeze(1)
return image
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Main training loop."""
import os
import time
import copy
import json
import pickle
import psutil
import PIL.Image
import numpy as np
import torch
import dnnlib
from torch_utils import misc
from torch_utils import training_stats
from torch_utils.ops import conv2d_gradfix
from torch_utils.ops import grid_sample_gradfix
import legacy
from metrics import metric_main
#----------------------------------------------------------------------------
def setup_snapshot_image_grid(training_set, random_seed=0):
rnd = np.random.RandomState(random_seed)
gw = np.clip(7680 // training_set.image_shape[2], 7, 32)
gh = np.clip(4320 // training_set.image_shape[1], 4, 32)
# No labels => show random subset of training samples.
if not training_set.has_labels:
all_indices = list(range(len(training_set)))
rnd.shuffle(all_indices)
grid_indices = [all_indices[i % len(all_indices)] for i in range(gw * gh)]
else:
# Group training samples by label.
label_groups = dict() # label => [idx, ...]
for idx in range(len(training_set)):
label = tuple(training_set.get_details(idx).raw_label.flat[::-1])
if label not in label_groups:
label_groups[label] = []
label_groups[label].append(idx)
# Reorder.
label_order = sorted(label_groups.keys())
for label in label_order:
rnd.shuffle(label_groups[label])
# Organize into grid.
grid_indices = []
for y in range(gh):
label = label_order[y % len(label_order)]
indices = label_groups[label]
grid_indices += [indices[x % len(indices)] for x in range(gw)]
label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))]
# Load data.
images, labels = zip(*[training_set[i] for i in grid_indices])
return (gw, gh), np.stack(images), np.stack(labels)
#----------------------------------------------------------------------------
def save_image_grid(img, fname, drange, grid_size):
lo, hi = drange
img = np.asarray(img, dtype=np.float32)
img = (img - lo) * (255 / (hi - lo))
img = np.rint(img).clip(0, 255).astype(np.uint8)
gw, gh = grid_size
_N, C, H, W = img.shape
img = img.reshape([gh, gw, C, H, W])
img = img.transpose(0, 3, 1, 4, 2)
img = img.reshape([gh * H, gw * W, C])
assert C in [1, 3]
if C == 1:
PIL.Image.fromarray(img[:, :, 0], 'L').save(fname)
if C == 3:
PIL.Image.fromarray(img, 'RGB').save(fname)
#----------------------------------------------------------------------------
def training_loop(
run_dir = '.', # Output directory.
training_set_kwargs = {}, # Options for training set.
data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader.
G_kwargs = {}, # Options for generator network.
D_kwargs = {}, # Options for discriminator network.
G_opt_kwargs = {}, # Options for generator optimizer.
D_opt_kwargs = {}, # Options for discriminator optimizer.
augment_kwargs = None, # Options for augmentation pipeline. None = disable.
loss_kwargs = {}, # Options for loss function.
metrics = [], # Metrics to evaluate during training.
random_seed = 0, # Global random seed.
num_gpus = 1, # Number of GPUs participating in the training.
rank = 0, # Rank of the current process in [0, num_gpus[.
batch_size = 4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus.
batch_gpu = 4, # Number of samples processed at a time by one GPU.
ema_kimg = 10, # Half-life of the exponential moving average (EMA) of generator weights.
ema_rampup = 0.05, # EMA ramp-up coefficient. None = no rampup.
G_reg_interval = None, # How often to perform regularization for G? None = disable lazy regularization.
D_reg_interval = 16, # How often to perform regularization for D? None = disable lazy regularization.
augment_p = 0, # Initial value of augmentation probability.
ada_target = None, # ADA target value. None = fixed p.
ada_interval = 4, # How often to perform ADA adjustment?
ada_kimg = 500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit.
total_kimg = 25000, # Total length of the training, measured in thousands of real images.
kimg_per_tick = 4, # Progress snapshot interval.
image_snapshot_ticks = 50, # How often to save image snapshots? None = disable.
network_snapshot_ticks = 50, # How often to save network snapshots? None = disable.
resume_pkl = None, # Network pickle to resume training from.
resume_kimg = 0, # First kimg to report when resuming training.
cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks.
progress_fn = None, # Callback function for updating training progress. Called for all ranks.
):
# Initialize.
start_time = time.time()
device = torch.device('cuda', rank)
np.random.seed(random_seed * num_gpus + rank)
torch.manual_seed(random_seed * num_gpus + rank)
torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed.
torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy.
torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy.
conv2d_gradfix.enabled = True # Improves training speed.
grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe.
# Load training set.
if rank == 0:
print('Loading training set...')
training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset
training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed)
training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs))
if rank == 0:
print()
print('Num images: ', len(training_set))
print('Image shape:', training_set.image_shape)
print('Label shape:', training_set.label_shape)
print()
# Construct networks.
if rank == 0:
print('Constructing networks...')
common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels)
G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
G_ema = copy.deepcopy(G).eval()
# Resume from existing pickle.
if (resume_pkl is not None) and (rank == 0):
print(f'Resuming from "{resume_pkl}"')
with dnnlib.util.open_url(resume_pkl) as f:
resume_data = legacy.load_network_pkl(f)
for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
misc.copy_params_and_buffers(resume_data[name], module, require_all=False)
# Print network summary tables.
if rank == 0:
z = torch.empty([batch_gpu, G.z_dim], device=device)
c = torch.empty([batch_gpu, G.c_dim], device=device)
img = misc.print_module_summary(G, [z, c])
misc.print_module_summary(D, [img, c])
# Setup augmentation.
if rank == 0:
print('Setting up augmentation...')
augment_pipe = None
ada_stats = None
if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None):
augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
augment_pipe.p.copy_(torch.as_tensor(augment_p))
if ada_target is not None:
ada_stats = training_stats.Collector(regex='Loss/signs/real')
# Distribute across GPUs.
if rank == 0:
print(f'Distributing across {num_gpus} GPUs...')
for module in [G, D, G_ema, augment_pipe]:
if module is not None and num_gpus > 1:
for param in misc.params_and_buffers(module):
torch.distributed.broadcast(param, src=0)
# Setup training phases.
if rank == 0:
print('Setting up training phases...')
loss = dnnlib.util.construct_class_by_name(device=device, G=G, D=D, augment_pipe=augment_pipe, **loss_kwargs) # subclass of training.loss.Loss
phases = []
for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval)]:
if reg_interval is None:
opt = dnnlib.util.construct_class_by_name(params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer
phases += [dnnlib.EasyDict(name=name+'both', module=module, opt=opt, interval=1)]
else: # Lazy regularization.
mb_ratio = reg_interval / (reg_interval + 1)
opt_kwargs = dnnlib.EasyDict(opt_kwargs)
opt_kwargs.lr = opt_kwargs.lr * mb_ratio
opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas]
opt = dnnlib.util.construct_class_by_name(module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer
phases += [dnnlib.EasyDict(name=name+'main', module=module, opt=opt, interval=1)]
phases += [dnnlib.EasyDict(name=name+'reg', module=module, opt=opt, interval=reg_interval)]
for phase in phases:
phase.start_event = None
phase.end_event = None
if rank == 0:
phase.start_event = torch.cuda.Event(enable_timing=True)
phase.end_event = torch.cuda.Event(enable_timing=True)
# Export sample images.
grid_size = None
grid_z = None
grid_c = None
if rank == 0:
print('Exporting sample images...')
grid_size, images, labels = setup_snapshot_image_grid(training_set=training_set)
save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size)
grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu)
grid_c = torch.from_numpy(labels).to(device).split(batch_gpu)
images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
# Initialize logs.
if rank == 0:
print('Initializing logs...')
stats_collector = training_stats.Collector(regex='.*')
stats_metrics = dict()
stats_jsonl = None
stats_tfevents = None
if rank == 0:
stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt')
try:
import torch.utils.tensorboard as tensorboard
stats_tfevents = tensorboard.SummaryWriter(run_dir)
except ImportError as err:
print('Skipping tfevents export:', err)
# Train.
if rank == 0:
print(f'Training for {total_kimg} kimg...')
print()
cur_nimg = resume_kimg * 1000
cur_tick = 0
tick_start_nimg = cur_nimg
tick_start_time = time.time()
maintenance_time = tick_start_time - start_time
batch_idx = 0
if progress_fn is not None:
progress_fn(0, total_kimg)
while True:
# Fetch training data.
with torch.autograd.profiler.record_function('data_fetch'):
phase_real_img, phase_real_c = next(training_set_iterator)
phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu)
phase_real_c = phase_real_c.to(device).split(batch_gpu)
all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device)
all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size)]
all_gen_c = [training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size)]
all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device)
all_gen_c = [phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size)]
# Execute training phases.
for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c):
if batch_idx % phase.interval != 0:
continue
if phase.start_event is not None:
phase.start_event.record(torch.cuda.current_stream(device))
# Accumulate gradients.
phase.opt.zero_grad(set_to_none=True)
phase.module.requires_grad_(True)
for real_img, real_c, gen_z, gen_c in zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c):
loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, gain=phase.interval, cur_nimg=cur_nimg)
phase.module.requires_grad_(False)
# Update weights.
with torch.autograd.profiler.record_function(phase.name + '_opt'):
params = [param for param in phase.module.parameters() if param.grad is not None]
if len(params) > 0:
flat = torch.cat([param.grad.flatten() for param in params])
if num_gpus > 1:
torch.distributed.all_reduce(flat)
flat /= num_gpus
misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat)
grads = flat.split([param.numel() for param in params])
for param, grad in zip(params, grads):
param.grad = grad.reshape(param.shape)
phase.opt.step()
# Phase done.
if phase.end_event is not None:
phase.end_event.record(torch.cuda.current_stream(device))
# Update G_ema.
with torch.autograd.profiler.record_function('Gema'):
ema_nimg = ema_kimg * 1000
if ema_rampup is not None:
ema_nimg = min(ema_nimg, cur_nimg * ema_rampup)
ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8))
for p_ema, p in zip(G_ema.parameters(), G.parameters()):
p_ema.copy_(p.lerp(p_ema, ema_beta))
for b_ema, b in zip(G_ema.buffers(), G.buffers()):
b_ema.copy_(b)
# Update state.
cur_nimg += batch_size
batch_idx += 1
# Execute ADA heuristic.
if (ada_stats is not None) and (batch_idx % ada_interval == 0):
ada_stats.update()
adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (batch_size * ada_interval) / (ada_kimg * 1000)
augment_pipe.p.copy_((augment_pipe.p + adjust).max(misc.constant(0, device=device)))
# Perform maintenance tasks once per tick.
done = (cur_nimg >= total_kimg * 1000)
if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
continue
# Print status line, accumulating the same information in training_stats.
tick_end_time = time.time()
fields = []
fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"]
fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"]
torch.cuda.reset_peak_memory_stats()
fields += [f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"]
training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60))
training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60))
if rank == 0:
print(' '.join(fields))
# Check for abort.
if (not done) and (abort_fn is not None) and abort_fn():
done = True
if rank == 0:
print()
print('Aborting...')
# Save image snapshot.
if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0):
images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1,1], grid_size=grid_size)
# Save network snapshot.
snapshot_pkl = None
snapshot_data = None
if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0):
snapshot_data = dict(G=G, D=D, G_ema=G_ema, augment_pipe=augment_pipe, training_set_kwargs=dict(training_set_kwargs))
for key, value in snapshot_data.items():
if isinstance(value, torch.nn.Module):
value = copy.deepcopy(value).eval().requires_grad_(False)
if num_gpus > 1:
misc.check_ddp_consistency(value, ignore_regex=r'.*\.[^.]+_(avg|ema)')
for param in misc.params_and_buffers(value):
torch.distributed.broadcast(param, src=0)
snapshot_data[key] = value.cpu()
del value # conserve memory
snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl')
if rank == 0:
with open(snapshot_pkl, 'wb') as f:
pickle.dump(snapshot_data, f)
# Evaluate metrics.
if (snapshot_data is not None) and (len(metrics) > 0):
if rank == 0:
print('Evaluating metrics...')
for metric in metrics:
result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'],
dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device)
if rank == 0:
metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl)
stats_metrics.update(result_dict.results)
del snapshot_data # conserve memory
# Collect statistics.
for phase in phases:
value = []
if (phase.start_event is not None) and (phase.end_event is not None):
phase.end_event.synchronize()
value = phase.start_event.elapsed_time(phase.end_event)
training_stats.report0('Timing/' + phase.name, value)
stats_collector.update()
stats_dict = stats_collector.as_dict()
# Update logs.
timestamp = time.time()
if stats_jsonl is not None:
fields = dict(stats_dict, timestamp=timestamp)
stats_jsonl.write(json.dumps(fields) + '\n')
stats_jsonl.flush()
if stats_tfevents is not None:
global_step = int(cur_nimg / 1e3)
walltime = timestamp - start_time
for name, value in stats_dict.items():
stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime)
for name, value in stats_metrics.items():
stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime)
stats_tfevents.flush()
if progress_fn is not None:
progress_fn(cur_nimg // 1000, total_kimg)
# Update state.
cur_tick += 1
tick_start_nimg = cur_nimg
tick_start_time = time.time()
maintenance_time = tick_start_time - tick_end_time
if done:
break
# Done.
if rank == 0:
print()
print('Exiting...')
#----------------------------------------------------------------------------
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Train a GAN using the techniques described in the paper
"Training Generative Adversarial Networks with Limited Data"."""
import os
import click
import re
import json
import tempfile
import torch
import dnnlib
import ast
from training import training_loop
from metrics import metric_main
from torch_utils import training_stats
from torch_utils import custom_ops
#----------------------------------------------------------------------------
class UserError(Exception):
pass
#----------------------------------------------------------------------------
def setup_training_loop_kwargs(
# General options (not included in desc).
gpus = None, # Number of GPUs: <int>, default = 1 gpu
snap = None, # Snapshot interval: <int>, default = 50 ticks
metrics = None, # List of metric names: [], ['fid50k_full'] (default), ...
seed = None, # Random seed: <int>, default = 0
# Dataset.
data = None, # Training dataset (required): <path>
cond = None, # Train conditional model based on dataset labels: <bool>, default = False
subset = None, # Train with only N images: <int>, default = all
mirror = None, # Augment dataset with x-flips: <bool>, default = False
square = None,
# Base config.
cfg = None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar', 'shhq'
gamma = None, # Override R1 gamma: <float>
kimg = None, # Override training duration: <int>
batch = None, # Override batch size: <int>
# Discriminator augmentation.
aug = None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed'
p = None, # Specify p for 'fixed' (required): <float>
target = None, # Override ADA target for 'ada': <float>, default = depends on aug
augpipe = None, # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc'
# Transfer learning.
resume = None, # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', <file>, <url>
freezed = None, # Freeze-D: <int>, default = 0 discriminator layers
# Performance options (not included in desc).
fp32 = None, # Disable mixed-precision training: <bool>, default = False
nhwc = None, # Use NHWC memory format with FP16: <bool>, default = False
allow_tf32 = None, # Allow PyTorch to use TF32 for matmul and convolutions: <bool>, default = False
nobench = None, # Disable cuDNN benchmarking: <bool>, default = False
workers = None, # Override number of DataLoader workers: <int>, default = 3
):
args = dnnlib.EasyDict()
# ------------------------------------------
# General options: gpus, snap, metrics, seed
# ------------------------------------------
if gpus is None:
gpus = 1
assert isinstance(gpus, int)
if not (gpus >= 1 and gpus & (gpus - 1) == 0):
raise UserError('--gpus must be a power of two')
args.num_gpus = gpus
if snap is None:
snap = 50
assert isinstance(snap, int)
if snap < 1:
raise UserError('--snap must be at least 1')
args.image_snapshot_ticks = snap
args.network_snapshot_ticks = snap
if metrics is None:
metrics = ['fid50k_full']
assert isinstance(metrics, list)
if not all(metric_main.is_valid_metric(metric) for metric in metrics):
raise UserError('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
args.metrics = metrics
if seed is None:
seed = 0
assert isinstance(seed, int)
args.random_seed = seed
# -------------------------------------------
# Dataset: data, cond, subset, mirror, square
# -------------------------------------------
print('square : ', square)
assert data is not None
assert isinstance(data, str)
args.training_set_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False, square=square)
args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=3, prefetch_factor=2)
try:
training_set = dnnlib.util.construct_class_by_name(**args.training_set_kwargs) # subclass of training.dataset.Dataset
args.training_set_kwargs.resolution = training_set.resolution # be explicit about resolution
args.training_set_kwargs.use_labels = training_set.has_labels # be explicit about labels
args.training_set_kwargs.max_size = len(training_set) # be explicit about dataset size
desc = training_set.name
print('desc: ', desc)
del training_set # conserve memory
except IOError as err:
raise UserError(f'--data: {err}')
if square: desc += '-square'
else: desc += '-rectangle'
if cond is None:
cond = False
assert isinstance(cond, bool)
if cond:
if not args.training_set_kwargs.use_labels:
raise UserError('--cond=True requires labels specified in dataset.json')
desc += '-cond'
else:
args.training_set_kwargs.use_labels = False
if subset is not None:
assert isinstance(subset, int)
if not 1 <= subset <= args.training_set_kwargs.max_size:
raise UserError(f'--subset must be between 1 and {args.training_set_kwargs.max_size}')
desc += f'-subset{subset}'
if subset < args.training_set_kwargs.max_size:
args.training_set_kwargs.max_size = subset
args.training_set_kwargs.random_seed = args.random_seed
if mirror is None:
mirror = False
assert isinstance(mirror, bool)
if mirror:
desc += '-mirror'
args.training_set_kwargs.xflip = True
# ------------------------------------
# Base config: cfg, gamma, kimg, batch
# ------------------------------------
if cfg is None:
cfg = 'auto'
assert isinstance(cfg, str)
desc += f'-{cfg}'
cfg_specs = {
'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2),
'shhq': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=8), # Populated dynamically based on resolution and GPU count.
'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # Uses mixed-precision, unlike the original StyleGAN2.
'paper256': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8),
'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8),
'paper1024': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=2, ema=10, ramp=None, map=8),
'cifar': dict(ref_gpus=2, kimg=100000, mb=64, mbstd=32, fmaps=1, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2),
}
assert cfg in cfg_specs
spec = dnnlib.EasyDict(cfg_specs[cfg])
if cfg == 'auto' or cfg == 'shhq':
desc += f'{gpus:d}'
spec.ref_gpus = gpus
res = args.training_set_kwargs.resolution
spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) # keep gpu memory consumption at bay
spec.mbstd = min(spec.mb // gpus, 4) # other hyperparams behave more predictably if mbstd group size remains fixed
spec.fmaps = 1 if res >= 512 else 0.5
spec.lrate = 0.002 if res >= 1024 else 0.0025
spec.gamma = 0.0002 * (res ** 2) / spec.mb # heuristic formula
spec.ema = spec.mb * 10 / 32
args.G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator', z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict(),square=square)
args.D_kwargs = dnnlib.EasyDict(class_name='training.networks.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict(),square=square)
args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int(spec.fmaps * 32768)
args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512
args.G_kwargs.mapping_kwargs.num_layers = spec.map
args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4 # enable mixed-precision training
args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256 # clamp activations to avoid float16 overflow
args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd
args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8)
args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8)
args.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.StyleGAN2Loss', r1_gamma=spec.gamma)
args.total_kimg = spec.kimg
args.batch_size = spec.mb
args.batch_gpu = spec.mb // spec.ref_gpus
args.ema_kimg = spec.ema
args.ema_rampup = spec.ramp
if cfg == 'cifar':
args.loss_kwargs.pl_weight = 0 # disable path length regularization
args.loss_kwargs.style_mixing_prob = 0 # disable style mixing
args.D_kwargs.architecture = 'orig' # disable residual skip connections
if gamma is not None:
assert isinstance(gamma, float)
if not gamma >= 0:
raise UserError('--gamma must be non-negative')
desc += f'-gamma{gamma:g}'
args.loss_kwargs.r1_gamma = gamma
if kimg is not None:
assert isinstance(kimg, int)
if not kimg >= 1:
raise UserError('--kimg must be at least 1')
desc += f'-kimg{kimg:d}'
args.total_kimg = kimg
if batch is not None:
assert isinstance(batch, int)
if not (batch >= 1 and batch % gpus == 0):
raise UserError('--batch must be at least 1 and divisible by --gpus')
desc += f'-batch{batch}'
args.batch_size = batch
args.batch_gpu = batch // gpus
# ---------------------------------------------------
# Discriminator augmentation: aug, p, target, augpipe
# ---------------------------------------------------
if aug is None:
aug = 'ada'
else:
assert isinstance(aug, str)
desc += f'-{aug}'
if aug == 'ada':
args.ada_target = 0.6
elif aug == 'noaug':
pass
elif aug == 'fixed':
if p is None:
raise UserError(f'--aug={aug} requires specifying --p')
else:
raise UserError(f'--aug={aug} not supported')
if p is not None:
assert isinstance(p, float)
if aug != 'fixed':
raise UserError('--p can only be specified with --aug=fixed')
if not 0 <= p <= 1:
raise UserError('--p must be between 0 and 1')
desc += f'-p{p:g}'
args.augment_p = p
if target is not None:
assert isinstance(target, float)
if aug != 'ada':
raise UserError('--target can only be specified with --aug=ada')
if not 0 <= target <= 1:
raise UserError('--target must be between 0 and 1')
desc += f'-target{target:g}'
args.ada_target = target
assert augpipe is None or isinstance(augpipe, str)
if augpipe is None:
augpipe = 'bgc'
else:
if aug == 'noaug':
raise UserError('--augpipe cannot be specified with --aug=noaug')
desc += f'-{augpipe}'
augpipe_specs = {
'blit': dict(xflip=1, rotate90=1, xint=1),
'geom': dict(scale=1, rotate=1, aniso=1, xfrac=1),
'color': dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
'filter': dict(imgfilter=1),
'noise': dict(noise=1),
'cutout': dict(cutout=1),
'bg': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1),
'bgc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
'bgcf': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1),
'bgcfn': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1),
'bgcfnc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, cutout=1),
'body': dict(xflip=1, rotate90=0, xint=1, scale=1, rotate=0, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1)
}
assert augpipe in augpipe_specs
if aug != 'noaug':
args.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', **augpipe_specs[augpipe])
# ----------------------------------
# Transfer learning: resume, freezed
# ----------------------------------
resume_specs = {
'ffhq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl',
'ffhq512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl',
'ffhq1024': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl',
'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl',
'lsundog256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl',
}
assert resume is None or isinstance(resume, str)
if resume is None:
resume = 'noresume'
elif resume == 'noresume':
desc += '-noresume'
elif resume in resume_specs:
desc += f'-resume{resume}'
args.resume_pkl = resume_specs[resume] # predefined url
else:
desc += '-resumecustom'
args.resume_pkl = resume # custom path or url
if resume != 'noresume':
args.ada_kimg = 100 # make ADA react faster at the beginning
args.ema_rampup = None # disable EMA rampup
if freezed is not None:
assert isinstance(freezed, int)
if not freezed >= 0:
raise UserError('--freezed must be non-negative')
desc += f'-freezed{freezed:d}'
args.D_kwargs.block_kwargs.freeze_layers = freezed
# -------------------------------------------------
# Performance options: fp32, nhwc, nobench, workers
# -------------------------------------------------
if fp32 is None:
fp32 = False
assert isinstance(fp32, bool)
if fp32:
args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0
args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None
if nhwc is None:
nhwc = False
assert isinstance(nhwc, bool)
if nhwc:
args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True
if nobench is None:
nobench = False
assert isinstance(nobench, bool)
if nobench:
args.cudnn_benchmark = False
if allow_tf32 is None:
allow_tf32 = False
assert isinstance(allow_tf32, bool)
if allow_tf32:
args.allow_tf32 = True
if workers is not None:
assert isinstance(workers, int)
if not workers >= 1:
raise UserError('--workers must be at least 1')
args.data_loader_kwargs.num_workers = workers
return desc, args
#----------------------------------------------------------------------------
def subprocess_fn(rank, args, temp_dir):
dnnlib.util.Logger(file_name=os.path.join(args.run_dir, 'log.txt'), file_mode='a', should_flush=True)
# Init torch.distributed.
if args.num_gpus > 1:
init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
if os.name == 'nt':
init_method = 'file:///' + init_file.replace('\\', '/')
torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
else:
init_method = f'file://{init_file}'
torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
# Init torch_utils.
sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
if rank != 0:
custom_ops.verbosity = 'none'
# Execute training loop.
training_loop.training_loop(rank=rank, **args)
#----------------------------------------------------------------------------
class CommaSeparatedList(click.ParamType):
name = 'list'
def convert(self, value, param, ctx):
_ = param, ctx
if value is None or value.lower() == 'none' or value == '':
return []
return value.split(',')
#----------------------------------------------------------------------------
@click.command()
@click.pass_context
# General options.
@click.option('--outdir', help='Where to save the results', required=True, metavar='DIR')
@click.option('--gpus', help='Number of GPUs to use [default: 1]', type=int, metavar='INT')
@click.option('--snap', help='Snapshot interval [default: 50 ticks]', type=int, metavar='INT')
@click.option('--metrics', help='Comma-separated list or "none" [default: fid50k_full]', type=CommaSeparatedList())
@click.option('--seed', help='Random seed [default: 0]', type=int, metavar='INT')
@click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True)
# Dataset.
@click.option('--data', help='Training data (directory or zip)', metavar='PATH', required=True)
@click.option('--cond', help='Train conditional model based on dataset labels [default: false]', type=bool, metavar='BOOL')
@click.option('--subset', help='Train with only N images [default: all]', type=int, metavar='INT')
@click.option('--mirror', help='Enable dataset x-flips [default: false]', type=bool, metavar='BOOL')
@click.option('--square', help='True for square, False for rectangle', type=bool, metavar='BOOL', default=False)
# Base config.
@click.option('--cfg', help='Base config [default: auto]', type=click.Choice(['auto', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar','shhq']))
@click.option('--gamma', help='Override R1 gamma', type=float)
@click.option('--kimg', help='Override training duration', type=int, metavar='INT')
@click.option('--batch', help='Override batch size', type=int, metavar='INT')
# Discriminator augmentation.
@click.option('--aug', help='Augmentation mode [default: ada]', type=click.Choice(['noaug', 'ada', 'fixed']))
@click.option('--p', help='Augmentation probability for --aug=fixed', type=float)
@click.option('--target', help='ADA target value for --aug=ada', type=float)
@click.option('--augpipe', help='Augmentation pipeline [default: bgc]', type=click.Choice(['blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc', 'bgcf', 'bgcfn', 'bgcfnc', 'body']))
# Transfer learning.
@click.option('--resume', help='Resume training [default: noresume]', metavar='PKL')
@click.option('--freezed', help='Freeze-D [default: 0 layers]', type=int, metavar='INT')
# Performance options.
@click.option('--fp32', help='Disable mixed-precision training', type=bool, metavar='BOOL')
@click.option('--nhwc', help='Use NHWC memory format with FP16', type=bool, metavar='BOOL')
@click.option('--nobench', help='Disable cuDNN benchmarking', type=bool, metavar='BOOL')
@click.option('--allow-tf32', help='Allow PyTorch to use TF32 internally', type=bool, metavar='BOOL')
@click.option('--workers', help='Override number of DataLoader workers', type=int, metavar='INT')
def main(ctx, outdir, dry_run, **config_kwargs):
"""Train a GAN using the techniques described in the paper
"Training Generative Adversarial Networks with Limited Data".
Examples:
\b
# Train with custom dataset using 1 GPU.
python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1
\b
# Train class-conditional CIFAR-10 using 2 GPUs.
python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \\
--gpus=2 --cfg=cifar --cond=1
\b
# Transfer learn MetFaces from FFHQ using 4 GPUs.
python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \\
--gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10
\b
# Reproduce original StyleGAN2 config F.
python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \\
--gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug
\b
Base configs (--cfg):
auto Automatically select reasonable defaults based on resolution
and GPU count. Good starting point for new datasets.
stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024.
paper256 Reproduce results for FFHQ and LSUN Cat at 256x256.
paper512 Reproduce results for BreCaHAD and AFHQ at 512x512.
paper1024 Reproduce results for MetFaces at 1024x1024.
cifar Reproduce results for CIFAR-10 at 32x32.
\b
Transfer learning source networks (--resume):
ffhq256 FFHQ trained at 256x256 resolution.
ffhq512 FFHQ trained at 512x512 resolution.
ffhq1024 FFHQ trained at 1024x1024 resolution.
celebahq256 CelebA-HQ trained at 256x256 resolution.
lsundog256 LSUN Dog trained at 256x256 resolution.
<PATH or URL> Custom network pickle.
"""
dnnlib.util.Logger(should_flush=True)
# Setup training options.
try:
run_desc, args = setup_training_loop_kwargs(**config_kwargs)
except UserError as err:
ctx.fail(err)
# Pick output directory.
prev_run_dirs = []
if os.path.isdir(outdir):
prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))]
prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
cur_run_id = max(prev_run_ids, default=-1) + 1
args.run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{run_desc}')
assert not os.path.exists(args.run_dir)
# Print options.
print()
print('Training options:')
print(json.dumps(args, indent=2))
print()
print(f'Output directory: {args.run_dir}')
print(f'Training data: {args.training_set_kwargs.path}')
print(f'Training duration: {args.total_kimg} kimg')
print(f'Number of GPUs: {args.num_gpus}')
print(f'Number of images: {args.training_set_kwargs.max_size}')
print(f'Image resolution: {args.training_set_kwargs.resolution}')
print(f'Conditional model: {args.training_set_kwargs.use_labels}')
print(f'Dataset x-flips: {args.training_set_kwargs.xflip}')
print()
# Dry run?
if dry_run:
print('Dry run; exiting.')
return
# Create output directory.
print('Creating output directory...')
os.makedirs(args.run_dir, exist_ok=True)
with open(os.path.join(args.run_dir, 'training_options.json'), 'wt') as f:
json.dump(args, f, indent=2)
# Launch processes.
print('Launching processes...')
torch.multiprocessing.set_start_method('spawn')
with tempfile.TemporaryDirectory() as temp_dir:
if args.num_gpus == 1:
subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
else:
torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
#----------------------------------------------------------------------------
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter
#----------------------------------------------------------------------------
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import os
import numpy as np
import zipfile
import PIL.Image
import json
import torch
import dnnlib
import cv2
from collections import Counter
try:
import pyspng
except ImportError:
pyspng = None
#----------------------------------------------------------------------------
class Dataset(torch.utils.data.Dataset):
def __init__(self,
name, # Name of the dataset.
raw_shape, # Shape of the raw image data (NCHW).
max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
use_labels = False, # Enable conditioning labels? False = label dimension is zero.
xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
random_seed = 0, # Random seed to use when applying max_size.
square = False,
):
# print(' Inside Dataset ')
self._name = name
self._raw_shape = list(raw_shape)
self._use_labels = use_labels
self._raw_labels = None
self._label_shape = None
self._square = square
# Apply max_size.
self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
if (max_size is not None) and (self._raw_idx.size > max_size):
np.random.RandomState(random_seed).shuffle(self._raw_idx)
self._raw_idx = np.sort(self._raw_idx[:max_size])
# Apply xflip.
self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
if xflip:
self._raw_idx = np.tile(self._raw_idx, 2)
self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
def _get_raw_labels(self):
if self._raw_labels is None:
self._raw_labels = self._load_raw_labels() if self._use_labels else None
if self._raw_labels is None:
self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
assert isinstance(self._raw_labels, np.ndarray)
assert self._raw_labels.shape[0] == self._raw_shape[0]
assert self._raw_labels.dtype in [np.float32, np.int64]
if self._raw_labels.dtype == np.int64:
assert self._raw_labels.ndim == 1
assert np.all(self._raw_labels >= 0)
return self._raw_labels
def close(self): # to be overridden by subclass
pass
def _load_raw_image(self, raw_idx): # to be overridden by subclass
raise NotImplementedError
def _load_raw_labels(self): # to be overridden by subclass
raise NotImplementedError
def __getstate__(self):
return dict(self.__dict__, _raw_labels=None)
def __del__(self):
try:
self.close()
except:
pass
def __len__(self):
return self._raw_idx.size
def __getitem__(self, idx):
image = self._load_raw_image(self._raw_idx[idx])
assert isinstance(image, np.ndarray)
assert list(image.shape) == self.image_shape
assert image.dtype == np.uint8
if self._xflip[idx]:
assert image.ndim == 3 # CHW
image = image[:, :, ::-1]
return image.copy(), self.get_label(idx)
def get_label(self, idx):
label = self._get_raw_labels()[self._raw_idx[idx]]
if label.dtype == np.int64:
onehot = np.zeros(self.label_shape, dtype=np.float32)
onehot[label] = 1
label = onehot
return label.copy()
def get_details(self, idx):
d = dnnlib.EasyDict()
d.raw_idx = int(self._raw_idx[idx])
d.xflip = (int(self._xflip[idx]) != 0)
d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
return d
@property
def name(self):
return self._name
@property
def image_shape(self):
return list(self._raw_shape[1:])
@property
def num_channels(self):
assert len(self.image_shape) == 3 # CHW
return self.image_shape[0]
@property
def resolution(self):
assert len(self.image_shape) == 3 # CHW
if self._square:
assert self.image_shape[1] == self.image_shape[2]
else:
assert self.image_shape[1] == self.image_shape[2] * 2
return self.image_shape[1]
@property
def label_shape(self):
if self._label_shape is None:
raw_labels = self._get_raw_labels()
if raw_labels.dtype == np.int64:
self._label_shape = [int(np.max(raw_labels)) + 1]
else:
self._label_shape = raw_labels.shape[1:]
return list(self._label_shape)
@property
def label_dim(self):
assert len(self.label_shape) == 1
return self.label_shape[0]
@property
def has_labels(self):
return any(x != 0 for x in self.label_shape)
@property
def has_onehot_labels(self):
return self._get_raw_labels().dtype == np.int64
#----------------------------------------------------------------------------
class ImageFolderDataset(Dataset):
def __init__(self,
path, # Path to directory or zip.
resolution = None, # Ensure specific resolution, None = highest available.
square = False,
**super_kwargs, # Additional arguments for the Dataset base class.
):
self._path = path
self._zipfile = None
self._square = square
if os.path.isdir(self._path):
self._type = 'dir'
self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
elif self._file_ext(self._path) == '.zip':
self._type = 'zip'
self._all_fnames = set(self._get_zipfile().namelist())
else:
raise IOError('Path must point to a directory or zip')
PIL.Image.init()
self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
if len(self._image_fnames) == 0:
raise IOError('No image files found in the specified path')
name = os.path.splitext(os.path.basename(self._path))[0]
raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
# if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
# raise IOError('Image files do not match the specified resolution')
if resolution is not None:
if self._square:
raw_shape[2] = raw_shape[3] = resolution
else:
raw_shape[2] = resolution
raw_shape[3] = resolution // 2
# print(raw_shape)
super().__init__(name=name, raw_shape=raw_shape,square=square, **super_kwargs)
@staticmethod
def _file_ext(fname):
return os.path.splitext(fname)[1].lower()
def _get_zipfile(self):
assert self._type == 'zip'
if self._zipfile is None:
self._zipfile = zipfile.ZipFile(self._path)
return self._zipfile
def _open_file(self, fname):
if self._type == 'dir':
return open(os.path.join(self._path, fname), 'rb')
if self._type == 'zip':
return self._get_zipfile().open(fname, 'r')
return None
def close(self):
try:
if self._zipfile is not None:
self._zipfile.close()
finally:
self._zipfile = None
def __getstate__(self):
return dict(super().__getstate__(), _zipfile=None)
def _load_raw_image(self, raw_idx): #load single image
fname = self._image_fnames[raw_idx]
with self._open_file(fname) as f:
if pyspng is not None and self._file_ext(fname) == '.png':
image = pyspng.load(f.read())
else:
image = np.array(PIL.Image.open(f))
if image.ndim == 2:
image = image[:, :, np.newaxis] # HW => HWC
image = image.transpose(2, 0, 1) # HWC => CHW
return image
def _load_raw_labels(self):
fname = 'dataset.json'
if fname not in self._all_fnames:
return None
with self._open_file(fname) as f:
labels = json.load(f)['labels']
if labels is None:
return None
labels = dict(labels)
labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
labels = np.array(labels)
labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
return labels
#----------------------------------------------------------------------------
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import numpy as np
import torch
from torch_utils import misc
from torch_utils import persistence
from torch_utils.ops import conv2d_resample
from torch_utils.ops import upfirdn2d
from torch_utils.ops import bias_act
from torch_utils.ops import fma
#----------------------------------------------------------------------------
@misc.profiled_function
def normalize_2nd_moment(x, dim=1, eps=1e-8):
return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
#----------------------------------------------------------------------------
@misc.profiled_function
def modulated_conv2d(
x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
styles, # Modulation coefficients of shape [batch_size, in_channels].
noise = None, # Optional noise tensor to add to the output activations.
up = 1, # Integer upsampling factor.
down = 1, # Integer downsampling factor.
padding = 0, # Padding with respect to the upsampled image.
resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
demodulate = True, # Apply weight demodulation?
flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation?
):
batch_size = x.shape[0]
out_channels, in_channels, kh, kw = weight.shape
misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
misc.assert_shape(styles, [batch_size, in_channels]) # [NI]
# Pre-normalize inputs to avoid FP16 overflow.
if x.dtype == torch.float16 and demodulate:
weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk
styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
# Calculate per-sample weights and demodulation coefficients.
w = None
dcoefs = None
if demodulate or fused_modconv:
w = weight.unsqueeze(0) # [NOIkk]
w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
if demodulate:
dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
if demodulate and fused_modconv:
w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
# Execute by scaling the activations before and after the convolution.
if not fused_modconv:
x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight)
if demodulate and noise is not None:
x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
elif demodulate:
x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
elif noise is not None:
x = x.add_(noise.to(x.dtype))
return x
# Execute as one fused op using grouped convolution.
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
batch_size = int(batch_size)
misc.assert_shape(x, [batch_size, in_channels, None, None])
x = x.reshape(1, -1, *x.shape[2:])
w = w.reshape(-1, in_channels, kh, kw)
x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight)
x = x.reshape(batch_size, -1, *x.shape[2:])
if noise is not None:
x = x.add_(noise)
return x
#----------------------------------------------------------------------------
@persistence.persistent_class
class FullyConnectedLayer(torch.nn.Module):
def __init__(self,
in_features, # Number of input features.
out_features, # Number of output features.
bias = True, # Apply additive bias before the activation function?
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
lr_multiplier = 1, # Learning rate multiplier.
bias_init = 0, # Initial value for the additive bias.
):
super().__init__()
self.activation = activation
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
self.weight_gain = lr_multiplier / np.sqrt(in_features)
self.bias_gain = lr_multiplier
def forward(self, x):
w = self.weight.to(x.dtype) * self.weight_gain
b = self.bias
if b is not None:
b = b.to(x.dtype)
if self.bias_gain != 1:
b = b * self.bias_gain
if self.activation == 'linear' and b is not None:
x = torch.addmm(b.unsqueeze(0), x, w.t())
else:
x = x.matmul(w.t())
x = bias_act.bias_act(x, b, act=self.activation)
return x
#----------------------------------------------------------------------------
@persistence.persistent_class
class Conv2dLayer(torch.nn.Module):
def __init__(self,
in_channels, # Number of input channels.
out_channels, # Number of output channels.
kernel_size, # Width and height of the convolution kernel.
bias = True, # Apply additive bias before the activation function?
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
up = 1, # Integer upsampling factor.
down = 1, # Integer downsampling factor.
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
conv_clamp = None, # Clamp the output to +-X, None = disable clamping.
channels_last = False, # Expect the input to have memory_format=channels_last?
trainable = True, # Update the weights of this layer during training?
):
super().__init__()
self.activation = activation
self.up = up
self.down = down
self.conv_clamp = conv_clamp
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
self.padding = kernel_size // 2
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
self.act_gain = bias_act.activation_funcs[activation].def_gain
memory_format = torch.channels_last if channels_last else torch.contiguous_format
weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
bias = torch.zeros([out_channels]) if bias else None
if trainable:
self.weight = torch.nn.Parameter(weight)
self.bias = torch.nn.Parameter(bias) if bias is not None else None
else:
self.register_buffer('weight', weight)
if bias is not None:
self.register_buffer('bias', bias)
else:
self.bias = None
def forward(self, x, gain=1):
w = self.weight * self.weight_gain
b = self.bias.to(x.dtype) if self.bias is not None else None
flip_weight = (self.up == 1) # slightly faster
x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight)
act_gain = self.act_gain * gain
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
return x
#----------------------------------------------------------------------------
@persistence.persistent_class
class MappingNetwork(torch.nn.Module):
def __init__(self,
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
w_dim, # Intermediate latent (W) dimensionality.
num_ws, # Number of intermediate latents to output, None = do not broadcast.
num_layers = 8, # Number of mapping layers.
embed_features = None, # Label embedding dimensionality, None = same as w_dim.
layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track.
):
super().__init__()
self.z_dim = z_dim
self.c_dim = c_dim
self.w_dim = w_dim
self.num_ws = num_ws
self.num_layers = num_layers
self.w_avg_beta = w_avg_beta
if embed_features is None:
embed_features = w_dim
if c_dim == 0:
embed_features = 0
if layer_features is None:
layer_features = w_dim
features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
if c_dim > 0:
self.embed = FullyConnectedLayer(c_dim, embed_features)
for idx in range(num_layers):
in_features = features_list[idx]
out_features = features_list[idx + 1]
layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
setattr(self, f'fc{idx}', layer)
if num_ws is not None and w_avg_beta is not None:
self.register_buffer('w_avg', torch.zeros([w_dim]))
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
# Embed, normalize, and concat inputs.
x = None
with torch.autograd.profiler.record_function('input'):
if self.z_dim > 0:
misc.assert_shape(z, [None, self.z_dim])
x = normalize_2nd_moment(z.to(torch.float32))
if self.c_dim > 0:
misc.assert_shape(c, [None, self.c_dim])
y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
x = torch.cat([x, y], dim=1) if x is not None else y
# Main layers.
for idx in range(self.num_layers):
layer = getattr(self, f'fc{idx}')
x = layer(x)
# Update moving average of W.
if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
with torch.autograd.profiler.record_function('update_w_avg'):
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
# Broadcast.
if self.num_ws is not None:
with torch.autograd.profiler.record_function('broadcast'):
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
# Apply truncation.
if truncation_psi != 1:
with torch.autograd.profiler.record_function('truncate'):
assert self.w_avg_beta is not None
if self.num_ws is None or truncation_cutoff is None:
x = self.w_avg.lerp(x, truncation_psi)
else:
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
return x
#----------------------------------------------------------------------------
@persistence.persistent_class
class SynthesisLayer(torch.nn.Module):
def __init__(self,
in_channels, # Number of input channels.
out_channels, # Number of output channels.
w_dim, # Intermediate latent (W) dimensionality.
resolution, # Resolution of this layer.
kernel_size = 3, # Convolution kernel size.
up = 1, # Integer upsampling factor.
use_noise = True, # Enable noise input?
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
channels_last = False, # Use channels_last format for the weights?
square = False, # default if for rectangle images
):
super().__init__()
self.resolution = resolution
self.up = up
self.use_noise = use_noise
self.activation = activation
self.conv_clamp = conv_clamp
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
self.padding = kernel_size // 2
self.act_gain = bias_act.activation_funcs[activation].def_gain
self.square=square
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
memory_format = torch.channels_last if channels_last else torch.contiguous_format
self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
if use_noise:
if self.square:
self.register_buffer('noise_const', torch.randn([resolution, resolution]))
else:
self.register_buffer('noise_const', torch.randn([resolution, resolution // 2]))
self.noise_strength = torch.nn.Parameter(torch.zeros([]))
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1):
assert noise_mode in ['random', 'const', 'none']
in_resolution = self.resolution // self.up
if self.square:
misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution])
else:
misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution // 2])
styles = self.affine(w)
noise = None
if self.use_noise and noise_mode == 'random':
if self.square:
noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
else:
noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution // 2], device=x.device) * self.noise_strength
if self.use_noise and noise_mode == 'const':
noise = self.noise_const * self.noise_strength
flip_weight = (self.up == 1) # slightly faster
x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv)
act_gain = self.act_gain * gain
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp)
return x
#----------------------------------------------------------------------------
@persistence.persistent_class
class ToRGBLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):
super().__init__()
self.conv_clamp = conv_clamp
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
memory_format = torch.channels_last if channels_last else torch.contiguous_format
self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
def forward(self, x, w, fused_modconv=True):
styles = self.affine(w) * self.weight_gain
x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)
x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
return x
#----------------------------------------------------------------------------
@persistence.persistent_class
class SynthesisBlock(torch.nn.Module):
def __init__(self,
in_channels, # Number of input channels, 0 = first block.
out_channels, # Number of output channels.
w_dim, # Intermediate latent (W) dimensionality.
resolution, # Resolution of this block.
img_channels, # Number of output color channels.
is_last, # Is this the last block?
architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
use_fp16 = False, # Use FP16 for this block?
fp16_channels_last = False, # Use channels-last memory format with FP16?
square = False, # default is for rectangle images
**layer_kwargs, # Arguments for SynthesisLayer.
):
assert architecture in ['orig', 'skip', 'resnet']
super().__init__()
self.in_channels = in_channels
self.w_dim = w_dim
self.resolution = resolution
self.img_channels = img_channels
self.is_last = is_last
self.architecture = architecture
self.use_fp16 = use_fp16
self.channels_last = (use_fp16 and fp16_channels_last)
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
self.num_conv = 0
self.num_torgb = 0
self.square = square
if in_channels == 0:
if self.square:
self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
else: # rectangle
self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution // 2]))
if in_channels != 0:
self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2,
resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, square=square,**layer_kwargs)
self.num_conv += 1
self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution,
conv_clamp=conv_clamp, channels_last=self.channels_last, square=square, **layer_kwargs)
self.num_conv += 1
if is_last or architecture == 'skip':
self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim,
conv_clamp=conv_clamp, channels_last=self.channels_last)
self.num_torgb += 1
if in_channels != 0 and architecture == 'resnet':
self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
resample_filter=resample_filter, channels_last=self.channels_last)
def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs):
misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
w_iter = iter(ws.unbind(dim=1))
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
if fused_modconv is None:
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)
# Input.
if self.in_channels == 0:
x = self.const.to(dtype=dtype, memory_format=memory_format)
x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
else:
if self.square:
misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
else: # rectangle
misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 4])
x = x.to(dtype=dtype, memory_format=memory_format)
# Main layers.
if self.in_channels == 0:
x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
elif self.architecture == 'resnet':
y = self.skip(x, gain=np.sqrt(0.5))
x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
x = y.add_(x)
else:
x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
# ToRGB.
if img is not None:
if self.square:
misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
else:
misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 4])
img = upfirdn2d.upsample2d(img, self.resample_filter)
if self.is_last or self.architecture == 'skip':
y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
img = img.add_(y) if img is not None else y
assert x.dtype == dtype
assert img is None or img.dtype == torch.float32
return x, img
#----------------------------------------------------------------------------
@persistence.persistent_class
class SynthesisNetwork(torch.nn.Module):
def __init__(self,
w_dim, # Intermediate latent (W) dimensionality.
img_resolution, # Output image resolution.
img_channels, # Number of color channels.
square,
channel_base = 32768, # Overall multiplier for the number of channels.
channel_max = 512, # Maximum number of channels in any layer.
num_fp16_res = 0, # Use FP16 for the N highest resolutions.
**block_kwargs, # Arguments for SynthesisBlock.
):
assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
super().__init__()
self.w_dim = w_dim
self.img_resolution = img_resolution
self.img_resolution_log2 = int(np.log2(img_resolution))
self.img_channels = img_channels
self.square=square
self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)]
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
self.num_ws = 0
for res in self.block_resolutions:
in_channels = channels_dict[res // 2] if res > 4 else 0
out_channels = channels_dict[res]
use_fp16 = (res >= fp16_resolution)
is_last = (res == self.img_resolution)
block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
img_channels=img_channels, is_last=is_last, use_fp16=use_fp16,square=square, **block_kwargs)
self.num_ws += block.num_conv
if is_last:
self.num_ws += block.num_torgb
setattr(self, f'b{res}', block)
def forward(self, ws, return_feature=False, **block_kwargs):
block_ws = []
features = []
with torch.autograd.profiler.record_function('split_ws'):
misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
ws = ws.to(torch.float32)
w_idx = 0
for res in self.block_resolutions:
block = getattr(self, f'b{res}')
block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
w_idx += block.num_conv
x = img = None
for res, cur_ws in zip(self.block_resolutions, block_ws):
block = getattr(self, f'b{res}')
x, img = block(x, img, cur_ws, **block_kwargs)
features.append(x)
if return_feature:
return img, features
else:
return img
#----------------------------------------------------------------------------
@persistence.persistent_class
class Generator(torch.nn.Module):
def __init__(self,
z_dim, # Input latent (Z) dimensionality.
c_dim, # Conditioning label (C) dimensionality.
w_dim, # Intermediate latent (W) dimensionality.
img_resolution, # Output resolution.
square,
img_channels, # Number of output color channels.
mapping_kwargs = {}, # Arguments for MappingNetwork.
synthesis_kwargs = {}, # Arguments for SynthesisNetwork.
padding=False
):
super().__init__()
self.z_dim = z_dim
self.c_dim = c_dim
self.w_dim = w_dim
self.square = square
self.img_resolution = img_resolution
self.img_channels = img_channels
self.padding = padding
self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels,square=square,**synthesis_kwargs)
self.num_ws = self.synthesis.num_ws
self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, input_is_w=False, return_feature=False, **synthesis_kwargs):
if input_is_w:
ws = z
if ws.dim() == 2:
ws = ws.unsqueeze(1).repeat([1, self.mapping.num_ws, 1])
else:
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
img = self.synthesis(ws, return_feature=return_feature, **synthesis_kwargs)
if return_feature:
img, feature = img
if self.padding:
pad = (img.size(2) - img.size(3)) // 2
img = torch.nn.functional.pad(img, (pad, pad), "constant", 1)
if return_feature:
for i, feat in enumerate(feature):
pad = (feat.size(2) - feat.size(3)) // 2
feature[i] = torch.nn.functional.pad(feat, (pad, pad), "constant", 0)
if return_feature:
return img, feature
else:
return img
#----------------------------------------------------------------------------
@persistence.persistent_class
class DiscriminatorBlock(torch.nn.Module):
def __init__(self,
in_channels, # Number of input channels, 0 = first block.
tmp_channels, # Number of intermediate channels.
out_channels, # Number of output channels.
resolution, # Resolution of this block.
img_channels, # Number of input color channels.
first_layer_idx, # Index of the first layer.
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
use_fp16 = False, # Use FP16 for this block?
fp16_channels_last = False, # Use channels-last memory format with FP16?
freeze_layers = 0, # Freeze-D: Number of layers to freeze.
square = False,
):
assert in_channels in [0, tmp_channels]
assert architecture in ['orig', 'skip', 'resnet']
super().__init__()
self.in_channels = in_channels
self.resolution = resolution
self.img_channels = img_channels
self.first_layer_idx = first_layer_idx
self.architecture = architecture
self.use_fp16 = use_fp16
self.channels_last = (use_fp16 and fp16_channels_last)
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
self.square = square
self.num_layers = 0
def trainable_gen():
while True:
layer_idx = self.first_layer_idx + self.num_layers
trainable = (layer_idx >= freeze_layers)
self.num_layers += 1
yield trainable
trainable_iter = trainable_gen()
if in_channels == 0 or architecture == 'skip':
self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation,
trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation,
trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2,
trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last)
if architecture == 'resnet':
self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2,
trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last)
def forward(self, x, img, force_fp32=False):
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
# Input.
if x is not None:
if self.square:
misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution])
else:
misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution // 2])
x = x.to(dtype=dtype, memory_format=memory_format)
# FromRGB.
if self.in_channels == 0 or self.architecture == 'skip':
if self.square:
misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
else:
misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution // 2])
img = img.to(dtype=dtype, memory_format=memory_format)
y = self.fromrgb(img)
x = x + y if x is not None else y
img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None
# Main layers.
if self.architecture == 'resnet':
y = self.skip(x, gain=np.sqrt(0.5))
x = self.conv0(x)
x = self.conv1(x, gain=np.sqrt(0.5))
x = y.add_(x)
else:
x = self.conv0(x)
x = self.conv1(x)
assert x.dtype == dtype
return x, img
#----------------------------------------------------------------------------
@persistence.persistent_class
class MinibatchStdLayer(torch.nn.Module):
def __init__(self, group_size, num_channels=1):
super().__init__()
self.group_size = group_size
self.num_channels = num_channels
def forward(self, x):
N, C, H, W = x.shape
with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants
G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N
F = self.num_channels
c = C // F
y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels.
y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
return x
#----------------------------------------------------------------------------
@persistence.persistent_class
class DiscriminatorEpilogue(torch.nn.Module):
def __init__(self,
in_channels, # Number of input channels.
cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
resolution, # Resolution of this block.
img_channels, # Number of input color channels.
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable.
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
square = False,
):
assert architecture in ['orig', 'skip', 'resnet']
super().__init__()
self.in_channels = in_channels
self.cmap_dim = cmap_dim
self.resolution = resolution
self.img_channels = img_channels
self.architecture = architecture
self.square = square
if architecture == 'skip':
self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation)
self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp)
if self.square:
self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation)
else:
self.fc = FullyConnectedLayer(in_channels * (resolution ** 2 // 2), in_channels, activation=activation)
self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim)
def forward(self, x, img, cmap, force_fp32=False):
if self.square:
misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution])
else:
misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution // 2]) # [NCHW]
_ = force_fp32 # unused
dtype = torch.float32
memory_format = torch.contiguous_format
# FromRGB.
x = x.to(dtype=dtype, memory_format=memory_format)
if self.architecture == 'skip':
if self.square:
misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
else:
misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution // 2])
img = img.to(dtype=dtype, memory_format=memory_format)
x = x + self.fromrgb(img)
# Main layers.
if self.mbstd is not None:
x = self.mbstd(x)
x = self.conv(x)
x = self.fc(x.flatten(1))
x = self.out(x)
# Conditioning.
if self.cmap_dim > 0:
misc.assert_shape(cmap, [None, self.cmap_dim])
x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
assert x.dtype == dtype
return x
#----------------------------------------------------------------------------
@persistence.persistent_class
class Discriminator(torch.nn.Module):
def __init__(self,
c_dim, # Conditioning label (C) dimensionality.
img_resolution, # Input resolution.
img_channels, # Number of input color channels.
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
channel_base = 32768, # Overall multiplier for the number of channels.
channel_max = 512, # Maximum number of channels in any layer.
num_fp16_res = 0, # Use FP16 for the N highest resolutions.
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
square = False, # default for rectangle images
block_kwargs = {}, # Arguments for DiscriminatorBlock.
mapping_kwargs = {}, # Arguments for MappingNetwork.
epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
):
super().__init__()
self.c_dim = c_dim
self.img_resolution = img_resolution
self.img_resolution_log2 = int(np.log2(img_resolution))
self.img_channels = img_channels
self.square = square
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
if cmap_dim is None:
cmap_dim = channels_dict[4]
if c_dim == 0:
cmap_dim = 0
common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
cur_layer_idx = 0
for res in self.block_resolutions:
in_channels = channels_dict[res] if res < img_resolution else 0
tmp_channels = channels_dict[res]
out_channels = channels_dict[res // 2]
use_fp16 = (res >= fp16_resolution)
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, square=square, **block_kwargs, **common_kwargs)
setattr(self, f'b{res}', block)
cur_layer_idx += block.num_layers
if c_dim > 0:
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, square=square, **epilogue_kwargs, **common_kwargs)
def forward(self, img, c, **block_kwargs):
x = None
for res in self.block_resolutions:
block = getattr(self, f'b{res}')
x, img = block(x, img, **block_kwargs)
cmap = None
if self.c_dim > 0:
cmap = self.mapping(None, c)
x = self.b4(x, img, cmap)
return x
#----------------------------------------------------------------------------
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Train a GAN using the techniques described in the paper
"Alias-Free Generative Adversarial Networks"."""
import os
import click
import re
import json
import tempfile
import torch
import dnnlib
from training import training_loop
from metrics import metric_main
from torch_utils import training_stats
from torch_utils import custom_ops
import ast
#----------------------------------------------------------------------------
def subprocess_fn(rank, c, temp_dir):
dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True)
# Init torch.distributed.
if c.num_gpus > 1:
init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
if os.name == 'nt':
init_method = 'file:///' + init_file.replace('\\', '/')
torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=c.num_gpus)
else:
init_method = f'file://{init_file}'
torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=c.num_gpus)
# Init torch_utils.
sync_device = torch.device('cuda', rank) if c.num_gpus > 1 else None
training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
if rank != 0:
custom_ops.verbosity = 'none'
# Execute training loop.
training_loop.training_loop(rank=rank, **c)
#----------------------------------------------------------------------------
def launch_training(c, desc, outdir, dry_run):
dnnlib.util.Logger(should_flush=True)
# Pick output directory.
prev_run_dirs = []
if os.path.isdir(outdir):
prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))]
prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
cur_run_id = max(prev_run_ids, default=-1) + 1
c.run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{desc}')
assert not os.path.exists(c.run_dir)
# Print options.
print()
print('Training options:')
print(json.dumps(c, indent=2))
print()
print(f'Output directory: {c.run_dir}')
print(f'Number of GPUs: {c.num_gpus}')
print(f'Batch size: {c.batch_size} images')
print(f'Training duration: {c.total_kimg} kimg')
print(f'Dataset path: {c.training_set_kwargs.path}')
print(f'Dataset size: {c.training_set_kwargs.max_size} images')
print(f'Dataset resolution: {c.training_set_kwargs.resolution}')
print(f'Dataset labels: {c.training_set_kwargs.use_labels}')
print(f'Dataset x-flips: {c.training_set_kwargs.xflip}')
print()
# Dry run?
if dry_run:
print('Dry run; exiting.')
return
# Create output directory.
print('Creating output directory...')
os.makedirs(c.run_dir)
with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f:
json.dump(c, f, indent=2)
# Launch processes.
print('Launching processes...')
torch.multiprocessing.set_start_method('spawn')
with tempfile.TemporaryDirectory() as temp_dir:
if c.num_gpus == 1:
subprocess_fn(rank=0, c=c, temp_dir=temp_dir)
else:
torch.multiprocessing.spawn(fn=subprocess_fn, args=(c, temp_dir), nprocs=c.num_gpus)
#----------------------------------------------------------------------------
def init_dataset_kwargs(data, square=False):
# dataset
try:
dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False, square=square)
dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # Subclass of training.dataset.Dataset.
dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution.
dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels.
dataset_kwargs.max_size = len(dataset_obj) # Be explicit about dataset size.
return dataset_kwargs, dataset_obj.name
except IOError as err:
raise click.ClickException(f'--data: {err}')
print("out of dataset")
#----------------------------------------------------------------------------
def parse_comma_separated_list(s):
if isinstance(s, list):
return s
if s is None or s.lower() == 'none' or s == '':
return []
return s.split(',')
#----------------------------------------------------------------------------
@click.command()
# Required.
@click.option('--outdir', help='Where to save the results', metavar='DIR', required=True)
@click.option('--cfg', help='Base configuration', type=click.Choice(['stylegan3-t', 'stylegan3-r', 'stylegan2']), required=True)
@click.option('--data', help='Training data', metavar='PATH', required=True)
@click.option('--gpus', help='Number of GPUs to use', metavar='INT', type=click.IntRange(min=1), required=True)
@click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), required=True)
@click.option('--gamma', help='R1 regularization weight', metavar='FLOAT', type=click.FloatRange(min=0), required=True)
@click.option('--square', help='True for square, False for rectangle', type=bool, metavar='BOOL', default=False)
# Optional features.
@click.option('--cond', help='Train conditional model', metavar='BOOL', type=bool, default=False, show_default=True)
@click.option('--mirror', help='Enable dataset x-flips', metavar='BOOL', type=bool, default=False, show_default=True)
@click.option('--aug', help='Augmentation mode', type=click.Choice(['noaug', 'ada', 'fixed']), default='ada', show_default=True)
@click.option('--resume', help='Resume from given network pickle', metavar='[PATH|URL]', type=str)
@click.option('--freezed', help='Freeze first layers of D', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True)
# Misc hyperparameters.
@click.option('--p', help='Probability for --aug=fixed', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.2, show_default=True)
@click.option('--target', help='Target value for --aug=ada', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.6, show_default=True)
@click.option('--batch-gpu', help='Limit batch size per GPU', metavar='INT', type=click.IntRange(min=1))
@click.option('--cbase', help='Capacity multiplier', metavar='INT', type=click.IntRange(min=1), default=32768, show_default=True)
@click.option('--cmax', help='Max. feature maps', metavar='INT', type=click.IntRange(min=1), default=512, show_default=True)
@click.option('--glr', help='G learning rate [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0))
@click.option('--dlr', help='D learning rate', metavar='FLOAT', type=click.FloatRange(min=0), default=0.002, show_default=True)
@click.option('--map-depth', help='Mapping network depth [default: varies]', metavar='INT', type=click.IntRange(min=1))
@click.option('--mbstd-group', help='Minibatch std group size', metavar='INT', type=click.IntRange(min=1), default=4, show_default=True)
# Misc settings.
@click.option('--desc', help='String to include in result dir name', metavar='STR', type=str)
@click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k_full', show_default=True)
@click.option('--kimg', help='Total training duration', metavar='KIMG', type=click.IntRange(min=1), default=25000, show_default=True)
@click.option('--tick', help='How often to print progress', metavar='KIMG', type=click.IntRange(min=1), default=4, show_default=True)
@click.option('--snap', help='How often to save snapshots', metavar='TICKS', type=click.IntRange(min=1), default=50, show_default=True)
@click.option('--seed', help='Random seed', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True)
@click.option('--fp32', help='Disable mixed-precision', metavar='BOOL', type=bool, default=False, show_default=True)
@click.option('--nobench', help='Disable cuDNN benchmarking', metavar='BOOL', type=bool, default=False, show_default=True)
@click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=3, show_default=True)
@click.option('-n','--dry-run', help='Print training options and exit', is_flag=True)
def main(**kwargs):
"""Train a GAN using the techniques described in the paper
"Alias-Free Generative Adversarial Networks".
Examples:
\b
# Train StyleGAN3-T for AFHQv2 using 8 GPUs.
python train.py --outdir=~/training-runs --cfg=stylegan3-t --data=~/datasets/afhqv2-512x512.zip \\
--gpus=8 --batch=32 --gamma=8.2 --mirror=1
\b
# Fine-tune StyleGAN3-R for MetFaces-U using 1 GPU, starting from the pre-trained FFHQ-U pickle.
python train.py --outdir=~/training-runs --cfg=stylegan3-r --data=~/datasets/metfacesu-1024x1024.zip \\
--gpus=8 --batch=32 --gamma=6.6 --mirror=1 --kimg=5000 --snap=5 \\
--resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl
\b
# Train StyleGAN2 for FFHQ at 1024x1024 resolution using 8 GPUs.
python train.py --outdir=~/training-runs --cfg=stylegan2 --data=~/datasets/ffhq-1024x1024.zip \\
--gpus=8 --batch=32 --gamma=10 --mirror=1 --aug=noaug
"""
# Initialize config.
opts = dnnlib.EasyDict(kwargs) # Command line arguments.
c = dnnlib.EasyDict() # Main config dict.
print('---- square: ',opts.square)
c.G_kwargs = dnnlib.EasyDict(class_name=None, z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict(),square=opts.square)
c.D_kwargs = dnnlib.EasyDict(class_name='training.networks_stylegan2.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict(),square=opts.square)
c.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8)
c.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8)
c.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.StyleGAN2Loss')
c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, prefetch_factor=2)
# Training set.
c.training_set_kwargs, dataset_name = init_dataset_kwargs(data=opts.data, square=opts.square)
if opts.cond and not c.training_set_kwargs.use_labels:
raise click.ClickException('--cond=True requires labels specified in dataset.json')
c.training_set_kwargs.use_labels = opts.cond
c.training_set_kwargs.xflip = opts.mirror
# Hyperparameters & settings.
c.num_gpus = opts.gpus
c.batch_size = opts.batch
c.batch_gpu = opts.batch_gpu or opts.batch // opts.gpus
c.G_kwargs.channel_base = c.D_kwargs.channel_base = opts.cbase
c.G_kwargs.channel_max = c.D_kwargs.channel_max = opts.cmax
c.G_kwargs.mapping_kwargs.num_layers = (8 if opts.cfg == 'stylegan2' else 2) if opts.map_depth is None else opts.map_depth
c.D_kwargs.block_kwargs.freeze_layers = opts.freezed
c.D_kwargs.epilogue_kwargs.mbstd_group_size = opts.mbstd_group
c.loss_kwargs.r1_gamma = opts.gamma
c.G_opt_kwargs.lr = (0.002 if opts.cfg == 'stylegan2' else 0.0025) if opts.glr is None else opts.glr
c.D_opt_kwargs.lr = opts.dlr
c.metrics = opts.metrics
c.total_kimg = opts.kimg
c.kimg_per_tick = opts.tick
c.image_snapshot_ticks = c.network_snapshot_ticks = opts.snap
c.random_seed = c.training_set_kwargs.random_seed = opts.seed
c.data_loader_kwargs.num_workers = opts.workers
# Sanity checks.
if c.batch_size % c.num_gpus != 0:
raise click.ClickException('--batch must be a multiple of --gpus')
if c.batch_size % (c.num_gpus * c.batch_gpu) != 0:
raise click.ClickException('--batch must be a multiple of --gpus times --batch-gpu')
if c.batch_gpu < c.D_kwargs.epilogue_kwargs.mbstd_group_size:
raise click.ClickException('--batch-gpu cannot be smaller than --mbstd')
if any(not metric_main.is_valid_metric(metric) for metric in c.metrics):
raise click.ClickException('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
# Base configuration.
c.ema_kimg = c.batch_size * 10 / 32
if opts.cfg == 'stylegan2':
c.G_kwargs.class_name = 'training.networks_stylegan2.Generator'
c.loss_kwargs.style_mixing_prob = 0.9 # Enable style mixing regularization.
c.loss_kwargs.pl_weight = 2 # Enable path length regularization.
c.G_reg_interval = 4 # Enable lazy regularization for G.
c.G_kwargs.fused_modconv_default = 'inference_only' # Speed up training by using regular convolutions instead of grouped convolutions.
c.loss_kwargs.pl_no_weight_grad = True # Speed up path length regularization by skipping gradient computation wrt. conv2d weights.
else:
c.G_kwargs.class_name = 'training.networks_stylegan3.Generator'
c.G_kwargs.magnitude_ema_beta = 0.5 ** (c.batch_size / (20 * 1e3))
if opts.cfg == 'stylegan3-r':
c.G_kwargs.conv_kernel = 1 # Use 1x1 convolutions.
c.G_kwargs.channel_base *= 2 # Double the number of feature maps.
c.G_kwargs.channel_max *= 2
c.G_kwargs.use_radial_filters = True # Use radially symmetric downsampling filters.
c.loss_kwargs.blur_init_sigma = 10 # Blur the images seen by the discriminator.
c.loss_kwargs.blur_fade_kimg = c.batch_size * 200 / 32 # Fade out the blur during the first N kimg.
# Augmentation.
if opts.aug != 'noaug':
c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1)
if opts.aug == 'ada':
c.ada_target = opts.target
if opts.aug == 'fixed':
c.augment_p = opts.p
# Resume.
if opts.resume is not None:
c.resume_pkl = opts.resume
c.ada_kimg = 100 # Make ADA react faster at the beginning.
c.ema_rampup = None # Disable EMA rampup.
c.loss_kwargs.blur_init_sigma = 0 # Disable blur rampup.
# Performance-related toggles.
if opts.fp32:
c.G_kwargs.num_fp16_res = c.D_kwargs.num_fp16_res = 0
c.G_kwargs.conv_clamp = c.D_kwargs.conv_clamp = None
if opts.nobench:
c.cudnn_benchmark = False
# Description string.
desc = f'{opts.cfg:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}-gamma{c.loss_kwargs.r1_gamma:g}'
if opts.desc is not None:
desc += f'-{opts.desc}'
# Launch.
launch_training(c=c, desc=desc, outdir=opts.outdir, dry_run=opts.dry_run)
#----------------------------------------------------------------------------
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter
#----------------------------------------------------------------------------
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment