Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmgeneration
Commits
b7536f78
Commit
b7536f78
authored
Jun 16, 2025
by
limm
Browse files
add a to another part of mmgeneration code
parent
57e0e891
Pipeline
#2777
canceled with stages
Changes
185
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
6398 additions
and
0 deletions
+6398
-0
mmgen/models/architectures/pix2pix/modules.py
mmgen/models/architectures/pix2pix/modules.py
+172
-0
mmgen/models/architectures/positional_encoding.py
mmgen/models/architectures/positional_encoding.py
+211
-0
mmgen/models/architectures/singan/__init__.py
mmgen/models/architectures/singan/__init__.py
+9
-0
mmgen/models/architectures/singan/generator_discriminator.py
mmgen/models/architectures/singan/generator_discriminator.py
+262
-0
mmgen/models/architectures/singan/modules.py
mmgen/models/architectures/singan/modules.py
+230
-0
mmgen/models/architectures/singan/positional_encoding.py
mmgen/models/architectures/singan/positional_encoding.py
+237
-0
mmgen/models/architectures/sngan_proj/__init__.py
mmgen/models/architectures/sngan_proj/__init__.py
+8
-0
mmgen/models/architectures/sngan_proj/generator_discriminator.py
...odels/architectures/sngan_proj/generator_discriminator.py
+756
-0
mmgen/models/architectures/sngan_proj/modules.py
mmgen/models/architectures/sngan_proj/modules.py
+610
-0
mmgen/models/architectures/stylegan/__init__.py
mmgen/models/architectures/stylegan/__init__.py
+13
-0
mmgen/models/architectures/stylegan/ada/augment.py
mmgen/models/architectures/stylegan/ada/augment.py
+784
-0
mmgen/models/architectures/stylegan/ada/grid_sample_gradfix.py
.../models/architectures/stylegan/ada/grid_sample_gradfix.py
+108
-0
mmgen/models/architectures/stylegan/ada/misc.py
mmgen/models/architectures/stylegan/ada/misc.py
+31
-0
mmgen/models/architectures/stylegan/ada/upfirdn2d.py
mmgen/models/architectures/stylegan/ada/upfirdn2d.py
+189
-0
mmgen/models/architectures/stylegan/generator_discriminator_v1.py
...dels/architectures/stylegan/generator_discriminator_v1.py
+523
-0
mmgen/models/architectures/stylegan/generator_discriminator_v2.py
...dels/architectures/stylegan/generator_discriminator_v2.py
+704
-0
mmgen/models/architectures/stylegan/generator_discriminator_v3.py
...dels/architectures/stylegan/generator_discriminator_v3.py
+197
-0
mmgen/models/architectures/stylegan/modules/__init__.py
mmgen/models/architectures/stylegan/modules/__init__.py
+12
-0
mmgen/models/architectures/stylegan/modules/styleganv1_modules.py
...dels/architectures/stylegan/modules/styleganv1_modules.py
+174
-0
mmgen/models/architectures/stylegan/modules/styleganv2_modules.py
...dels/architectures/stylegan/modules/styleganv2_modules.py
+1168
-0
No files found.
mmgen/models/architectures/pix2pix/modules.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
ConvModule
,
kaiming_init
,
normal_init
,
xavier_init
from
torch.nn
import
init
def
generation_init_weights
(
module
,
init_type
=
'normal'
,
init_gain
=
0.02
):
"""Default initialization of network weights for image generation.
By default, we use normal init, but xavier and kaiming might work
better for some applications.
Args:
module (nn.Module): Module to be initialized.
init_type (str): The name of an initialization method:
normal | xavier | kaiming | orthogonal.
init_gain (float): Scaling factor for normal, xavier and
orthogonal.
"""
def
init_func
(
m
):
"""Initialization function.
Args:
m (nn.Module): Module to be initialized.
"""
classname
=
m
.
__class__
.
__name__
if
hasattr
(
m
,
'weight'
)
and
(
classname
.
find
(
'Conv'
)
!=
-
1
or
classname
.
find
(
'Linear'
)
!=
-
1
):
if
init_type
==
'normal'
:
normal_init
(
m
,
0.0
,
init_gain
)
elif
init_type
==
'xavier'
:
xavier_init
(
m
,
gain
=
init_gain
,
distribution
=
'normal'
)
elif
init_type
==
'kaiming'
:
kaiming_init
(
m
,
a
=
0
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
,
distribution
=
'normal'
)
elif
init_type
==
'orthogonal'
:
init
.
orthogonal_
(
m
.
weight
,
gain
=
init_gain
)
init
.
constant_
(
m
.
bias
.
data
,
0.0
)
else
:
raise
NotImplementedError
(
f
"Initialization method '
{
init_type
}
' is not implemented"
)
elif
classname
.
find
(
'BatchNorm2d'
)
!=
-
1
:
# BatchNorm Layer's weight is not a matrix;
# only normal distribution applies.
normal_init
(
m
,
1.0
,
init_gain
)
module
.
apply
(
init_func
)
class
UnetSkipConnectionBlock
(
nn
.
Module
):
"""Construct a Unet submodule with skip connections, with the following.
structure: downsampling - `submodule` - upsampling.
Args:
outer_channels (int): Number of channels at the outer conv layer.
inner_channels (int): Number of channels at the inner conv layer.
in_channels (int): Number of channels in input images/features. If is
None, equals to `outer_channels`. Default: None.
submodule (UnetSkipConnectionBlock): Previously constructed submodule.
Default: None.
is_outermost (bool): Whether this module is the outermost module.
Default: False.
is_innermost (bool): Whether this module is the innermost module.
Default: False.
norm_cfg (dict): Config dict to build norm layer. Default:
`dict(type='BN')`.
use_dropout (bool): Whether to use dropout layers. Default: False.
"""
def
__init__
(
self
,
outer_channels
,
inner_channels
,
in_channels
=
None
,
submodule
=
None
,
is_outermost
=
False
,
is_innermost
=
False
,
norm_cfg
=
dict
(
type
=
'BN'
),
use_dropout
=
False
):
super
().
__init__
()
# cannot be both outermost and innermost
assert
not
(
is_outermost
and
is_innermost
),
(
"'is_outermost' and 'is_innermost' cannot be True"
'at the same time.'
)
self
.
is_outermost
=
is_outermost
assert
isinstance
(
norm_cfg
,
dict
),
(
"'norm_cfg' should be dict, but"
f
'got
{
type
(
norm_cfg
)
}
'
)
assert
'type'
in
norm_cfg
,
"'norm_cfg' must have key 'type'"
# We use norm layers in the unet skip connection block.
# Only for IN, use bias since it does not have affine parameters.
use_bias
=
norm_cfg
[
'type'
]
==
'IN'
kernel_size
=
4
stride
=
2
padding
=
1
if
in_channels
is
None
:
in_channels
=
outer_channels
down_conv_cfg
=
dict
(
type
=
'Conv2d'
)
down_norm_cfg
=
norm_cfg
down_act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
)
up_conv_cfg
=
dict
(
type
=
'deconv'
)
up_norm_cfg
=
norm_cfg
up_act_cfg
=
dict
(
type
=
'ReLU'
)
up_in_channels
=
inner_channels
*
2
up_bias
=
use_bias
middle
=
[
submodule
]
upper
=
[]
if
is_outermost
:
down_act_cfg
=
None
down_norm_cfg
=
None
up_bias
=
True
up_norm_cfg
=
None
upper
=
[
nn
.
Tanh
()]
elif
is_innermost
:
down_norm_cfg
=
None
up_in_channels
=
inner_channels
middle
=
[]
else
:
upper
=
[
nn
.
Dropout
(
0.5
)]
if
use_dropout
else
[]
down
=
[
ConvModule
(
in_channels
=
in_channels
,
out_channels
=
inner_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias
=
use_bias
,
conv_cfg
=
down_conv_cfg
,
norm_cfg
=
down_norm_cfg
,
act_cfg
=
down_act_cfg
,
order
=
(
'act'
,
'conv'
,
'norm'
))
]
up
=
[
ConvModule
(
in_channels
=
up_in_channels
,
out_channels
=
outer_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias
=
up_bias
,
conv_cfg
=
up_conv_cfg
,
norm_cfg
=
up_norm_cfg
,
act_cfg
=
up_act_cfg
,
order
=
(
'act'
,
'conv'
,
'norm'
))
]
model
=
down
+
middle
+
up
+
upper
self
.
model
=
nn
.
Sequential
(
*
model
)
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if
self
.
is_outermost
:
return
self
.
model
(
x
)
# add skip connections
return
torch
.
cat
([
x
,
self
.
model
(
x
)],
1
)
mmgen/models/architectures/positional_encoding.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmgen.models.builder
import
MODULES
@
MODULES
.
register_module
(
'SPE'
)
@
MODULES
.
register_module
(
'SPE2d'
)
class
SinusoidalPositionalEmbedding
(
nn
.
Module
):
"""Sinusoidal Positional Embedding 1D or 2D (SPE/SPE2d).
This module is a modified from:
https://github.com/pytorch/fairseq/blob/master/fairseq/modules/sinusoidal_positional_embedding.py # noqa
Based on the original SPE in single dimension, we implement a 2D sinusoidal
positional encodding (SPE2d), as introduced in Positional Encoding as
Spatial Inductive Bias in GANs, CVPR'2021.
Args:
embedding_dim (int): The number of dimensions for the positional
encoding.
padding_idx (int | list[int]): The index for the padding contents. The
padding positions will obtain an encoding vector filling in zeros.
init_size (int, optional): The initial size of the positional buffer.
Defaults to 1024.
div_half_dim (bool, optional): If true, the embedding will be divided
by :math:`d/2`. Otherwise, it will be divided by
:math:`(d/2 -1)`. Defaults to False.
center_shift (int | None, optional): Shift the center point to some
index. Defaults to None.
"""
def
__init__
(
self
,
embedding_dim
,
padding_idx
,
init_size
=
1024
,
div_half_dim
=
False
,
center_shift
=
None
):
super
().
__init__
()
self
.
embedding_dim
=
embedding_dim
self
.
padding_idx
=
padding_idx
self
.
div_half_dim
=
div_half_dim
self
.
center_shift
=
center_shift
self
.
weights
=
SinusoidalPositionalEmbedding
.
get_embedding
(
init_size
,
embedding_dim
,
padding_idx
,
self
.
div_half_dim
)
self
.
register_buffer
(
'_float_tensor'
,
torch
.
FloatTensor
(
1
))
self
.
max_positions
=
int
(
1e5
)
@
staticmethod
def
get_embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
None
,
div_half_dim
=
False
):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert
embedding_dim
%
2
==
0
,
(
'In this version, we request '
f
'embedding_dim divisible by 2 but got
{
embedding_dim
}
'
)
# there is a little difference from the original paper.
half_dim
=
embedding_dim
//
2
if
not
div_half_dim
:
emb
=
np
.
log
(
10000
)
/
(
half_dim
-
1
)
else
:
emb
=
np
.
log
(
1e4
)
/
half_dim
# compute exp(-log10000 / d * i)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float
)
*
-
emb
)
emb
=
torch
.
arange
(
num_embeddings
,
dtype
=
torch
.
float
).
unsqueeze
(
1
)
*
emb
.
unsqueeze
(
0
)
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
).
view
(
num_embeddings
,
-
1
)
if
padding_idx
is
not
None
:
emb
[
padding_idx
,
:]
=
0
return
emb
def
forward
(
self
,
input
,
**
kwargs
):
"""Input is expected to be of size [bsz x seqlen].
Returned tensor is expected to be of size [bsz x seq_len x emb_dim]
"""
assert
input
.
dim
()
==
2
or
input
.
dim
(
)
==
4
,
'Input dimension should be 2 (1D) or 4(2D)'
if
input
.
dim
()
==
4
:
return
self
.
make_grid2d_like
(
input
,
**
kwargs
)
b
,
seq_len
=
input
.
shape
max_pos
=
self
.
padding_idx
+
1
+
seq_len
if
self
.
weights
is
None
or
max_pos
>
self
.
weights
.
size
(
0
):
# recompute/expand embedding if needed
self
.
weights
=
SinusoidalPositionalEmbedding
.
get_embedding
(
max_pos
,
self
.
embedding_dim
,
self
.
padding_idx
)
self
.
weights
=
self
.
weights
.
to
(
self
.
_float_tensor
)
positions
=
self
.
make_positions
(
input
,
self
.
padding_idx
).
to
(
self
.
_float_tensor
.
device
)
return
self
.
weights
.
index_select
(
0
,
positions
.
view
(
-
1
)).
view
(
b
,
seq_len
,
self
.
embedding_dim
).
detach
()
def
make_positions
(
self
,
input
,
padding_idx
):
mask
=
input
.
ne
(
padding_idx
).
int
()
return
(
torch
.
cumsum
(
mask
,
dim
=
1
).
type_as
(
mask
)
*
mask
).
long
()
+
padding_idx
def
make_grid2d
(
self
,
height
,
width
,
num_batches
=
1
,
center_shift
=
None
):
h
,
w
=
height
,
width
# if `center_shift` is not given from the outside, use
# `self.center_shift`
if
center_shift
is
None
:
center_shift
=
self
.
center_shift
h_shift
=
0
w_shift
=
0
# center shift to the input grid
if
center_shift
is
not
None
:
# if h/w is even, the left center should be aligned with
# center shift
if
h
%
2
==
0
:
h_left_center
=
h
//
2
h_shift
=
center_shift
-
h_left_center
else
:
h_center
=
h
//
2
+
1
h_shift
=
center_shift
-
h_center
if
w
%
2
==
0
:
w_left_center
=
w
//
2
w_shift
=
center_shift
-
w_left_center
else
:
w_center
=
w
//
2
+
1
w_shift
=
center_shift
-
w_center
# Note that the index is started from 1 since zero will be padding idx.
# axis -- (b, h or w)
x_axis
=
torch
.
arange
(
1
,
w
+
1
).
unsqueeze
(
0
).
repeat
(
num_batches
,
1
)
+
w_shift
y_axis
=
torch
.
arange
(
1
,
h
+
1
).
unsqueeze
(
0
).
repeat
(
num_batches
,
1
)
+
h_shift
# emb -- (b, emb_dim, h or w)
x_emb
=
self
(
x_axis
).
transpose
(
1
,
2
)
y_emb
=
self
(
y_axis
).
transpose
(
1
,
2
)
# make grid for x/y axis
# Note that repeat will copy data. If use learned emb, expand may be
# better.
x_grid
=
x_emb
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
h
,
1
)
y_grid
=
y_emb
.
unsqueeze
(
3
).
repeat
(
1
,
1
,
1
,
w
)
# cat grid -- (b, 2 x emb_dim, h, w)
grid
=
torch
.
cat
([
x_grid
,
y_grid
],
dim
=
1
)
return
grid
.
detach
()
def
make_grid2d_like
(
self
,
x
,
center_shift
=
None
):
"""Input tensor with shape of (b, ..., h, w) Return tensor with shape
of (b, 2 x emb_dim, h, w)
Note that the positional embedding highly depends on the the function,
``make_positions``.
"""
h
,
w
=
x
.
shape
[
-
2
:]
grid
=
self
.
make_grid2d
(
h
,
w
,
x
.
size
(
0
),
center_shift
)
return
grid
.
to
(
x
)
@
MODULES
.
register_module
(
'CSG2d'
)
@
MODULES
.
register_module
(
'CSG'
)
@
MODULES
.
register_module
()
class
CatersianGrid
(
nn
.
Module
):
"""Catersian Grid for 2d tensor.
The Catersian Grid is a common-used positional encoding in deep learning.
In this implementation, we follow the convention of ``grid_sample`` in
PyTorch. In other words, ``[-1, -1]`` denotes the left-top corner while
``[1, 1]`` denotes the right-botton corner.
"""
def
forward
(
self
,
x
,
**
kwargs
):
assert
x
.
dim
()
==
4
return
self
.
make_grid2d_like
(
x
,
**
kwargs
)
def
make_grid2d
(
self
,
height
,
width
,
num_batches
=
1
,
requires_grad
=
False
):
h
,
w
=
height
,
width
grid_y
,
grid_x
=
torch
.
meshgrid
(
torch
.
arange
(
0
,
h
),
torch
.
arange
(
0
,
w
))
grid_x
=
2
*
grid_x
/
max
(
float
(
w
)
-
1.
,
1.
)
-
1.
grid_y
=
2
*
grid_y
/
max
(
float
(
h
)
-
1.
,
1.
)
-
1.
grid
=
torch
.
stack
((
grid_x
,
grid_y
),
0
)
grid
.
requires_grad
=
requires_grad
grid
=
torch
.
unsqueeze
(
grid
,
0
)
grid
=
grid
.
repeat
(
num_batches
,
1
,
1
,
1
)
return
grid
def
make_grid2d_like
(
self
,
x
,
requires_grad
=
False
):
h
,
w
=
x
.
shape
[
-
2
:]
grid
=
self
.
make_grid2d
(
h
,
w
,
x
.
size
(
0
),
requires_grad
=
requires_grad
)
return
grid
.
to
(
x
)
mmgen/models/architectures/singan/__init__.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
from
.generator_discriminator
import
(
SinGANMultiScaleDiscriminator
,
SinGANMultiScaleGenerator
)
from
.positional_encoding
import
SinGANMSGeneratorPE
__all__
=
[
'SinGANMultiScaleDiscriminator'
,
'SinGANMultiScaleGenerator'
,
'SinGANMSGeneratorPE'
]
mmgen/models/architectures/singan/generator_discriminator.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
from
functools
import
partial
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.runner
import
load_state_dict
from
mmcv.utils
import
print_log
from
mmgen.models.builder
import
MODULES
from
mmgen.utils
import
get_root_logger
from
.modules
import
DiscriminatorBlock
,
GeneratorBlock
@
MODULES
.
register_module
()
class
SinGANMultiScaleGenerator
(
nn
.
Module
):
"""Multi-Scale Generator used in SinGAN.
More details can be found in: Singan: Learning a Generative Model from a
Single Natural Image, ICCV'19.
Notes:
- In this version, we adopt the interpolation function from the official
PyTorch APIs, which is different from the original implementation by the
authors. However, in our experiments, this influence can be ignored.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
num_scales (int): The number of scales/stages in generator. Note
that this number is counted from zero, which is the same as the
original paper.
kernel_size (int, optional): Kernel size, same as :obj:`nn.Conv2d`.
Defaults to 3.
padding (int, optional): Padding for the convolutional layer, same as
:obj:`nn.Conv2d`. Defaults to 0.
num_layers (int, optional): The number of convolutional layers in each
generator block. Defaults to 5.
base_channels (int, optional): The basic channels for convolutional
layers in the generator block. Defaults to 32.
min_feat_channels (int, optional): Minimum channels for the feature
maps in the generator block. Defaults to 32.
out_act_cfg (dict | None, optional): Configs for output activation
layer. Defaults to dict(type='Tanh').
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
num_scales
,
kernel_size
=
3
,
padding
=
0
,
num_layers
=
5
,
base_channels
=
32
,
min_feat_channels
=
32
,
out_act_cfg
=
dict
(
type
=
'Tanh'
),
**
kwargs
):
super
().
__init__
()
self
.
pad_head
=
int
((
kernel_size
-
1
)
/
2
*
num_layers
)
self
.
blocks
=
nn
.
ModuleList
()
self
.
upsample
=
partial
(
F
.
interpolate
,
mode
=
'bicubic'
,
align_corners
=
True
)
for
scale
in
range
(
num_scales
+
1
):
base_ch
=
min
(
base_channels
*
pow
(
2
,
int
(
np
.
floor
(
scale
/
4
))),
128
)
min_feat_ch
=
min
(
min_feat_channels
*
pow
(
2
,
int
(
np
.
floor
(
scale
/
4
))),
128
)
self
.
blocks
.
append
(
GeneratorBlock
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
padding
=
padding
,
num_layers
=
num_layers
,
base_channels
=
base_ch
,
min_feat_channels
=
min_feat_ch
,
out_act_cfg
=
out_act_cfg
,
**
kwargs
))
self
.
noise_padding_layer
=
nn
.
ZeroPad2d
(
self
.
pad_head
)
self
.
img_padding_layer
=
nn
.
ZeroPad2d
(
self
.
pad_head
)
def
forward
(
self
,
input_sample
,
fixed_noises
,
noise_weights
,
rand_mode
,
curr_scale
,
num_batches
=
1
,
get_prev_res
=
False
,
return_noise
=
False
):
"""Forward function.
Args:
input_sample (Tensor | None): The input for generator. In the
original implementation, a tensor filled with zeros is adopted.
If None is given, we will construct it from the first fixed
noises.
fixed_noises (list[Tensor]): List of the fixed noises in SinGAN.
noise_weights (list[float]): List of the weights for random noises.
rand_mode (str): Choices from ['rand', 'recon']. In ``rand`` mode,
it will sample from random noises. Otherwise, the
reconstruction for the single image will be returned.
curr_scale (int): The scale for the current inference or training.
num_batches (int, optional): The number of batches. Defaults to 1.
get_prev_res (bool, optional): Whether to return results from
previous stages. Defaults to False.
return_noise (bool, optional): Whether to return noises tensor.
Defaults to False.
Returns:
Tensor | dict: Generated image tensor or dictionary containing
\
more data.
"""
if
get_prev_res
or
return_noise
:
prev_res_list
=
[]
noise_list
=
[]
if
input_sample
is
None
:
input_sample
=
torch
.
zeros
(
(
num_batches
,
3
,
fixed_noises
[
0
].
shape
[
-
2
],
fixed_noises
[
0
].
shape
[
-
1
])).
to
(
fixed_noises
[
0
])
g_res
=
input_sample
for
stage
in
range
(
curr_scale
+
1
):
if
rand_mode
==
'recon'
:
noise_
=
fixed_noises
[
stage
]
else
:
noise_
=
torch
.
randn
(
num_batches
,
*
fixed_noises
[
stage
].
shape
[
1
:]).
to
(
g_res
)
if
return_noise
:
noise_list
.
append
(
noise_
)
# add padding at head
pad_
=
(
self
.
pad_head
,
)
*
4
noise_
=
F
.
pad
(
noise_
,
pad_
)
g_res_pad
=
F
.
pad
(
g_res
,
pad_
)
noise
=
noise_
*
noise_weights
[
stage
]
+
g_res_pad
g_res
=
self
.
blocks
[
stage
](
noise
.
detach
(),
g_res
)
if
get_prev_res
and
stage
!=
curr_scale
:
prev_res_list
.
append
(
g_res
)
# upsample, here we use interpolation from PyTorch
if
stage
!=
curr_scale
:
h_next
,
w_next
=
fixed_noises
[
stage
+
1
].
shape
[
-
2
:]
g_res
=
self
.
upsample
(
g_res
,
(
h_next
,
w_next
))
if
get_prev_res
or
return_noise
:
output_dict
=
dict
(
fake_img
=
g_res
,
prev_res_list
=
prev_res_list
,
noise_batch
=
noise_list
)
return
output_dict
return
g_res
def
check_and_load_prev_weight
(
self
,
curr_scale
):
if
curr_scale
==
0
:
return
prev_ch
=
self
.
blocks
[
curr_scale
-
1
].
base_channels
curr_ch
=
self
.
blocks
[
curr_scale
].
base_channels
prev_in_ch
=
self
.
blocks
[
curr_scale
-
1
].
in_channels
curr_in_ch
=
self
.
blocks
[
curr_scale
].
in_channels
if
prev_ch
==
curr_ch
and
prev_in_ch
==
curr_in_ch
:
load_state_dict
(
self
.
blocks
[
curr_scale
],
self
.
blocks
[
curr_scale
-
1
].
state_dict
(),
logger
=
get_root_logger
())
print_log
(
'Successfully load pretrianed model from last scale.'
)
else
:
print_log
(
'Cannot load pretrained model from last scale since'
f
' prev_ch(
{
prev_ch
}
) != curr_ch(
{
curr_ch
}
)'
f
' or prev_in_ch(
{
prev_in_ch
}
) != curr_in_ch(
{
curr_in_ch
}
)'
)
@
MODULES
.
register_module
()
class
SinGANMultiScaleDiscriminator
(
nn
.
Module
):
"""Multi-Scale Discriminator used in SinGAN.
More details can be found in: Singan: Learning a Generative Model from a
Single Natural Image, ICCV'19.
Args:
in_channels (int): Input channels.
num_scales (int): The number of scales/stages in generator. Note
that this number is counted from zero, which is the same as the
original paper.
kernel_size (int, optional): Kernel size, same as :obj:`nn.Conv2d`.
Defaults to 3.
padding (int, optional): Padding for the convolutional layer, same as
:obj:`nn.Conv2d`. Defaults to 0.
num_layers (int, optional): The number of convolutional layers in each
generator block. Defaults to 5.
base_channels (int, optional): The basic channels for convolutional
layers in the generator block. Defaults to 32.
min_feat_channels (int, optional): Minimum channels for the feature
maps in the generator block. Defaults to 32.
"""
def
__init__
(
self
,
in_channels
,
num_scales
,
kernel_size
=
3
,
padding
=
0
,
num_layers
=
5
,
base_channels
=
32
,
min_feat_channels
=
32
,
**
kwargs
):
super
().
__init__
()
self
.
blocks
=
nn
.
ModuleList
()
for
scale
in
range
(
num_scales
+
1
):
base_ch
=
min
(
base_channels
*
pow
(
2
,
int
(
np
.
floor
(
scale
/
4
))),
128
)
min_feat_ch
=
min
(
min_feat_channels
*
pow
(
2
,
int
(
np
.
floor
(
scale
/
4
))),
128
)
self
.
blocks
.
append
(
DiscriminatorBlock
(
in_channels
=
in_channels
,
kernel_size
=
kernel_size
,
padding
=
padding
,
num_layers
=
num_layers
,
base_channels
=
base_ch
,
min_feat_channels
=
min_feat_ch
,
**
kwargs
))
def
forward
(
self
,
x
,
curr_scale
):
"""Forward function.
Args:
x (Tensor): Input feature map.
curr_scale (int): Current scale for discriminator. If in testing,
you need to set it to the last scale.
Returns:
Tensor: Discriminative results.
"""
out
=
self
.
blocks
[
curr_scale
](
x
)
return
out
def
check_and_load_prev_weight
(
self
,
curr_scale
):
if
curr_scale
==
0
:
return
prev_ch
=
self
.
blocks
[
curr_scale
-
1
].
base_channels
curr_ch
=
self
.
blocks
[
curr_scale
].
base_channels
if
prev_ch
==
curr_ch
:
self
.
blocks
[
curr_scale
].
load_state_dict
(
self
.
blocks
[
curr_scale
-
1
].
state_dict
())
print_log
(
'Successfully load pretrianed model from last scale.'
)
else
:
print_log
(
'Cannot load pretrained model from last scale since'
f
' prev_ch(
{
prev_ch
}
) != curr_ch(
{
curr_ch
}
)'
)
mmgen/models/architectures/singan/modules.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
import
torch.nn
as
nn
from
mmcv.cnn
import
ConvModule
,
constant_init
,
normal_init
from
mmcv.runner
import
load_checkpoint
from
mmcv.utils.parrots_wrapper
import
_BatchNorm
from
mmgen.utils
import
get_root_logger
class
GeneratorBlock
(
nn
.
Module
):
"""Generator block used in SinGAN.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
num_scales (int): The number of scales/stages in generator. Note
that this number is counted from zero, which is the same as the
original paper.
kernel_size (int, optional): Kernel size, same as :obj:`nn.Conv2d`.
Defaults to 3.
padding (int, optional): Padding for the convolutional layer, same as
:obj:`nn.Conv2d`. Defaults to 0.
num_layers (int, optional): The number of convolutional layers in each
generator block. Defaults to 5.
base_channels (int, optional): The basic channels for convolutional
layers in the generator block. Defaults to 32.
min_feat_channels (int, optional): Minimum channels for the feature
maps in the generator block. Defaults to 32.
out_act_cfg (dict | None, optional): Configs for output activation
layer. Defaults to dict(type='Tanh').
stride (int, optional): Same as :obj:`nn.Conv2d`. Defaults to 1.
allow_no_residual (bool, optional): Whether to allow no residual link
in this block. Defaults to False.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
padding
,
num_layers
,
base_channels
,
min_feat_channels
,
out_act_cfg
=
dict
(
type
=
'Tanh'
),
stride
=
1
,
allow_no_residual
=
False
,
**
kwargs
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
base_channels
=
base_channels
self
.
kernel_size
=
kernel_size
self
.
num_layers
=
num_layers
self
.
allow_no_residual
=
allow_no_residual
self
.
head
=
ConvModule
(
in_channels
=
in_channels
,
out_channels
=
base_channels
,
kernel_size
=
kernel_size
,
padding
=
padding
,
stride
=
1
,
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
),
**
kwargs
)
self
.
body
=
nn
.
Sequential
()
for
i
in
range
(
num_layers
-
2
):
feat_channels_
=
int
(
base_channels
/
pow
(
2
,
(
i
+
1
)))
block
=
ConvModule
(
max
(
2
*
feat_channels_
,
min_feat_channels
),
max
(
feat_channels_
,
min_feat_channels
),
kernel_size
=
kernel_size
,
padding
=
padding
,
stride
=
stride
,
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
),
**
kwargs
)
self
.
body
.
add_module
(
f
'block
{
i
+
1
}
'
,
block
)
self
.
tail
=
ConvModule
(
max
(
feat_channels_
,
min_feat_channels
),
out_channels
,
kernel_size
=
kernel_size
,
padding
=
padding
,
stride
=
1
,
norm_cfg
=
None
,
act_cfg
=
out_act_cfg
,
**
kwargs
)
self
.
init_weights
()
def
forward
(
self
,
x
,
prev
):
"""Forward function.
Args:
x (Tensor): Input feature map.
prev (Tensor): Previous feature map.
Returns:
Tensor: Output feature map with the shape of (N, C, H, W).
"""
x
=
self
.
head
(
x
)
x
=
self
.
body
(
x
)
x
=
self
.
tail
(
x
)
# if prev and x are not in the same shape at the channel dimension
if
self
.
allow_no_residual
and
x
.
shape
[
1
]
!=
prev
.
shape
[
1
]:
return
x
return
x
+
prev
def
init_weights
(
self
,
pretrained
=
None
):
if
isinstance
(
pretrained
,
str
):
logger
=
get_root_logger
()
load_checkpoint
(
self
,
pretrained
,
strict
=
False
,
logger
=
logger
)
elif
pretrained
is
None
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
normal_init
(
m
,
0
,
0.02
)
elif
isinstance
(
m
,
(
_BatchNorm
,
nn
.
InstanceNorm2d
)):
constant_init
(
m
,
1
)
else
:
raise
TypeError
(
'pretrained must be a str or None but'
f
' got
{
type
(
pretrained
)
}
instead.'
)
class
DiscriminatorBlock
(
nn
.
Module
):
"""Discriminator Block used in SinGAN.
Args:
in_channels (int): Input channels.
base_channels (int): Base channels for this block.
min_feat_channels (int): The minimum channels for feature map.
kernel_size (int): Size of convolutional kernel, same as
:obj:`nn.Conv2d`.
padding (int): Padding for convolutional layer, same as
:obj:`nn.Conv2d`.
num_layers (int): The number of convolutional layers in this block.
norm_cfg (dict | None, optional): Config for the normalization layer.
Defaults to dict(type='BN').
act_cfg (dict | None, optional): Config for the activation layer.
Defaults to dict(type='LeakyReLU', negative_slope=0.2).
stride (int, optional): The stride for the convolutional layer, same as
:obj:`nn.Conv2d`. Defaults to 1.
"""
def
__init__
(
self
,
in_channels
,
base_channels
,
min_feat_channels
,
kernel_size
,
padding
,
num_layers
,
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
),
stride
=
1
,
**
kwargs
):
super
().
__init__
()
self
.
base_channels
=
base_channels
self
.
stride
=
stride
self
.
head
=
ConvModule
(
in_channels
,
base_channels
,
kernel_size
=
kernel_size
,
padding
=
padding
,
stride
=
1
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
,
**
kwargs
)
self
.
body
=
nn
.
Sequential
()
for
i
in
range
(
num_layers
-
2
):
feat_channels_
=
int
(
base_channels
/
pow
(
2
,
(
i
+
1
)))
block
=
ConvModule
(
max
(
2
*
feat_channels_
,
min_feat_channels
),
max
(
feat_channels_
,
min_feat_channels
),
kernel_size
=
kernel_size
,
padding
=
padding
,
stride
=
stride
,
conv_cfg
=
None
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
,
**
kwargs
)
self
.
body
.
add_module
(
f
'block
{
i
+
1
}
'
,
block
)
self
.
tail
=
ConvModule
(
max
(
feat_channels_
,
min_feat_channels
),
1
,
kernel_size
=
kernel_size
,
padding
=
padding
,
stride
=
1
,
norm_cfg
=
None
,
act_cfg
=
None
,
**
kwargs
)
self
.
init_weights
()
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map.
"""
x
=
self
.
head
(
x
)
x
=
self
.
body
(
x
)
x
=
self
.
tail
(
x
)
return
x
# TODO: study the effects of init functions
def
init_weights
(
self
,
pretrained
=
None
):
if
isinstance
(
pretrained
,
str
):
logger
=
get_root_logger
()
load_checkpoint
(
self
,
pretrained
,
strict
=
False
,
logger
=
logger
)
elif
pretrained
is
None
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
normal_init
(
m
,
0
,
0.02
)
elif
isinstance
(
m
,
(
_BatchNorm
,
nn
.
InstanceNorm2d
)):
constant_init
(
m
,
1
)
else
:
raise
TypeError
(
'pretrained must be a str or None but'
f
' got
{
type
(
pretrained
)
}
instead.'
)
mmgen/models/architectures/singan/positional_encoding.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
"""Implementation for Positional Encoding as Spatial Inductive Bias in GANs.
In this module, we provide necessary components to conduct experiments
mentioned in the paper: Positional Encoding as Spatial Inductive Bias in GANs.
More details can be found in: https://arxiv.org/pdf/2012.05217.pdf
"""
from
functools
import
partial
import
mmcv
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmgen.models.builder
import
MODULES
,
build_module
from
.generator_discriminator
import
SinGANMultiScaleGenerator
from
.modules
import
GeneratorBlock
@
MODULES
.
register_module
()
class
SinGANMSGeneratorPE
(
SinGANMultiScaleGenerator
):
"""Multi-Scale Generator used in SinGAN with positional encoding.
More details can be found in: Positional Encoding as Spatial Inductvie Bias
in GANs, CVPR'2021.
Notes:
- In this version, we adopt the interpolation function from the official
PyTorch APIs, which is different from the original implementation by the
authors. However, in our experiments, this influence can be ignored.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
num_scales (int): The number of scales/stages in generator. Note
that this number is counted from zero, which is the same as the
original paper.
kernel_size (int, optional): Kernel size, same as :obj:`nn.Conv2d`.
Defaults to 3.
padding (int, optional): Padding for the convolutional layer, same as
:obj:`nn.Conv2d`. Defaults to 0.
num_layers (int, optional): The number of convolutional layers in each
generator block. Defaults to 5.
base_channels (int, optional): The basic channels for convolutional
layers in the generator block. Defaults to 32.
min_feat_channels (int, optional): Minimum channels for the feature
maps in the generator block. Defaults to 32.
out_act_cfg (dict | None, optional): Configs for output activation
layer. Defaults to dict(type='Tanh').
padding_mode (str, optional): The mode of convolutional padding, same
as :obj:`nn.Conv2d`. Defaults to 'zero'.
pad_at_head (bool, optional): Whether to add padding at head.
Defaults to True.
interp_pad (bool, optional): The padding value of interpolating feature
maps. Defaults to False.
noise_with_pad (bool, optional): Whether the input fixed noises are
with explicit padding. Defaults to False.
positional_encoding (dict | None, optional): Configs for the positional
encoding. Defaults to None.
first_stage_in_channels (int | None, optional): The input channel of
the first generator block. If None, the first stage will adopt the
same input channels as other stages. Defaults to None.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
num_scales
,
kernel_size
=
3
,
padding
=
0
,
num_layers
=
5
,
base_channels
=
32
,
min_feat_channels
=
32
,
out_act_cfg
=
dict
(
type
=
'Tanh'
),
padding_mode
=
'zero'
,
pad_at_head
=
True
,
interp_pad
=
False
,
noise_with_pad
=
False
,
positional_encoding
=
None
,
first_stage_in_channels
=
None
,
**
kwargs
):
super
(
SinGANMultiScaleGenerator
,
self
).
__init__
()
self
.
pad_at_head
=
pad_at_head
self
.
interp_pad
=
interp_pad
self
.
noise_with_pad
=
noise_with_pad
self
.
with_positional_encode
=
positional_encoding
is
not
None
if
self
.
with_positional_encode
:
self
.
head_position_encode
=
build_module
(
positional_encoding
)
self
.
pad_head
=
int
((
kernel_size
-
1
)
/
2
*
num_layers
)
self
.
blocks
=
nn
.
ModuleList
()
self
.
upsample
=
partial
(
F
.
interpolate
,
mode
=
'bicubic'
,
align_corners
=
True
)
for
scale
in
range
(
num_scales
+
1
):
base_ch
=
min
(
base_channels
*
pow
(
2
,
int
(
np
.
floor
(
scale
/
4
))),
128
)
min_feat_ch
=
min
(
min_feat_channels
*
pow
(
2
,
int
(
np
.
floor
(
scale
/
4
))),
128
)
if
scale
==
0
:
in_ch
=
(
first_stage_in_channels
if
first_stage_in_channels
else
in_channels
)
else
:
in_ch
=
in_channels
self
.
blocks
.
append
(
GeneratorBlock
(
in_channels
=
in_ch
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
padding
=
padding
,
num_layers
=
num_layers
,
base_channels
=
base_ch
,
min_feat_channels
=
min_feat_ch
,
out_act_cfg
=
out_act_cfg
,
padding_mode
=
padding_mode
,
**
kwargs
))
if
padding_mode
==
'zero'
:
self
.
noise_padding_layer
=
nn
.
ZeroPad2d
(
self
.
pad_head
)
self
.
img_padding_layer
=
nn
.
ZeroPad2d
(
self
.
pad_head
)
self
.
mask_padding_layer
=
nn
.
ReflectionPad2d
(
self
.
pad_head
)
elif
padding_mode
==
'reflect'
:
self
.
noise_padding_layer
=
nn
.
ReflectionPad2d
(
self
.
pad_head
)
self
.
img_padding_layer
=
nn
.
ReflectionPad2d
(
self
.
pad_head
)
self
.
mask_padding_layer
=
nn
.
ReflectionPad2d
(
self
.
pad_head
)
mmcv
.
print_log
(
'Using Reflection padding'
,
'mmgen'
)
else
:
raise
NotImplementedError
(
f
'Padding mode
{
padding_mode
}
is not supported'
)
def
forward
(
self
,
input_sample
,
fixed_noises
,
noise_weights
,
rand_mode
,
curr_scale
,
num_batches
=
1
,
get_prev_res
=
False
,
return_noise
=
False
):
"""Forward function.
Args:
input_sample (Tensor | None): The input for generator. In the
original implementation, a tensor filled with zeros is adopted.
If None is given, we will construct it from the first fixed
noises.
fixed_noises (list[Tensor]): List of the fixed noises in SinGAN.
noise_weights (list[float]): List of the weights for random noises.
rand_mode (str): Choices from ['rand', 'recon']. In ``rand`` mode,
it will sample from random noises. Otherwise, the
reconstruction for the single image will be returned.
curr_scale (int): The scale for the current inference or training.
num_batches (int, optional): The number of batches. Defaults to 1.
get_prev_res (bool, optional): Whether to return results from
previous stages. Defaults to False.
return_noise (bool, optional): Whether to return noises tensor.
Defaults to False.
Returns:
Tensor | dict: Generated image tensor or dictionary containing
\
more data.
"""
if
get_prev_res
or
return_noise
:
prev_res_list
=
[]
noise_list
=
[]
if
input_sample
is
None
:
input_sample
=
torch
.
zeros
(
(
num_batches
,
3
,
fixed_noises
[
0
].
shape
[
-
2
],
fixed_noises
[
0
].
shape
[
-
1
])).
to
(
fixed_noises
[
0
])
g_res
=
input_sample
for
stage
in
range
(
curr_scale
+
1
):
if
rand_mode
==
'recon'
:
noise_
=
fixed_noises
[
stage
]
else
:
noise_
=
torch
.
randn
(
num_batches
,
*
fixed_noises
[
stage
].
shape
[
1
:]).
to
(
g_res
)
if
return_noise
:
noise_list
.
append
(
noise_
)
if
self
.
with_positional_encode
and
stage
==
0
:
head_grid
=
self
.
head_position_encode
(
fixed_noises
[
0
])
noise_
=
noise_
+
head_grid
# add padding at head
if
self
.
pad_at_head
:
if
self
.
interp_pad
:
if
self
.
noise_with_pad
:
size
=
noise_
.
shape
[
-
2
:]
else
:
size
=
(
noise_
.
size
(
2
)
+
2
*
self
.
pad_head
,
noise_
.
size
(
3
)
+
2
*
self
.
pad_head
)
noise_
=
self
.
upsample
(
noise_
,
size
)
g_res_pad
=
self
.
upsample
(
g_res
,
size
)
else
:
if
not
self
.
noise_with_pad
:
noise_
=
self
.
noise_padding_layer
(
noise_
)
g_res_pad
=
self
.
img_padding_layer
(
g_res
)
else
:
g_res_pad
=
g_res
if
stage
==
0
and
self
.
with_positional_encode
:
noise
=
noise_
*
noise_weights
[
stage
]
else
:
noise
=
noise_
*
noise_weights
[
stage
]
+
g_res_pad
g_res
=
self
.
blocks
[
stage
](
noise
.
detach
(),
g_res
)
if
get_prev_res
and
stage
!=
curr_scale
:
prev_res_list
.
append
(
g_res
)
# upsample, here we use interpolation from PyTorch
if
stage
!=
curr_scale
:
h_next
,
w_next
=
fixed_noises
[
stage
+
1
].
shape
[
-
2
:]
if
self
.
noise_with_pad
:
# remove the additional padding if noise with pad
h_next
-=
2
*
self
.
pad_head
w_next
-=
2
*
self
.
pad_head
g_res
=
self
.
upsample
(
g_res
,
(
h_next
,
w_next
))
if
get_prev_res
or
return_noise
:
output_dict
=
dict
(
fake_img
=
g_res
,
prev_res_list
=
prev_res_list
,
noise_batch
=
noise_list
)
return
output_dict
return
g_res
mmgen/models/architectures/sngan_proj/__init__.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
from
.generator_discriminator
import
ProjDiscriminator
,
SNGANGenerator
from
.modules
import
SNGANDiscHeadResBlock
,
SNGANDiscResBlock
,
SNGANGenResBlock
__all__
=
[
'ProjDiscriminator'
,
'SNGANGenerator'
,
'SNGANGenResBlock'
,
'SNGANDiscResBlock'
,
'SNGANDiscHeadResBlock'
]
mmgen/models/architectures/sngan_proj/generator_discriminator.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
(
ConvModule
,
build_activation_layer
,
constant_init
,
xavier_init
)
from
mmcv.runner
import
load_checkpoint
from
mmcv.runner.checkpoint
import
_load_checkpoint_with_prefix
from
mmcv.utils
import
is_list_of
from
torch.nn.init
import
xavier_uniform_
from
torch.nn.utils
import
spectral_norm
from
mmgen.models.builder
import
MODULES
,
build_module
from
mmgen.utils
import
check_dist_init
from
mmgen.utils.logger
import
get_root_logger
from
..common
import
get_module_device
@
MODULES
.
register_module
(
'SAGANGenerator'
)
@
MODULES
.
register_module
()
class
SNGANGenerator
(
nn
.
Module
):
r
"""Generator for SNGAN / Proj-GAN. The implementation refers to
https://github.com/pfnet-research/sngan_projection/tree/master/gen_models
In our implementation, we have two notable design. Namely,
``channels_cfg`` and ``blocks_cfg``.
``channels_cfg``: In default config of SNGAN / Proj-GAN, the number of
ResBlocks and the channels of those blocks are corresponding to the
resolution of the output image. Therefore, we allow user to define
``channels_cfg`` to try their own models. We also provide a default
config to allow users to build the model only from the output
resolution.
``block_cfg``: In reference code, the generator consists of a group of
ResBlock. However, in our implementation, to make this model more
generalize, we support defining ``blocks_cfg`` by users and loading
the blocks by calling the build_module method.
Args:
output_scale (int): Output scale for the generated image.
num_classes (int, optional): The number classes you would like to
generate. This arguments would influence the structure of the
intermedia blocks and label sampling operation in ``forward``
(e.g. If num_classes=0, ConditionalNormalization layers would
degrade to unconditional ones.). This arguments would be passed
to intermedia blocks by overwrite their config. Defaults to 0.
base_channels (int, optional): The basic channel number of the
generator. The other layers contains channels based on this number.
Default to 64.
out_channels (int, optional): Channels of the output images.
Default to 3.
input_scale (int, optional): Input scale for the features.
Defaults to 4.
noise_size (int, optional): Size of the input noise vector.
Default to 128.
attention_cfg (dict, optional): Config for the self-attention block.
Default to ``dict(type='SelfAttentionBlock')``.
attention_after_nth_block (int | list[int], optional): Self attention
block would be added after which *ConvBlock*. If ``int`` is passed,
only one attention block would be added. If ``list`` is passed,
self-attention blocks would be added after multiple ConvBlocks.
To be noted that if the input is smaller than ``1``,
self-attention corresponding to this index would be ignored.
Default to 0.
channels_cfg (list | dict[list], optional): Config for input channels
of the intermedia blocks. If list is passed, each element of the
list means the input channels of current block is how many times
compared to the ``base_channels``. For block ``i``, the input and
output channels should be ``channels_cfg[i]`` and
``channels_cfg[i+1]`` If dict is provided, the key of the dict
should be the output scale and corresponding value should be a list
to define channels. Default: Please refer to
``_defualt_channels_cfg``.
blocks_cfg (dict, optional): Config for the intermedia blocks.
Defaults to ``dict(type='SNGANGenResBlock')``
act_cfg (dict, optional): Activation config for the final output
layer. Defaults to ``dict(type='ReLU')``.
use_cbn (bool, optional): Whether use conditional normalization. This
argument would pass to norm layers. Defaults to True.
auto_sync_bn (bool, optional): Whether convert Batch Norm to
Synchronized ones when Distributed training is on. Defaults to
True.
with_spectral_norm (bool, optional): Whether use spectral norm for
conv blocks or not. Default to False.
with_embedding_spectral_norm (bool, optional): Whether use spectral
norm for embedding layers in normalization blocks or not. If not
specified (set as ``None``), ``with_embedding_spectral_norm`` would
be set as the same value as ``with_spectral_norm``.
Defaults to None.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
norm_eps (float, optional): eps for Normalization layers (both
conditional and non-conditional ones). Default to `1e-4`.
sn_eps (float, optional): eps for spectral normalization operation.
Defaults to `1e-12`.
init_cfg (string, optional): Config for weight initialization.
Defaults to ``dict(type='BigGAN')``.
pretrained (str | dict, optional): Path for the pretrained model or
dict containing information for pretained models whose necessary
key is 'ckpt_path'. Besides, you can also provide 'prefix' to load
the generator part from the whole state dict. Defaults to None.
"""
# default channel factors
_default_channels_cfg
=
{
32
:
[
1
,
1
,
1
],
64
:
[
16
,
8
,
4
,
2
],
128
:
[
16
,
16
,
8
,
4
,
2
]
}
def
__init__
(
self
,
output_scale
,
num_classes
=
0
,
base_channels
=
64
,
out_channels
=
3
,
input_scale
=
4
,
noise_size
=
128
,
attention_cfg
=
dict
(
type
=
'SelfAttentionBlock'
),
attention_after_nth_block
=
0
,
channels_cfg
=
None
,
blocks_cfg
=
dict
(
type
=
'SNGANGenResBlock'
),
act_cfg
=
dict
(
type
=
'ReLU'
),
use_cbn
=
True
,
auto_sync_bn
=
True
,
with_spectral_norm
=
False
,
with_embedding_spectral_norm
=
None
,
sn_style
=
'torch'
,
norm_eps
=
1e-4
,
sn_eps
=
1e-12
,
init_cfg
=
dict
(
type
=
'BigGAN'
),
pretrained
=
None
):
super
().
__init__
()
self
.
input_scale
=
input_scale
self
.
output_scale
=
output_scale
self
.
noise_size
=
noise_size
self
.
num_classes
=
num_classes
self
.
init_type
=
init_cfg
.
get
(
'type'
,
None
)
self
.
blocks_cfg
=
deepcopy
(
blocks_cfg
)
self
.
blocks_cfg
.
setdefault
(
'num_classes'
,
num_classes
)
self
.
blocks_cfg
.
setdefault
(
'act_cfg'
,
act_cfg
)
self
.
blocks_cfg
.
setdefault
(
'use_cbn'
,
use_cbn
)
self
.
blocks_cfg
.
setdefault
(
'auto_sync_bn'
,
auto_sync_bn
)
self
.
blocks_cfg
.
setdefault
(
'with_spectral_norm'
,
with_spectral_norm
)
# set `norm_spectral_norm` as `with_spectral_norm` if not defined
with_embedding_spectral_norm
=
with_embedding_spectral_norm
\
if
with_embedding_spectral_norm
is
not
None
else
with_spectral_norm
self
.
blocks_cfg
.
setdefault
(
'with_embedding_spectral_norm'
,
with_embedding_spectral_norm
)
self
.
blocks_cfg
.
setdefault
(
'init_cfg'
,
init_cfg
)
self
.
blocks_cfg
.
setdefault
(
'sn_style'
,
sn_style
)
self
.
blocks_cfg
.
setdefault
(
'norm_eps'
,
norm_eps
)
self
.
blocks_cfg
.
setdefault
(
'sn_eps'
,
sn_eps
)
channels_cfg
=
deepcopy
(
self
.
_default_channels_cfg
)
\
if
channels_cfg
is
None
else
deepcopy
(
channels_cfg
)
if
isinstance
(
channels_cfg
,
dict
):
if
output_scale
not
in
channels_cfg
:
raise
KeyError
(
f
'`output_scale=
{
output_scale
}
is not found in '
'`channel_cfg`, only support configs for '
f
'
{
[
chn
for
chn
in
channels_cfg
.
keys
()]
}
'
)
self
.
channel_factor_list
=
channels_cfg
[
output_scale
]
elif
isinstance
(
channels_cfg
,
list
):
self
.
channel_factor_list
=
channels_cfg
else
:
raise
ValueError
(
'Only support list or dict for `channel_cfg`, '
f
'receive
{
type
(
channels_cfg
)
}
'
)
self
.
noise2feat
=
nn
.
Linear
(
noise_size
,
input_scale
**
2
*
base_channels
*
self
.
channel_factor_list
[
0
])
if
with_spectral_norm
:
self
.
noise2feat
=
spectral_norm
(
self
.
noise2feat
)
# check `attention_after_nth_block`
if
not
isinstance
(
attention_after_nth_block
,
list
):
attention_after_nth_block
=
[
attention_after_nth_block
]
if
not
is_list_of
(
attention_after_nth_block
,
int
):
raise
ValueError
(
'`attention_after_nth_block` only support int or '
'a list of int. Please check your input type.'
)
self
.
conv_blocks
=
nn
.
ModuleList
()
self
.
attention_block_idx
=
[]
for
idx
in
range
(
len
(
self
.
channel_factor_list
)):
factor_input
=
self
.
channel_factor_list
[
idx
]
factor_output
=
self
.
channel_factor_list
[
idx
+
1
]
\
if
idx
<
len
(
self
.
channel_factor_list
)
-
1
else
1
# get block-specific config
block_cfg_
=
deepcopy
(
self
.
blocks_cfg
)
block_cfg_
[
'in_channels'
]
=
factor_input
*
base_channels
block_cfg_
[
'out_channels'
]
=
factor_output
*
base_channels
self
.
conv_blocks
.
append
(
build_module
(
block_cfg_
))
# build self-attention block
# `idx` is start from 0, add 1 to get the index
if
idx
+
1
in
attention_after_nth_block
:
self
.
attention_block_idx
.
append
(
len
(
self
.
conv_blocks
))
attn_cfg_
=
deepcopy
(
attention_cfg
)
attn_cfg_
[
'in_channels'
]
=
factor_output
*
base_channels
attn_cfg_
[
'sn_style'
]
=
sn_style
self
.
conv_blocks
.
append
(
build_module
(
attn_cfg_
))
to_rgb_norm_cfg
=
dict
(
type
=
'BN'
,
eps
=
norm_eps
)
if
check_dist_init
()
and
auto_sync_bn
:
to_rgb_norm_cfg
[
'type'
]
=
'SyncBN'
self
.
to_rgb
=
ConvModule
(
factor_output
*
base_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
,
norm_cfg
=
to_rgb_norm_cfg
,
act_cfg
=
act_cfg
,
order
=
(
'norm'
,
'act'
,
'conv'
),
with_spectral_norm
=
with_spectral_norm
)
self
.
final_act
=
build_activation_layer
(
dict
(
type
=
'Tanh'
))
self
.
init_weights
(
pretrained
)
def
forward
(
self
,
noise
,
num_batches
=
0
,
label
=
None
,
return_noise
=
False
):
"""Forward function.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
label (torch.Tensor | callable | None): You can directly give a
batch of label through a ``torch.Tensor`` or offer a callable
function to sample a batch of label data. Otherwise, the
``None`` indicates to use the default label sampler.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
Returns:
torch.Tensor | dict: If not ``return_noise``, only the output
image will be returned. Otherwise, a dict contains
``fake_image``, ``noise_batch`` and ``label_batch``
would be returned.
"""
if
isinstance
(
noise
,
torch
.
Tensor
):
assert
noise
.
shape
[
1
]
==
self
.
noise_size
assert
noise
.
ndim
==
2
,
(
'The noise should be in shape of (n, c), '
f
'but got
{
noise
.
shape
}
'
)
noise_batch
=
noise
# receive a noise generator and sample noise.
elif
callable
(
noise
):
noise_generator
=
noise
assert
num_batches
>
0
noise_batch
=
noise_generator
((
num_batches
,
self
.
noise_size
))
# otherwise, we will adopt default noise sampler.
else
:
assert
num_batches
>
0
noise_batch
=
torch
.
randn
((
num_batches
,
self
.
noise_size
))
if
isinstance
(
label
,
torch
.
Tensor
):
assert
label
.
ndim
==
1
,
(
'The label shoube be in shape of (n, )'
f
'but got
{
label
.
shape
}
.'
)
label_batch
=
label
elif
callable
(
label
):
label_generator
=
label
assert
num_batches
>
0
label_batch
=
label_generator
(
num_batches
)
elif
self
.
num_classes
==
0
:
label_batch
=
None
else
:
assert
num_batches
>
0
label_batch
=
torch
.
randint
(
0
,
self
.
num_classes
,
(
num_batches
,
))
# dirty code for putting data on the right device
noise_batch
=
noise_batch
.
to
(
get_module_device
(
self
))
if
label_batch
is
not
None
:
label_batch
=
label_batch
.
to
(
get_module_device
(
self
))
x
=
self
.
noise2feat
(
noise_batch
)
x
=
x
.
reshape
(
x
.
size
(
0
),
-
1
,
self
.
input_scale
,
self
.
input_scale
)
for
idx
,
conv_block
in
enumerate
(
self
.
conv_blocks
):
if
idx
in
self
.
attention_block_idx
:
x
=
conv_block
(
x
)
else
:
x
=
conv_block
(
x
,
label_batch
)
out_feat
=
self
.
to_rgb
(
x
)
out_img
=
self
.
final_act
(
out_feat
)
if
return_noise
:
return
dict
(
fake_img
=
out_img
,
noise_batch
=
noise_batch
,
label
=
label_batch
)
return
out_img
def
init_weights
(
self
,
pretrained
=
None
,
strict
=
True
):
"""Init weights for SNGAN-Proj and SAGAN. If ``pretrained=None``,
weight initialization would follow the ``INIT_TYPE`` in
``init_cfg=dict(type=INIT_TYPE)``.
For SNGAN-Proj,
(``INIT_TYPE.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']``),
we follow the initialization method in the official Chainer's
implementation (https://github.com/pfnet-research/sngan_projection).
For SAGAN (``INIT_TYPE.upper() == 'SAGAN'``), we follow the
initialization method in official tensorflow's implementation
(https://github.com/brain-research/self-attention-gan).
Besides the reimplementation of the official code's initialization, we
provide BigGAN's and Pytorch-StudioGAN's style initialization
(``INIT_TYPE.upper() == BIGGAN`` and ``INIT_TYPE.upper() == STUDIO``).
Please refer to https://github.com/ajbrock/BigGAN-PyTorch and
https://github.com/POSTECH-CVLab/PyTorch-StudioGAN.
Args:
pretrained (str | dict, optional): Path for the pretrained model or
dict containing information for pretained models whose
necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
"""
if
isinstance
(
pretrained
,
str
):
logger
=
get_root_logger
()
load_checkpoint
(
self
,
pretrained
,
strict
=
strict
,
logger
=
logger
)
elif
isinstance
(
pretrained
,
dict
):
ckpt_path
=
pretrained
.
get
(
'ckpt_path'
,
None
)
assert
ckpt_path
is
not
None
prefix
=
pretrained
.
get
(
'prefix'
,
''
)
map_location
=
pretrained
.
get
(
'map_location'
,
'cpu'
)
strict
=
pretrained
.
get
(
'strict'
,
True
)
state_dict
=
_load_checkpoint_with_prefix
(
prefix
,
ckpt_path
,
map_location
)
self
.
load_state_dict
(
state_dict
,
strict
=
strict
)
elif
pretrained
is
None
:
if
self
.
init_type
.
upper
()
in
'STUDIO'
:
# initialization method from Pytorch-StudioGAN
# * weight: orthogonal_init gain=1
# * bias : 0
for
m
in
self
.
modules
():
if
isinstance
(
m
,
(
nn
.
Conv2d
,
nn
.
Linear
,
nn
.
Embedding
)):
nn
.
init
.
orthogonal_
(
m
.
weight
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
m
.
bias
.
data
.
fill_
(
0.
)
elif
self
.
init_type
.
upper
()
==
'BIGGAN'
:
# initialization method from BigGAN-pytorch
# * weight: xavier_init gain=1
# * bias : default
for
n
,
m
in
self
.
named_modules
():
if
isinstance
(
m
,
(
nn
.
Conv2d
,
nn
.
Linear
,
nn
.
Embedding
)):
xavier_uniform_
(
m
.
weight
,
gain
=
1
)
elif
self
.
init_type
.
upper
()
==
'SAGAN'
:
# initialization method from official tensorflow code
# * weight : xavier_init gain=1
# * bias : 0
# * weight_embedding: 1
# * bias_embedding : 0
for
n
,
m
in
self
.
named_modules
():
if
isinstance
(
m
,
(
nn
.
Conv2d
,
nn
.
Linear
)):
xavier_init
(
m
,
gain
=
1
,
distribution
=
'uniform'
)
if
isinstance
(
m
,
nn
.
Embedding
):
# To be noted that here we initialize the embedding
# layer in cBN with specific prefix. If you implement
# your own cBN and want to use this initialization
# method, please make sure the embedding layers in
# your implementation have the same prefix as ours.
if
'weight'
in
n
:
constant_init
(
m
,
1
)
if
'bias'
in
n
:
constant_init
(
m
,
0
)
elif
self
.
init_type
.
upper
()
in
[
'SNGAN'
,
'SNGAN-PROJ'
,
'GAN-PROJ'
]:
# initialization method from the official chainer code
# * conv.weight : xavier_init gain=sqrt(2)
# * shortcut.weight : xavier_init gain=1
# * bias : 0
# * weight_embedding: 1
# * bias_embedding : 0
for
n
,
m
in
self
.
named_modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
if
'shortcut'
in
n
or
'to_rgb'
in
n
:
xavier_init
(
m
,
gain
=
1
,
distribution
=
'uniform'
)
else
:
xavier_init
(
m
,
gain
=
np
.
sqrt
(
2
),
distribution
=
'uniform'
)
if
isinstance
(
m
,
nn
.
Linear
):
xavier_init
(
m
,
gain
=
1
,
distribution
=
'uniform'
)
if
isinstance
(
m
,
nn
.
Embedding
):
# To be noted that here we initialize the embedding
# layer in cBN with specific prefix. If you implement
# your own cBN and want to use this initialization
# method, please make sure the embedding layers in
# your implementation have the same prefix as ours.
if
'weight'
in
n
:
constant_init
(
m
,
1
)
if
'bias'
in
n
:
constant_init
(
m
,
0
)
else
:
raise
NotImplementedError
(
'Unknown initialization method: '
f
'
\'
{
self
.
init_type
}
\'
'
)
else
:
raise
TypeError
(
"'pretrined' must be a str or None. "
f
'But receive
{
type
(
pretrained
)
}
.'
)
@
MODULES
.
register_module
(
'SAGANDiscriminator'
)
@
MODULES
.
register_module
()
class
ProjDiscriminator
(
nn
.
Module
):
r
"""Discriminator for SNGAN / Proj-GAN. The implementation is refer to
https://github.com/pfnet-research/sngan_projection/tree/master/dis_models
The overall structure of the projection discriminator can be split into a
``from_rgb`` layer, a group of ResBlocks, a linear decision layer, and a
projection layer. To support defining custom layers, we introduce
``from_rgb_cfg`` and ``blocks_cfg``.
The design of the model structure is highly corresponding to the output
resolution. Therefore, we provide `channels_cfg` and `downsample_cfg` to
control the input channels and the downsample behavior of the intermedia
blocks.
``downsample_cfg``: In default config of SNGAN / Proj-GAN, whether to apply
downsample in each intermedia blocks is quite flexible and
corresponding to the resolution of the output image. Therefore, we
support user to define the ``downsample_cfg`` by themselves, and to
control the structure of the discriminator.
``channels_cfg``: In default config of SNGAN / Proj-GAN, the number of
ResBlocks and the channels of those blocks are corresponding to the
resolution of the output image. Therefore, we allow user to define
`channels_cfg` for try their own models. We also provide a default
config to allow users to build the model only from the output
resolution.
Args:
input_scale (int): The scale of the input image.
num_classes (int, optional): The number classes you would like to
generate. If num_classes=0, no label projection would be used.
Default to 0.
base_channels (int, optional): The basic channel number of the
discriminator. The other layers contains channels based on this
number. Defaults to 128.
input_channels (int, optional): Channels of the input image.
Defaults to 3.
attention_cfg (dict, optional): Config for the self-attention block.
Default to ``dict(type='SelfAttentionBlock')``.
attention_after_nth_block (int | list[int], optional): Self-attention
block would be added after which *ConvBlock* (including the head
block). If ``int`` is passed, only one attention block would be
added. If ``list`` is passed, self-attention blocks would be added
after multiple ConvBlocks. To be noted that if the input is
smaller than ``1``, self-attention corresponding to this index
would be ignored. Default to 0.
channels_cfg (list | dict[list], optional): Config for input channels
of the intermedia blocks. If list is passed, each element of the
list means the input channels of current block is how many times
compared to the ``base_channels``. For block ``i``, the input and
output channels should be ``channels_cfg[i]`` and
``channels_cfg[i+1]`` If dict is provided, the key of the dict
should be the output scale and corresponding value should be a list
to define channels. Default: Please refer to
``_defualt_channels_cfg``.
downsample_cfg (list[bool] | dict[list], optional): Config for
downsample behavior of the intermedia layers. If a list is passed,
``downsample_cfg[idx] == True`` means apply downsample in idx-th
block, and vice versa. If dict is provided, the key dict should
be the input scale of the image and corresponding value should be
a list ti define the downsample behavior. Default: Please refer
to ``_default_downsample_cfg``.
from_rgb_cfg (dict, optional): Config for the first layer to convert
rgb image to feature map. Defaults to
``dict(type='SNGANDiscHeadResBlock')``.
blocks_cfg (dict, optional): Config for the intermedia blocks.
Defaults to ``dict(type='SNGANDiscResBlock')``
act_cfg (dict, optional): Activation config for the final output
layer. Defaults to ``dict(type='ReLU')``.
with_spectral_norm (bool, optional): Whether use spectral norm for
all conv blocks or not. Default to True.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
sn_eps (float, optional): eps for spectral normalization operation.
Defaults to `1e-12`.
init_cfg (dict, optional): Config for weight initialization.
Default to ``dict(type='BigGAN')``.
pretrained (str | dict , optional): Path for the pretrained model or
dict containing information for pretained models whose necessary
key is 'ckpt_path'. Besides, you can also provide 'prefix' to load
the generator part from the whole state dict. Defaults to None.
"""
# default channel factors
_defualt_channels_cfg
=
{
32
:
[
1
,
1
,
1
],
64
:
[
2
,
4
,
8
,
16
],
128
:
[
2
,
4
,
8
,
16
,
16
],
}
# default downsample behavior
_defualt_downsample_cfg
=
{
32
:
[
True
,
False
,
False
],
64
:
[
True
,
True
,
True
,
True
],
128
:
[
True
,
True
,
True
,
True
,
False
]
}
def
__init__
(
self
,
input_scale
,
num_classes
=
0
,
base_channels
=
128
,
input_channels
=
3
,
attention_cfg
=
dict
(
type
=
'SelfAttentionBlock'
),
attention_after_nth_block
=-
1
,
channels_cfg
=
None
,
downsample_cfg
=
None
,
from_rgb_cfg
=
dict
(
type
=
'SNGANDiscHeadResBlock'
),
blocks_cfg
=
dict
(
type
=
'SNGANDiscResBlock'
),
act_cfg
=
dict
(
type
=
'ReLU'
),
with_spectral_norm
=
True
,
sn_style
=
'torch'
,
sn_eps
=
1e-12
,
init_cfg
=
dict
(
type
=
'BigGAN'
),
pretrained
=
None
):
super
().
__init__
()
self
.
init_type
=
init_cfg
.
get
(
'type'
,
None
)
# add SN options and activation function options to cfg
self
.
from_rgb_cfg
=
deepcopy
(
from_rgb_cfg
)
self
.
from_rgb_cfg
.
setdefault
(
'act_cfg'
,
act_cfg
)
self
.
from_rgb_cfg
.
setdefault
(
'with_spectral_norm'
,
with_spectral_norm
)
self
.
from_rgb_cfg
.
setdefault
(
'sn_style'
,
sn_style
)
self
.
from_rgb_cfg
.
setdefault
(
'init_cfg'
,
init_cfg
)
# add SN options and activation function options to cfg
self
.
blocks_cfg
=
deepcopy
(
blocks_cfg
)
self
.
blocks_cfg
.
setdefault
(
'act_cfg'
,
act_cfg
)
self
.
blocks_cfg
.
setdefault
(
'with_spectral_norm'
,
with_spectral_norm
)
self
.
blocks_cfg
.
setdefault
(
'sn_style'
,
sn_style
)
self
.
blocks_cfg
.
setdefault
(
'sn_eps'
,
sn_eps
)
self
.
blocks_cfg
.
setdefault
(
'init_cfg'
,
init_cfg
)
channels_cfg
=
deepcopy
(
self
.
_defualt_channels_cfg
)
\
if
channels_cfg
is
None
else
deepcopy
(
channels_cfg
)
if
isinstance
(
channels_cfg
,
dict
):
if
input_scale
not
in
channels_cfg
:
raise
KeyError
(
f
'`input_scale=
{
input_scale
}
is not found in '
'`channel_cfg`, only support configs for '
f
'
{
[
chn
for
chn
in
channels_cfg
.
keys
()]
}
'
)
self
.
channel_factor_list
=
channels_cfg
[
input_scale
]
elif
isinstance
(
channels_cfg
,
list
):
self
.
channel_factor_list
=
channels_cfg
else
:
raise
ValueError
(
'Only support list or dict for `channel_cfg`, '
f
'receive
{
type
(
channels_cfg
)
}
'
)
downsample_cfg
=
deepcopy
(
self
.
_defualt_downsample_cfg
)
\
if
downsample_cfg
is
None
else
deepcopy
(
downsample_cfg
)
if
isinstance
(
downsample_cfg
,
dict
):
if
input_scale
not
in
downsample_cfg
:
raise
KeyError
(
f
'`output_scale=
{
input_scale
}
is not found in '
'`downsample_cfg`, only support configs for '
f
'
{
[
chn
for
chn
in
downsample_cfg
.
keys
()]
}
'
)
self
.
downsample_list
=
downsample_cfg
[
input_scale
]
elif
isinstance
(
downsample_cfg
,
list
):
self
.
downsample_list
=
downsample_cfg
else
:
raise
ValueError
(
'Only support list or dict for `channel_cfg`, '
f
'receive
{
type
(
downsample_cfg
)
}
'
)
if
len
(
self
.
downsample_list
)
!=
len
(
self
.
channel_factor_list
):
raise
ValueError
(
'`downsample_cfg` should have same length with '
'`channels_cfg`, but receive '
f
'
{
len
(
self
.
downsample_list
)
}
and '
f
'
{
len
(
self
.
channel_factor_list
)
}
.'
)
# check `attention_after_nth_block`
if
not
isinstance
(
attention_after_nth_block
,
list
):
attention_after_nth_block
=
[
attention_after_nth_block
]
if
not
all
([
isinstance
(
idx
,
int
)
for
idx
in
attention_after_nth_block
]):
raise
ValueError
(
'`attention_after_nth_block` only support int or '
'a list of int. Please check your input type.'
)
self
.
from_rgb
=
build_module
(
self
.
from_rgb_cfg
,
dict
(
in_channels
=
input_channels
,
out_channels
=
base_channels
))
self
.
conv_blocks
=
nn
.
ModuleList
()
# add self-attention block after the first block
if
1
in
attention_after_nth_block
:
attn_cfg_
=
deepcopy
(
attention_cfg
)
attn_cfg_
[
'in_channels'
]
=
base_channels
attn_cfg_
[
'sn_style'
]
=
sn_style
self
.
conv_blocks
.
append
(
build_module
(
attn_cfg_
))
for
idx
in
range
(
len
(
self
.
downsample_list
)):
factor_input
=
1
if
idx
==
0
else
self
.
channel_factor_list
[
idx
-
1
]
factor_output
=
self
.
channel_factor_list
[
idx
]
# get block-specific config
block_cfg_
=
deepcopy
(
self
.
blocks_cfg
)
block_cfg_
[
'downsample'
]
=
self
.
downsample_list
[
idx
]
block_cfg_
[
'in_channels'
]
=
factor_input
*
base_channels
block_cfg_
[
'out_channels'
]
=
factor_output
*
base_channels
self
.
conv_blocks
.
append
(
build_module
(
block_cfg_
))
# build self-attention block
# the first ConvBlock is `from_rgb` block,
# add 2 to get the index of the ConvBlocks
if
idx
+
2
in
attention_after_nth_block
:
attn_cfg_
=
deepcopy
(
attention_cfg
)
attn_cfg_
[
'in_channels'
]
=
factor_output
*
base_channels
self
.
conv_blocks
.
append
(
build_module
(
attn_cfg_
))
self
.
decision
=
nn
.
Linear
(
factor_output
*
base_channels
,
1
)
if
with_spectral_norm
:
self
.
decision
=
spectral_norm
(
self
.
decision
)
self
.
num_classes
=
num_classes
# In this case, discriminator is designed for conditional synthesis.
if
num_classes
>
0
:
self
.
proj_y
=
nn
.
Embedding
(
num_classes
,
factor_output
*
base_channels
)
if
with_spectral_norm
:
self
.
proj_y
=
spectral_norm
(
self
.
proj_y
)
self
.
activate
=
build_activation_layer
(
act_cfg
)
self
.
init_weights
(
pretrained
)
def
forward
(
self
,
x
,
label
=
None
):
"""Forward function. If `self.num_classes` is larger than 0, label
projection would be used.
Args:
x (torch.Tensor): Fake or real image tensor.
label (torch.Tensor, options): Label correspond to the input image.
Noted that, if `self.num_classed` is larger than 0,
`label` should not be None. Default to None.
Returns:
torch.Tensor: Prediction for the reality of the input image.
"""
h
=
self
.
from_rgb
(
x
)
for
conv_block
in
self
.
conv_blocks
:
h
=
conv_block
(
h
)
h
=
self
.
activate
(
h
)
h
=
torch
.
sum
(
h
,
dim
=
[
2
,
3
])
out
=
self
.
decision
(
h
)
if
self
.
num_classes
>
0
:
w_y
=
self
.
proj_y
(
label
)
out
=
out
+
torch
.
sum
(
w_y
*
h
,
dim
=
1
,
keepdim
=
True
)
return
out
.
view
(
out
.
size
(
0
),
-
1
)
def
init_weights
(
self
,
pretrained
=
None
,
strict
=
True
):
"""Init weights for SNGAN-Proj and SAGAN. If ``pretrained=None`` and
weight initialization would follow the ``INIT_TYPE`` in
``init_cfg=dict(type=INIT_TYPE)``.
For SNGAN-Proj
(``INIT_TYPE.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']``),
we follow the initialization method in the official Chainer's
implementation (https://github.com/pfnet-research/sngan_projection).
For SAGAN (``INIT_TYPE.upper() == 'SAGAN'``), we follow the
initialization method in official tensorflow's implementation
(https://github.com/brain-research/self-attention-gan).
Besides the reimplementation of the official code's initialization, we
provide BigGAN's and Pytorch-StudioGAN's style initialization
(``INIT_TYPE.upper() == BIGGAN`` and ``INIT_TYPE.upper() == STUDIO``).
Please refer to https://github.com/ajbrock/BigGAN-PyTorch and
https://github.com/POSTECH-CVLab/PyTorch-StudioGAN.
Args:
pretrained (str | dict, optional): Path for the pretrained model or
dict containing information for pretained models whose
necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
"""
if
isinstance
(
pretrained
,
str
):
logger
=
get_root_logger
()
load_checkpoint
(
self
,
pretrained
,
strict
=
strict
,
logger
=
logger
)
elif
isinstance
(
pretrained
,
dict
):
ckpt_path
=
pretrained
.
get
(
'ckpt_path'
,
None
)
assert
ckpt_path
is
not
None
prefix
=
pretrained
.
get
(
'prefix'
,
''
)
map_location
=
pretrained
.
get
(
'map_location'
,
'cpu'
)
strict
=
pretrained
.
get
(
'strict'
,
True
)
state_dict
=
_load_checkpoint_with_prefix
(
prefix
,
ckpt_path
,
map_location
)
self
.
load_state_dict
(
state_dict
,
strict
=
strict
)
elif
pretrained
is
None
:
if
self
.
init_type
.
upper
()
==
'STUDIO'
:
# initialization method from Pytorch-StudioGAN
# * weight: orthogonal_init gain=1
# * bias : 0
for
m
in
self
.
modules
():
if
isinstance
(
m
,
(
nn
.
Conv2d
,
nn
.
Linear
,
nn
.
Embedding
)):
nn
.
init
.
orthogonal_
(
m
.
weight
,
gain
=
1
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
m
.
bias
.
data
.
fill_
(
0.
)
elif
self
.
init_type
.
upper
()
==
'BIGGAN'
:
# initialization method from BigGAN-pytorch
# * weight: xavier_init gain=1
# * bias : default
for
m
in
self
.
modules
():
if
isinstance
(
m
,
(
nn
.
Conv2d
,
nn
.
Linear
,
nn
.
Embedding
)):
xavier_uniform_
(
m
.
weight
,
gain
=
1
)
elif
self
.
init_type
.
upper
()
==
'SAGAN'
:
# initialization method from official tensorflow code
# * weight: xavier_init gain=1
# * bias : 0
for
m
in
self
.
modules
():
if
isinstance
(
m
,
(
nn
.
Conv2d
,
nn
.
Linear
,
nn
.
Embedding
)):
xavier_init
(
m
,
gain
=
1
,
distribution
=
'uniform'
)
elif
self
.
init_type
.
upper
()
in
[
'SNGAN'
,
'SNGAN-PROJ'
,
'GAN-PROJ'
]:
# initialization method from the official chainer code
# * embedding.weight: xavier_init gain=1
# * conv.weight : xavier_init gain=sqrt(2)
# * shortcut.weight : xavier_init gain=1
# * bias : 0
for
n
,
m
in
self
.
named_modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
if
'shortcut'
in
n
:
xavier_init
(
m
,
gain
=
1
,
distribution
=
'uniform'
)
else
:
xavier_init
(
m
,
gain
=
np
.
sqrt
(
2
),
distribution
=
'uniform'
)
if
isinstance
(
m
,
(
nn
.
Linear
,
nn
.
Embedding
)):
xavier_init
(
m
,
gain
=
1
,
distribution
=
'uniform'
)
else
:
raise
NotImplementedError
(
'Unknown initialization method: '
f
'
\'
{
self
.
init_type
}
\'
'
)
else
:
raise
TypeError
(
"'pretrained' must by a str or None. "
f
'But receive
{
type
(
pretrained
)
}
.'
)
mmgen/models/architectures/sngan_proj/modules.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
import
numpy
as
np
import
torch.nn
as
nn
from
mmcv.cnn
import
(
build_activation_layer
,
build_norm_layer
,
build_upsample_layer
,
constant_init
,
xavier_init
)
from
torch.nn.init
import
xavier_uniform_
from
torch.nn.utils
import
spectral_norm
from
mmgen.models.architectures.biggan.biggan_snmodule
import
SNEmbedding
from
mmgen.models.architectures.biggan.modules
import
SNConvModule
from
mmgen.models.builder
import
MODULES
from
mmgen.utils
import
check_dist_init
@
MODULES
.
register_module
()
class
SNGANGenResBlock
(
nn
.
Module
):
"""ResBlock used in Generator of SNGAN / Proj-GAN.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
hidden_channels (int, optional): Input channels of the second Conv
layer of the block. If ``None`` is given, would be set as
``out_channels``. Default to None.
num_classes (int, optional): Number of classes would like to generate.
This argument would pass to norm layers and influence the structure
and behavior of the normalization process. Default to 0.
use_cbn (bool, optional): Whether use conditional normalization. This
argument would pass to norm layers. Default to True.
use_norm_affine (bool, optional): Whether use learnable affine
parameters in norm operation when cbn is off. Default False.
act_cfg (dict, optional): Config for activate function. Default
to ``dict(type='ReLU')``.
upsample_cfg (dict, optional): Config for the upsample method.
Default to ``dict(type='nearest', scale_factor=2)``.
upsample (bool, optional): Whether apply upsample operation in this
module. Default to True.
auto_sync_bn (bool, optional): Whether convert Batch Norm to
Synchronized ones when Distributed training is on. Default to True.
conv_cfg (dict | None): Config for conv blocks of this module. If pass
``None``, would use ``_default_conv_cfg``. Default to ``None``.
with_spectral_norm (bool, optional): Whether use spectral norm for
conv blocks and norm layers. Default to True.
with_embedding_spectral_norm (bool, optional): Whether use spectral
norm for embedding layers in normalization blocks or not. If not
specified (set as ``None``), ``with_embedding_spectral_norm`` would
be set as the same value as ``with_spectral_norm``.
Default to None.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
norm_eps (float, optional): eps for Normalization layers (both
conditional and non-conditional ones). Default to `1e-4`.
sn_eps (float, optional): eps for spectral normalization operation.
Default to `1e-12`.
init_cfg (dict, optional): Config for weight initialization.
Default to ``dict(type='BigGAN')``.
"""
_default_conv_cfg
=
dict
(
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
act_cfg
=
None
)
def
__init__
(
self
,
in_channels
,
out_channels
,
hidden_channels
=
None
,
num_classes
=
0
,
use_cbn
=
True
,
use_norm_affine
=
False
,
act_cfg
=
dict
(
type
=
'ReLU'
),
norm_cfg
=
dict
(
type
=
'BN'
),
upsample_cfg
=
dict
(
type
=
'nearest'
,
scale_factor
=
2
),
upsample
=
True
,
auto_sync_bn
=
True
,
conv_cfg
=
None
,
with_spectral_norm
=
False
,
with_embedding_spectral_norm
=
None
,
sn_style
=
'torch'
,
norm_eps
=
1e-4
,
sn_eps
=
1e-12
,
init_cfg
=
dict
(
type
=
'BigGAN'
)):
super
().
__init__
()
self
.
learnable_sc
=
in_channels
!=
out_channels
or
upsample
self
.
with_upsample
=
upsample
self
.
init_type
=
init_cfg
.
get
(
'type'
,
None
)
self
.
activate
=
build_activation_layer
(
act_cfg
)
hidden_channels
=
out_channels
if
hidden_channels
is
None
\
else
hidden_channels
if
self
.
with_upsample
:
self
.
upsample
=
build_upsample_layer
(
upsample_cfg
)
self
.
conv_cfg
=
deepcopy
(
self
.
_default_conv_cfg
)
if
conv_cfg
is
not
None
:
self
.
conv_cfg
.
update
(
conv_cfg
)
# set `norm_spectral_norm` as `with_spectral_norm` if not defined
with_embedding_spectral_norm
=
with_embedding_spectral_norm
\
if
with_embedding_spectral_norm
is
not
None
else
with_spectral_norm
sn_cfg
=
dict
(
eps
=
sn_eps
,
sn_style
=
sn_style
)
self
.
conv_1
=
SNConvModule
(
in_channels
,
hidden_channels
,
with_spectral_norm
=
with_spectral_norm
,
spectral_norm_cfg
=
sn_cfg
,
**
self
.
conv_cfg
)
self
.
conv_2
=
SNConvModule
(
hidden_channels
,
out_channels
,
with_spectral_norm
=
with_spectral_norm
,
spectral_norm_cfg
=
sn_cfg
,
**
self
.
conv_cfg
)
self
.
norm_1
=
SNConditionNorm
(
in_channels
,
num_classes
,
use_cbn
,
norm_cfg
,
use_norm_affine
,
auto_sync_bn
,
with_embedding_spectral_norm
,
sn_style
,
norm_eps
,
sn_eps
,
init_cfg
)
self
.
norm_2
=
SNConditionNorm
(
hidden_channels
,
num_classes
,
use_cbn
,
norm_cfg
,
use_norm_affine
,
auto_sync_bn
,
with_embedding_spectral_norm
,
sn_style
,
norm_eps
,
sn_eps
,
init_cfg
)
if
self
.
learnable_sc
:
# use hyperparameters-fixed shortcut here
self
.
shortcut
=
SNConvModule
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
act_cfg
=
None
,
with_spectral_norm
=
with_spectral_norm
,
spectral_norm_cfg
=
sn_cfg
)
self
.
init_weights
()
def
forward
(
self
,
x
,
y
=
None
):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
y (Tensor): Input label with shape (n, ).
Default None.
Returns:
Tensor: Forward results.
"""
out
=
self
.
norm_1
(
x
,
y
)
out
=
self
.
activate
(
out
)
if
self
.
with_upsample
:
out
=
self
.
upsample
(
out
)
out
=
self
.
conv_1
(
out
)
out
=
self
.
norm_2
(
out
,
y
)
out
=
self
.
activate
(
out
)
out
=
self
.
conv_2
(
out
)
shortcut
=
self
.
forward_shortcut
(
x
)
return
out
+
shortcut
def
forward_shortcut
(
self
,
x
):
out
=
x
if
self
.
learnable_sc
:
if
self
.
with_upsample
:
out
=
self
.
upsample
(
out
)
out
=
self
.
shortcut
(
out
)
return
out
def
init_weights
(
self
):
"""Initialize weights for the model."""
if
self
.
init_type
.
upper
()
==
'STUDIO'
:
nn
.
init
.
orthogonal_
(
self
.
conv_1
.
conv
.
weight
)
nn
.
init
.
orthogonal_
(
self
.
conv_2
.
conv
.
weight
)
self
.
conv_1
.
conv
.
bias
.
data
.
fill_
(
0.
)
self
.
conv_2
.
conv
.
bias
.
data
.
fill_
(
0.
)
if
self
.
learnable_sc
:
nn
.
init
.
orthogonal_
(
self
.
shortcut
.
conv
.
weight
)
self
.
shortcut
.
conv
.
bias
.
data
.
fill_
(
0.
)
elif
self
.
init_type
.
upper
()
==
'BIGGAN'
:
xavier_uniform_
(
self
.
conv_1
.
conv
.
weight
,
gain
=
1
)
xavier_uniform_
(
self
.
conv_2
.
conv
.
weight
,
gain
=
1
)
if
self
.
learnable_sc
:
xavier_uniform_
(
self
.
shortcut
.
conv
.
weight
,
gain
=
1
)
elif
self
.
init_type
.
upper
()
==
'SAGAN'
:
xavier_init
(
self
.
conv_1
,
gain
=
1
,
distribution
=
'uniform'
)
xavier_init
(
self
.
conv_2
,
gain
=
1
,
distribution
=
'uniform'
)
if
self
.
learnable_sc
:
xavier_init
(
self
.
shortcut
,
gain
=
1
,
distribution
=
'uniform'
)
elif
self
.
init_type
.
upper
()
in
[
'SNGAN'
,
'SNGAN-PROJ'
,
'GAN-PROJ'
]:
xavier_init
(
self
.
conv_1
,
gain
=
np
.
sqrt
(
2
),
distribution
=
'uniform'
)
xavier_init
(
self
.
conv_2
,
gain
=
np
.
sqrt
(
2
),
distribution
=
'uniform'
)
if
self
.
learnable_sc
:
xavier_init
(
self
.
shortcut
,
gain
=
1
,
distribution
=
'uniform'
)
else
:
raise
NotImplementedError
(
'Unknown initialization method: '
f
'
\'
{
self
.
init_type
}
\'
'
)
@
MODULES
.
register_module
()
class
SNGANDiscResBlock
(
nn
.
Module
):
"""resblock used in discriminator of sngan / proj-gan.
args:
in_channels (int): input channels.
out_channels (int): output channels.
hidden_channels (int, optional): input channels of the second conv
layer of the block. if ``none`` is given, would be set as
``out_channels``. Defaults to none.
downsample (bool, optional): whether apply downsample operation in this
module. Defaults to false.
act_cfg (dict, optional): config for activate function. default
to ``dict(type='relu')``.
conv_cfg (dict | none): config for conv blocks of this module. if pass
``none``, would use ``_default_conv_cfg``. default to ``none``.
with_spectral_norm (bool, optional): whether use spectral norm for
conv blocks and norm layers. Defaults to true.
sn_eps (float, optional): eps for spectral normalization operation.
Default to `1e-12`.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
init_cfg (dict, optional): Config for weight initialization.
Defaults to ``dict(type='BigGAN')``.
"""
_default_conv_cfg
=
dict
(
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
act_cfg
=
None
)
def
__init__
(
self
,
in_channels
,
out_channels
,
hidden_channels
=
None
,
downsample
=
False
,
act_cfg
=
dict
(
type
=
'ReLU'
),
conv_cfg
=
None
,
with_spectral_norm
=
True
,
sn_style
=
'torch'
,
sn_eps
=
1e-12
,
init_cfg
=
dict
(
type
=
'BigGAN'
)):
super
().
__init__
()
hidden_channels
=
out_channels
if
hidden_channels
is
None
\
else
hidden_channels
self
.
with_downsample
=
downsample
self
.
init_type
=
init_cfg
.
get
(
'type'
,
None
)
self
.
conv_cfg
=
deepcopy
(
self
.
_default_conv_cfg
)
if
conv_cfg
is
not
None
:
self
.
conv_cfg
.
update
(
conv_cfg
)
self
.
activate
=
build_activation_layer
(
act_cfg
)
sn_cfg
=
dict
(
eps
=
sn_eps
,
sn_style
=
sn_style
)
self
.
conv_1
=
SNConvModule
(
in_channels
,
hidden_channels
,
with_spectral_norm
=
with_spectral_norm
,
spectral_norm_cfg
=
sn_cfg
,
**
self
.
conv_cfg
)
self
.
conv_2
=
SNConvModule
(
hidden_channels
,
out_channels
,
with_spectral_norm
=
with_spectral_norm
,
spectral_norm_cfg
=
sn_cfg
,
**
self
.
conv_cfg
)
if
self
.
with_downsample
:
self
.
downsample
=
nn
.
AvgPool2d
(
2
,
2
)
self
.
learnable_sc
=
in_channels
!=
out_channels
or
downsample
if
self
.
learnable_sc
:
# use hyperparameters-fixed shortcut here
self
.
shortcut
=
SNConvModule
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
act_cfg
=
None
,
with_spectral_norm
=
with_spectral_norm
,
spectral_norm_cfg
=
sn_cfg
)
self
.
init_weights
()
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
out
=
self
.
activate
(
x
)
out
=
self
.
conv_1
(
out
)
out
=
self
.
activate
(
out
)
out
=
self
.
conv_2
(
out
)
if
self
.
with_downsample
:
out
=
self
.
downsample
(
out
)
shortcut
=
self
.
forward_shortcut
(
x
)
return
out
+
shortcut
def
forward_shortcut
(
self
,
x
):
out
=
x
if
self
.
learnable_sc
:
out
=
self
.
shortcut
(
out
)
if
self
.
with_downsample
:
out
=
self
.
downsample
(
out
)
return
out
def
init_weights
(
self
):
if
self
.
init_type
.
upper
()
==
'STUDIO'
:
nn
.
init
.
orthogonal_
(
self
.
conv_1
.
conv
.
weight
)
nn
.
init
.
orthogonal_
(
self
.
conv_2
.
conv
.
weight
)
self
.
conv_1
.
conv
.
bias
.
data
.
fill_
(
0.
)
self
.
conv_2
.
conv
.
bias
.
data
.
fill_
(
0.
)
if
self
.
learnable_sc
:
nn
.
init
.
orthogonal_
(
self
.
shortcut
.
conv
.
weight
)
self
.
shortcut
.
conv
.
bias
.
data
.
fill_
(
0.
)
elif
self
.
init_type
.
upper
()
==
'BIGGAN'
:
xavier_uniform_
(
self
.
conv_1
.
conv
.
weight
,
gain
=
1
)
xavier_uniform_
(
self
.
conv_2
.
conv
.
weight
,
gain
=
1
)
if
self
.
learnable_sc
:
xavier_uniform_
(
self
.
shortcut
.
conv
.
weight
,
gain
=
1
)
elif
self
.
init_type
.
upper
()
==
'SAGAN'
:
xavier_init
(
self
.
conv_1
,
gain
=
1
,
distribution
=
'uniform'
)
xavier_init
(
self
.
conv_2
,
gain
=
1
,
distribution
=
'uniform'
)
if
self
.
learnable_sc
:
xavier_init
(
self
.
shortcut
,
gain
=
1
,
distribution
=
'uniform'
)
elif
self
.
init_type
.
upper
()
in
[
'SNGAN'
,
'SNGAN-PROJ'
,
'GAN-PROJ'
]:
xavier_init
(
self
.
conv_1
,
gain
=
np
.
sqrt
(
2
),
distribution
=
'uniform'
)
xavier_init
(
self
.
conv_2
,
gain
=
np
.
sqrt
(
2
),
distribution
=
'uniform'
)
if
self
.
learnable_sc
:
xavier_init
(
self
.
shortcut
,
gain
=
1
,
distribution
=
'uniform'
)
else
:
raise
NotImplementedError
(
'Unknown initialization method: '
f
'
\'
{
self
.
init_type
}
\'
'
)
@
MODULES
.
register_module
()
class
SNGANDiscHeadResBlock
(
nn
.
Module
):
"""The first ResBlock used in discriminator of sngan / proj-gan. Compared
to ``SNGANDisResBlock``, this module has a different forward order.
args:
in_channels (int): Input channels.
out_channels (int): Output channels.
downsample (bool, optional): whether apply downsample operation in this
module. default to false.
conv_cfg (dict | none): config for conv blocks of this module. if pass
``none``, would use ``_default_conv_cfg``. default to ``none``.
act_cfg (dict, optional): config for activate function. default
to ``dict(type='relu')``.
with_spectral_norm (bool, optional): whether use spectral norm for
conv blocks and norm layers. default to true.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
sn_eps (float, optional): eps for spectral normalization operation.
Default to `1e-12`.
init_cfg (dict, optional): Config for weight initialization.
Default to ``dict(type='BigGAN')``.
"""
_default_conv_cfg
=
dict
(
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
act_cfg
=
None
)
def
__init__
(
self
,
in_channels
,
out_channels
,
conv_cfg
=
None
,
act_cfg
=
dict
(
type
=
'ReLU'
),
with_spectral_norm
=
True
,
sn_eps
=
1e-12
,
sn_style
=
'torch'
,
init_cfg
=
dict
(
type
=
'BigGAN'
)):
super
().
__init__
()
self
.
init_type
=
init_cfg
.
get
(
'type'
,
None
)
self
.
conv_cfg
=
deepcopy
(
self
.
_default_conv_cfg
)
if
conv_cfg
is
not
None
:
self
.
conv_cfg
.
update
(
conv_cfg
)
self
.
activate
=
build_activation_layer
(
act_cfg
)
sn_cfg
=
dict
(
eps
=
sn_eps
,
sn_style
=
sn_style
)
self
.
conv_1
=
SNConvModule
(
in_channels
,
out_channels
,
with_spectral_norm
=
with_spectral_norm
,
spectral_norm_cfg
=
sn_cfg
,
**
self
.
conv_cfg
)
self
.
conv_2
=
SNConvModule
(
out_channels
,
out_channels
,
with_spectral_norm
=
with_spectral_norm
,
spectral_norm_cfg
=
sn_cfg
,
**
self
.
conv_cfg
)
self
.
downsample
=
nn
.
AvgPool2d
(
2
,
2
)
# use hyperparameters-fixed shortcut here
self
.
shortcut
=
SNConvModule
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
act_cfg
=
None
,
with_spectral_norm
=
with_spectral_norm
,
spectral_norm_cfg
=
sn_cfg
)
self
.
init_weights
()
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
out
=
self
.
conv_1
(
x
)
out
=
self
.
activate
(
out
)
out
=
self
.
conv_2
(
out
)
out
=
self
.
downsample
(
out
)
shortcut
=
self
.
forward_shortcut
(
x
)
return
out
+
shortcut
def
forward_shortcut
(
self
,
x
):
out
=
self
.
downsample
(
x
)
out
=
self
.
shortcut
(
out
)
return
out
def
init_weights
(
self
):
if
self
.
init_type
.
upper
()
==
'STUDIO'
:
for
m
in
[
self
.
conv_1
,
self
.
conv_2
,
self
.
shortcut
]:
nn
.
init
.
orthogonal_
(
m
.
conv
.
weight
)
m
.
conv
.
bias
.
data
.
fill_
(
0.
)
elif
self
.
init_type
.
upper
()
==
'BIGGAN'
:
xavier_uniform_
(
self
.
conv_1
.
conv
.
weight
,
gain
=
1
)
xavier_uniform_
(
self
.
conv_2
.
conv
.
weight
,
gain
=
1
)
xavier_uniform_
(
self
.
shortcut
.
conv
.
weight
,
gain
=
1
)
elif
self
.
init_type
.
upper
()
==
'SAGAN'
:
xavier_init
(
self
.
conv_1
,
gain
=
1
,
distribution
=
'uniform'
)
xavier_init
(
self
.
conv_2
,
gain
=
1
,
distribution
=
'uniform'
)
xavier_init
(
self
.
shortcut
,
gain
=
1
,
distribution
=
'uniform'
)
elif
self
.
init_type
.
upper
()
in
[
'SNGAN'
,
'SNGAN-PROJ'
,
'GAN-PROJ'
]:
xavier_init
(
self
.
conv_1
,
gain
=
np
.
sqrt
(
2
),
distribution
=
'uniform'
)
xavier_init
(
self
.
conv_2
,
gain
=
np
.
sqrt
(
2
),
distribution
=
'uniform'
)
xavier_init
(
self
.
shortcut
,
gain
=
1
,
distribution
=
'uniform'
)
else
:
raise
NotImplementedError
(
'Unknown initialization method: '
f
'
\'
{
self
.
init_type
}
\'
'
)
@
MODULES
.
register_module
()
class
SNConditionNorm
(
nn
.
Module
):
"""Conditional Normalization for SNGAN / Proj-GAN. The implementation
refers to.
https://github.com/pfnet-research/sngan_projection/blob/master/source/links/conditional_batch_normalization.py # noda
and
https://github.com/POSTECH-CVLab/PyTorch-StudioGAN/blob/master/src/utils/model_ops.py # noqa
Args:
in_channels (int): Number of the channels of the input feature map.
num_classes (int): Number of the classes in the dataset. If ``use_cbn``
is True, ``num_classes`` must larger than 0.
use_cbn (bool, optional): Whether use conditional normalization. If
``use_cbn`` is True, two embedding layers would be used to mapping
label to weight and bias used in normalization process.
norm_cfg (dict, optional): Config for normalization method. Defaults
to ``dict(type='BN')``.
cbn_norm_affine (bool): Whether set ``affine=True`` when use conditional batch norm.
This argument only work when ``use_cbn`` is True. Defaults to False.
auto_sync_bn (bool, optional): Whether convert Batch Norm to
Synchronized ones when Distributed training is on. Defaults to True.
with_spectral_norm (bool, optional): whether use spectral norm for
conv blocks and norm layers. Defaults to true.
norm_eps (float, optional): eps for Normalization layers (both
conditional and non-conditional ones). Defaults to `1e-4`.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
sn_eps (float, optional): eps for spectral normalization operation.
Defaults to `1e-12`.
init_cfg (dict, optional): Config for weight initialization.
Defaults to ``dict(type='BigGAN')``.
"""
def
__init__
(
self
,
in_channels
,
num_classes
,
use_cbn
=
True
,
norm_cfg
=
dict
(
type
=
'BN'
),
cbn_norm_affine
=
False
,
auto_sync_bn
=
True
,
with_spectral_norm
=
False
,
sn_style
=
'torch'
,
norm_eps
=
1e-4
,
sn_eps
=
1e-12
,
init_cfg
=
dict
(
type
=
'BigGAN'
)):
super
().
__init__
()
self
.
use_cbn
=
use_cbn
self
.
init_type
=
init_cfg
.
get
(
'type'
,
None
)
norm_cfg
=
deepcopy
(
norm_cfg
)
norm_type
=
norm_cfg
[
'type'
]
if
norm_type
not
in
[
'IN'
,
'BN'
,
'SyncBN'
]:
raise
ValueError
(
'Only support `IN` (InstanceNorm), '
'`BN` (BatcnNorm) and `SyncBN` for '
'Class-conditional bn. '
f
'Receive norm_type:
{
norm_type
}
'
)
if
self
.
use_cbn
:
norm_cfg
.
setdefault
(
'affine'
,
cbn_norm_affine
)
norm_cfg
.
setdefault
(
'eps'
,
norm_eps
)
if
check_dist_init
()
and
auto_sync_bn
and
norm_type
==
'BN'
:
norm_cfg
[
'type'
]
=
'SyncBN'
_
,
self
.
norm
=
build_norm_layer
(
norm_cfg
,
in_channels
)
if
self
.
use_cbn
:
if
num_classes
<=
0
:
raise
ValueError
(
'`num_classes` must be larger '
'than 0 with `use_cbn=True`'
)
self
.
reweight_embedding
=
(
self
.
init_type
.
upper
()
==
'BIGGAN'
or
self
.
init_type
.
upper
()
==
'STUDIO'
)
if
with_spectral_norm
:
if
sn_style
==
'torch'
:
self
.
weight_embedding
=
spectral_norm
(
nn
.
Embedding
(
num_classes
,
in_channels
),
eps
=
sn_eps
)
self
.
bias_embedding
=
spectral_norm
(
nn
.
Embedding
(
num_classes
,
in_channels
),
eps
=
sn_eps
)
elif
sn_style
==
'ajbrock'
:
self
.
weight_embedding
=
SNEmbedding
(
num_classes
,
in_channels
,
eps
=
sn_eps
)
self
.
bias_embedding
=
SNEmbedding
(
num_classes
,
in_channels
,
eps
=
sn_eps
)
else
:
raise
NotImplementedError
(
f
'
{
sn_style
}
style spectral Norm is not '
'supported yet'
)
else
:
self
.
weight_embedding
=
nn
.
Embedding
(
num_classes
,
in_channels
)
self
.
bias_embedding
=
nn
.
Embedding
(
num_classes
,
in_channels
)
self
.
init_weights
()
def
forward
(
self
,
x
,
y
=
None
):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
y (Tensor, optional): Input label with shape (n, ).
Default None.
Returns:
Tensor: Forward results.
"""
out
=
self
.
norm
(
x
)
if
self
.
use_cbn
:
weight
=
self
.
weight_embedding
(
y
)[:,
:,
None
,
None
]
bias
=
self
.
bias_embedding
(
y
)[:,
:,
None
,
None
]
if
self
.
reweight_embedding
:
# print('reweight_called --> correct')
weight
=
weight
+
1.
out
=
out
*
weight
+
bias
return
out
def
init_weights
(
self
):
if
self
.
use_cbn
:
if
self
.
init_type
.
upper
()
==
'STUDIO'
:
nn
.
init
.
orthogonal_
(
self
.
weight_embedding
.
weight
)
nn
.
init
.
orthogonal_
(
self
.
bias_embedding
.
weight
)
elif
self
.
init_type
.
upper
()
==
'BIGGAN'
:
xavier_uniform_
(
self
.
weight_embedding
.
weight
,
gain
=
1
)
xavier_uniform_
(
self
.
bias_embedding
.
weight
,
gain
=
1
)
elif
self
.
init_type
.
upper
()
in
[
'SNGAN'
,
'SNGAN-PROJ'
,
'GAN-PROJ'
,
'SAGAN'
]:
constant_init
(
self
.
weight_embedding
,
1
)
constant_init
(
self
.
bias_embedding
,
0
)
else
:
raise
NotImplementedError
(
'Unknown initialization method: '
f
'
\'
{
self
.
init_type
}
\'
'
)
mmgen/models/architectures/stylegan/__init__.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
from
.generator_discriminator_v1
import
(
StyleGAN1Discriminator
,
StyleGANv1Generator
)
from
.generator_discriminator_v2
import
(
StyleGAN2Discriminator
,
StyleGANv2Generator
)
from
.generator_discriminator_v3
import
StyleGANv3Generator
from
.mspie
import
MSStyleGAN2Discriminator
,
MSStyleGANv2Generator
__all__
=
[
'StyleGAN2Discriminator'
,
'StyleGANv2Generator'
,
'StyleGANv1Generator'
,
'StyleGAN1Discriminator'
,
'MSStyleGAN2Discriminator'
,
'MSStyleGANv2Generator'
,
'StyleGANv3Generator'
]
mmgen/models/architectures/stylegan/ada/augment.py
0 → 100644
View file @
b7536f78
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import
numpy
as
np
import
scipy.signal
import
torch
from
mmgen.ops
import
conv2d_gradfix
from
.
import
grid_sample_gradfix
,
misc
,
upfirdn2d
# ----------------------------------------------------------------------------
# Coefficients of various wavelet decomposition low-pass filters.
wavelets
=
{
'haar'
:
[
0.7071067811865476
,
0.7071067811865476
],
'db1'
:
[
0.7071067811865476
,
0.7071067811865476
],
'db2'
:
[
-
0.12940952255092145
,
0.22414386804185735
,
0.836516303737469
,
0.48296291314469025
],
'db3'
:
[
0.035226291882100656
,
-
0.08544127388224149
,
-
0.13501102001039084
,
0.4598775021193313
,
0.8068915093133388
,
0.3326705529509569
],
'db4'
:
[
-
0.010597401784997278
,
0.032883011666982945
,
0.030841381835986965
,
-
0.18703481171888114
,
-
0.02798376941698385
,
0.6308807679295904
,
0.7148465705525415
,
0.23037781330885523
],
'db5'
:
[
0.003335725285001549
,
-
0.012580751999015526
,
-
0.006241490213011705
,
0.07757149384006515
,
-
0.03224486958502952
,
-
0.24229488706619015
,
0.13842814590110342
,
0.7243085284385744
,
0.6038292697974729
,
0.160102397974125
],
'db6'
:
[
-
0.00107730108499558
,
0.004777257511010651
,
0.0005538422009938016
,
-
0.031582039318031156
,
0.02752286553001629
,
0.09750160558707936
,
-
0.12976686756709563
,
-
0.22626469396516913
,
0.3152503517092432
,
0.7511339080215775
,
0.4946238903983854
,
0.11154074335008017
],
'db7'
:
[
0.0003537138000010399
,
-
0.0018016407039998328
,
0.00042957797300470274
,
0.012550998556013784
,
-
0.01657454163101562
,
-
0.03802993693503463
,
0.0806126091510659
,
0.07130921926705004
,
-
0.22403618499416572
,
-
0.14390600392910627
,
0.4697822874053586
,
0.7291320908465551
,
0.39653931948230575
,
0.07785205408506236
],
'db8'
:
[
-
0.00011747678400228192
,
0.0006754494059985568
,
-
0.0003917403729959771
,
-
0.00487035299301066
,
0.008746094047015655
,
0.013981027917015516
,
-
0.04408825393106472
,
-
0.01736930100202211
,
0.128747426620186
,
0.00047248457399797254
,
-
0.2840155429624281
,
-
0.015829105256023893
,
0.5853546836548691
,
0.6756307362980128
,
0.3128715909144659
,
0.05441584224308161
],
'sym2'
:
[
-
0.12940952255092145
,
0.22414386804185735
,
0.836516303737469
,
0.48296291314469025
],
'sym3'
:
[
0.035226291882100656
,
-
0.08544127388224149
,
-
0.13501102001039084
,
0.4598775021193313
,
0.8068915093133388
,
0.3326705529509569
],
'sym4'
:
[
-
0.07576571478927333
,
-
0.02963552764599851
,
0.49761866763201545
,
0.8037387518059161
,
0.29785779560527736
,
-
0.09921954357684722
,
-
0.012603967262037833
,
0.0322231006040427
],
'sym5'
:
[
0.027333068345077982
,
0.029519490925774643
,
-
0.039134249302383094
,
0.1993975339773936
,
0.7234076904024206
,
0.6339789634582119
,
0.01660210576452232
,
-
0.17532808990845047
,
-
0.021101834024758855
,
0.019538882735286728
],
'sym6'
:
[
0.015404109327027373
,
0.0034907120842174702
,
-
0.11799011114819057
,
-
0.048311742585633
,
0.4910559419267466
,
0.787641141030194
,
0.3379294217276218
,
-
0.07263752278646252
,
-
0.021060292512300564
,
0.04472490177066578
,
0.0017677118642428036
,
-
0.007800708325034148
],
'sym7'
:
[
0.002681814568257878
,
-
0.0010473848886829163
,
-
0.01263630340325193
,
0.03051551316596357
,
0.0678926935013727
,
-
0.049552834937127255
,
0.017441255086855827
,
0.5361019170917628
,
0.767764317003164
,
0.2886296317515146
,
-
0.14004724044296152
,
-
0.10780823770381774
,
0.004010244871533663
,
0.010268176708511255
],
'sym8'
:
[
-
0.0033824159510061256
,
-
0.0005421323317911481
,
0.03169508781149298
,
0.007607487324917605
,
-
0.1432942383508097
,
-
0.061273359067658524
,
0.4813596512583722
,
0.7771857517005235
,
0.3644418948353314
,
-
0.05194583810770904
,
-
0.027219029917056003
,
0.049137179673607506
,
0.003808752013890615
,
-
0.01495225833704823
,
-
0.0003029205147213668
,
0.0018899503327594609
],
}
# ----------------------------------------------------------------------------
# Helpers for constructing transformation matrices.
def
matrix
(
*
rows
,
device
=
None
):
"""Constructing transformation matrices.
Args:
device (str|torch.device, optional): Matrix device. Defaults to None.
Returns:
ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor
format.
"""
assert
all
(
len
(
row
)
==
len
(
rows
[
0
])
for
row
in
rows
)
elems
=
[
x
for
row
in
rows
for
x
in
row
]
ref
=
[
x
for
x
in
elems
if
isinstance
(
x
,
torch
.
Tensor
)]
if
len
(
ref
)
==
0
:
return
misc
.
constant
(
np
.
asarray
(
rows
),
device
=
device
)
assert
device
is
None
or
device
==
ref
[
0
].
device
# change `x.float()` to support pt1.5
elems
=
[
x
.
float
()
if
isinstance
(
x
,
torch
.
Tensor
)
else
misc
.
constant
(
x
,
shape
=
ref
[
0
].
shape
,
device
=
ref
[
0
].
device
)
for
x
in
elems
]
return
torch
.
stack
(
elems
,
dim
=-
1
).
reshape
(
ref
[
0
].
shape
+
(
len
(
rows
),
-
1
))
def
translate2d
(
tx
,
ty
,
**
kwargs
):
"""Construct 2d translation matrix.
Args:
tx (float): X-direction translation amount.
ty (float): Y-direction translation amount.
Returns:
ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor
format.
"""
return
matrix
([
1
,
0
,
tx
],
[
0
,
1
,
ty
],
[
0
,
0
,
1
],
**
kwargs
)
def
translate3d
(
tx
,
ty
,
tz
,
**
kwargs
):
"""Construct 3d translation matrix.
Args:
tx (float): X-direction translation amount.
ty (float): Y-direction translation amount.
tz (float): Z-direction translation amount.
Returns:
ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor
format.
"""
return
matrix
([
1
,
0
,
0
,
tx
],
[
0
,
1
,
0
,
ty
],
[
0
,
0
,
1
,
tz
],
[
0
,
0
,
0
,
1
],
**
kwargs
)
def
scale2d
(
sx
,
sy
,
**
kwargs
):
"""Construct 2d scaling matrix.
Args:
sx (float): X-direction scaling coefficient.
sy (float): Y-direction scaling coefficient.
Returns:
ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor
format.
"""
return
matrix
([
sx
,
0
,
0
],
[
0
,
sy
,
0
],
[
0
,
0
,
1
],
**
kwargs
)
def
scale3d
(
sx
,
sy
,
sz
,
**
kwargs
):
"""Construct 3d scaling matrix.
Args:
sx (float): X-direction scaling coefficient.
sy (float): Y-direction scaling coefficient.
sz (float): Z-direction scaling coefficient.
Returns:
ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor
format.
"""
return
matrix
([
sx
,
0
,
0
,
0
],
[
0
,
sy
,
0
,
0
],
[
0
,
0
,
sz
,
0
],
[
0
,
0
,
0
,
1
],
**
kwargs
)
def
rotate2d
(
theta
,
**
kwargs
):
"""Construct 2d rotating matrix.
Args:
theta (float): Counter-clock wise rotation angle.
Returns:
ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor
format.
"""
return
matrix
([
torch
.
cos
(
theta
),
torch
.
sin
(
-
theta
),
0
],
[
torch
.
sin
(
theta
),
torch
.
cos
(
theta
),
0
],
[
0
,
0
,
1
],
**
kwargs
)
def
rotate3d
(
v
,
theta
,
**
kwargs
):
"""Constructing 3d rotating matrix.
Args:
v (torch.Tensor): Luma axis.
theta (float): Rotate theta counter-clock wise with ``v`` as the axis.
Returns:
ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor
format.
"""
vx
=
v
[...,
0
]
vy
=
v
[...,
1
]
vz
=
v
[...,
2
]
s
=
torch
.
sin
(
theta
)
c
=
torch
.
cos
(
theta
)
cc
=
1
-
c
return
matrix
(
[
vx
*
vx
*
cc
+
c
,
vx
*
vy
*
cc
-
vz
*
s
,
vx
*
vz
*
cc
+
vy
*
s
,
0
],
[
vy
*
vx
*
cc
+
vz
*
s
,
vy
*
vy
*
cc
+
c
,
vy
*
vz
*
cc
-
vx
*
s
,
0
],
[
vz
*
vx
*
cc
-
vy
*
s
,
vz
*
vy
*
cc
+
vx
*
s
,
vz
*
vz
*
cc
+
c
,
0
],
[
0
,
0
,
0
,
1
],
**
kwargs
)
def
translate2d_inv
(
tx
,
ty
,
**
kwargs
):
"""Construct inverse matrix of 2d translation matrix.
Args:
tx (float): X-direction translation amount.
ty (float): Y-direction translation amount.
Returns:
ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor
format.
"""
return
translate2d
(
-
tx
,
-
ty
,
**
kwargs
)
def
scale2d_inv
(
sx
,
sy
,
**
kwargs
):
"""Construct inverse matrix of 2d scaling matrix.
Args:
sx (float): X-direction scaling coefficient.
sy (float): Y-direction scaling coefficient.
Returns:
ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor
format.
"""
return
scale2d
(
1
/
sx
,
1
/
sy
,
**
kwargs
)
def
rotate2d_inv
(
theta
,
**
kwargs
):
"""Construct inverse matrix of 2d rotating matrix.
Args:
theta (float): Counter-clock wise rotation angle.
Returns:
ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor
format.
"""
return
rotate2d
(
-
theta
,
**
kwargs
)
# ----------------------------------------------------------------------------
# Versatile image augmentation pipeline from the paper
# "Training Generative Adversarial Networks with Limited Data".
#
# All augmentations are disabled by default; individual augmentations can
# be enabled by setting their probability multipliers to 1.
class
AugmentPipe
(
torch
.
nn
.
Module
):
"""Augmentation pipeline include multiple geometric and color
transformations.
Note: The meaning of arguments are written in the comments of
``__init__`` function.
"""
def
__init__
(
self
,
xflip
=
0
,
rotate90
=
0
,
xint
=
0
,
xint_max
=
0.125
,
scale
=
0
,
rotate
=
0
,
aniso
=
0
,
xfrac
=
0
,
scale_std
=
0.2
,
rotate_max
=
1
,
aniso_std
=
0.2
,
xfrac_std
=
0.125
,
brightness
=
0
,
contrast
=
0
,
lumaflip
=
0
,
hue
=
0
,
saturation
=
0
,
brightness_std
=
0.2
,
contrast_std
=
0.5
,
hue_max
=
1
,
saturation_std
=
1
,
imgfilter
=
0
,
imgfilter_bands
=
[
1
,
1
,
1
,
1
],
imgfilter_std
=
1
,
noise
=
0
,
cutout
=
0
,
noise_std
=
0.1
,
cutout_size
=
0.5
,
):
super
().
__init__
()
self
.
register_buffer
(
'p'
,
torch
.
ones
(
[]))
# Overall multiplier for augmentation probability.
# Pixel blitting.
self
.
xflip
=
float
(
xflip
)
# Probability multiplier for x-flip.
self
.
rotate90
=
float
(
rotate90
)
# Probability multiplier for 90 degree rotations.
self
.
xint
=
float
(
xint
)
# Probability multiplier for integer translation.
self
.
xint_max
=
float
(
xint_max
)
# Range of integer translation, relative to image dimensions.
# General geometric transformations.
self
.
scale
=
float
(
scale
)
# Probability multiplier for isotropic scaling.
self
.
rotate
=
float
(
rotate
)
# Probability multiplier for arbitrary rotation.
self
.
aniso
=
float
(
aniso
)
# Probability multiplier for anisotropic scaling.
self
.
xfrac
=
float
(
xfrac
)
# Probability multiplier for fractional translation.
self
.
scale_std
=
float
(
scale_std
)
# Log2 standard deviation of isotropic scaling.
self
.
rotate_max
=
float
(
rotate_max
)
# Range of arbitrary rotation, 1 = full circle.
self
.
aniso_std
=
float
(
aniso_std
)
# Log2 standard deviation of anisotropic scaling.
self
.
xfrac_std
=
float
(
xfrac_std
)
# Standard deviation of frational translation, relative to img dims.
# Color transformations.
self
.
brightness
=
float
(
brightness
)
# Probability multiplier for brightness.
self
.
contrast
=
float
(
contrast
)
# Probability multiplier for contrast.
self
.
lumaflip
=
float
(
lumaflip
)
# Probability multiplier for luma flip.
self
.
hue
=
float
(
hue
)
# Probability multiplier for hue rotation.
self
.
saturation
=
float
(
saturation
)
# Probability multiplier for saturation.
self
.
brightness_std
=
float
(
brightness_std
)
# Standard deviation of brightness.
self
.
contrast_std
=
float
(
contrast_std
)
# Log2 standard deviation of contrast.
self
.
hue_max
=
float
(
hue_max
)
# Range of hue rotation, 1 = full circle.
self
.
saturation_std
=
float
(
saturation_std
)
# Log2 standard deviation of saturation.
# Image-space filtering.
self
.
imgfilter
=
float
(
imgfilter
)
# Probability multiplier for image-space filtering.
self
.
imgfilter_bands
=
list
(
imgfilter_bands
)
# Probability multipliers for individual frequency bands.
self
.
imgfilter_std
=
float
(
imgfilter_std
)
# Log2 standard deviation of image-space filter amplification.
# Image-space corruptions.
self
.
noise
=
float
(
noise
)
# Probability multiplier for additive RGB noise.
self
.
cutout
=
float
(
cutout
)
# Probability multiplier for cutout.
self
.
noise_std
=
float
(
noise_std
)
# Standard deviation of additive RGB noise.
self
.
cutout_size
=
float
(
cutout_size
)
# Size of the cutout rectangle, relative to image dimensions.
# Setup orthogonal lowpass filter for geometric augmentations.
self
.
register_buffer
(
'Hz_geom'
,
upfirdn2d
.
setup_filter
(
wavelets
[
'sym6'
]))
# Construct filter bank for image-space filtering.
Hz_lo
=
np
.
asarray
(
wavelets
[
'sym2'
])
# H(z)
Hz_hi
=
Hz_lo
*
((
-
1
)
**
np
.
arange
(
Hz_lo
.
size
))
# H(-z)
Hz_lo2
=
np
.
convolve
(
Hz_lo
,
Hz_lo
[::
-
1
])
/
2
# H(z) * H(z^-1) / 2
Hz_hi2
=
np
.
convolve
(
Hz_hi
,
Hz_hi
[::
-
1
])
/
2
# H(-z) * H(-z^-1) / 2
Hz_fbank
=
np
.
eye
(
4
,
1
)
# Bandpass(H(z), b_i)
for
i
in
range
(
1
,
Hz_fbank
.
shape
[
0
]):
Hz_fbank
=
np
.
dstack
([
Hz_fbank
,
np
.
zeros_like
(
Hz_fbank
)
]).
reshape
(
Hz_fbank
.
shape
[
0
],
-
1
)[:,
:
-
1
]
Hz_fbank
=
scipy
.
signal
.
convolve
(
Hz_fbank
,
[
Hz_lo2
])
Hz_fbank
[
i
,
(
Hz_fbank
.
shape
[
1
]
-
Hz_hi2
.
size
)
//
2
:(
Hz_fbank
.
shape
[
1
]
+
Hz_hi2
.
size
)
//
2
]
+=
Hz_hi2
self
.
register_buffer
(
'Hz_fbank'
,
torch
.
as_tensor
(
Hz_fbank
,
dtype
=
torch
.
float32
))
def
forward
(
self
,
images
,
debug_percentile
=
None
):
assert
isinstance
(
images
,
torch
.
Tensor
)
and
images
.
ndim
==
4
batch_size
,
num_channels
,
height
,
width
=
images
.
shape
device
=
images
.
device
if
debug_percentile
is
not
None
:
debug_percentile
=
torch
.
as_tensor
(
debug_percentile
,
dtype
=
torch
.
float32
,
device
=
device
)
# -------------------------------------
# Select parameters for pixel blitting.
# -------------------------------------
# Initialize inverse homogeneous 2D transform:
# G_inv @ pixel_out ==> pixel_in
I_3
=
torch
.
eye
(
3
,
device
=
device
)
G_inv
=
I_3
# Apply x-flip with probability (xflip * strength).
if
self
.
xflip
>
0
:
i
=
torch
.
floor
(
torch
.
rand
([
batch_size
],
device
=
device
)
*
2
)
i
=
torch
.
where
(
torch
.
rand
([
batch_size
],
device
=
device
)
<
self
.
xflip
*
self
.
p
,
i
,
torch
.
zeros_like
(
i
))
if
debug_percentile
is
not
None
:
i
=
torch
.
full_like
(
i
,
torch
.
floor
(
debug_percentile
*
2
))
G_inv
=
G_inv
@
scale2d_inv
(
1
-
2
*
i
,
1
)
# Apply 90 degree rotations with probability (rotate90 * strength).
if
self
.
rotate90
>
0
:
i
=
torch
.
floor
(
torch
.
rand
([
batch_size
],
device
=
device
)
*
4
)
i
=
torch
.
where
(
torch
.
rand
([
batch_size
],
device
=
device
)
<
self
.
rotate90
*
self
.
p
,
i
,
torch
.
zeros_like
(
i
))
if
debug_percentile
is
not
None
:
i
=
torch
.
full_like
(
i
,
torch
.
floor
(
debug_percentile
*
4
))
G_inv
=
G_inv
@
rotate2d_inv
(
-
np
.
pi
/
2
*
i
)
# Apply integer translation with probability (xint * strength).
if
self
.
xint
>
0
:
t
=
(
torch
.
rand
([
batch_size
,
2
],
device
=
device
)
*
2
-
1
)
*
self
.
xint_max
t
=
torch
.
where
(
torch
.
rand
([
batch_size
,
1
],
device
=
device
)
<
self
.
xint
*
self
.
p
,
t
,
torch
.
zeros_like
(
t
))
if
debug_percentile
is
not
None
:
t
=
torch
.
full_like
(
t
,
(
debug_percentile
*
2
-
1
)
*
self
.
xint_max
)
G_inv
=
G_inv
@
translate2d_inv
(
torch
.
round
(
t
[:,
0
]
*
width
),
torch
.
round
(
t
[:,
1
]
*
height
))
# --------------------------------------------------------
# Select parameters for general geometric transformations.
# --------------------------------------------------------
# support for pt1.5 (pt1.5 does not contain exp2)
_scalor_log2
=
torch
.
log
(
torch
.
tensor
(
2.
,
device
=
images
.
device
,
dtype
=
images
.
dtype
))
# Apply isotropic scaling with probability (scale * strength).
if
self
.
scale
>
0
:
s
=
torch
.
exp
(
torch
.
randn
([
batch_size
],
device
=
device
)
*
self
.
scale_std
*
_scalor_log2
)
s
=
torch
.
where
(
torch
.
rand
([
batch_size
],
device
=
device
)
<
self
.
scale
*
self
.
p
,
s
,
torch
.
ones_like
(
s
))
if
debug_percentile
is
not
None
:
s
=
torch
.
full_like
(
s
,
torch
.
exp2
(
torch
.
erfinv
(
debug_percentile
*
2
-
1
)
*
self
.
scale_std
))
G_inv
=
G_inv
@
scale2d_inv
(
s
,
s
)
# Apply pre-rotation with probability p_rot.
p_rot
=
1
-
torch
.
sqrt
(
(
1
-
self
.
rotate
*
self
.
p
).
clamp
(
0
,
1
))
# P(pre OR post) = p
if
self
.
rotate
>
0
:
theta
=
(
torch
.
rand
([
batch_size
],
device
=
device
)
*
2
-
1
)
*
np
.
pi
*
self
.
rotate_max
theta
=
torch
.
where
(
torch
.
rand
([
batch_size
],
device
=
device
)
<
p_rot
,
theta
,
torch
.
zeros_like
(
theta
))
if
debug_percentile
is
not
None
:
theta
=
torch
.
full_like
(
theta
,
(
debug_percentile
*
2
-
1
)
*
np
.
pi
*
self
.
rotate_max
)
G_inv
=
G_inv
@
rotate2d_inv
(
-
theta
)
# Before anisotropic scaling.
# Apply anisotropic scaling with probability (aniso * strength).
if
self
.
aniso
>
0
:
s
=
torch
.
exp
(
torch
.
randn
([
batch_size
],
device
=
device
)
*
self
.
aniso_std
*
_scalor_log2
)
s
=
torch
.
where
(
torch
.
rand
([
batch_size
],
device
=
device
)
<
self
.
aniso
*
self
.
p
,
s
,
torch
.
ones_like
(
s
))
if
debug_percentile
is
not
None
:
s
=
torch
.
full_like
(
s
,
torch
.
exp2
(
torch
.
erfinv
(
debug_percentile
*
2
-
1
)
*
self
.
aniso_std
))
G_inv
=
G_inv
@
scale2d_inv
(
s
,
1
/
s
)
# Apply post-rotation with probability p_rot.
if
self
.
rotate
>
0
:
theta
=
(
torch
.
rand
([
batch_size
],
device
=
device
)
*
2
-
1
)
*
np
.
pi
*
self
.
rotate_max
theta
=
torch
.
where
(
torch
.
rand
([
batch_size
],
device
=
device
)
<
p_rot
,
theta
,
torch
.
zeros_like
(
theta
))
if
debug_percentile
is
not
None
:
theta
=
torch
.
zeros_like
(
theta
)
G_inv
=
G_inv
@
rotate2d_inv
(
-
theta
)
# After anisotropic scaling.
# Apply fractional translation with probability (xfrac * strength).
if
self
.
xfrac
>
0
:
t
=
torch
.
randn
([
batch_size
,
2
],
device
=
device
)
*
self
.
xfrac_std
t
=
torch
.
where
(
torch
.
rand
([
batch_size
,
1
],
device
=
device
)
<
self
.
xfrac
*
self
.
p
,
t
,
torch
.
zeros_like
(
t
))
if
debug_percentile
is
not
None
:
t
=
torch
.
full_like
(
t
,
torch
.
erfinv
(
debug_percentile
*
2
-
1
)
*
self
.
xfrac_std
)
G_inv
=
G_inv
@
translate2d_inv
(
t
[:,
0
]
*
width
,
t
[:,
1
]
*
height
)
# ----------------------------------
# Execute geometric transformations.
# ----------------------------------
# Execute if the transform is not identity.
if
G_inv
is
not
I_3
:
# Calculate padding.
cx
=
(
width
-
1
)
/
2
cy
=
(
height
-
1
)
/
2
cp
=
matrix
([
-
cx
,
-
cy
,
1
],
[
cx
,
-
cy
,
1
],
[
cx
,
cy
,
1
],
[
-
cx
,
cy
,
1
],
device
=
device
)
# [idx, xyz]
cp
=
G_inv
@
cp
.
t
()
# [batch, xyz, idx]
Hz_pad
=
self
.
Hz_geom
.
shape
[
0
]
//
4
margin
=
cp
[:,
:
2
,
:].
permute
(
1
,
0
,
2
).
flatten
(
1
)
# [xy, batch * idx]
margin
=
torch
.
cat
([
-
margin
,
margin
]).
max
(
dim
=
1
).
values
# [x0, y0, x1, y1]
margin
=
margin
+
misc
.
constant
(
[
Hz_pad
*
2
-
cx
,
Hz_pad
*
2
-
cy
]
*
2
,
device
=
device
)
margin
=
margin
.
max
(
misc
.
constant
([
0
,
0
]
*
2
,
device
=
device
))
margin
=
margin
.
min
(
misc
.
constant
([
width
-
1
,
height
-
1
]
*
2
,
device
=
device
))
mx0
,
my0
,
mx1
,
my1
=
margin
.
ceil
().
to
(
torch
.
int32
)
# Pad image and adjust origin.
images
=
torch
.
nn
.
functional
.
pad
(
input
=
images
,
pad
=
[
mx0
,
mx1
,
my0
,
my1
],
mode
=
'reflect'
)
G_inv
=
translate2d
(
torch
.
true_divide
(
mx0
-
mx1
,
2
),
torch
.
true_divide
(
my0
-
my1
,
2
))
@
G_inv
# Upsample.
images
=
upfirdn2d
.
upsample2d
(
x
=
images
,
f
=
self
.
Hz_geom
,
up
=
2
)
G_inv
=
scale2d
(
2
,
2
,
device
=
device
)
@
G_inv
@
scale2d_inv
(
2
,
2
,
device
=
device
)
G_inv
=
translate2d
(
-
0.5
,
-
0.5
,
device
=
device
)
@
G_inv
@
translate2d_inv
(
-
0.5
,
-
0.5
,
device
=
device
)
# Execute transformation.
shape
=
[
batch_size
,
num_channels
,
(
height
+
Hz_pad
*
2
)
*
2
,
(
width
+
Hz_pad
*
2
)
*
2
]
G_inv
=
scale2d
(
2
/
images
.
shape
[
3
],
2
/
images
.
shape
[
2
],
device
=
device
)
@
G_inv
@
scale2d_inv
(
2
/
shape
[
3
],
2
/
shape
[
2
],
device
=
device
)
grid
=
torch
.
nn
.
functional
.
affine_grid
(
theta
=
G_inv
[:,
:
2
,
:],
size
=
shape
,
align_corners
=
False
)
images
=
grid_sample_gradfix
.
grid_sample
(
images
,
grid
)
# Downsample and crop.
images
=
upfirdn2d
.
downsample2d
(
x
=
images
,
f
=
self
.
Hz_geom
,
down
=
2
,
padding
=-
Hz_pad
*
2
,
flip_filter
=
True
)
# --------------------------------------------
# Select parameters for color transformations.
# --------------------------------------------
# Initialize homogeneous 3D transformation matrix:
# C @ color_in ==> color_out
I_4
=
torch
.
eye
(
4
,
device
=
device
)
C
=
I_4
# Apply brightness with probability (brightness * strength).
if
self
.
brightness
>
0
:
b
=
torch
.
randn
([
batch_size
],
device
=
device
)
*
self
.
brightness_std
b
=
torch
.
where
(
torch
.
rand
([
batch_size
],
device
=
device
)
<
self
.
brightness
*
self
.
p
,
b
,
torch
.
zeros_like
(
b
))
if
debug_percentile
is
not
None
:
b
=
torch
.
full_like
(
b
,
torch
.
erfinv
(
debug_percentile
*
2
-
1
)
*
self
.
brightness_std
)
C
=
translate3d
(
b
,
b
,
b
)
@
C
# Apply contrast with probability (contrast * strength).
if
self
.
contrast
>
0
:
c
=
torch
.
exp
(
torch
.
randn
([
batch_size
],
device
=
device
)
*
self
.
contrast_std
*
_scalor_log2
)
c
=
torch
.
where
(
torch
.
rand
([
batch_size
],
device
=
device
)
<
self
.
contrast
*
self
.
p
,
c
,
torch
.
ones_like
(
c
))
if
debug_percentile
is
not
None
:
c
=
torch
.
full_like
(
c
,
torch
.
exp2
(
torch
.
erfinv
(
debug_percentile
*
2
-
1
)
*
self
.
contrast_std
))
C
=
scale3d
(
c
,
c
,
c
)
@
C
# Apply luma flip with probability (lumaflip * strength).
v
=
misc
.
constant
(
np
.
asarray
([
1
,
1
,
1
,
0
])
/
np
.
sqrt
(
3
),
device
=
device
)
# Luma axis.
if
self
.
lumaflip
>
0
:
i
=
torch
.
floor
(
torch
.
rand
([
batch_size
,
1
,
1
],
device
=
device
)
*
2
)
i
=
torch
.
where
(
torch
.
rand
([
batch_size
,
1
,
1
],
device
=
device
)
<
self
.
lumaflip
*
self
.
p
,
i
,
torch
.
zeros_like
(
i
))
if
debug_percentile
is
not
None
:
i
=
torch
.
full_like
(
i
,
torch
.
floor
(
debug_percentile
*
2
))
C
=
(
I_4
-
2
*
v
.
ger
(
v
)
*
i
)
@
C
# Householder reflection.
# Apply hue rotation with probability (hue * strength).
if
self
.
hue
>
0
and
num_channels
>
1
:
theta
=
(
torch
.
rand
([
batch_size
],
device
=
device
)
*
2
-
1
)
*
np
.
pi
*
self
.
hue_max
theta
=
torch
.
where
(
torch
.
rand
([
batch_size
],
device
=
device
)
<
self
.
hue
*
self
.
p
,
theta
,
torch
.
zeros_like
(
theta
))
if
debug_percentile
is
not
None
:
theta
=
torch
.
full_like
(
theta
,
(
debug_percentile
*
2
-
1
)
*
np
.
pi
*
self
.
hue_max
)
C
=
rotate3d
(
v
,
theta
)
@
C
# Rotate around v.
# Apply saturation with probability (saturation * strength).
if
self
.
saturation
>
0
and
num_channels
>
1
:
s
=
torch
.
exp
(
torch
.
randn
([
batch_size
,
1
,
1
],
device
=
device
)
*
self
.
saturation_std
*
_scalor_log2
)
s
=
torch
.
where
(
torch
.
rand
([
batch_size
,
1
,
1
],
device
=
device
)
<
self
.
saturation
*
self
.
p
,
s
,
torch
.
ones_like
(
s
))
if
debug_percentile
is
not
None
:
s
=
torch
.
full_like
(
s
,
torch
.
exp2
(
torch
.
erfinv
(
debug_percentile
*
2
-
1
)
*
self
.
saturation_std
))
C
=
(
v
.
ger
(
v
)
+
(
I_4
-
v
.
ger
(
v
))
*
s
)
@
C
# ------------------------------
# Execute color transformations.
# ------------------------------
# Execute if the transform is not identity.
if
C
is
not
I_4
:
images
=
images
.
reshape
([
batch_size
,
num_channels
,
height
*
width
])
if
num_channels
==
3
:
images
=
C
[:,
:
3
,
:
3
]
@
images
+
C
[:,
:
3
,
3
:]
elif
num_channels
==
1
:
C
=
C
[:,
:
3
,
:].
mean
(
dim
=
1
,
keepdims
=
True
)
images
=
images
*
C
[:,
:,
:
3
].
sum
(
dim
=
2
,
keepdims
=
True
)
+
C
[:,
:,
3
:]
else
:
raise
ValueError
(
'Image must be RGB (3 channels) or L (1 channel)'
)
images
=
images
.
reshape
([
batch_size
,
num_channels
,
height
,
width
])
# ----------------------
# Image-space filtering.
# ----------------------
if
self
.
imgfilter
>
0
:
num_bands
=
self
.
Hz_fbank
.
shape
[
0
]
assert
len
(
self
.
imgfilter_bands
)
==
num_bands
expected_power
=
misc
.
constant
(
np
.
array
([
10
,
1
,
1
,
1
])
/
13
,
device
=
device
)
# Expected power spectrum (1/f).
# Apply amplification for each band with probability
# (imgfilter * strength * band_strength).
g
=
torch
.
ones
([
batch_size
,
num_bands
],
device
=
device
)
# Global gain vector (identity).
for
i
,
band_strength
in
enumerate
(
self
.
imgfilter_bands
):
t_i
=
torch
.
exp
(
torch
.
randn
([
batch_size
],
device
=
device
)
*
self
.
imgfilter_std
*
_scalor_log2
)
t_i
=
torch
.
where
(
torch
.
rand
([
batch_size
],
device
=
device
)
<
self
.
imgfilter
*
self
.
p
*
band_strength
,
t_i
,
torch
.
ones_like
(
t_i
))
if
debug_percentile
is
not
None
:
t_i
=
torch
.
full_like
(
t_i
,
torch
.
exp2
(
torch
.
erfinv
(
debug_percentile
*
2
-
1
)
*
self
.
imgfilter_std
)
)
if
band_strength
>
0
else
torch
.
ones_like
(
t_i
)
t
=
torch
.
ones
([
batch_size
,
num_bands
],
device
=
device
)
# Temporary gain vector.
t
[:,
i
]
=
t_i
# Replace i'th element.
t
=
t
/
(
expected_power
*
t
.
square
()).
sum
(
dim
=-
1
,
keepdims
=
True
).
sqrt
()
# Normalize power.
g
=
g
*
t
# Accumulate into global gain.
# Construct combined amplification filter.
Hz_prime
=
g
@
self
.
Hz_fbank
# [batch, tap]
Hz_prime
=
Hz_prime
.
unsqueeze
(
1
).
repeat
(
[
1
,
num_channels
,
1
])
# [batch, channels, tap]
Hz_prime
=
Hz_prime
.
reshape
([
batch_size
*
num_channels
,
1
,
-
1
])
# [batch * channels, 1, tap]
# Apply filter.
p
=
self
.
Hz_fbank
.
shape
[
1
]
//
2
images
=
images
.
reshape
(
[
1
,
batch_size
*
num_channels
,
height
,
width
])
images
=
torch
.
nn
.
functional
.
pad
(
input
=
images
,
pad
=
[
p
,
p
,
p
,
p
],
mode
=
'reflect'
)
images
=
conv2d_gradfix
.
conv2d
(
input
=
images
,
weight
=
Hz_prime
.
unsqueeze
(
2
),
groups
=
batch_size
*
num_channels
)
images
=
conv2d_gradfix
.
conv2d
(
input
=
images
,
weight
=
Hz_prime
.
unsqueeze
(
3
),
groups
=
batch_size
*
num_channels
)
images
=
images
.
reshape
([
batch_size
,
num_channels
,
height
,
width
])
# ------------------------
# Image-space corruptions.
# ------------------------
# Apply additive RGB noise with probability (noise * strength).
if
self
.
noise
>
0
:
sigma
=
torch
.
randn
([
batch_size
,
1
,
1
,
1
],
device
=
device
).
abs
()
*
self
.
noise_std
sigma
=
torch
.
where
(
torch
.
rand
([
batch_size
,
1
,
1
,
1
],
device
=
device
)
<
self
.
noise
*
self
.
p
,
sigma
,
torch
.
zeros_like
(
sigma
))
if
debug_percentile
is
not
None
:
sigma
=
torch
.
full_like
(
sigma
,
torch
.
erfinv
(
debug_percentile
)
*
self
.
noise_std
)
images
=
images
+
torch
.
randn
(
[
batch_size
,
num_channels
,
height
,
width
],
device
=
device
)
*
sigma
# Apply cutout with probability (cutout * strength).
if
self
.
cutout
>
0
:
size
=
torch
.
full
([
batch_size
,
2
,
1
,
1
,
1
],
self
.
cutout_size
,
device
=
device
)
size
=
torch
.
where
(
torch
.
rand
([
batch_size
,
1
,
1
,
1
,
1
],
device
=
device
)
<
self
.
cutout
*
self
.
p
,
size
,
torch
.
zeros_like
(
size
))
center
=
torch
.
rand
([
batch_size
,
2
,
1
,
1
,
1
],
device
=
device
)
if
debug_percentile
is
not
None
:
size
=
torch
.
full_like
(
size
,
self
.
cutout_size
)
center
=
torch
.
full_like
(
center
,
debug_percentile
)
coord_x
=
torch
.
arange
(
width
,
device
=
device
).
reshape
([
1
,
1
,
1
,
-
1
])
coord_y
=
torch
.
arange
(
height
,
device
=
device
).
reshape
([
1
,
1
,
-
1
,
1
])
mask_x
=
(((
coord_x
+
0.5
)
/
width
-
center
[:,
0
]).
abs
()
>=
size
[:,
0
]
/
2
)
mask_y
=
(((
coord_y
+
0.5
)
/
height
-
center
[:,
1
]).
abs
()
>=
size
[:,
1
]
/
2
)
mask
=
torch
.
logical_or
(
mask_x
,
mask_y
).
to
(
torch
.
float32
)
images
=
images
*
mask
return
images
mmgen/models/architectures/stylegan/ada/grid_sample_gradfix.py
0 → 100644
View file @
b7536f78
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Custom replacement for `torch.nn.functional.grid_sample` that supports
arbitrarily high order gradients between the input and output.
Only works on 2D images and assumes `mode='bilinear'`, `padding_mode='zeros'`,
`align_corners=False`.
"""
import
warnings
import
torch
# pylint: disable=redefined-builtin
# pylint: disable=arguments-differ
# pylint: disable=protected-access
# ----------------------------------------------------------------------------
enabled
=
True
# Enable the custom op by setting this to true.
# ----------------------------------------------------------------------------
def
grid_sample
(
input
,
grid
):
if
_should_use_custom_op
():
return
_GridSample2dForward
.
apply
(
input
,
grid
)
return
torch
.
nn
.
functional
.
grid_sample
(
input
=
input
,
grid
=
grid
,
mode
=
'bilinear'
,
padding_mode
=
'zeros'
,
align_corners
=
False
)
# ----------------------------------------------------------------------------
def
_should_use_custom_op
():
if
not
enabled
:
return
False
if
any
(
torch
.
__version__
.
startswith
(
x
)
for
x
in
[
'1.5.'
,
'1.6.'
,
'1.7.'
,
'1.8.'
,
'1.9.'
,
'1.10.'
]):
return
True
warnings
.
warn
(
f
'grid_sample_gradfix not supported on PyTorch
{
torch
.
__version__
}
.'
' Falling back to torch.nn.functional.grid_sample().'
)
return
False
# ----------------------------------------------------------------------------
class
_GridSample2dForward
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
grid
):
assert
input
.
ndim
==
4
assert
grid
.
ndim
==
4
output
=
torch
.
nn
.
functional
.
grid_sample
(
input
=
input
,
grid
=
grid
,
mode
=
'bilinear'
,
padding_mode
=
'zeros'
,
align_corners
=
False
)
ctx
.
save_for_backward
(
input
,
grid
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
grid
=
ctx
.
saved_tensors
grad_input
,
grad_grid
=
_GridSample2dBackward
.
apply
(
grad_output
,
input
,
grid
)
return
grad_input
,
grad_grid
# ----------------------------------------------------------------------------
class
_GridSample2dBackward
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
grad_output
,
input
,
grid
):
op
=
torch
.
_C
.
_jit_get_operation
(
'aten::grid_sampler_2d_backward'
)
grad_input
,
grad_grid
=
op
(
grad_output
,
input
,
grid
,
0
,
0
,
False
)
ctx
.
save_for_backward
(
grid
)
return
grad_input
,
grad_grid
@
staticmethod
def
backward
(
ctx
,
grad2_grad_input
,
grad2_grad_grid
):
_
=
grad2_grad_grid
# unused
grid
,
=
ctx
.
saved_tensors
grad2_grad_output
=
None
grad2_input
=
None
grad2_grid
=
None
if
ctx
.
needs_input_grad
[
0
]:
grad2_grad_output
=
_GridSample2dForward
.
apply
(
grad2_grad_input
,
grid
)
assert
not
ctx
.
needs_input_grad
[
2
]
return
grad2_grad_output
,
grad2_input
,
grad2_grid
mmgen/models/architectures/stylegan/ada/misc.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
torch
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
# same constant is used multiple times.
_constant_cache
=
dict
()
def
constant
(
value
,
shape
=
None
,
dtype
=
None
,
device
=
None
,
memory_format
=
None
):
value
=
np
.
asarray
(
value
)
if
shape
is
not
None
:
shape
=
tuple
(
shape
)
if
dtype
is
None
:
dtype
=
torch
.
get_default_dtype
()
if
device
is
None
:
device
=
torch
.
device
(
'cpu'
)
if
memory_format
is
None
:
memory_format
=
torch
.
contiguous_format
key
=
(
value
.
shape
,
value
.
dtype
,
value
.
tobytes
(),
shape
,
dtype
,
device
,
memory_format
)
tensor
=
_constant_cache
.
get
(
key
,
None
)
if
tensor
is
None
:
tensor
=
torch
.
as_tensor
(
value
.
copy
(),
dtype
=
dtype
,
device
=
device
)
if
shape
is
not
None
:
tensor
,
_
=
torch
.
broadcast_tensors
(
tensor
,
torch
.
empty
(
shape
))
tensor
=
tensor
.
contiguous
(
memory_format
=
memory_format
)
_constant_cache
[
key
]
=
tensor
return
tensor
mmgen/models/architectures/stylegan/ada/upfirdn2d.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
torch
from
mmcv.ops.upfirdn2d
import
upfirdn2d
def
_parse_scaling
(
scaling
):
if
isinstance
(
scaling
,
int
):
scaling
=
[
scaling
,
scaling
]
assert
isinstance
(
scaling
,
(
list
,
tuple
))
assert
all
(
isinstance
(
x
,
int
)
for
x
in
scaling
)
sx
,
sy
=
scaling
assert
sx
>=
1
and
sy
>=
1
return
sx
,
sy
def
_parse_padding
(
padding
):
if
isinstance
(
padding
,
int
):
padding
=
[
padding
,
padding
]
assert
isinstance
(
padding
,
(
list
,
tuple
))
assert
all
(
isinstance
(
x
,
int
)
for
x
in
padding
)
if
len
(
padding
)
==
2
:
padx
,
pady
=
padding
padding
=
[
padx
,
padx
,
pady
,
pady
]
padx0
,
padx1
,
pady0
,
pady1
=
padding
return
padx0
,
padx1
,
pady0
,
pady1
def
_get_filter_size
(
f
):
if
f
is
None
:
return
1
,
1
assert
isinstance
(
f
,
torch
.
Tensor
)
and
f
.
ndim
in
[
1
,
2
]
fw
=
f
.
shape
[
-
1
]
fh
=
f
.
shape
[
0
]
fw
=
int
(
fw
)
fh
=
int
(
fh
)
assert
fw
>=
1
and
fh
>=
1
return
fw
,
fh
def
upsample2d
(
x
,
f
,
up
=
2
,
padding
=
0
,
flip_filter
=
False
,
gain
=
1
,
impl
=
'cuda'
):
r
"""Upsample a batch of 2D images using the given 2D FIR filter.
By default, the result is padded so that its shape is a multiple of the
input.
User-specified padding is applied on top of that, with negative values
indicating cropping. Pixels outside the image are assumed to be zero.
Args:
x: Float32/float64/float16 input tensor of the shape
`[batch_size, num_channels, in_height, in_width]`.
f: Float32 FIR filter of the shape
`[filter_height, filter_width]` (non-separable),
`[filter_taps]` (separable), or
`None` (identity).
up: Integer upsampling factor. Can be a single int or a
list/tuple
`[x, y]` (default: 1).
padding: Padding with respect to the output. Can be a single
number or a
list/tuple `[x, y]` or `[x_before, x_after, y_before,
y_after]`
(default: 0).
flip_filter: False = convolution, True = correlation (default: False).
gain: Overall scaling factor for signal magnitude (default: 1).
impl: Implementation to use. Can be `'ref'` or `'cuda'`
(default: `'cuda'`).
Returns:
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`
"""
upx
,
upy
=
_parse_scaling
(
up
)
padx0
,
padx1
,
pady0
,
pady1
=
_parse_padding
(
padding
)
fw
,
fh
=
_get_filter_size
(
f
)
p
=
[
padx0
+
(
fw
+
upx
-
1
)
//
2
,
padx1
+
(
fw
-
upx
)
//
2
,
pady0
+
(
fh
+
upy
-
1
)
//
2
,
pady1
+
(
fh
-
upy
)
//
2
,
]
gain
=
gain
*
upx
*
upy
f
=
f
*
(
gain
**
(
f
.
ndim
/
2
))
if
flip_filter
:
f
=
f
.
flip
(
list
(
range
(
f
.
ndim
)))
if
f
.
ndim
==
1
:
x
=
upfirdn2d
(
x
,
f
.
unsqueeze
(
0
),
up
=
(
upx
,
1
),
pad
=
(
p
[
0
],
p
[
1
],
0
,
0
))
x
=
upfirdn2d
(
x
,
f
.
unsqueeze
(
1
),
up
=
(
1
,
upy
),
pad
=
(
0
,
0
,
p
[
2
],
p
[
3
]))
return
x
def
setup_filter
(
f
,
device
=
torch
.
device
(
'cpu'
),
normalize
=
True
,
flip_filter
=
False
,
gain
=
1
,
separable
=
None
):
r
"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
Args:
f: Torch tensor, numpy array, or python list of the shape
`[filter_height, filter_width]` (non-separable),
`[filter_taps]` (separable),
`[]` (impulse), or
`None` (identity).
device: Result device (default: cpu).
normalize: Normalize the filter so that it retains the magnitude
for constant input signal (DC)? (default: True).
flip_filter: Flip the filter? (default: False).
gain: Overall scaling factor for signal magnitude (default: 1).
separable: Return a separable filter? (default: select automatically)
Returns:
Float32 tensor of the shape
`[filter_height, filter_width]` (non-separable) or
`[filter_taps]` (separable).
"""
# Validate.
if
f
is
None
:
f
=
1
f
=
torch
.
as_tensor
(
f
,
dtype
=
torch
.
float32
)
assert
f
.
ndim
in
[
0
,
1
,
2
]
assert
f
.
numel
()
>
0
if
f
.
ndim
==
0
:
f
=
f
[
np
.
newaxis
]
# Separable?
if
separable
is
None
:
separable
=
(
f
.
ndim
==
1
and
f
.
numel
()
>=
8
)
if
f
.
ndim
==
1
and
not
separable
:
f
=
f
.
ger
(
f
)
assert
f
.
ndim
==
(
1
if
separable
else
2
)
# Apply normalize, flip, gain, and device.
if
normalize
:
f
/=
f
.
sum
()
if
flip_filter
:
f
=
f
.
flip
(
list
(
range
(
f
.
ndim
)))
f
=
f
*
(
gain
**
(
f
.
ndim
/
2
))
f
=
f
.
to
(
device
=
device
)
return
f
def
downsample2d
(
x
,
f
,
down
=
2
,
padding
=
0
,
flip_filter
=
False
,
gain
=
1
,
impl
=
'cuda'
):
r
"""Downsample a batch of 2D images using the given 2D FIR filter.
By default, the result is padded so that its shape is a fraction of the
input.
User-specified padding is applied on top of that, with negative values
indicating cropping. Pixels outside the image are assumed to be zero.
Args:
x: Float32/float64/float16 input tensor of the shape
`[batch_size, num_channels, in_height, in_width]`.
f: Float32 FIR filter of the shape
`[filter_height, filter_width]` (non-separable),
`[filter_taps]` (separable), or
`None` (identity).
down: Integer downsampling factor. Can be a single int or a
list/tuple
`[x, y]` (default: 1).
padding: Padding with respect to the input. Can be a single number
or a
list/tuple `[x, y]` or `[x_before, x_after, y_before,
y_after]`
(default: 0).
flip_filter: False = convolution, True = correlation (default: False).
gain: Overall scaling factor for signal magnitude (default: 1).
impl: Implementation to use. Can be `'ref'` or `'cuda'`
(default: `'cuda'`).
Returns:
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`
"""
downx
,
downy
=
_parse_scaling
(
down
)
padx0
,
padx1
,
pady0
,
pady1
=
_parse_padding
(
padding
)
fw
,
fh
=
_get_filter_size
(
f
)
p
=
[
padx0
+
(
fw
-
downx
+
1
)
//
2
,
padx1
+
(
fw
-
downx
)
//
2
,
pady0
+
(
fh
-
downy
+
1
)
//
2
,
pady1
+
(
fh
-
downy
)
//
2
,
]
if
flip_filter
:
f
=
f
.
flip
(
list
(
range
(
f
.
ndim
)))
if
f
.
ndim
==
1
:
x
=
upfirdn2d
(
x
,
f
.
unsqueeze
(
0
),
down
=
(
downx
,
1
),
pad
=
(
p
[
0
],
p
[
1
],
0
,
0
))
x
=
upfirdn2d
(
x
,
f
.
unsqueeze
(
1
),
down
=
(
1
,
downy
),
pad
=
(
0
,
0
,
p
[
2
],
p
[
3
]))
return
x
mmgen/models/architectures/stylegan/generator_discriminator_v1.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
import
random
import
mmcv
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmgen.models.architectures
import
PixelNorm
from
mmgen.models.architectures.common
import
get_module_device
from
mmgen.models.architectures.pggan
import
(
EqualizedLRConvDownModule
,
EqualizedLRConvModule
)
from
mmgen.models.architectures.stylegan.modules
import
Blur
from
mmgen.models.builder
import
MODULES
from
..
import
MiniBatchStddevLayer
from
.modules.styleganv1_modules
import
StyleConv
from
.modules.styleganv2_modules
import
EqualLinearActModule
from
.utils
import
get_mean_latent
,
style_mixing
@
MODULES
.
register_module
()
class
StyleGANv1Generator
(
nn
.
Module
):
"""StyleGAN1 Generator.
In StyleGAN1, we use a progressive growing architecture composing of a
style mapping module and number of convolutional style blocks. More details
can be found in: A Style-Based Generator Architecture for Generative
Adversarial Networks CVPR2019.
Args:
out_size (int): The output size of the StyleGAN1 generator.
style_channels (int): The number of channels for style code.
num_mlps (int, optional): The number of MLP layers. Defaults to 8.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 2, 1].
lr_mlp (float, optional): The learning rate for the style mapping
layer. Defaults to 0.01.
default_style_mode (str, optional): The default mode of style mixing.
In training, we defaultly adopt mixing style mode. However, in the
evaluation, we use 'single' style mode. `['mix', 'single']` are
currently supported. Defaults to 'mix'.
eval_style_mode (str, optional): The evaluation mode of style mixing.
Defaults to 'single'.
mix_prob (float, optional): Mixing probability. The value should be
in range of [0, 1]. Defaults to 0.9.
"""
def
__init__
(
self
,
out_size
,
style_channels
,
num_mlps
=
8
,
blur_kernel
=
[
1
,
2
,
1
],
lr_mlp
=
0.01
,
default_style_mode
=
'mix'
,
eval_style_mode
=
'single'
,
mix_prob
=
0.9
):
super
().
__init__
()
self
.
out_size
=
out_size
self
.
style_channels
=
style_channels
self
.
num_mlps
=
num_mlps
self
.
lr_mlp
=
lr_mlp
self
.
_default_style_mode
=
default_style_mode
self
.
default_style_mode
=
default_style_mode
self
.
eval_style_mode
=
eval_style_mode
self
.
mix_prob
=
mix_prob
# define style mapping layers
mapping_layers
=
[
PixelNorm
()]
for
_
in
range
(
num_mlps
):
mapping_layers
.
append
(
EqualLinearActModule
(
style_channels
,
style_channels
,
equalized_lr_cfg
=
dict
(
lr_mul
=
lr_mlp
,
gain
=
1.
),
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
)))
self
.
style_mapping
=
nn
.
Sequential
(
*
mapping_layers
)
self
.
channels
=
{
4
:
512
,
8
:
512
,
16
:
512
,
32
:
512
,
64
:
256
,
128
:
128
,
256
:
64
,
512
:
32
,
1024
:
16
,
}
# generator backbone (8x8 --> higher resolutions)
self
.
log_size
=
int
(
np
.
log2
(
self
.
out_size
))
self
.
convs
=
nn
.
ModuleList
()
self
.
to_rgbs
=
nn
.
ModuleList
()
in_channels_
=
self
.
channels
[
4
]
for
i
in
range
(
2
,
self
.
log_size
+
1
):
out_channels_
=
self
.
channels
[
2
**
i
]
self
.
convs
.
append
(
StyleConv
(
in_channels_
,
out_channels_
,
3
,
style_channels
,
initial
=
(
i
==
2
),
upsample
=
True
,
fused
=
True
))
self
.
to_rgbs
.
append
(
EqualizedLRConvModule
(
out_channels_
,
3
,
1
,
act_cfg
=
None
))
in_channels_
=
out_channels_
self
.
num_latents
=
self
.
log_size
*
2
-
2
self
.
num_injected_noises
=
self
.
num_latents
# register buffer for injected noises
for
layer_idx
in
range
(
self
.
num_injected_noises
):
res
=
(
layer_idx
+
4
)
//
2
shape
=
[
1
,
1
,
2
**
res
,
2
**
res
]
self
.
register_buffer
(
f
'injected_noise_
{
layer_idx
}
'
,
torch
.
randn
(
*
shape
))
def
train
(
self
,
mode
=
True
):
if
mode
:
if
self
.
default_style_mode
!=
self
.
_default_style_mode
:
mmcv
.
print_log
(
f
'Switch to train style mode:
{
self
.
_default_style_mode
}
'
,
'mmgen'
)
self
.
default_style_mode
=
self
.
_default_style_mode
else
:
if
self
.
default_style_mode
!=
self
.
eval_style_mode
:
mmcv
.
print_log
(
f
'Switch to evaluation style mode:
{
self
.
eval_style_mode
}
'
,
'mmgen'
)
self
.
default_style_mode
=
self
.
eval_style_mode
return
super
(
StyleGANv1Generator
,
self
).
train
(
mode
)
def
make_injected_noise
(
self
):
"""make noises that will be injected into feature maps.
Returns:
list[Tensor]: List of layer-wise noise tensor.
"""
device
=
get_module_device
(
self
)
# noises = [torch.randn(1, 1, 2**2, 2**2, device=device)]
noises
=
[]
for
i
in
range
(
2
,
self
.
log_size
+
1
):
for
_
in
range
(
2
):
noises
.
append
(
torch
.
randn
(
1
,
1
,
2
**
i
,
2
**
i
,
device
=
device
))
return
noises
def
get_mean_latent
(
self
,
num_samples
=
4096
,
**
kwargs
):
"""Get mean latent of W space in this generator.
Args:
num_samples (int, optional): Number of sample times. Defaults
to 4096.
Returns:
Tensor: Mean latent of this generator.
"""
return
get_mean_latent
(
self
,
num_samples
,
**
kwargs
)
def
style_mixing
(
self
,
n_source
,
n_target
,
inject_index
=
1
,
truncation_latent
=
None
,
truncation
=
0.7
,
curr_scale
=-
1
,
transition_weight
=
1
):
return
style_mixing
(
self
,
n_source
=
n_source
,
n_target
=
n_target
,
inject_index
=
inject_index
,
truncation
=
truncation
,
truncation_latent
=
truncation_latent
,
style_channels
=
self
.
style_channels
,
curr_scale
=
curr_scale
,
transition_weight
=
transition_weight
)
def
forward
(
self
,
styles
,
num_batches
=-
1
,
return_noise
=
False
,
return_latents
=
False
,
inject_index
=
None
,
truncation
=
1
,
truncation_latent
=
None
,
input_is_latent
=
False
,
injected_noise
=
None
,
randomize_noise
=
True
,
transition_weight
=
1.
,
curr_scale
=-
1
):
"""Forward function.
This function has been integrated with the truncation trick. Please
refer to the usage of `truncation` and `truncation_latent`.
Args:
styles (torch.Tensor | list[torch.Tensor] | callable | None): In
StyleGAN1, you can provide noise tensor or latent tensor. Given
a list containing more than one noise or latent tensors, style
mixing trick will be used in training. Of course, You can
directly give a batch of noise through a ``torch.Tensor`` or
offer a callable function to sample a batch of noise data.
Otherwise, the ``None`` indicates to use the default noise
sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
return_latents (bool, optional): If True, ``latent`` will be
returned in a dict with ``fake_img``. Defaults to False.
inject_index (int | None, optional): The index number for mixing
style codes. Defaults to None.
truncation (float, optional): Truncation factor. Give value less
than 1., the truncation trick will be adopted. Defaults to 1.
truncation_latent (torch.Tensor, optional): Mean truncation latent.
Defaults to None.
input_is_latent (bool, optional): If `True`, the input tensor is
the latent tensor. Defaults to False.
injected_noise (torch.Tensor | None, optional): Given a tensor, the
random noise will be fixed as this input injected noise.
Defaults to None.
randomize_noise (bool, optional): If `False`, images are sampled
with the buffered noise tensor injected to the style conv
block. Defaults to True.
transition_weight (float, optional): The weight used in resolution
transition. Defaults to 1..
curr_scale (int, optional): The resolution scale of generated image
tensor. -1 means the max resolution scale of the StyleGAN1.
Defaults to -1.
Returns:
torch.Tensor | dict: Generated image tensor or dictionary
\
containing more data.
"""
# receive noise and conduct sanity check.
if
isinstance
(
styles
,
torch
.
Tensor
):
assert
styles
.
shape
[
1
]
==
self
.
style_channels
styles
=
[
styles
]
elif
mmcv
.
is_seq_of
(
styles
,
torch
.
Tensor
):
for
t
in
styles
:
assert
t
.
shape
[
-
1
]
==
self
.
style_channels
# receive a noise generator and sample noise.
elif
callable
(
styles
):
device
=
get_module_device
(
self
)
noise_generator
=
styles
assert
num_batches
>
0
if
self
.
default_style_mode
==
'mix'
and
random
.
random
(
)
<
self
.
mix_prob
:
styles
=
[
noise_generator
((
num_batches
,
self
.
style_channels
))
for
_
in
range
(
2
)
]
else
:
styles
=
[
noise_generator
((
num_batches
,
self
.
style_channels
))]
styles
=
[
s
.
to
(
device
)
for
s
in
styles
]
# otherwise, we will adopt default noise sampler.
else
:
device
=
get_module_device
(
self
)
assert
num_batches
>
0
and
not
input_is_latent
if
self
.
default_style_mode
==
'mix'
and
random
.
random
(
)
<
self
.
mix_prob
:
styles
=
[
torch
.
randn
((
num_batches
,
self
.
style_channels
))
for
_
in
range
(
2
)
]
else
:
styles
=
[
torch
.
randn
((
num_batches
,
self
.
style_channels
))]
styles
=
[
s
.
to
(
device
)
for
s
in
styles
]
if
not
input_is_latent
:
noise_batch
=
styles
styles
=
[
self
.
style_mapping
(
s
)
for
s
in
styles
]
else
:
noise_batch
=
None
if
injected_noise
is
None
:
if
randomize_noise
:
injected_noise
=
[
None
]
*
self
.
num_injected_noises
else
:
injected_noise
=
[
getattr
(
self
,
f
'injected_noise_
{
i
}
'
)
for
i
in
range
(
self
.
num_injected_noises
)
]
# use truncation trick
if
truncation
<
1
:
style_t
=
[]
# calculate truncation latent on the fly
if
truncation_latent
is
None
and
not
hasattr
(
self
,
'truncation_latent'
):
self
.
truncation_latent
=
self
.
get_mean_latent
()
truncation_latent
=
self
.
truncation_latent
elif
truncation_latent
is
None
and
hasattr
(
self
,
'truncation_latent'
):
truncation_latent
=
self
.
truncation_latent
for
style
in
styles
:
style_t
.
append
(
truncation_latent
+
truncation
*
(
style
-
truncation_latent
))
styles
=
style_t
# no style mixing
if
len
(
styles
)
<
2
:
inject_index
=
self
.
num_latents
if
styles
[
0
].
ndim
<
3
:
latent
=
styles
[
0
].
unsqueeze
(
1
).
repeat
(
1
,
inject_index
,
1
)
else
:
latent
=
styles
[
0
]
# style mixing
else
:
if
inject_index
is
None
:
inject_index
=
random
.
randint
(
1
,
self
.
num_latents
-
1
)
latent
=
styles
[
0
].
unsqueeze
(
1
).
repeat
(
1
,
inject_index
,
1
)
latent2
=
styles
[
1
].
unsqueeze
(
1
).
repeat
(
1
,
self
.
num_latents
-
inject_index
,
1
)
latent
=
torch
.
cat
([
latent
,
latent2
],
1
)
curr_log_size
=
self
.
log_size
if
curr_scale
<
0
else
int
(
np
.
log2
(
curr_scale
))
step
=
curr_log_size
-
2
_index
=
0
out
=
latent
# 4x4 ---> higher resolutions
for
i
,
(
conv
,
to_rgb
)
in
enumerate
(
zip
(
self
.
convs
,
self
.
to_rgbs
)):
if
i
>
0
and
step
>
0
:
out_prev
=
out
out
=
conv
(
out
,
latent
[:,
_index
],
latent
[:,
_index
+
1
],
noise1
=
injected_noise
[
2
*
i
],
noise2
=
injected_noise
[
2
*
i
+
1
])
if
i
==
step
:
out
=
to_rgb
(
out
)
if
i
>
0
and
0
<=
transition_weight
<
1
:
skip_rgb
=
self
.
to_rgbs
[
i
-
1
](
out_prev
)
skip_rgb
=
F
.
interpolate
(
skip_rgb
,
scale_factor
=
2
,
mode
=
'nearest'
)
out
=
(
1
-
transition_weight
)
*
skip_rgb
+
transition_weight
*
out
break
_index
+=
2
img
=
out
if
return_latents
or
return_noise
:
output_dict
=
dict
(
fake_img
=
img
,
latent
=
latent
,
inject_index
=
inject_index
,
noise_batch
=
noise_batch
)
return
output_dict
return
img
@
MODULES
.
register_module
()
class
StyleGAN1Discriminator
(
nn
.
Module
):
"""StyleGAN1 Discriminator.
The architecture of this discriminator is proposed in StyleGAN1. More
details can be found in: A Style-Based Generator Architecture for
Generative Adversarial Networks CVPR2019.
Args:
in_size (int): The input size of images.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 2, 1].
mbstd_cfg (dict, optional): Configs for minibatch-stddev layer.
Defaults to dict(group_size=4).
"""
def
__init__
(
self
,
in_size
,
blur_kernel
=
[
1
,
2
,
1
],
mbstd_cfg
=
dict
(
group_size
=
4
)):
super
().
__init__
()
self
.
with_mbstd
=
mbstd_cfg
is
not
None
channels
=
{
4
:
512
,
8
:
512
,
16
:
512
,
32
:
512
,
64
:
256
,
128
:
128
,
256
:
64
,
512
:
32
,
1024
:
16
,
}
log_size
=
int
(
np
.
log2
(
in_size
))
self
.
log_size
=
log_size
in_channels
=
channels
[
in_size
]
self
.
convs
=
nn
.
ModuleList
()
self
.
from_rgb
=
nn
.
ModuleList
()
for
i
in
range
(
log_size
,
2
,
-
1
):
out_channel
=
channels
[
2
**
(
i
-
1
)]
self
.
from_rgb
.
append
(
EqualizedLRConvModule
(
3
,
in_channels
,
kernel_size
=
3
,
padding
=
1
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
)))
self
.
convs
.
append
(
nn
.
Sequential
(
EqualizedLRConvModule
(
in_channels
,
out_channel
,
kernel_size
=
3
,
padding
=
1
,
bias
=
True
,
norm_cfg
=
None
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
)),
Blur
(
blur_kernel
,
pad
=
(
1
,
1
)),
EqualizedLRConvDownModule
(
out_channel
,
out_channel
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
act_cfg
=
None
),
nn
.
LeakyReLU
(
negative_slope
=
0.2
,
inplace
=
True
)))
in_channels
=
out_channel
self
.
from_rgb
.
append
(
EqualizedLRConvModule
(
3
,
in_channels
,
kernel_size
=
3
,
padding
=
0
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
)))
self
.
convs
.
append
(
nn
.
Sequential
(
EqualizedLRConvModule
(
in_channels
+
1
,
512
,
kernel_size
=
3
,
padding
=
1
,
bias
=
True
,
norm_cfg
=
None
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
)),
EqualizedLRConvModule
(
512
,
512
,
kernel_size
=
4
,
padding
=
0
,
bias
=
True
,
norm_cfg
=
None
,
act_cfg
=
None
),
))
if
self
.
with_mbstd
:
self
.
mbstd_layer
=
MiniBatchStddevLayer
(
**
mbstd_cfg
)
self
.
final_linear
=
nn
.
Sequential
(
EqualLinearActModule
(
channels
[
4
],
1
))
self
.
n_layer
=
len
(
self
.
convs
)
def
forward
(
self
,
input
,
transition_weight
=
1.
,
curr_scale
=-
1
):
"""Forward function.
Args:
input (torch.Tensor): Input image tensor.
transition_weight (float, optional): The weight used in resolution
transition. Defaults to 1..
curr_scale (int, optional): The resolution scale of image tensor.
-1 means the max resolution scale of the StyleGAN1.
Defaults to -1.
Returns:
torch.Tensor: Predict score for the input image.
"""
curr_log_size
=
self
.
log_size
if
curr_scale
<
0
else
int
(
np
.
log2
(
curr_scale
))
step
=
curr_log_size
-
2
for
i
in
range
(
step
,
-
1
,
-
1
):
index
=
self
.
n_layer
-
i
-
1
if
i
==
step
:
out
=
self
.
from_rgb
[
index
](
input
)
# minibatch standard deviation
if
i
==
0
:
out
=
self
.
mbstd_layer
(
out
)
out
=
self
.
convs
[
index
](
out
)
if
i
>
0
:
if
i
==
step
and
0
<=
transition_weight
<
1
:
skip_rgb
=
F
.
avg_pool2d
(
input
,
2
)
skip_rgb
=
self
.
from_rgb
[
index
+
1
](
skip_rgb
)
out
=
(
1
-
transition_weight
)
*
skip_rgb
+
transition_weight
*
out
out
=
out
.
view
(
out
.
shape
[
0
],
-
1
)
out
=
self
.
final_linear
(
out
)
return
out
mmgen/models/architectures/stylegan/generator_discriminator_v2.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
import
random
import
mmcv
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.runner.checkpoint
import
_load_checkpoint_with_prefix
from
mmgen.core.runners.fp16_utils
import
auto_fp16
from
mmgen.models.architectures
import
PixelNorm
from
mmgen.models.architectures.common
import
get_module_device
from
mmgen.models.builder
import
MODULES
,
build_module
from
.ada.augment
import
AugmentPipe
from
.ada.misc
import
constant
from
.modules.styleganv2_modules
import
(
ConstantInput
,
ConvDownLayer
,
EqualLinearActModule
,
ModMBStddevLayer
,
ModulatedStyleConv
,
ModulatedToRGB
,
ResBlock
)
from
.utils
import
get_mean_latent
,
style_mixing
@
MODULES
.
register_module
()
class
StyleGANv2Generator
(
nn
.
Module
):
r
"""StyleGAN2 Generator.
In StyleGAN2, we use a static architecture composing of a style mapping
module and number of convolutional style blocks. More details can be found
in: Analyzing and Improving the Image Quality of StyleGAN CVPR2020.
You can load pretrained model through passing information into
``pretrained`` argument. We have already offered official weights as
follows:
- stylegan2-ffhq-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-ffhq-config-f-official_20210327_171224-bce9310c.pth # noqa
- stylegan2-horse-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-horse-config-f-official_20210327_173203-ef3e69ca.pth # noqa
- stylegan2-car-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-car-config-f-official_20210327_172340-8cfe053c.pth # noqa
- stylegan2-cat-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-cat-config-f-official_20210327_172444-15bc485b.pth # noqa
- stylegan2-church-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-church-config-f-official_20210327_172657-1d42b7d1.pth # noqa
If you want to load the ema model, you can just use following codes:
.. code-block:: python
# ckpt_http is one of the valid path from http source
generator = StyleGANv2Generator(1024, 512,
pretrained=dict(
ckpt_path=ckpt_http,
prefix='generator_ema'))
Of course, you can also download the checkpoint in advance and set
``ckpt_path`` with local path. If you just want to load the original
generator (not the ema model), please set the prefix with 'generator'.
Note that our implementation allows to generate BGR image, while the
original StyleGAN2 outputs RGB images by default. Thus, we provide
``bgr2rgb`` argument to convert the image space.
Args:
out_size (int): The output size of the StyleGAN2 generator.
style_channels (int): The number of channels for style code.
num_mlps (int, optional): The number of MLP layers. Defaults to 8.
channel_multiplier (int, optional): The multiplier factor for the
channel number. Defaults to 2.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 3, 3, 1].
lr_mlp (float, optional): The learning rate for the style mapping
layer. Defaults to 0.01.
default_style_mode (str, optional): The default mode of style mixing.
In training, we defaultly adopt mixing style mode. However, in the
evaluation, we use 'single' style mode. `['mix', 'single']` are
currently supported. Defaults to 'mix'.
eval_style_mode (str, optional): The evaluation mode of style mixing.
Defaults to 'single'.
mix_prob (float, optional): Mixing probability. The value should be
in range of [0, 1]. Defaults to ``0.9``.
num_fp16_scales (int, optional): The number of resolutions to use auto
fp16 training. Different from ``fp16_enabled``, this argument
allows users to adopt FP16 training only in several blocks.
This behaviour is much more similar to the official implementation
by Tero. Defaults to 0.
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. If this flag is `True`, the whole module will be wrapped
with ``auto_fp16``. Defaults to False.
pretrained (dict | None, optional): Information for pretained models.
The necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
"""
def
__init__
(
self
,
out_size
,
style_channels
,
num_mlps
=
8
,
channel_multiplier
=
2
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
lr_mlp
=
0.01
,
default_style_mode
=
'mix'
,
eval_style_mode
=
'single'
,
mix_prob
=
0.9
,
num_fp16_scales
=
0
,
fp16_enabled
=
False
,
pretrained
=
None
):
super
().
__init__
()
self
.
out_size
=
out_size
self
.
style_channels
=
style_channels
self
.
num_mlps
=
num_mlps
self
.
channel_multiplier
=
channel_multiplier
self
.
lr_mlp
=
lr_mlp
self
.
_default_style_mode
=
default_style_mode
self
.
default_style_mode
=
default_style_mode
self
.
eval_style_mode
=
eval_style_mode
self
.
mix_prob
=
mix_prob
self
.
num_fp16_scales
=
num_fp16_scales
self
.
fp16_enabled
=
fp16_enabled
# define style mapping layers
mapping_layers
=
[
PixelNorm
()]
for
_
in
range
(
num_mlps
):
mapping_layers
.
append
(
EqualLinearActModule
(
style_channels
,
style_channels
,
equalized_lr_cfg
=
dict
(
lr_mul
=
lr_mlp
,
gain
=
1.
),
act_cfg
=
dict
(
type
=
'fused_bias'
)))
self
.
style_mapping
=
nn
.
Sequential
(
*
mapping_layers
)
self
.
channels
=
{
4
:
512
,
8
:
512
,
16
:
512
,
32
:
512
,
64
:
256
*
channel_multiplier
,
128
:
128
*
channel_multiplier
,
256
:
64
*
channel_multiplier
,
512
:
32
*
channel_multiplier
,
1024
:
16
*
channel_multiplier
,
}
# constant input layer
self
.
constant_input
=
ConstantInput
(
self
.
channels
[
4
])
# 4x4 stage
self
.
conv1
=
ModulatedStyleConv
(
self
.
channels
[
4
],
self
.
channels
[
4
],
kernel_size
=
3
,
style_channels
=
style_channels
,
blur_kernel
=
blur_kernel
)
self
.
to_rgb1
=
ModulatedToRGB
(
self
.
channels
[
4
],
style_channels
,
upsample
=
False
,
fp16_enabled
=
fp16_enabled
)
# generator backbone (8x8 --> higher resolutions)
self
.
log_size
=
int
(
np
.
log2
(
self
.
out_size
))
self
.
convs
=
nn
.
ModuleList
()
self
.
upsamples
=
nn
.
ModuleList
()
self
.
to_rgbs
=
nn
.
ModuleList
()
in_channels_
=
self
.
channels
[
4
]
for
i
in
range
(
3
,
self
.
log_size
+
1
):
out_channels_
=
self
.
channels
[
2
**
i
]
# If `fp16_enabled` is True, all of layers will be run in auto
# FP16. In the case of `num_fp16_sacles` > 0, only partial
# layers will be run in fp16.
_use_fp16
=
(
self
.
log_size
-
i
)
<
num_fp16_scales
or
fp16_enabled
self
.
convs
.
append
(
ModulatedStyleConv
(
in_channels_
,
out_channels_
,
3
,
style_channels
,
upsample
=
True
,
blur_kernel
=
blur_kernel
,
fp16_enabled
=
_use_fp16
))
self
.
convs
.
append
(
ModulatedStyleConv
(
out_channels_
,
out_channels_
,
3
,
style_channels
,
upsample
=
False
,
blur_kernel
=
blur_kernel
,
fp16_enabled
=
_use_fp16
))
self
.
to_rgbs
.
append
(
ModulatedToRGB
(
out_channels_
,
style_channels
,
upsample
=
True
,
fp16_enabled
=
_use_fp16
))
# set to global fp16
in_channels_
=
out_channels_
self
.
num_latents
=
self
.
log_size
*
2
-
2
self
.
num_injected_noises
=
self
.
num_latents
-
1
# register buffer for injected noises
for
layer_idx
in
range
(
self
.
num_injected_noises
):
res
=
(
layer_idx
+
5
)
//
2
shape
=
[
1
,
1
,
2
**
res
,
2
**
res
]
self
.
register_buffer
(
f
'injected_noise_
{
layer_idx
}
'
,
torch
.
randn
(
*
shape
))
if
pretrained
is
not
None
:
self
.
_load_pretrained_model
(
**
pretrained
)
def
_load_pretrained_model
(
self
,
ckpt_path
,
prefix
=
''
,
map_location
=
'cpu'
,
strict
=
True
):
state_dict
=
_load_checkpoint_with_prefix
(
prefix
,
ckpt_path
,
map_location
)
self
.
load_state_dict
(
state_dict
,
strict
=
strict
)
mmcv
.
print_log
(
f
'Load pretrained model from
{
ckpt_path
}
'
,
'mmgen'
)
def
train
(
self
,
mode
=
True
):
if
mode
:
if
self
.
default_style_mode
!=
self
.
_default_style_mode
:
mmcv
.
print_log
(
f
'Switch to train style mode:
{
self
.
_default_style_mode
}
'
,
'mmgen'
)
self
.
default_style_mode
=
self
.
_default_style_mode
else
:
if
self
.
default_style_mode
!=
self
.
eval_style_mode
:
mmcv
.
print_log
(
f
'Switch to evaluation style mode:
{
self
.
eval_style_mode
}
'
,
'mmgen'
)
self
.
default_style_mode
=
self
.
eval_style_mode
return
super
(
StyleGANv2Generator
,
self
).
train
(
mode
)
def
make_injected_noise
(
self
):
"""make noises that will be injected into feature maps.
Returns:
list[Tensor]: List of layer-wise noise tensor.
"""
device
=
get_module_device
(
self
)
noises
=
[
torch
.
randn
(
1
,
1
,
2
**
2
,
2
**
2
,
device
=
device
)]
for
i
in
range
(
3
,
self
.
log_size
+
1
):
for
_
in
range
(
2
):
noises
.
append
(
torch
.
randn
(
1
,
1
,
2
**
i
,
2
**
i
,
device
=
device
))
return
noises
def
get_mean_latent
(
self
,
num_samples
=
4096
,
**
kwargs
):
"""Get mean latent of W space in this generator.
Args:
num_samples (int, optional): Number of sample times. Defaults
to 4096.
Returns:
Tensor: Mean latent of this generator.
"""
return
get_mean_latent
(
self
,
num_samples
,
**
kwargs
)
def
style_mixing
(
self
,
n_source
,
n_target
,
inject_index
=
1
,
truncation_latent
=
None
,
truncation
=
0.7
):
return
style_mixing
(
self
,
n_source
=
n_source
,
n_target
=
n_target
,
inject_index
=
inject_index
,
truncation
=
truncation
,
truncation_latent
=
truncation_latent
,
style_channels
=
self
.
style_channels
)
@
auto_fp16
()
def
forward
(
self
,
styles
,
num_batches
=-
1
,
return_noise
=
False
,
return_latents
=
False
,
inject_index
=
None
,
truncation
=
1
,
truncation_latent
=
None
,
input_is_latent
=
False
,
injected_noise
=
None
,
randomize_noise
=
True
):
"""Forward function.
This function has been integrated with the truncation trick. Please
refer to the usage of `truncation` and `truncation_latent`.
Args:
styles (torch.Tensor | list[torch.Tensor] | callable | None): In
StyleGAN2, you can provide noise tensor or latent tensor. Given
a list containing more than one noise or latent tensors, style
mixing trick will be used in training. Of course, You can
directly give a batch of noise through a ``torch.Tensor`` or
offer a callable function to sample a batch of noise data.
Otherwise, the ``None`` indicates to use the default noise
sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
return_latents (bool, optional): If True, ``latent`` will be
returned in a dict with ``fake_img``. Defaults to False.
inject_index (int | None, optional): The index number for mixing
style codes. Defaults to None.
truncation (float, optional): Truncation factor. Give value less
than 1., the truncation trick will be adopted. Defaults to 1.
truncation_latent (torch.Tensor, optional): Mean truncation latent.
Defaults to None.
input_is_latent (bool, optional): If `True`, the input tensor is
the latent tensor. Defaults to False.
injected_noise (torch.Tensor | None, optional): Given a tensor, the
random noise will be fixed as this input injected noise.
Defaults to None.
randomize_noise (bool, optional): If `False`, images are sampled
with the buffered noise tensor injected to the style conv
block. Defaults to True.
Returns:
torch.Tensor | dict: Generated image tensor or dictionary
\
containing more data.
"""
# receive noise and conduct sanity check.
if
isinstance
(
styles
,
torch
.
Tensor
):
assert
styles
.
shape
[
1
]
==
self
.
style_channels
styles
=
[
styles
]
elif
mmcv
.
is_seq_of
(
styles
,
torch
.
Tensor
):
for
t
in
styles
:
assert
t
.
shape
[
-
1
]
==
self
.
style_channels
# receive a noise generator and sample noise.
elif
callable
(
styles
):
device
=
get_module_device
(
self
)
noise_generator
=
styles
assert
num_batches
>
0
if
self
.
default_style_mode
==
'mix'
and
random
.
random
(
)
<
self
.
mix_prob
:
styles
=
[
noise_generator
((
num_batches
,
self
.
style_channels
))
for
_
in
range
(
2
)
]
else
:
styles
=
[
noise_generator
((
num_batches
,
self
.
style_channels
))]
styles
=
[
s
.
to
(
device
)
for
s
in
styles
]
# otherwise, we will adopt default noise sampler.
else
:
device
=
get_module_device
(
self
)
assert
num_batches
>
0
and
not
input_is_latent
if
self
.
default_style_mode
==
'mix'
and
random
.
random
(
)
<
self
.
mix_prob
:
styles
=
[
torch
.
randn
((
num_batches
,
self
.
style_channels
))
for
_
in
range
(
2
)
]
else
:
styles
=
[
torch
.
randn
((
num_batches
,
self
.
style_channels
))]
styles
=
[
s
.
to
(
device
)
for
s
in
styles
]
if
not
input_is_latent
:
noise_batch
=
styles
styles
=
[
self
.
style_mapping
(
s
)
for
s
in
styles
]
else
:
noise_batch
=
None
if
injected_noise
is
None
:
if
randomize_noise
:
injected_noise
=
[
None
]
*
self
.
num_injected_noises
else
:
injected_noise
=
[
getattr
(
self
,
f
'injected_noise_
{
i
}
'
)
for
i
in
range
(
self
.
num_injected_noises
)
]
# use truncation trick
if
truncation
<
1
:
style_t
=
[]
# calculate truncation latent on the fly
if
truncation_latent
is
None
and
not
hasattr
(
self
,
'truncation_latent'
):
self
.
truncation_latent
=
self
.
get_mean_latent
()
truncation_latent
=
self
.
truncation_latent
elif
truncation_latent
is
None
and
hasattr
(
self
,
'truncation_latent'
):
truncation_latent
=
self
.
truncation_latent
for
style
in
styles
:
style_t
.
append
(
truncation_latent
+
truncation
*
(
style
-
truncation_latent
))
styles
=
style_t
# no style mixing
if
len
(
styles
)
<
2
:
inject_index
=
self
.
num_latents
if
styles
[
0
].
ndim
<
3
:
latent
=
styles
[
0
].
unsqueeze
(
1
).
repeat
(
1
,
inject_index
,
1
)
else
:
latent
=
styles
[
0
]
# style mixing
else
:
if
inject_index
is
None
:
inject_index
=
random
.
randint
(
1
,
self
.
num_latents
-
1
)
latent
=
styles
[
0
].
unsqueeze
(
1
).
repeat
(
1
,
inject_index
,
1
)
latent2
=
styles
[
1
].
unsqueeze
(
1
).
repeat
(
1
,
self
.
num_latents
-
inject_index
,
1
)
latent
=
torch
.
cat
([
latent
,
latent2
],
1
)
# 4x4 stage
out
=
self
.
constant_input
(
latent
)
out
=
self
.
conv1
(
out
,
latent
[:,
0
],
noise
=
injected_noise
[
0
])
skip
=
self
.
to_rgb1
(
out
,
latent
[:,
1
])
_index
=
1
# 8x8 ---> higher resolutions
for
up_conv
,
conv
,
noise1
,
noise2
,
to_rgb
in
zip
(
self
.
convs
[::
2
],
self
.
convs
[
1
::
2
],
injected_noise
[
1
::
2
],
injected_noise
[
2
::
2
],
self
.
to_rgbs
):
out
=
up_conv
(
out
,
latent
[:,
_index
],
noise
=
noise1
)
out
=
conv
(
out
,
latent
[:,
_index
+
1
],
noise
=
noise2
)
skip
=
to_rgb
(
out
,
latent
[:,
_index
+
2
],
skip
)
_index
+=
2
# make sure the output image is torch.float32 to avoid RunTime Error
# in other modules
img
=
skip
.
to
(
torch
.
float32
)
if
return_latents
or
return_noise
:
output_dict
=
dict
(
fake_img
=
img
,
latent
=
latent
,
inject_index
=
inject_index
,
noise_batch
=
noise_batch
)
return
output_dict
return
img
@
MODULES
.
register_module
()
class
StyleGAN2Discriminator
(
nn
.
Module
):
"""StyleGAN2 Discriminator.
The architecture of this discriminator is proposed in StyleGAN2. More
details can be found in: Analyzing and Improving the Image Quality of
StyleGAN CVPR2020.
You can load pretrained model through passing information into
``pretrained`` argument. We have already offered official weights as
follows:
- stylegan2-ffhq-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-ffhq-config-f-official_20210327_171224-bce9310c.pth # noqa
- stylegan2-horse-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-horse-config-f-official_20210327_173203-ef3e69ca.pth # noqa
- stylegan2-car-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-car-config-f-official_20210327_172340-8cfe053c.pth # noqa
- stylegan2-cat-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-cat-config-f-official_20210327_172444-15bc485b.pth # noqa
- stylegan2-church-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-church-config-f-official_20210327_172657-1d42b7d1.pth # noqa
If you want to load the ema model, you can just use following codes:
.. code-block:: python
# ckpt_http is one of the valid path from http source
discriminator = StyleGAN2Discriminator(1024, 512,
pretrained=dict(
ckpt_path=ckpt_http,
prefix='discriminator'))
Of course, you can also download the checkpoint in advance and set
``ckpt_path`` with local path.
Note that our implementation adopts BGR image as input, while the
original StyleGAN2 provides RGB images to the discriminator. Thus, we
provide ``bgr2rgb`` argument to convert the image space. If your images
follow the RGB order, please set it to ``True`` accordingly.
Args:
in_size (int): The input size of images.
channel_multiplier (int, optional): The multiplier factor for the
channel number. Defaults to 2.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 3, 3, 1].
mbstd_cfg (dict, optional): Configs for minibatch-stddev layer.
Defaults to dict(group_size=4, channel_groups=1).
num_fp16_scales (int, optional): The number of resolutions to use auto
fp16 training. Defaults to 0.
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. Defaults to False.
out_fp32 (bool, optional): Whether to convert the output feature map to
`torch.float32`. Defaults to `True`.
convert_input_fp32 (bool, optional): Whether to convert input type to
fp32 if not `fp16_enabled`. This argument is designed to deal with
the cases where some modules are run in FP16 and others in FP32.
Defaults to True.
input_bgr2rgb (bool, optional): Whether to reformat the input channels
with order `rgb`. Since we provide several converted weights,
whose input order is `rgb`. You can set this argument to True if
you want to finetune on converted weights. Defaults to False.
pretrained (dict | None, optional): Information for pretained models.
The necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
"""
def
__init__
(
self
,
in_size
,
channel_multiplier
=
2
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
mbstd_cfg
=
dict
(
group_size
=
4
,
channel_groups
=
1
),
num_fp16_scales
=
0
,
fp16_enabled
=
False
,
out_fp32
=
True
,
convert_input_fp32
=
True
,
input_bgr2rgb
=
False
,
pretrained
=
None
):
super
().
__init__
()
self
.
num_fp16_scale
=
num_fp16_scales
self
.
fp16_enabled
=
fp16_enabled
self
.
convert_input_fp32
=
convert_input_fp32
self
.
out_fp32
=
out_fp32
channels
=
{
4
:
512
,
8
:
512
,
16
:
512
,
32
:
512
,
64
:
256
*
channel_multiplier
,
128
:
128
*
channel_multiplier
,
256
:
64
*
channel_multiplier
,
512
:
32
*
channel_multiplier
,
1024
:
16
*
channel_multiplier
,
}
log_size
=
int
(
np
.
log2
(
in_size
))
in_channels
=
channels
[
in_size
]
_use_fp16
=
num_fp16_scales
>
0
convs
=
[
ConvDownLayer
(
3
,
channels
[
in_size
],
1
,
fp16_enabled
=
_use_fp16
)
]
for
i
in
range
(
log_size
,
2
,
-
1
):
out_channel
=
channels
[
2
**
(
i
-
1
)]
# add fp16 training for higher resolutions
_use_fp16
=
(
log_size
-
i
)
<
num_fp16_scales
or
fp16_enabled
convs
.
append
(
ResBlock
(
in_channels
,
out_channel
,
blur_kernel
,
fp16_enabled
=
_use_fp16
,
convert_input_fp32
=
convert_input_fp32
))
in_channels
=
out_channel
self
.
convs
=
nn
.
Sequential
(
*
convs
)
self
.
mbstd_layer
=
ModMBStddevLayer
(
**
mbstd_cfg
)
self
.
final_conv
=
ConvDownLayer
(
in_channels
+
1
,
channels
[
4
],
3
)
self
.
final_linear
=
nn
.
Sequential
(
EqualLinearActModule
(
channels
[
4
]
*
4
*
4
,
channels
[
4
],
act_cfg
=
dict
(
type
=
'fused_bias'
)),
EqualLinearActModule
(
channels
[
4
],
1
),
)
self
.
input_bgr2rgb
=
input_bgr2rgb
if
pretrained
is
not
None
:
self
.
_load_pretrained_model
(
**
pretrained
)
def
_load_pretrained_model
(
self
,
ckpt_path
,
prefix
=
''
,
map_location
=
'cpu'
,
strict
=
True
):
state_dict
=
_load_checkpoint_with_prefix
(
prefix
,
ckpt_path
,
map_location
)
self
.
load_state_dict
(
state_dict
,
strict
=
strict
)
mmcv
.
print_log
(
f
'Load pretrained model from
{
ckpt_path
}
'
,
'mmgen'
)
@
auto_fp16
()
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (torch.Tensor): Input image tensor.
Returns:
torch.Tensor: Predict score for the input image.
"""
# This setting was used to finetune on converted weights
if
self
.
input_bgr2rgb
:
x
=
x
[:,
[
2
,
1
,
0
],
...]
x
=
self
.
convs
(
x
)
x
=
self
.
mbstd_layer
(
x
)
if
not
self
.
final_conv
.
fp16_enabled
and
self
.
convert_input_fp32
:
x
=
x
.
to
(
torch
.
float32
)
x
=
self
.
final_conv
(
x
)
x
=
x
.
view
(
x
.
shape
[
0
],
-
1
)
x
=
self
.
final_linear
(
x
)
return
x
@
MODULES
.
register_module
()
class
ADAStyleGAN2Discriminator
(
StyleGAN2Discriminator
):
def
__init__
(
self
,
in_size
,
*
args
,
data_aug
=
None
,
**
kwargs
):
"""StyleGANv2 Discriminator with adaptive augmentation.
Args:
in_size (int): The input size of images.
data_aug (dict, optional): Config for data
augmentation. Defaults to None.
"""
super
().
__init__
(
in_size
,
*
args
,
**
kwargs
)
self
.
with_ada
=
data_aug
is
not
None
if
self
.
with_ada
:
self
.
ada_aug
=
build_module
(
data_aug
)
self
.
ada_aug
.
requires_grad
=
False
self
.
log_size
=
int
(
np
.
log2
(
in_size
))
def
forward
(
self
,
x
):
"""Forward function."""
if
self
.
with_ada
:
x
=
self
.
ada_aug
.
aug_pipeline
(
x
)
return
super
().
forward
(
x
)
@
MODULES
.
register_module
()
class
ADAAug
(
nn
.
Module
):
"""Data Augmentation Module for Adaptive Discriminator augmentation.
Args:
aug_pipeline (dict, optional): Config for augmentation pipeline.
Defaults to None.
update_interval (int, optional): Interval for updating
augmentation probability. Defaults to 4.
augment_initial_p (float, optional): Initial augmentation
probability. Defaults to 0..
ada_target (float, optional): ADA target. Defaults to 0.6.
ada_kimg (int, optional): ADA training duration. Defaults to 500.
"""
def
__init__
(
self
,
aug_pipeline
=
None
,
update_interval
=
4
,
augment_initial_p
=
0.
,
ada_target
=
0.6
,
ada_kimg
=
500
):
super
().
__init__
()
self
.
aug_pipeline
=
AugmentPipe
(
**
aug_pipeline
)
self
.
update_interval
=
update_interval
self
.
ada_kimg
=
ada_kimg
self
.
ada_target
=
ada_target
self
.
aug_pipeline
.
p
.
copy_
(
torch
.
tensor
(
augment_initial_p
))
# this log buffer stores two numbers: num_scalars, sum_scalars.
self
.
register_buffer
(
'log_buffer'
,
torch
.
zeros
((
2
,
)))
def
update
(
self
,
iteration
=
0
,
num_batches
=
0
):
"""Update Augment probability.
Args:
iteration (int, optional): Training iteration.
Defaults to 0.
num_batches (int, optional): The number of reals batches.
Defaults to 0.
"""
if
(
iteration
+
1
)
%
self
.
update_interval
==
0
:
adjust_step
=
float
(
num_batches
*
self
.
update_interval
)
/
float
(
self
.
ada_kimg
*
1000.
)
# get the mean value as the ada heuristic
ada_heuristic
=
self
.
log_buffer
[
1
]
/
self
.
log_buffer
[
0
]
adjust
=
np
.
sign
(
ada_heuristic
.
item
()
-
self
.
ada_target
)
*
adjust_step
# update the augment p
# Note that p may be bigger than 1.0
self
.
aug_pipeline
.
p
.
copy_
((
self
.
aug_pipeline
.
p
+
adjust
).
max
(
constant
(
0
,
device
=
self
.
log_buffer
.
device
)))
self
.
log_buffer
=
self
.
log_buffer
*
0.
mmgen/models/architectures/stylegan/generator_discriminator_v3.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
import
mmcv
import
torch
import
torch.nn
as
nn
from
mmcv.runner.checkpoint
import
_load_checkpoint_with_prefix
from
mmgen.models.architectures.common
import
get_module_device
from
mmgen.models.builder
import
MODULES
,
build_module
from
.utils
import
get_mean_latent
@
MODULES
.
register_module
()
class
StyleGANv3Generator
(
nn
.
Module
):
"""StyleGAN3 Generator.
In StyleGAN3, we make several changes to StyleGANv2's generator which
include transformed fourier features, filtered nonlinearities and
non-critical sampling, etc. More details can be found in: Alias-Free
Generative Adversarial Networks NeurIPS'2021.
Ref: https://github.com/NVlabs/stylegan3
Args:
out_size (int): The output size of the StyleGAN3 generator.
style_channels (int): The number of channels for style code.
img_channels (int): The number of output's channels.
noise_size (int, optional): Size of the input noise vector.
Defaults to 512.
rgb2bgr (bool, optional): Whether to reformat the output channels
with order `bgr`. We provide several pre-trained StyleGAN3
weights whose output channels order is `rgb`. You can set
this argument to True to use the weights.
pretrained (str | dict, optional): Path for the pretrained model or
dict containing information for pretained models whose necessary
key is 'ckpt_path'. Besides, you can also provide 'prefix' to load
the generator part from the whole state dict. Defaults to None.
synthesis_cfg (dict, optional): Config for synthesis network. Defaults
to dict(type='SynthesisNetwork').
mapping_cfg (dict, optional): Config for mapping network. Defaults to
dict(type='MappingNetwork').
"""
def
__init__
(
self
,
out_size
,
style_channels
,
img_channels
,
noise_size
=
512
,
rgb2bgr
=
False
,
pretrained
=
None
,
synthesis_cfg
=
dict
(
type
=
'SynthesisNetwork'
),
mapping_cfg
=
dict
(
type
=
'MappingNetwork'
)):
super
().
__init__
()
self
.
noise_size
=
noise_size
self
.
style_channels
=
style_channels
self
.
out_size
=
out_size
self
.
img_channels
=
img_channels
self
.
rgb2bgr
=
rgb2bgr
self
.
_synthesis_cfg
=
deepcopy
(
synthesis_cfg
)
self
.
_synthesis_cfg
.
setdefault
(
'style_channels'
,
style_channels
)
self
.
_synthesis_cfg
.
setdefault
(
'out_size'
,
out_size
)
self
.
_synthesis_cfg
.
setdefault
(
'img_channels'
,
img_channels
)
self
.
synthesis
=
build_module
(
self
.
_synthesis_cfg
)
self
.
num_ws
=
self
.
synthesis
.
num_ws
self
.
_mapping_cfg
=
deepcopy
(
mapping_cfg
)
self
.
_mapping_cfg
.
setdefault
(
'noise_size'
,
noise_size
)
self
.
_mapping_cfg
.
setdefault
(
'style_channels'
,
style_channels
)
self
.
_mapping_cfg
.
setdefault
(
'num_ws'
,
self
.
num_ws
)
self
.
style_mapping
=
build_module
(
self
.
_mapping_cfg
)
if
pretrained
is
not
None
:
self
.
_load_pretrained_model
(
**
pretrained
)
def
_load_pretrained_model
(
self
,
ckpt_path
,
prefix
=
''
,
map_location
=
'cpu'
,
strict
=
True
):
state_dict
=
_load_checkpoint_with_prefix
(
prefix
,
ckpt_path
,
map_location
)
self
.
load_state_dict
(
state_dict
,
strict
=
strict
)
mmcv
.
print_log
(
f
'Load pretrained model from
{
ckpt_path
}
'
,
'mmgen'
)
def
forward
(
self
,
noise
,
num_batches
=
0
,
input_is_latent
=
False
,
truncation
=
1
,
num_truncation_layer
=
None
,
update_emas
=
False
,
force_fp32
=
True
,
return_noise
=
False
,
return_latents
=
False
):
"""Forward Function for stylegan3.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
input_is_latent (bool, optional): If `True`, the input tensor is
the latent tensor. Defaults to False.
truncation (float, optional): Truncation factor. Give value less
than 1., the truncation trick will be adopted. Defaults to 1.
num_truncation_layer (int, optional): Number of layers use
truncated latent. Defaults to None.
update_emas (bool, optional): Whether update moving average of
mean latent. Defaults to False.
force_fp32 (bool, optional): Force fp32 ignore the weights.
Defaults to True.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
return_latents (bool, optional): If True, ``latent`` will be
returned in a dict with ``fake_img``. Defaults to False.
Returns:
torch.Tensor | dict: Generated image tensor or dictionary
\
containing more data.
"""
# if input is latent, set noise size as the style_channels
noise_size
=
(
self
.
style_channels
if
input_is_latent
else
self
.
noise_size
)
if
isinstance
(
noise
,
torch
.
Tensor
):
assert
noise
.
shape
[
1
]
==
noise_size
assert
noise
.
ndim
==
2
,
(
'The noise should be in shape of (n, c), '
f
'but got
{
noise
.
shape
}
'
)
noise_batch
=
noise
# receive a noise generator and sample noise.
elif
callable
(
noise
):
noise_generator
=
noise
assert
num_batches
>
0
noise_batch
=
noise_generator
((
num_batches
,
noise_size
))
# otherwise, we will adopt default noise sampler.
else
:
assert
num_batches
>
0
noise_batch
=
torch
.
randn
((
num_batches
,
noise_size
))
device
=
get_module_device
(
self
)
noise_batch
=
noise_batch
.
to
(
device
)
if
input_is_latent
:
ws
=
noise_batch
.
unsqueeze
(
1
).
repeat
([
1
,
self
.
num_ws
,
1
])
else
:
ws
=
self
.
style_mapping
(
noise_batch
,
truncation
=
truncation
,
num_truncation_layer
=
num_truncation_layer
,
update_emas
=
update_emas
)
out_img
=
self
.
synthesis
(
ws
,
update_emas
=
update_emas
,
force_fp32
=
force_fp32
)
if
self
.
rgb2bgr
:
out_img
=
out_img
[:,
[
2
,
1
,
0
],
...]
if
return_noise
or
return_latents
:
output
=
dict
(
fake_img
=
out_img
,
noise_batch
=
noise_batch
,
latent
=
ws
)
return
output
return
out_img
def
get_mean_latent
(
self
,
num_samples
=
4096
,
**
kwargs
):
"""Get mean latent of W space in this generator.
Args:
num_samples (int, optional): Number of sample times. Defaults
to 4096.
Returns:
Tensor: Mean latent of this generator.
"""
if
hasattr
(
self
.
style_mapping
,
'w_avg'
):
return
self
.
style_mapping
.
w_avg
return
get_mean_latent
(
self
,
num_samples
,
**
kwargs
)
def
get_training_kwargs
(
self
,
phase
):
"""Get training kwargs. In StyleGANv3, we enable fp16, and update
mangitude ema during training of discriminator. This function is used
to pass related arguments.
Args:
phase (str): Current training phase.
Returns:
dict: Training kwargs.
"""
if
phase
==
'disc'
:
return
dict
(
update_emas
=
True
,
force_fp32
=
False
)
if
phase
==
'gen'
:
return
dict
(
force_fp32
=
False
)
return
{}
mmgen/models/architectures/stylegan/modules/__init__.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
from
.styleganv2_modules
import
(
Blur
,
ConstantInput
,
ModulatedConv2d
,
ModulatedStyleConv
,
ModulatedToRGB
,
NoiseInjection
)
from
.styleganv3_modules
import
(
MappingNetwork
,
SynthesisInput
,
SynthesisLayer
,
SynthesisNetwork
)
__all__
=
[
'Blur'
,
'ModulatedStyleConv'
,
'ModulatedToRGB'
,
'NoiseInjection'
,
'ConstantInput'
,
'MappingNetwork'
,
'SynthesisInput'
,
'SynthesisLayer'
,
'SynthesisNetwork'
,
'ModulatedConv2d'
]
mmgen/models/architectures/stylegan/modules/styleganv1_modules.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
import
torch.nn
as
nn
from
mmgen.models.architectures.pggan
import
(
EqualizedLRConvModule
,
EqualizedLRConvUpModule
,
EqualizedLRLinearModule
)
from
mmgen.models.architectures.stylegan.modules
import
(
Blur
,
ConstantInput
,
NoiseInjection
)
class
AdaptiveInstanceNorm
(
nn
.
Module
):
r
"""Adaptive Instance Normalization Module.
Ref: https://github.com/rosinality/style-based-gan-pytorch/blob/master/model.py # noqa
Args:
in_channel (int): The number of input's channel.
style_dim (int): Style latent dimension.
"""
def
__init__
(
self
,
in_channel
,
style_dim
):
super
().
__init__
()
self
.
norm
=
nn
.
InstanceNorm2d
(
in_channel
)
self
.
affine
=
EqualizedLRLinearModule
(
style_dim
,
in_channel
*
2
)
self
.
affine
.
bias
.
data
[:
in_channel
]
=
1
self
.
affine
.
bias
.
data
[
in_channel
:]
=
0
def
forward
(
self
,
input
,
style
):
"""Forward function.
Args:
input (Tensor): Input tensor with shape (n, c, h, w).
style (Tensor): Input style tensor with shape (n, c).
Returns:
Tensor: Forward results.
"""
style
=
self
.
affine
(
style
).
unsqueeze
(
2
).
unsqueeze
(
3
)
gamma
,
beta
=
style
.
chunk
(
2
,
1
)
out
=
self
.
norm
(
input
)
out
=
gamma
*
out
+
beta
return
out
class
StyleConv
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
style_channels
,
padding
=
1
,
initial
=
False
,
blur_kernel
=
[
1
,
2
,
1
],
upsample
=
False
,
fused
=
False
):
"""Convolutional style blocks composing of noise injector, AdaIN module
and convolution layers.
Args:
in_channels (int): The channel number of the input tensor.
out_channels (itn): The channel number of the output tensor.
kernel_size (int): The kernel size of convolution layers.
style_channels (int): The number of channels for style code.
padding (int, optional): Padding of convolution layers.
Defaults to 1.
initial (bool, optional): Whether this is the first StyleConv of
StyleGAN's generator. Defaults to False.
blur_kernel (list, optional): The blurry kernel.
Defaults to [1, 2, 1].
upsample (bool, optional): Whether perform upsampling.
Defaults to False.
fused (bool, optional): Whether use fused upconv.
Defaults to False.
"""
super
().
__init__
()
if
initial
:
self
.
conv1
=
ConstantInput
(
in_channels
)
else
:
if
upsample
:
if
fused
:
self
.
conv1
=
nn
.
Sequential
(
EqualizedLRConvUpModule
(
in_channels
,
out_channels
,
kernel_size
,
padding
=
padding
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
)),
Blur
(
blur_kernel
,
pad
=
(
1
,
1
)),
)
else
:
self
.
conv1
=
nn
.
Sequential
(
nn
.
Upsample
(
scale_factor
=
2
,
mode
=
'nearest'
),
EqualizedLRConvModule
(
in_channels
,
out_channels
,
kernel_size
,
padding
=
padding
,
act_cfg
=
None
),
Blur
(
blur_kernel
,
pad
=
(
1
,
1
)))
else
:
self
.
conv1
=
EqualizedLRConvModule
(
in_channels
,
out_channels
,
kernel_size
,
padding
=
padding
,
act_cfg
=
None
)
self
.
noise_injector1
=
NoiseInjection
()
self
.
activate1
=
nn
.
LeakyReLU
(
0.2
)
self
.
adain1
=
AdaptiveInstanceNorm
(
out_channels
,
style_channels
)
self
.
conv2
=
EqualizedLRConvModule
(
out_channels
,
out_channels
,
kernel_size
,
padding
=
padding
,
act_cfg
=
None
)
self
.
noise_injector2
=
NoiseInjection
()
self
.
activate2
=
nn
.
LeakyReLU
(
0.2
)
self
.
adain2
=
AdaptiveInstanceNorm
(
out_channels
,
style_channels
)
def
forward
(
self
,
x
,
style1
,
style2
,
noise1
=
None
,
noise2
=
None
,
return_noise
=
False
):
"""Forward function.
Args:
x (Tensor): Input tensor.
style1 (Tensor): Input style tensor with shape (n, c).
style2 (Tensor): Input style tensor with shape (n, c).
noise1 (Tensor, optional): Noise tensor with shape (n, c, h, w).
Defaults to None.
noise2 (Tensor, optional): Noise tensor with shape (n, c, h, w).
Defaults to None.
return_noise (bool, optional): If True, ``noise1`` and ``noise2``
will be returned with ``out``. Defaults to False.
Returns:
Tensor | tuple[Tensor]: Forward results.
"""
out
=
self
.
conv1
(
x
)
if
return_noise
:
out
,
noise1
=
self
.
noise_injector1
(
out
,
noise
=
noise1
,
return_noise
=
return_noise
)
else
:
out
=
self
.
noise_injector1
(
out
,
noise
=
noise1
,
return_noise
=
return_noise
)
out
=
self
.
activate1
(
out
)
out
=
self
.
adain1
(
out
,
style1
)
out
=
self
.
conv2
(
out
)
if
return_noise
:
out
,
noise2
=
self
.
noise_injector2
(
out
,
noise
=
noise2
,
return_noise
=
return_noise
)
else
:
out
=
self
.
noise_injector2
(
out
,
noise
=
noise2
,
return_noise
=
return_noise
)
out
=
self
.
activate2
(
out
)
out
=
self
.
adain2
(
out
,
style2
)
if
return_noise
:
return
out
,
noise1
,
noise2
return
out
mmgen/models/architectures/stylegan/modules/styleganv2_modules.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
from
functools
import
partial
import
mmcv
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn.bricks.activation
import
build_activation_layer
from
mmcv.ops.fused_bias_leakyrelu
import
(
FusedBiasLeakyReLU
,
fused_bias_leakyrelu
)
from
mmcv.ops.upfirdn2d
import
upfirdn2d
from
mmcv.runner.dist_utils
import
get_dist_info
from
mmgen.core.runners.fp16_utils
import
auto_fp16
from
mmgen.models.architectures.pggan
import
(
EqualizedLRConvModule
,
EqualizedLRLinearModule
,
equalized_lr
)
from
mmgen.models.common
import
AllGatherLayer
from
mmgen.ops
import
conv2d
,
conv_transpose2d
class
_FusedBiasLeakyReLU
(
FusedBiasLeakyReLU
):
"""Wrap FusedBiasLeakyReLU to support FP16 training."""
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, ...).
Returns:
Tensor: Output feature map.
"""
return
fused_bias_leakyrelu
(
x
,
self
.
bias
.
to
(
x
.
dtype
),
self
.
negative_slope
,
self
.
scale
)
class
EqualLinearActModule
(
nn
.
Module
):
"""Equalized LR Linear Module with Activation Layer.
This module is modified from ``EqualizedLRLinearModule`` defined in PGGAN.
The major features updated in this module is adding support for activation
layers used in StyleGAN2.
Args:
equalized_lr_cfg (dict | None, optional): Config for equalized lr.
Defaults to dict(gain=1., lr_mul=1.).
bias (bool, optional): Whether to use bias item. Defaults to True.
bias_init (float, optional): The value for bias initialization.
Defaults to ``0.``.
act_cfg (dict | None, optional): Config for activation layer.
Defaults to None.
"""
def
__init__
(
self
,
*
args
,
equalized_lr_cfg
=
dict
(
gain
=
1.
,
lr_mul
=
1.
),
bias
=
True
,
bias_init
=
0.
,
act_cfg
=
None
,
**
kwargs
):
super
().
__init__
()
self
.
with_activation
=
act_cfg
is
not
None
# w/o bias in linear layer
self
.
linear
=
EqualizedLRLinearModule
(
*
args
,
bias
=
False
,
equalized_lr_cfg
=
equalized_lr_cfg
,
**
kwargs
)
if
equalized_lr_cfg
is
not
None
:
self
.
lr_mul
=
equalized_lr_cfg
.
get
(
'lr_mul'
,
1.
)
else
:
self
.
lr_mul
=
1.
# define bias outside linear layer
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
linear
.
out_features
).
fill_
(
bias_init
))
else
:
self
.
bias
=
None
if
self
.
with_activation
:
act_cfg
=
deepcopy
(
act_cfg
)
if
act_cfg
[
'type'
]
==
'fused_bias'
:
self
.
act_type
=
act_cfg
.
pop
(
'type'
)
assert
self
.
bias
is
not
None
self
.
activate
=
partial
(
fused_bias_leakyrelu
,
**
act_cfg
)
else
:
self
.
act_type
=
'normal'
self
.
activate
=
build_activation_layer
(
act_cfg
)
else
:
self
.
act_type
=
None
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, ...).
Returns:
Tensor: Output feature map.
"""
if
x
.
ndim
>=
3
:
x
=
x
.
reshape
(
x
.
size
(
0
),
-
1
)
x
=
self
.
linear
(
x
)
if
self
.
with_activation
and
self
.
act_type
==
'fused_bias'
:
x
=
self
.
activate
(
x
,
self
.
bias
*
self
.
lr_mul
)
elif
self
.
bias
is
not
None
and
self
.
with_activation
:
x
=
self
.
activate
(
x
+
self
.
bias
*
self
.
lr_mul
)
elif
self
.
bias
is
not
None
:
x
=
x
+
self
.
bias
*
self
.
lr_mul
elif
self
.
with_activation
:
x
=
self
.
activate
(
x
)
return
x
def
_make_kernel
(
k
):
k
=
torch
.
tensor
(
k
,
dtype
=
torch
.
float32
)
if
k
.
ndim
==
1
:
k
=
k
[
None
,
:]
*
k
[:,
None
]
k
/=
k
.
sum
()
return
k
class
UpsampleUpFIRDn
(
nn
.
Module
):
"""UpFIRDn for Upsampling.
This module is used in the ``to_rgb`` layers in StyleGAN2 for upsampling
the images.
Args:
kernel (Array): Blur kernel/filter used in UpFIRDn.
factor (int, optional): Upsampling factor. Defaults to 2.
"""
def
__init__
(
self
,
kernel
,
factor
=
2
):
super
().
__init__
()
self
.
factor
=
factor
kernel
=
_make_kernel
(
kernel
)
*
(
factor
**
2
)
self
.
register_buffer
(
'kernel'
,
kernel
)
p
=
kernel
.
shape
[
0
]
-
factor
pad0
=
(
p
+
1
)
//
2
+
factor
-
1
pad1
=
p
//
2
self
.
pad
=
(
pad0
,
pad1
)
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map.
"""
out
=
upfirdn2d
(
x
,
self
.
kernel
.
to
(
x
.
dtype
),
up
=
self
.
factor
,
down
=
1
,
pad
=
self
.
pad
)
return
out
class
DownsampleUpFIRDn
(
nn
.
Module
):
"""UpFIRDn for Downsampling.
This module is mentioned in StyleGAN2 for dowampling the feature maps.
Args:
kernel (Array): Blur kernel/filter used in UpFIRDn.
factor (int, optional): Downsampling factor. Defaults to 2.
"""
def
__init__
(
self
,
kernel
,
factor
=
2
):
super
().
__init__
()
self
.
factor
=
factor
kernel
=
_make_kernel
(
kernel
)
self
.
register_buffer
(
'kernel'
,
kernel
)
p
=
kernel
.
shape
[
0
]
-
factor
pad0
=
(
p
+
1
)
//
2
pad1
=
p
//
2
self
.
pad
=
(
pad0
,
pad1
)
def
forward
(
self
,
input
):
"""Forward function.
Args:
input (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map.
"""
out
=
upfirdn2d
(
input
,
self
.
kernel
.
to
(
input
.
dtype
),
up
=
1
,
down
=
self
.
factor
,
pad
=
self
.
pad
)
return
out
class
Blur
(
nn
.
Module
):
"""Blur module.
This module is adopted rightly after upsampling operation in StyleGAN2.
Args:
kernel (Array): Blur kernel/filter used in UpFIRDn.
pad (list[int]): Padding for features.
upsample_factor (int, optional): Upsampling factor. Defaults to 1.
"""
def
__init__
(
self
,
kernel
,
pad
,
upsample_factor
=
1
):
super
().
__init__
()
kernel
=
_make_kernel
(
kernel
)
if
upsample_factor
>
1
:
kernel
=
kernel
*
(
upsample_factor
**
2
)
self
.
register_buffer
(
'kernel'
,
kernel
)
self
.
pad
=
pad
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map.
"""
# In Tero's implementation, he uses fp32
return
upfirdn2d
(
x
,
self
.
kernel
.
to
(
x
.
dtype
),
pad
=
self
.
pad
)
class
ModulatedConv2d
(
nn
.
Module
):
r
"""Modulated Conv2d in StyleGANv2.
This module implements the modulated convolution layers proposed in
StyleGAN2. Details can be found in Analyzing and Improving the Image
Quality of StyleGAN, CVPR2020.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
style_channels (int): Channels for the style codes.
demodulate (bool, optional): Whether to adopt demodulation.
Defaults to True.
upsample (bool, optional): Whether to adopt upsampling in features.
Defaults to False.
downsample (bool, optional): Whether to adopt downsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
equalized_lr_cfg (dict | None, optional): Configs for equalized lr.
Defaults to dict(mode='fan_in', lr_mul=1., gain=1.).
style_mod_cfg (dict, optional): Configs for style modulation module.
Defaults to dict(bias_init=1.).
style_bias (float, optional): Bias value for style code.
Defaults to 0..
eps (float, optional): Epsilon value to avoid computation error.
Defaults to 1e-8.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
style_channels
,
demodulate
=
True
,
upsample
=
False
,
downsample
=
False
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
equalized_lr_cfg
=
dict
(
mode
=
'fan_in'
,
lr_mul
=
1.
,
gain
=
1.
),
style_mod_cfg
=
dict
(
bias_init
=
1.
),
style_bias
=
0.
,
padding
=
None
,
# self define padding
eps
=
1e-8
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
kernel_size
self
.
style_channels
=
style_channels
self
.
demodulate
=
demodulate
# sanity check for kernel size
assert
isinstance
(
self
.
kernel_size
,
int
)
and
(
self
.
kernel_size
>=
1
and
self
.
kernel_size
%
2
==
1
)
self
.
upsample
=
upsample
self
.
downsample
=
downsample
self
.
style_bias
=
style_bias
self
.
eps
=
eps
# build style modulation module
style_mod_cfg
=
dict
()
if
style_mod_cfg
is
None
else
style_mod_cfg
self
.
style_modulation
=
EqualLinearActModule
(
style_channels
,
in_channels
,
**
style_mod_cfg
)
# set lr_mul for conv weight
lr_mul_
=
1.
if
equalized_lr_cfg
is
not
None
:
lr_mul_
=
equalized_lr_cfg
.
get
(
'lr_mul'
,
1.
)
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
1
,
out_channels
,
in_channels
,
kernel_size
,
kernel_size
).
div_
(
lr_mul_
))
# build blurry layer for upsampling
if
upsample
:
factor
=
2
p
=
(
len
(
blur_kernel
)
-
factor
)
-
(
kernel_size
-
1
)
pad0
=
(
p
+
1
)
//
2
+
factor
-
1
pad1
=
p
//
2
+
1
self
.
blur
=
Blur
(
blur_kernel
,
(
pad0
,
pad1
),
upsample_factor
=
factor
)
# build blurry layer for downsampling
if
downsample
:
factor
=
2
p
=
(
len
(
blur_kernel
)
-
factor
)
+
(
kernel_size
-
1
)
pad0
=
(
p
+
1
)
//
2
pad1
=
p
//
2
self
.
blur
=
Blur
(
blur_kernel
,
pad
=
(
pad0
,
pad1
))
# add equalized_lr hook for conv weight
if
equalized_lr_cfg
is
not
None
:
equalized_lr
(
self
,
**
equalized_lr_cfg
)
self
.
padding
=
padding
if
padding
else
(
kernel_size
//
2
)
def
forward
(
self
,
x
,
style
,
input_gain
=
None
):
n
,
c
,
h
,
w
=
x
.
shape
weight
=
self
.
weight
# Pre-normalize inputs to avoid FP16 overflow.
if
x
.
dtype
==
torch
.
float16
and
self
.
demodulate
:
weight
=
weight
*
(
1
/
np
.
sqrt
(
self
.
in_channels
*
self
.
kernel_size
*
self
.
kernel_size
)
/
weight
.
norm
(
float
(
'inf'
),
dim
=
[
1
,
2
,
3
],
keepdim
=
True
)
)
# max_Ikk
style
=
style
/
style
.
norm
(
float
(
'inf'
),
dim
=
1
,
keepdim
=
True
)
# max_I
# process style code
style
=
self
.
style_modulation
(
style
).
view
(
n
,
1
,
c
,
1
,
1
)
+
self
.
style_bias
# combine weight and style
weight
=
weight
*
style
if
self
.
demodulate
:
demod
=
torch
.
rsqrt
(
weight
.
pow
(
2
).
sum
([
2
,
3
,
4
])
+
self
.
eps
)
weight
=
weight
*
demod
.
view
(
n
,
self
.
out_channels
,
1
,
1
,
1
)
if
input_gain
is
not
None
:
# input_gain shape [batch, in_ch]
input_gain
=
input_gain
.
expand
(
n
,
self
.
in_channels
)
# weight shape [batch, out_ch, in_ch, kernel_size, kernel_size]
weight
=
weight
*
input_gain
.
unsqueeze
(
1
).
unsqueeze
(
3
).
unsqueeze
(
4
)
weight
=
weight
.
view
(
n
*
self
.
out_channels
,
c
,
self
.
kernel_size
,
self
.
kernel_size
)
weight
=
weight
.
to
(
x
.
dtype
)
if
self
.
upsample
:
x
=
x
.
reshape
(
1
,
n
*
c
,
h
,
w
)
weight
=
weight
.
view
(
n
,
self
.
out_channels
,
c
,
self
.
kernel_size
,
self
.
kernel_size
)
weight
=
weight
.
transpose
(
1
,
2
).
reshape
(
n
*
c
,
self
.
out_channels
,
self
.
kernel_size
,
self
.
kernel_size
)
x
=
conv_transpose2d
(
x
,
weight
,
padding
=
0
,
stride
=
2
,
groups
=
n
)
x
=
x
.
reshape
(
n
,
self
.
out_channels
,
*
x
.
shape
[
-
2
:])
x
=
self
.
blur
(
x
)
elif
self
.
downsample
:
x
=
self
.
blur
(
x
)
x
=
x
.
view
(
1
,
n
*
self
.
in_channels
,
*
x
.
shape
[
-
2
:])
x
=
conv2d
(
x
,
weight
,
stride
=
2
,
padding
=
0
,
groups
=
n
)
x
=
x
.
view
(
n
,
self
.
out_channels
,
*
x
.
shape
[
-
2
:])
else
:
x
=
x
.
reshape
(
1
,
n
*
c
,
h
,
w
)
x
=
conv2d
(
x
,
weight
,
stride
=
1
,
padding
=
self
.
padding
,
groups
=
n
)
x
=
x
.
view
(
n
,
self
.
out_channels
,
*
x
.
shape
[
-
2
:])
return
x
class
NoiseInjection
(
nn
.
Module
):
"""Noise Injection Module.
In StyleGAN2, they adopt this module to inject spatial random noise map in
the generators.
Args:
noise_weight_init (float, optional): Initialization weight for noise
injection. Defaults to ``0.``.
"""
def
__init__
(
self
,
noise_weight_init
=
0.
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
1
).
fill_
(
noise_weight_init
))
def
forward
(
self
,
image
,
noise
=
None
,
return_noise
=
False
):
"""Forward Function.
Args:
image (Tensor): Spatial features with a shape of (N, C, H, W).
noise (Tensor, optional): Noises from the outside.
Defaults to None.
return_noise (bool, optional): Whether to return noise tensor.
Defaults to False.
Returns:
Tensor: Output features.
"""
if
noise
is
None
:
batch
,
_
,
height
,
width
=
image
.
shape
noise
=
image
.
new_empty
(
batch
,
1
,
height
,
width
).
normal_
()
noise
=
noise
.
to
(
image
.
dtype
)
if
return_noise
:
return
image
+
self
.
weight
.
to
(
image
.
dtype
)
*
noise
,
noise
return
image
+
self
.
weight
.
to
(
image
.
dtype
)
*
noise
class
ConstantInput
(
nn
.
Module
):
"""Constant Input.
In StyleGAN2, they substitute the original head noise input with such a
constant input module.
Args:
channel (int): Channels for the constant input tensor.
size (int, optional): Spatial size for the constant input.
Defaults to 4.
"""
def
__init__
(
self
,
channel
,
size
=
4
):
super
().
__init__
()
if
isinstance
(
size
,
int
):
size
=
[
size
,
size
]
elif
mmcv
.
is_seq_of
(
size
,
int
):
assert
len
(
size
)
==
2
,
f
'The length of size should be 2 but got
{
len
(
size
)
}
'
else
:
raise
ValueError
(
f
'Got invalid value in size,
{
size
}
'
)
self
.
input
=
nn
.
Parameter
(
torch
.
randn
(
1
,
channel
,
*
size
))
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, ...).
Returns:
Tensor: Output feature map.
"""
batch
=
x
.
shape
[
0
]
out
=
self
.
input
.
repeat
(
batch
,
1
,
1
,
1
)
return
out
class
ModulatedPEConv2d
(
nn
.
Module
):
r
"""Modulated Conv2d in StyleGANv2 with Positional Encoding (PE).
This module is modified from the ``ModulatedConv2d`` in StyleGAN2 to
support the experiments in: Positional Encoding as Spatial Inductive Bias
in GANs, CVPR'2021.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
style_channels (int): Channels for the style codes.
demodulate (bool, optional): Whether to adopt demodulation.
Defaults to True.
upsample (bool, optional): Whether to adopt upsampling in features.
Defaults to False.
downsample (bool, optional): Whether to adopt downsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
equalized_lr_cfg (dict | None, optional): Configs for equalized lr.
Defaults to dict(mode='fan_in', lr_mul=1., gain=1.).
style_mod_cfg (dict, optional): Configs for style modulation module.
Defaults to dict(bias_init=1.).
style_bias (float, optional): Bias value for style code.
Defaults to 0..
eps (float, optional): Epsilon value to avoid computation error.
Defaults to 1e-8.
no_pad (bool, optional): Whether to removing the padding in
convolution. Defaults to False.
deconv2conv (bool, optional): Whether to substitute the transposed conv
with (conv2d, upsampling). Defaults to False.
interp_pad (int | None, optional): The padding number of interpolation
pad. Defaults to None.
up_config (dict, optional): Upsampling config.
Defaults to dict(scale_factor=2, mode='nearest').
up_after_conv (bool, optional): Whether to adopt upsampling after
convolution. Defaults to False.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
style_channels
,
demodulate
=
True
,
upsample
=
False
,
downsample
=
False
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
equalized_lr_cfg
=
dict
(
mode
=
'fan_in'
,
lr_mul
=
1.
,
gain
=
1.
),
style_mod_cfg
=
dict
(
bias_init
=
1.
),
style_bias
=
0.
,
eps
=
1e-8
,
no_pad
=
False
,
deconv2conv
=
False
,
interp_pad
=
None
,
up_config
=
dict
(
scale_factor
=
2
,
mode
=
'nearest'
),
up_after_conv
=
False
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
kernel_size
self
.
style_channels
=
style_channels
self
.
demodulate
=
demodulate
# sanity check for kernel size
assert
isinstance
(
self
.
kernel_size
,
int
)
and
(
self
.
kernel_size
>=
1
and
self
.
kernel_size
%
2
==
1
)
self
.
upsample
=
upsample
self
.
downsample
=
downsample
self
.
style_bias
=
style_bias
self
.
eps
=
eps
self
.
no_pad
=
no_pad
self
.
deconv2conv
=
deconv2conv
self
.
interp_pad
=
interp_pad
self
.
with_interp_pad
=
interp_pad
is
not
None
self
.
up_config
=
deepcopy
(
up_config
)
self
.
up_after_conv
=
up_after_conv
# build style modulation module
style_mod_cfg
=
dict
()
if
style_mod_cfg
is
None
else
style_mod_cfg
self
.
style_modulation
=
EqualLinearActModule
(
style_channels
,
in_channels
,
**
style_mod_cfg
)
# set lr_mul for conv weight
lr_mul_
=
1.
if
equalized_lr_cfg
is
not
None
:
lr_mul_
=
equalized_lr_cfg
.
get
(
'lr_mul'
,
1.
)
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
1
,
out_channels
,
in_channels
,
kernel_size
,
kernel_size
).
div_
(
lr_mul_
))
# build blurry layer for upsampling
if
upsample
and
not
self
.
deconv2conv
:
factor
=
2
p
=
(
len
(
blur_kernel
)
-
factor
)
-
(
kernel_size
-
1
)
pad0
=
(
p
+
1
)
//
2
+
factor
-
1
pad1
=
p
//
2
+
1
self
.
blur
=
Blur
(
blur_kernel
,
(
pad0
,
pad1
),
upsample_factor
=
factor
)
# build blurry layer for downsampling
if
downsample
:
factor
=
2
p
=
(
len
(
blur_kernel
)
-
factor
)
+
(
kernel_size
-
1
)
pad0
=
(
p
+
1
)
//
2
pad1
=
p
//
2
self
.
blur
=
Blur
(
blur_kernel
,
pad
=
(
pad0
,
pad1
))
# add equalized_lr hook for conv weight
if
equalized_lr_cfg
is
not
None
:
equalized_lr
(
self
,
**
equalized_lr_cfg
)
# if `no_pad`, remove all of the padding in conv
self
.
padding
=
kernel_size
//
2
if
not
no_pad
else
0
def
forward
(
self
,
x
,
style
):
"""Forward function.
Args:
x ([Tensor): Input features with shape of (N, C, H, W).
style (Tensor): Style latent with shape of (N, C).
Returns:
Tensor: Output feature with shape of (N, C, H, W).
"""
n
,
c
,
h
,
w
=
x
.
shape
# process style code
style
=
self
.
style_modulation
(
style
).
view
(
n
,
1
,
c
,
1
,
1
)
+
self
.
style_bias
# combine weight and style
weight
=
self
.
weight
*
style
if
self
.
demodulate
:
demod
=
torch
.
rsqrt
(
weight
.
pow
(
2
).
sum
([
2
,
3
,
4
])
+
self
.
eps
)
weight
=
weight
*
demod
.
view
(
n
,
self
.
out_channels
,
1
,
1
,
1
)
weight
=
weight
.
view
(
n
*
self
.
out_channels
,
c
,
self
.
kernel_size
,
self
.
kernel_size
)
if
self
.
upsample
and
not
self
.
deconv2conv
:
x
=
x
.
reshape
(
1
,
n
*
c
,
h
,
w
)
weight
=
weight
.
view
(
n
,
self
.
out_channels
,
c
,
self
.
kernel_size
,
self
.
kernel_size
)
weight
=
weight
.
transpose
(
1
,
2
).
reshape
(
n
*
c
,
self
.
out_channels
,
self
.
kernel_size
,
self
.
kernel_size
)
x
=
conv_transpose2d
(
x
,
weight
,
padding
=
0
,
stride
=
2
,
groups
=
n
)
x
=
x
.
reshape
(
n
,
self
.
out_channels
,
*
x
.
shape
[
-
2
:])
x
=
self
.
blur
(
x
)
elif
self
.
upsample
and
self
.
deconv2conv
:
if
self
.
up_after_conv
:
x
=
x
.
reshape
(
1
,
n
*
c
,
h
,
w
)
x
=
conv2d
(
x
,
weight
,
padding
=
self
.
padding
,
groups
=
n
)
x
=
x
.
view
(
n
,
self
.
out_channels
,
*
x
.
shape
[
2
:
4
])
if
self
.
with_interp_pad
:
h_
,
w_
=
x
.
shape
[
-
2
:]
up_cfg_
=
deepcopy
(
self
.
up_config
)
up_scale
=
up_cfg_
.
pop
(
'scale_factor'
)
size_
=
(
h_
*
up_scale
+
self
.
interp_pad
,
w_
*
up_scale
+
self
.
interp_pad
)
x
=
F
.
interpolate
(
x
,
size
=
size_
,
**
up_cfg_
)
else
:
x
=
F
.
interpolate
(
x
,
**
self
.
up_config
)
if
not
self
.
up_after_conv
:
h_
,
w_
=
x
.
shape
[
-
2
:]
x
=
x
.
view
(
1
,
n
*
c
,
h_
,
w_
)
x
=
conv2d
(
x
,
weight
,
padding
=
self
.
padding
,
groups
=
n
)
x
=
x
.
view
(
n
,
self
.
out_channels
,
*
x
.
shape
[
2
:
4
])
elif
self
.
downsample
:
x
=
self
.
blur
(
x
)
x
=
x
.
view
(
1
,
n
*
self
.
in_channels
,
*
x
.
shape
[
-
2
:])
x
=
conv2d
(
x
,
weight
,
stride
=
2
,
padding
=
0
,
groups
=
n
)
x
=
x
.
view
(
n
,
self
.
out_channels
,
*
x
.
shape
[
-
2
:])
else
:
x
=
x
.
view
(
1
,
n
*
c
,
h
,
w
)
x
=
conv2d
(
x
,
weight
,
stride
=
1
,
padding
=
self
.
padding
,
groups
=
n
)
x
=
x
.
view
(
n
,
self
.
out_channels
,
*
x
.
shape
[
-
2
:])
return
x
class
ModulatedStyleConv
(
nn
.
Module
):
"""Modulated Style Convolution.
In this module, we integrate the modulated conv2d, noise injector and
activation layers into together.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
style_channels (int): Channels for the style codes.
demodulate (bool, optional): Whether to adopt demodulation.
Defaults to True.
upsample (bool, optional): Whether to adopt upsampling in features.
Defaults to False.
downsample (bool, optional): Whether to adopt downsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
equalized_lr_cfg (dict | None, optional): Configs for equalized lr.
Defaults to dict(mode='fan_in', lr_mul=1., gain=1.).
style_mod_cfg (dict, optional): Configs for style modulation module.
Defaults to dict(bias_init=1.).
style_bias (float, optional): Bias value for style code.
Defaults to ``0.``.
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. Defaults to False.
conv_clamp (float, optional): Clamp the convolutional layer results to
avoid gradient overflow. Defaults to `256.0`.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
style_channels
,
upsample
=
False
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
demodulate
=
True
,
style_mod_cfg
=
dict
(
bias_init
=
1.
),
style_bias
=
0.
,
fp16_enabled
=
False
,
conv_clamp
=
256
):
super
().
__init__
()
# add support for fp16
self
.
fp16_enabled
=
fp16_enabled
self
.
conv_clamp
=
float
(
conv_clamp
)
self
.
conv
=
ModulatedConv2d
(
in_channels
,
out_channels
,
kernel_size
,
style_channels
,
demodulate
=
demodulate
,
upsample
=
upsample
,
blur_kernel
=
blur_kernel
,
style_mod_cfg
=
style_mod_cfg
,
style_bias
=
style_bias
)
self
.
noise_injector
=
NoiseInjection
()
self
.
activate
=
_FusedBiasLeakyReLU
(
out_channels
)
# if self.fp16_enabled:
# self.half()
@
auto_fp16
(
apply_to
=
(
'x'
,
'noise'
))
def
forward
(
self
,
x
,
style
,
noise
=
None
,
return_noise
=
False
):
"""Forward Function.
Args:
x ([Tensor): Input features with shape of (N, C, H, W).
style (Tensor): Style latent with shape of (N, C).
noise (Tensor, optional): Noise for injection. Defaults to None.
return_noise (bool, optional): Whether to return noise tensors.
Defaults to False.
Returns:
Tensor: Output features with shape of (N, C, H, W)
"""
out
=
self
.
conv
(
x
,
style
)
if
return_noise
:
out
,
noise
=
self
.
noise_injector
(
out
,
noise
=
noise
,
return_noise
=
return_noise
)
else
:
out
=
self
.
noise_injector
(
out
,
noise
=
noise
,
return_noise
=
return_noise
)
# TODO: FP16 in activate layers
out
=
self
.
activate
(
out
)
if
self
.
fp16_enabled
:
out
=
torch
.
clamp
(
out
,
min
=-
self
.
conv_clamp
,
max
=
self
.
conv_clamp
)
if
return_noise
:
return
out
,
noise
return
out
class
ModulatedPEStyleConv
(
nn
.
Module
):
"""Modulated Style Convolution with Positional Encoding.
This module is modified from the ``ModulatedStyleConv`` in StyleGAN2 to
support the experiments in: Positional Encoding as Spatial Inductive Bias
in GANs, CVPR'2021.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
style_channels (int): Channels for the style codes.
demodulate (bool, optional): Whether to adopt demodulation.
Defaults to True.
upsample (bool, optional): Whether to adopt upsampling in features.
Defaults to False.
downsample (bool, optional): Whether to adopt downsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
equalized_lr_cfg (dict | None, optional): Configs for equalized lr.
Defaults to dict(mode='fan_in', lr_mul=1., gain=1.).
style_mod_cfg (dict, optional): Configs for style modulation module.
Defaults to dict(bias_init=1.).
style_bias (float, optional): Bias value for style code.
Defaults to 0..
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
style_channels
,
upsample
=
False
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
demodulate
=
True
,
style_mod_cfg
=
dict
(
bias_init
=
1.
),
style_bias
=
0.
,
**
kwargs
):
super
().
__init__
()
self
.
conv
=
ModulatedPEConv2d
(
in_channels
,
out_channels
,
kernel_size
,
style_channels
,
demodulate
=
demodulate
,
upsample
=
upsample
,
blur_kernel
=
blur_kernel
,
style_mod_cfg
=
style_mod_cfg
,
style_bias
=
style_bias
,
**
kwargs
)
self
.
noise_injector
=
NoiseInjection
()
self
.
activate
=
_FusedBiasLeakyReLU
(
out_channels
)
def
forward
(
self
,
x
,
style
,
noise
=
None
,
return_noise
=
False
):
"""Forward Function.
Args:
x ([Tensor): Input features with shape of (N, C, H, W).
style (Tensor): Style latent with shape of (N, C).
noise (Tensor, optional): Noise for injection. Defaults to None.
return_noise (bool, optional): Whether to return noise tensors.
Defaults to False.
Returns:
Tensor: Output features with shape of (N, C, H, W)
"""
out
=
self
.
conv
(
x
,
style
)
if
return_noise
:
out
,
noise
=
self
.
noise_injector
(
out
,
noise
=
noise
,
return_noise
=
return_noise
)
else
:
out
=
self
.
noise_injector
(
out
,
noise
=
noise
,
return_noise
=
return_noise
)
out
=
self
.
activate
(
out
)
if
return_noise
:
return
out
,
noise
return
out
class
ModulatedToRGB
(
nn
.
Module
):
"""To RGB layer.
This module is designed to output image tensor in StyleGAN2.
Args:
in_channels (int): Input channels.
style_channels (int): Channels for the style codes.
out_channels (int, optional): Output channels. Defaults to 3.
upsample (bool, optional): Whether to adopt upsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
style_mod_cfg (dict, optional): Configs for style modulation module.
Defaults to dict(bias_init=1.).
style_bias (float, optional): Bias value for style code.
Defaults to 0..
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. Defaults to False.
conv_clamp (float, optional): Clamp the convolutional layer results to
avoid gradient overflow. Defaults to `256.0`.
out_fp32 (bool, optional): Whether to convert the output feature map to
`torch.float32`. Defaults to `True`.
"""
def
__init__
(
self
,
in_channels
,
style_channels
,
out_channels
=
3
,
upsample
=
True
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
style_mod_cfg
=
dict
(
bias_init
=
1.
),
style_bias
=
0.
,
fp16_enabled
=
False
,
conv_clamp
=
256
,
out_fp32
=
True
):
super
().
__init__
()
if
upsample
:
self
.
upsample
=
UpsampleUpFIRDn
(
blur_kernel
)
# add support for fp16
self
.
fp16_enabled
=
fp16_enabled
self
.
conv_clamp
=
float
(
conv_clamp
)
self
.
conv
=
ModulatedConv2d
(
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
style_channels
=
style_channels
,
demodulate
=
False
,
style_mod_cfg
=
style_mod_cfg
,
style_bias
=
style_bias
)
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
3
,
1
,
1
))
# enforece the output to be fp32 (follow Tero's implementation)
self
.
out_fp32
=
out_fp32
@
auto_fp16
(
apply_to
=
(
'x'
,
'style'
))
def
forward
(
self
,
x
,
style
,
skip
=
None
):
"""Forward Function.
Args:
x ([Tensor): Input features with shape of (N, C, H, W).
style (Tensor): Style latent with shape of (N, C).
skip (Tensor, optional): Tensor for skip link. Defaults to None.
Returns:
Tensor: Output features with shape of (N, C, H, W)
"""
out
=
self
.
conv
(
x
,
style
)
out
=
out
+
self
.
bias
.
to
(
x
.
dtype
)
if
self
.
fp16_enabled
:
out
=
torch
.
clamp
(
out
,
min
=-
self
.
conv_clamp
,
max
=
self
.
conv_clamp
)
# Here, Tero adopts FP16 at `skip`.
if
skip
is
not
None
:
skip
=
self
.
upsample
(
skip
)
out
=
out
+
skip
return
out
class
ConvDownLayer
(
nn
.
Sequential
):
"""Convolution and Downsampling layer.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
downsample (bool, optional): Whether to adopt downsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
bias (bool, optional): Whether to use bias parameter. Defaults to True.
act_cfg (dict, optional): Activation configs.
Defaults to dict(type='fused_bias').
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. Defaults to False.
conv_clamp (float, optional): Clamp the convolutional layer results to
avoid gradient overflow. Defaults to `256.0`.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
downsample
=
False
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
bias
=
True
,
act_cfg
=
dict
(
type
=
'fused_bias'
),
fp16_enabled
=
False
,
conv_clamp
=
256.
):
self
.
fp16_enabled
=
fp16_enabled
self
.
conv_clamp
=
float
(
conv_clamp
)
layers
=
[]
if
downsample
:
factor
=
2
p
=
(
len
(
blur_kernel
)
-
factor
)
+
(
kernel_size
-
1
)
pad0
=
(
p
+
1
)
//
2
pad1
=
p
//
2
layers
.
append
(
Blur
(
blur_kernel
,
pad
=
(
pad0
,
pad1
)))
stride
=
2
self
.
padding
=
0
else
:
stride
=
1
self
.
padding
=
kernel_size
//
2
self
.
with_fused_bias
=
act_cfg
is
not
None
and
act_cfg
.
get
(
'type'
)
==
'fused_bias'
if
self
.
with_fused_bias
:
conv_act_cfg
=
None
else
:
conv_act_cfg
=
act_cfg
layers
.
append
(
EqualizedLRConvModule
(
in_channels
,
out_channels
,
kernel_size
,
padding
=
self
.
padding
,
stride
=
stride
,
bias
=
bias
and
not
self
.
with_fused_bias
,
norm_cfg
=
None
,
act_cfg
=
conv_act_cfg
,
equalized_lr_cfg
=
dict
(
mode
=
'fan_in'
,
gain
=
1.
)))
if
self
.
with_fused_bias
:
layers
.
append
(
_FusedBiasLeakyReLU
(
out_channels
))
super
(
ConvDownLayer
,
self
).
__init__
(
*
layers
)
@
auto_fp16
(
apply_to
=
(
'x'
,
))
def
forward
(
self
,
x
):
x
=
super
().
forward
(
x
)
if
self
.
fp16_enabled
:
x
=
torch
.
clamp
(
x
,
min
=-
self
.
conv_clamp
,
max
=
self
.
conv_clamp
)
return
x
class
ResBlock
(
nn
.
Module
):
"""Residual block used in the discriminator of StyleGAN2.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. Defaults to False.
convert_input_fp32 (bool, optional): Whether to convert input type to
fp32 if not `fp16_enabled`. This argument is designed to deal with
the cases where some modules are run in FP16 and others in FP32.
Defaults to True.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
fp16_enabled
=
False
,
convert_input_fp32
=
True
):
super
().
__init__
()
self
.
fp16_enabled
=
fp16_enabled
self
.
convert_input_fp32
=
convert_input_fp32
self
.
conv1
=
ConvDownLayer
(
in_channels
,
in_channels
,
3
,
blur_kernel
=
blur_kernel
)
self
.
conv2
=
ConvDownLayer
(
in_channels
,
out_channels
,
3
,
downsample
=
True
,
blur_kernel
=
blur_kernel
)
self
.
skip
=
ConvDownLayer
(
in_channels
,
out_channels
,
1
,
downsample
=
True
,
act_cfg
=
None
,
bias
=
False
,
blur_kernel
=
blur_kernel
)
@
auto_fp16
()
def
forward
(
self
,
input
):
"""Forward function.
Args:
input (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map.
"""
# TODO: study whether this explicit datatype transfer will harm the
# apex training speed
if
not
self
.
fp16_enabled
and
self
.
convert_input_fp32
:
input
=
input
.
to
(
torch
.
float32
)
out
=
self
.
conv1
(
input
)
out
=
self
.
conv2
(
out
)
skip
=
self
.
skip
(
input
)
out
=
(
out
+
skip
)
/
np
.
sqrt
(
2
)
return
out
class
ModMBStddevLayer
(
nn
.
Module
):
"""Modified MiniBatch Stddev Layer.
This layer is modified from ``MiniBatchStddevLayer`` used in PGGAN. In
StyleGAN2, the authors add a new feature, `channel_groups`, into this
layer.
Note that to accelerate the training procedure, we also add a new feature
of ``sync_std`` to achieve multi-nodes/machine training. This feature is
still in beta version and we have tested it on 256 scales.
Args:
group_size (int, optional): The size of groups in batch dimension.
Defaults to 4.
channel_groups (int, optional): The size of groups in channel
dimension. Defaults to 1.
sync_std (bool, optional): Whether to use synchronized std feature.
Defaults to False.
sync_groups (int | None, optional): The size of groups in node
dimension. Defaults to None.
eps (float, optional): Epsilon value to avoid computation error.
Defaults to 1e-8.
"""
def
__init__
(
self
,
group_size
=
4
,
channel_groups
=
1
,
sync_std
=
False
,
sync_groups
=
None
,
eps
=
1e-8
):
super
().
__init__
()
self
.
group_size
=
group_size
self
.
eps
=
eps
self
.
channel_groups
=
channel_groups
self
.
sync_std
=
sync_std
self
.
sync_groups
=
group_size
if
sync_groups
is
None
else
sync_groups
if
self
.
sync_std
:
assert
torch
.
distributed
.
is_initialized
(
),
'Only in distributed training can the sync_std be activated.'
mmcv
.
print_log
(
'Adopt synced minibatch stddev layer'
,
'mmgen'
)
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map with shape of (N, C+1, H, W).
"""
if
self
.
sync_std
:
# concatenate all features
all_features
=
torch
.
cat
(
AllGatherLayer
.
apply
(
x
),
dim
=
0
)
# get the exact features we need in calculating std-dev
rank
,
ws
=
get_dist_info
()
local_bs
=
all_features
.
shape
[
0
]
//
ws
start_idx
=
local_bs
*
rank
# avoid the case where start idx near the tail of features
if
start_idx
+
self
.
sync_groups
>
all_features
.
shape
[
0
]:
start_idx
=
all_features
.
shape
[
0
]
-
self
.
sync_groups
end_idx
=
min
(
local_bs
*
rank
+
self
.
sync_groups
,
all_features
.
shape
[
0
])
x
=
all_features
[
start_idx
:
end_idx
]
# batch size should be smaller than or equal to group size. Otherwise,
# batch size should be divisible by the group size.
assert
x
.
shape
[
0
]
<=
self
.
group_size
or
x
.
shape
[
0
]
%
self
.
group_size
==
0
,
(
'Batch size be smaller than or equal '
'to group size. Otherwise,'
' batch size should be divisible by the group size.'
f
'But got batch size
{
x
.
shape
[
0
]
}
,'
f
' group size
{
self
.
group_size
}
'
)
assert
x
.
shape
[
1
]
%
self
.
channel_groups
==
0
,
(
'"channel_groups" must be divided by the feature channels. '
f
'channel_groups:
{
self
.
channel_groups
}
, '
f
'feature channels:
{
x
.
shape
[
1
]
}
'
)
n
,
c
,
h
,
w
=
x
.
shape
group_size
=
min
(
n
,
self
.
group_size
)
# [G, M, Gc, C', H, W]
y
=
torch
.
reshape
(
x
,
(
group_size
,
-
1
,
self
.
channel_groups
,
c
//
self
.
channel_groups
,
h
,
w
))
y
=
torch
.
var
(
y
,
dim
=
0
,
unbiased
=
False
)
y
=
torch
.
sqrt
(
y
+
self
.
eps
)
# [M, 1, 1, 1]
y
=
y
.
mean
(
dim
=
(
2
,
3
,
4
),
keepdim
=
True
).
squeeze
(
2
)
y
=
y
.
repeat
(
group_size
,
1
,
h
,
w
)
return
torch
.
cat
([
x
,
y
],
dim
=
1
)
Prev
1
2
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment