Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
stylegan2_mmcv
Commits
1401de15
Commit
1401de15
authored
Jun 28, 2024
by
dongchy920
Browse files
stylegan2_mmcv
parents
Pipeline
#1274
canceled with stages
Changes
463
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4438 additions
and
0 deletions
+4438
-0
build/lib/mmgen/models/architectures/lpips/networks_basic.py
build/lib/mmgen/models/architectures/lpips/networks_basic.py
+213
-0
build/lib/mmgen/models/architectures/lpips/perceptual_loss.py
...d/lib/mmgen/models/architectures/lpips/perceptual_loss.py
+62
-0
build/lib/mmgen/models/architectures/lpips/pretrained_networks.py
...b/mmgen/models/architectures/lpips/pretrained_networks.py
+54
-0
build/lib/mmgen/models/architectures/lsgan/__init__.py
build/lib/mmgen/models/architectures/lsgan/__init__.py
+4
-0
build/lib/mmgen/models/architectures/lsgan/generator_discriminator.py
...gen/models/architectures/lsgan/generator_discriminator.py
+301
-0
build/lib/mmgen/models/architectures/pggan/__init__.py
build/lib/mmgen/models/architectures/pggan/__init__.py
+13
-0
build/lib/mmgen/models/architectures/pggan/generator_discriminator.py
...gen/models/architectures/pggan/generator_discriminator.py
+456
-0
build/lib/mmgen/models/architectures/pggan/modules.py
build/lib/mmgen/models/architectures/pggan/modules.py
+567
-0
build/lib/mmgen/models/architectures/pix2pix/__init__.py
build/lib/mmgen/models/architectures/pix2pix/__init__.py
+8
-0
build/lib/mmgen/models/architectures/pix2pix/generator_discriminator.py
...n/models/architectures/pix2pix/generator_discriminator.py
+252
-0
build/lib/mmgen/models/architectures/pix2pix/modules.py
build/lib/mmgen/models/architectures/pix2pix/modules.py
+172
-0
build/lib/mmgen/models/architectures/positional_encoding.py
build/lib/mmgen/models/architectures/positional_encoding.py
+211
-0
build/lib/mmgen/models/architectures/singan/__init__.py
build/lib/mmgen/models/architectures/singan/__init__.py
+9
-0
build/lib/mmgen/models/architectures/singan/generator_discriminator.py
...en/models/architectures/singan/generator_discriminator.py
+262
-0
build/lib/mmgen/models/architectures/singan/modules.py
build/lib/mmgen/models/architectures/singan/modules.py
+230
-0
build/lib/mmgen/models/architectures/singan/positional_encoding.py
.../mmgen/models/architectures/singan/positional_encoding.py
+237
-0
build/lib/mmgen/models/architectures/sngan_proj/__init__.py
build/lib/mmgen/models/architectures/sngan_proj/__init__.py
+8
-0
build/lib/mmgen/models/architectures/sngan_proj/generator_discriminator.py
...odels/architectures/sngan_proj/generator_discriminator.py
+756
-0
build/lib/mmgen/models/architectures/sngan_proj/modules.py
build/lib/mmgen/models/architectures/sngan_proj/modules.py
+610
-0
build/lib/mmgen/models/architectures/stylegan/__init__.py
build/lib/mmgen/models/architectures/stylegan/__init__.py
+13
-0
No files found.
Too many changes to show.
To preserve performance only
463 of 463+
files are displayed.
Plain diff
Email patch
build/lib/mmgen/models/architectures/lpips/networks_basic.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.nn
as
nn
from
.pretrained_networks
import
vgg16
def
normalize_tensor
(
in_feat
,
eps
=
1e-10
):
"""L2 normalization.
Args:
in_feat (Tensor): Tensor with shape [N, C, H, W].
eps (float, optional): Epsilon value to avoid computation error.
Defaults to 1e-10.
Returns:
Tensor: Tensor after L2 normalization per-instance.
"""
norm_factor
=
torch
.
sqrt
(
torch
.
sum
(
in_feat
**
2
,
dim
=
1
,
keepdim
=
True
))
return
in_feat
/
(
norm_factor
+
eps
)
def
spatial_average
(
in_tens
,
keepdim
=
True
):
"""Returns the mean value of each row of the input tensor in the spatial
dimension.
Args:
in_tens (Tensor): Tensor with shape [N, C, H, W].
keepdim (bool, optional): If keepdim is True, the output tensor is of
the shape [N, C, 1, 1]. Otherwise, the output will have shape
[N, C]. Defaults to True.
Returns:
Tensor: Tensor after average pooling to 1x1 with shape [N, C, 1, 1] or
[N, C].
"""
return
in_tens
.
mean
([
2
,
3
],
keepdim
=
keepdim
)
def
upsample
(
in_tens
,
out_H
=
64
):
# assumes scale factor is same for H and W
"""Upsamples the input to the given size.
Args:
in_tens (Tensor): Tensor with shape [N, C, H, W].
out_H (int, optional): Output spatial size. Defaults to 64.
Returns:
Tensor: Output Tensor.
"""
in_H
=
in_tens
.
shape
[
2
]
scale_factor
=
1.
*
out_H
/
in_H
return
nn
.
Upsample
(
scale_factor
=
scale_factor
,
mode
=
'bilinear'
,
align_corners
=
False
)(
in_tens
)
# Learned perceptual metric
class
PNetLin
(
nn
.
Module
):
r
"""
Ref: https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py # noqa
"""
def
__init__
(
self
,
pnet_rand
=
False
,
pnet_tune
=
False
,
use_dropout
=
True
,
spatial
=
False
,
version
=
'0.1'
,
lpips
=
True
):
super
().
__init__
()
self
.
pnet_tune
=
pnet_tune
self
.
pnet_rand
=
pnet_rand
self
.
spatial
=
spatial
self
.
lpips
=
lpips
self
.
version
=
version
self
.
scaling_layer
=
ScalingLayer
()
self
.
channels
=
[
64
,
128
,
256
,
512
,
512
]
self
.
L
=
len
(
self
.
channels
)
self
.
net
=
vgg16
(
pretrained
=
not
self
.
pnet_rand
,
requires_grad
=
self
.
pnet_tune
)
self
.
lin0
=
NetLinLayer
(
self
.
channels
[
0
],
use_dropout
=
use_dropout
)
self
.
lin1
=
NetLinLayer
(
self
.
channels
[
1
],
use_dropout
=
use_dropout
)
self
.
lin2
=
NetLinLayer
(
self
.
channels
[
2
],
use_dropout
=
use_dropout
)
self
.
lin3
=
NetLinLayer
(
self
.
channels
[
3
],
use_dropout
=
use_dropout
)
self
.
lin4
=
NetLinLayer
(
self
.
channels
[
4
],
use_dropout
=
use_dropout
)
self
.
lins
=
[
self
.
lin0
,
self
.
lin1
,
self
.
lin2
,
self
.
lin3
,
self
.
lin4
]
def
forward
(
self
,
in0
,
in1
,
retPerLayer
=
False
):
# v0.0 - original release had a bug, where input was not scaled
in0_input
,
in1_input
=
(
self
.
scaling_layer
(
in0
),
self
.
scaling_layer
(
in1
))
if
self
.
version
==
'0.1'
else
(
in0
,
in1
)
outs0
,
outs1
=
self
.
net
.
forward
(
in0_input
),
self
.
net
.
forward
(
in1_input
)
feats0
,
feats1
,
diffs
=
{},
{},
{}
for
kk
in
range
(
self
.
L
):
feats0
[
kk
],
feats1
[
kk
]
=
normalize_tensor
(
outs0
[
kk
]),
normalize_tensor
(
outs1
[
kk
])
diffs
[
kk
]
=
(
feats0
[
kk
]
-
feats1
[
kk
])
**
2
if
self
.
lpips
:
if
self
.
spatial
:
res
=
[
upsample
(
self
.
lins
[
kk
].
model
(
diffs
[
kk
]),
out_H
=
in0
.
shape
[
2
])
for
kk
in
range
(
self
.
L
)
]
else
:
res
=
[
spatial_average
(
self
.
lins
[
kk
].
model
(
diffs
[
kk
]),
keepdim
=
True
)
for
kk
in
range
(
self
.
L
)
]
else
:
if
self
.
spatial
:
res
=
[
upsample
(
diffs
[
kk
].
sum
(
dim
=
1
,
keepdim
=
True
),
out_H
=
in0
.
shape
[
2
])
for
kk
in
range
(
self
.
L
)
]
else
:
res
=
[
spatial_average
(
diffs
[
kk
].
sum
(
dim
=
1
,
keepdim
=
True
),
keepdim
=
True
)
for
kk
in
range
(
self
.
L
)
]
val
=
sum
(
res
)
if
retPerLayer
:
return
(
val
,
res
)
return
val
class
ScalingLayer
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
register_buffer
(
'shift'
,
torch
.
Tensor
([
-
.
030
,
-
.
088
,
-
.
188
])[
None
,
:,
None
,
None
])
self
.
register_buffer
(
'scale'
,
torch
.
Tensor
([.
458
,
.
448
,
.
450
])[
None
,
:,
None
,
None
])
def
forward
(
self
,
inp
):
return
(
inp
-
self
.
shift
)
/
self
.
scale
class
NetLinLayer
(
nn
.
Module
):
"""A single linear layer which does a 1x1 conv."""
def
__init__
(
self
,
chn_in
,
chn_out
=
1
,
use_dropout
=
False
):
super
().
__init__
()
layers
=
[
nn
.
Dropout
(),
]
if
(
use_dropout
)
else
[]
layers
+=
[
nn
.
Conv2d
(
chn_in
,
chn_out
,
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
),
]
self
.
model
=
nn
.
Sequential
(
*
layers
)
class
Dist2LogitLayer
(
nn
.
Module
):
"""takes 2 distances, puts through fc layers, spits out value between [0,
1] (if use_sigmoid is True)"""
def
__init__
(
self
,
chn_mid
=
32
,
use_sigmoid
=
True
):
super
().
__init__
()
layers
=
[
nn
.
Conv2d
(
5
,
chn_mid
,
1
,
stride
=
1
,
padding
=
0
,
bias
=
True
),
]
layers
+=
[
nn
.
LeakyReLU
(
0.2
,
True
),
]
layers
+=
[
nn
.
Conv2d
(
chn_mid
,
chn_mid
,
1
,
stride
=
1
,
padding
=
0
,
bias
=
True
),
]
layers
+=
[
nn
.
LeakyReLU
(
0.2
,
True
),
]
layers
+=
[
nn
.
Conv2d
(
chn_mid
,
1
,
1
,
stride
=
1
,
padding
=
0
,
bias
=
True
),
]
if
use_sigmoid
:
layers
+=
[
nn
.
Sigmoid
(),
]
self
.
model
=
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
d0
,
d1
,
eps
=
0.1
):
return
self
.
model
.
forward
(
torch
.
cat
((
d0
,
d1
,
d0
-
d1
,
d0
/
(
d1
+
eps
),
d1
/
(
d0
+
eps
)),
dim
=
1
))
class
BCERankingLoss
(
nn
.
Module
):
def
__init__
(
self
,
chn_mid
=
32
):
super
().
__init__
()
self
.
net
=
Dist2LogitLayer
(
chn_mid
=
chn_mid
)
# self.parameters = list(self.net.parameters())
self
.
loss
=
torch
.
nn
.
BCELoss
()
def
forward
(
self
,
d0
,
d1
,
judge
):
per
=
(
judge
+
1.
)
/
2.
self
.
logit
=
self
.
net
.
forward
(
d0
,
d1
)
return
self
.
loss
(
self
.
logit
,
per
)
build/lib/mmgen/models/architectures/lpips/perceptual_loss.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
torch.utils.model_zoo
import
load_url
from
.networks_basic
import
PNetLin
LPIPS_WEIGHTS_URL
=
'https://download.openmmlab.com/mmgen/evaluation/lpips/weights/v0.1/vgg.pth'
# noqa
class
PerceptualLoss
(
torch
.
nn
.
Module
):
r
"""LPIPS metric with VGG using our perceptually-learned weights.
Ref: https://github.com/rosinality/stylegan2-pytorch/blob/master/lpips/__init__.py # noqa
"""
def
__init__
(
self
,
spatial
=
False
,
use_gpu
=
True
,
gpu_ids
=
[
0
],
pretrained
=
True
):
super
().
__init__
()
print
(
'Setting up Perceptual loss...'
)
self
.
use_gpu
=
use_gpu
self
.
spatial
=
spatial
self
.
gpu_ids
=
gpu_ids
print
(
'...[pnet-lin, vgg16] initializing'
)
self
.
init_net
(
pretrained
=
pretrained
)
print
(
'...Done'
)
def
forward
(
self
,
pred
,
target
,
normalize
=
False
):
if
normalize
:
target
=
2
*
target
-
1
pred
=
2
*
pred
-
1
return
self
.
net
(
target
,
pred
)
def
init_net
(
self
,
pnet_rand
=
False
,
pnet_tune
=
False
,
pretrained
=
True
,
version
=
'0.1'
):
self
.
net
=
PNetLin
(
pnet_rand
=
pnet_rand
,
pnet_tune
=
pnet_tune
,
use_dropout
=
True
,
spatial
=
self
.
spatial
,
version
=
version
,
lpips
=
True
)
if
pretrained
:
print
(
'Loading model from: %s'
%
LPIPS_WEIGHTS_URL
)
self
.
net
.
load_state_dict
(
load_url
(
LPIPS_WEIGHTS_URL
,
map_location
=
'cpu'
,
progress
=
True
),
strict
=
False
)
self
.
parameters
=
list
(
self
.
net
.
parameters
())
self
.
net
.
eval
()
if
self
.
use_gpu
:
self
.
net
.
to
(
self
.
gpu_ids
[
0
])
self
.
net
=
torch
.
nn
.
DataParallel
(
self
.
net
,
device_ids
=
self
.
gpu_ids
)
build/lib/mmgen/models/architectures/lpips/pretrained_networks.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
collections
import
namedtuple
import
torch
from
torchvision
import
models
as
tv
class
vgg16
(
torch
.
nn
.
Module
):
r
"""VGG16 feature extractor for LPIPS metric.
Ref : https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/pretrained_networks.py # noqa
"""
def
__init__
(
self
,
requires_grad
=
False
,
pretrained
=
True
):
super
().
__init__
()
vgg_pretrained_features
=
tv
.
vgg16
(
pretrained
=
pretrained
).
features
self
.
slice1
=
torch
.
nn
.
Sequential
()
self
.
slice2
=
torch
.
nn
.
Sequential
()
self
.
slice3
=
torch
.
nn
.
Sequential
()
self
.
slice4
=
torch
.
nn
.
Sequential
()
self
.
slice5
=
torch
.
nn
.
Sequential
()
self
.
N_slices
=
5
for
x
in
range
(
4
):
self
.
slice1
.
add_module
(
str
(
x
),
vgg_pretrained_features
[
x
])
for
x
in
range
(
4
,
9
):
self
.
slice2
.
add_module
(
str
(
x
),
vgg_pretrained_features
[
x
])
for
x
in
range
(
9
,
16
):
self
.
slice3
.
add_module
(
str
(
x
),
vgg_pretrained_features
[
x
])
for
x
in
range
(
16
,
23
):
self
.
slice4
.
add_module
(
str
(
x
),
vgg_pretrained_features
[
x
])
for
x
in
range
(
23
,
30
):
self
.
slice5
.
add_module
(
str
(
x
),
vgg_pretrained_features
[
x
])
if
not
requires_grad
:
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
X
):
h
=
self
.
slice1
(
X
)
h_relu1_2
=
h
h
=
self
.
slice2
(
h
)
h_relu2_2
=
h
h
=
self
.
slice3
(
h
)
h_relu3_3
=
h
h
=
self
.
slice4
(
h
)
h_relu4_3
=
h
h
=
self
.
slice5
(
h
)
h_relu5_3
=
h
vgg_outputs
=
namedtuple
(
'VggOutputs'
,
[
'relu1_2'
,
'relu2_2'
,
'relu3_3'
,
'relu4_3'
,
'relu5_3'
])
out
=
vgg_outputs
(
h_relu1_2
,
h_relu2_2
,
h_relu3_3
,
h_relu4_3
,
h_relu5_3
)
return
out
build/lib/mmgen/models/architectures/lsgan/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.generator_discriminator
import
LSGANDiscriminator
,
LSGANGenerator
__all__
=
[
'LSGANDiscriminator'
,
'LSGANGenerator'
]
build/lib/mmgen/models/architectures/lsgan/generator_discriminator.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
ConvModule
from
mmcv.cnn.bricks
import
build_activation_layer
from
mmgen.models.builder
import
MODULES
from
..common
import
get_module_device
@
MODULES
.
register_module
()
class
LSGANGenerator
(
nn
.
Module
):
"""Generator for LSGAN.
Implementation Details for LSGAN architecture:
#. Adopt transposed convolution in the generator;
#. Use batchnorm in the generator except for the final output layer;
#. Use ReLU in the generator in addition to the final output layer;
#. Keep channels of feature maps unchanged in the convolution backbone;
#. Use one more 3x3 conv every upsampling in the convolution backbone.
We follow the implementation details of the origin paper:
Least Squares Generative Adversarial Networks
https://arxiv.org/pdf/1611.04076.pdf
Args:
output_scale (int, optional): Output scale for the generated image.
Defaults to 128.
out_channels (int, optional): The channel number of the output feature.
Defaults to 3.
base_channels (int, optional): The basic channel number of the
generator. The other layers contains channels based on this number.
Defaults to 256.
input_scale (int, optional): The scale of the input 2D feature map.
Defaults to 8.
noise_size (int, optional): Size of the input noise
vector. Defaults to 1024.
conv_cfg (dict, optional): Config for the convolution module used in
this generator. Defaults to dict(type='ConvTranspose2d').
default_norm_cfg (dict, optional): Norm config for all of layers
except for the final output layer. Defaults to dict(type='BN').
default_act_cfg (dict, optional): Activation config for all of layers
except for the final output layer. Defaults to dict(type='ReLU').
out_act_cfg (dict, optional): Activation config for the final output
layer. Defaults to dict(type='Tanh').
"""
def
__init__
(
self
,
output_scale
=
128
,
out_channels
=
3
,
base_channels
=
256
,
input_scale
=
8
,
noise_size
=
1024
,
conv_cfg
=
dict
(
type
=
'ConvTranspose2d'
),
default_norm_cfg
=
dict
(
type
=
'BN'
),
default_act_cfg
=
dict
(
type
=
'ReLU'
),
out_act_cfg
=
dict
(
type
=
'Tanh'
)):
super
().
__init__
()
assert
output_scale
%
input_scale
==
0
assert
output_scale
//
input_scale
>=
4
self
.
output_scale
=
output_scale
self
.
base_channels
=
base_channels
self
.
input_scale
=
input_scale
self
.
noise_size
=
noise_size
self
.
noise2feat_head
=
nn
.
Sequential
(
nn
.
Linear
(
noise_size
,
input_scale
*
input_scale
*
base_channels
))
self
.
noise2feat_tail
=
nn
.
Sequential
(
nn
.
BatchNorm2d
(
base_channels
))
if
default_act_cfg
is
not
None
:
self
.
noise2feat_tail
.
add_module
(
'act'
,
build_activation_layer
(
default_act_cfg
))
# the number of times for upsampling
self
.
num_upsamples
=
int
(
np
.
log2
(
output_scale
//
input_scale
))
-
2
# build up convolution backbone (excluding the output layer)
self
.
conv_blocks
=
nn
.
ModuleList
()
for
_
in
range
(
self
.
num_upsamples
):
self
.
conv_blocks
.
append
(
ConvModule
(
base_channels
,
base_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
conv_cfg
=
dict
(
conv_cfg
,
output_padding
=
1
),
norm_cfg
=
default_norm_cfg
,
act_cfg
=
default_act_cfg
))
self
.
conv_blocks
.
append
(
ConvModule
(
base_channels
,
base_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
default_norm_cfg
,
act_cfg
=
default_act_cfg
))
# output blocks
self
.
conv_blocks
.
append
(
ConvModule
(
base_channels
,
int
(
base_channels
//
2
),
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
conv_cfg
=
dict
(
conv_cfg
,
output_padding
=
1
),
norm_cfg
=
default_norm_cfg
,
act_cfg
=
default_act_cfg
))
self
.
conv_blocks
.
append
(
ConvModule
(
int
(
base_channels
//
2
),
int
(
base_channels
//
4
),
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
conv_cfg
=
dict
(
conv_cfg
,
output_padding
=
1
),
norm_cfg
=
default_norm_cfg
,
act_cfg
=
default_act_cfg
))
self
.
conv_blocks
.
append
(
ConvModule
(
int
(
base_channels
//
4
),
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
None
,
act_cfg
=
out_act_cfg
))
def
forward
(
self
,
noise
,
num_batches
=
0
,
return_noise
=
False
):
"""Forward function.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
Returns:
torch.Tensor | dict: If not ``return_noise``, only the output image
will be returned. Otherwise, a dict contains ``fake_img`` and
``noise_batch`` will be returned.
"""
# receive noise and conduct sanity check.
if
isinstance
(
noise
,
torch
.
Tensor
):
assert
noise
.
shape
[
1
]
==
self
.
noise_size
if
noise
.
ndim
==
2
:
noise_batch
=
noise
else
:
raise
ValueError
(
'The noise should be in shape of (n, c)'
f
'but got
{
noise
.
shape
}
'
)
# receive a noise generator and sample noise.
elif
callable
(
noise
):
noise_generator
=
noise
assert
num_batches
>
0
noise_batch
=
noise_generator
((
num_batches
,
self
.
noise_size
))
# otherwise, we will adopt default noise sampler.
else
:
assert
num_batches
>
0
noise_batch
=
torch
.
randn
((
num_batches
,
self
.
noise_size
))
# dirty code for putting data on the right device
noise_batch
=
noise_batch
.
to
(
get_module_device
(
self
))
# noise2feat
x
=
self
.
noise2feat_head
(
noise_batch
)
x
=
x
.
reshape
(
(
-
1
,
self
.
base_channels
,
self
.
input_scale
,
self
.
input_scale
))
x
=
self
.
noise2feat_tail
(
x
)
# conv module
for
conv
in
self
.
conv_blocks
:
x
=
conv
(
x
)
if
return_noise
:
return
dict
(
fake_img
=
x
,
noise_batch
=
noise_batch
)
return
x
@
MODULES
.
register_module
()
class
LSGANDiscriminator
(
nn
.
Module
):
"""Discriminator for LSGAN.
Implementation Details for LSGAN architecture:
#. Adopt convolution in the discriminator;
#. Use batchnorm in the discriminator except for the input and final
\
output layer;
#. Use LeakyReLU in the discriminator in addition to the output layer;
#. Use fully connected layer in the output layer;
#. Use 5x5 conv rather than 4x4 conv in DCGAN.
Args:
input_scale (int, optional): The scale of the input image. Defaults to
128.
output_scale (int, optional): The final scale of the convolutional
feature. Defaults to 8.
out_channels (int, optional): The channel number of the final output
layer. Defaults to 1.
in_channels (int, optional): The channel number of the input image.
Defaults to 3.
base_channels (int, optional): The basic channel number of the
generator. The other layers contains channels based on this number.
Defaults to 128.
conv_cfg (dict, optional): Config for the convolution module used in
this discriminator. Defaults to dict(type='Conv2d').
default_norm_cfg (dict, optional): Norm config for all of layers
except for the final output layer. Defaults to ``dict(type='BN')``.
default_act_cfg (dict, optional): Activation config for all of layers
except for the final output layer. Defaults to
``dict(type='LeakyReLU', negative_slope=0.2)``.
out_act_cfg (dict, optional): Activation config for the final output
layer. Defaults to ``dict(type='Tanh')``.
"""
def
__init__
(
self
,
input_scale
=
128
,
output_scale
=
8
,
out_channels
=
1
,
in_channels
=
3
,
base_channels
=
64
,
conv_cfg
=
dict
(
type
=
'Conv2d'
),
default_norm_cfg
=
dict
(
type
=
'BN'
),
default_act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
),
out_act_cfg
=
None
):
super
().
__init__
()
assert
input_scale
%
output_scale
==
0
assert
input_scale
//
output_scale
>=
2
self
.
input_scale
=
input_scale
self
.
output_scale
=
output_scale
self
.
out_channels
=
out_channels
self
.
base_channels
=
base_channels
self
.
with_out_activation
=
out_act_cfg
is
not
None
self
.
conv_blocks
=
nn
.
ModuleList
()
self
.
conv_blocks
.
append
(
ConvModule
(
in_channels
,
base_channels
,
kernel_size
=
5
,
stride
=
2
,
padding
=
2
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
None
,
act_cfg
=
default_act_cfg
))
# the number of times for downsampling
self
.
num_downsamples
=
int
(
np
.
log2
(
input_scale
//
output_scale
))
-
1
# build up downsampling backbone (excluding the output layer)
curr_channels
=
base_channels
for
_
in
range
(
self
.
num_downsamples
):
self
.
conv_blocks
.
append
(
ConvModule
(
curr_channels
,
curr_channels
*
2
,
kernel_size
=
5
,
stride
=
2
,
padding
=
2
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
default_norm_cfg
,
act_cfg
=
default_act_cfg
))
curr_channels
=
curr_channels
*
2
# output layer
self
.
decision
=
nn
.
Sequential
(
nn
.
Linear
(
output_scale
*
output_scale
*
curr_channels
,
out_channels
))
if
self
.
with_out_activation
:
self
.
out_activation
=
build_activation_layer
(
out_act_cfg
)
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (torch.Tensor): Fake or real image tensor.
Returns:
torch.Tensor: Prediction for the reality of the input image.
"""
n
=
x
.
shape
[
0
]
for
conv
in
self
.
conv_blocks
:
x
=
conv
(
x
)
x
=
x
.
reshape
(
n
,
-
1
)
x
=
self
.
decision
(
x
)
if
self
.
with_out_activation
:
x
=
self
.
out_activation
(
x
)
return
x
build/lib/mmgen/models/architectures/pggan/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.generator_discriminator
import
PGGANDiscriminator
,
PGGANGenerator
from
.modules
import
(
EqualizedLR
,
EqualizedLRConvDownModule
,
EqualizedLRConvModule
,
EqualizedLRConvUpModule
,
EqualizedLRLinearModule
,
MiniBatchStddevLayer
,
PGGANNoiseTo2DFeat
,
PixelNorm
,
equalized_lr
)
__all__
=
[
'EqualizedLR'
,
'equalized_lr'
,
'EqualizedLRConvModule'
,
'EqualizedLRLinearModule'
,
'EqualizedLRConvUpModule'
,
'EqualizedLRConvDownModule'
,
'PixelNorm'
,
'MiniBatchStddevLayer'
,
'PGGANNoiseTo2DFeat'
,
'PGGANGenerator'
,
'PGGANDiscriminator'
]
build/lib/mmgen/models/architectures/pggan/generator_discriminator.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
from
functools
import
partial
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn.bricks.upsample
import
build_upsample_layer
from
mmgen.models.builder
import
MODULES
from
..common
import
get_module_device
from
.modules
import
(
EqualizedLRConvDownModule
,
EqualizedLRConvModule
,
EqualizedLRConvUpModule
,
MiniBatchStddevLayer
,
PGGANDecisionHead
,
PGGANNoiseTo2DFeat
)
@
MODULES
.
register_module
()
class
PGGANGenerator
(
nn
.
Module
):
"""Generator for PGGAN.
Args:
noise_size (int): Size of the input noise vector.
out_scale (int): Output scale for the generated image.
label_size (int, optional): Size of the label vector.
Defaults to 0.
base_channels (int, optional): The basic channel number of the
generator. The other layers contains channels based on this
number. Defaults to 8192.
channel_decay (float, optional): Decay for channels of feature maps.
Defaults to 1.0.
max_channels (int, optional): Maximum channels for the feature
maps in the generator block. Defaults to 512.
fused_upconv (bool, optional): Whether use fused upconv.
Defaults to True.
conv_module_cfg (dict, optional): Config for the convolution
module used in this generator. Defaults to None.
fused_upconv_cfg (dict, optional): Config for the fused upconv
module used in this generator. Defaults to None.
upsample_cfg (dict, optional): Config for the upsampling operation.
Defaults to None.
"""
_default_fused_upconv_cfg
=
dict
(
conv_cfg
=
dict
(
type
=
'deconv'
),
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
bias
=
True
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
),
norm_cfg
=
dict
(
type
=
'PixelNorm'
),
order
=
(
'conv'
,
'act'
,
'norm'
))
_default_conv_module_cfg
=
dict
(
conv_cfg
=
None
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
),
norm_cfg
=
dict
(
type
=
'PixelNorm'
),
order
=
(
'conv'
,
'act'
,
'norm'
))
_default_upsample_cfg
=
dict
(
type
=
'nearest'
,
scale_factor
=
2
)
def
__init__
(
self
,
noise_size
,
out_scale
,
label_size
=
0
,
base_channels
=
8192
,
channel_decay
=
1.
,
max_channels
=
512
,
fused_upconv
=
True
,
conv_module_cfg
=
None
,
fused_upconv_cfg
=
None
,
upsample_cfg
=
None
):
super
().
__init__
()
self
.
noise_size
=
noise_size
if
noise_size
else
min
(
base_channels
,
max_channels
)
self
.
out_scale
=
out_scale
self
.
out_log2_scale
=
int
(
np
.
log2
(
out_scale
))
# sanity check for the output scale
assert
out_scale
==
2
**
self
.
out_log2_scale
and
out_scale
>=
4
self
.
label_size
=
label_size
self
.
base_channels
=
base_channels
self
.
channel_decay
=
channel_decay
self
.
max_channels
=
max_channels
self
.
fused_upconv
=
fused_upconv
# set conv cfg
self
.
conv_module_cfg
=
deepcopy
(
self
.
_default_conv_module_cfg
)
# update with customized config
if
conv_module_cfg
:
self
.
conv_module_cfg
.
update
(
conv_module_cfg
)
if
self
.
fused_upconv
:
self
.
fused_upconv_cfg
=
deepcopy
(
self
.
_default_fused_upconv_cfg
)
# update with customized config
if
fused_upconv_cfg
:
self
.
fused_upconv_cfg
.
update
(
fused_upconv_cfg
)
self
.
upsample_cfg
=
deepcopy
(
self
.
_default_upsample_cfg
)
if
upsample_cfg
is
not
None
:
self
.
upsample_cfg
.
update
(
upsample_cfg
)
self
.
noise2feat
=
PGGANNoiseTo2DFeat
(
noise_size
+
label_size
,
self
.
_num_out_channels
(
1
))
self
.
torgb_layers
=
nn
.
ModuleList
()
self
.
conv_blocks
=
nn
.
ModuleList
()
for
s
in
range
(
2
,
self
.
out_log2_scale
+
1
):
in_ch
=
self
.
_num_out_channels
(
s
-
1
)
if
s
==
2
else
self
.
_num_out_channels
(
s
-
2
)
# setup torgb layers
self
.
torgb_layers
.
append
(
self
.
_get_torgb_layer
(
self
.
_num_out_channels
(
s
-
1
)))
# setup upconv or conv blocks
self
.
conv_blocks
.
extend
(
self
.
_get_upconv_block
(
in_ch
,
s
))
# build upsample layer for residual path
self
.
upsample_layer
=
build_upsample_layer
(
self
.
upsample_cfg
)
def
_get_torgb_layer
(
self
,
in_channels
):
return
EqualizedLRConvModule
(
in_channels
,
3
,
kernel_size
=
1
,
stride
=
1
,
equalized_lr_cfg
=
dict
(
gain
=
1
),
bias
=
True
,
norm_cfg
=
None
,
act_cfg
=
None
)
def
_num_out_channels
(
self
,
log_scale
):
return
min
(
int
(
self
.
base_channels
/
(
2.0
**
(
log_scale
*
self
.
channel_decay
))),
self
.
max_channels
)
def
_get_upconv_block
(
self
,
in_channels
,
log_scale
):
modules
=
[]
# start 4x4 scale
if
log_scale
==
2
:
modules
.
append
(
EqualizedLRConvModule
(
in_channels
,
self
.
_num_out_channels
(
log_scale
-
1
),
**
self
.
conv_module_cfg
))
# 8x8 --> 1024x1024 scales
else
:
if
self
.
fused_upconv
:
cfg_
=
dict
(
upsample
=
dict
(
type
=
'fused_nn'
))
cfg_
.
update
(
self
.
fused_upconv_cfg
)
else
:
cfg_
=
dict
(
upsample
=
self
.
upsample_cfg
)
cfg_
.
update
(
self
.
conv_module_cfg
)
# up + conv
modules
.
append
(
EqualizedLRConvUpModule
(
in_channels
,
self
.
_num_out_channels
(
log_scale
-
1
),
**
cfg_
))
# refine conv
modules
.
append
(
EqualizedLRConvModule
(
self
.
_num_out_channels
(
log_scale
-
1
),
self
.
_num_out_channels
(
log_scale
-
1
),
**
self
.
conv_module_cfg
))
return
modules
def
forward
(
self
,
noise
,
label
=
None
,
num_batches
=
0
,
return_noise
=
False
,
transition_weight
=
1.
,
curr_scale
=-
1
):
"""Forward function.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
label (Tensor, optional): Label vector with shape [N, C]. Defaults
to None.
num_batches (int, optional): The number of batch size. Defaults to
0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
transition_weight (float, optional): The weight used in resolution
transition. Defaults to 1.0.
curr_scale (int, optional): The scale for the current inference or
training. Defaults to -1.
Returns:
torch.Tensor | dict: If not ``return_noise``, only the output image
will be returned. Otherwise, a dict contains ``fake_img`` and
``noise_batch`` will be returned.
"""
# receive noise and conduct sanity check.
if
isinstance
(
noise
,
torch
.
Tensor
):
assert
noise
.
shape
[
1
]
==
self
.
noise_size
assert
noise
.
ndim
==
2
,
(
'The noise should be in shape of (n, c), '
f
'but got
{
noise
.
shape
}
'
)
noise_batch
=
noise
# receive a noise generator and sample noise.
elif
callable
(
noise
):
noise_generator
=
noise
assert
num_batches
>
0
noise_batch
=
noise_generator
((
num_batches
,
self
.
noise_size
))
# otherwise, we will adopt default noise sampler.
else
:
assert
num_batches
>
0
# TODO: check pggan default noise type
noise_batch
=
torch
.
randn
((
num_batches
,
self
.
noise_size
))
# dirty code for putting data on the right device
noise_batch
=
noise_batch
.
to
(
get_module_device
(
self
))
if
label
is
not
None
:
noise_batch
=
torch
.
cat
(
[
noise_batch
,
label
.
to
(
noise_batch
)],
dim
=
1
)
# noise vector to 2D feature
x
=
self
.
noise2feat
(
noise_batch
)
# build current computational graph
curr_log2_scale
=
self
.
out_log2_scale
if
curr_scale
<
0
else
int
(
np
.
log2
(
curr_scale
))
# 4x4 scale
x
=
self
.
conv_blocks
[
0
](
x
)
if
curr_log2_scale
<=
3
:
out_img
=
last_img
=
self
.
torgb_layers
[
0
](
x
)
# 8x8 and larger scales
for
s
in
range
(
3
,
curr_log2_scale
+
1
):
x
=
self
.
conv_blocks
[
2
*
s
-
5
](
x
)
x
=
self
.
conv_blocks
[
2
*
s
-
4
](
x
)
if
s
+
1
==
curr_log2_scale
:
last_img
=
self
.
torgb_layers
[
s
-
2
](
x
)
elif
s
==
curr_log2_scale
:
out_img
=
self
.
torgb_layers
[
s
-
2
](
x
)
residual_img
=
self
.
upsample_layer
(
last_img
)
out_img
=
residual_img
+
transition_weight
*
(
out_img
-
residual_img
)
if
return_noise
:
output
=
dict
(
fake_img
=
out_img
,
noise_batch
=
noise_batch
,
label
=
label
)
return
output
return
out_img
@
MODULES
.
register_module
()
class
PGGANDiscriminator
(
nn
.
Module
):
"""Discriminator for PGGAN.
Args:
in_scale (int): The scale of the input image.
label_size (int, optional): Size of the label vector. Defaults to
0.
base_channels (int, optional): The basic channel number of the
generator. The other layers contains channels based on this
number. Defaults to 8192.
max_channels (int, optional): Maximum channels for the feature
maps in the discriminator block. Defaults to 512.
in_channels (int, optional): Number of channels in input images.
Defaults to 3.
channel_decay (float, optional): Decay for channels of feature
maps. Defaults to 1.0.
mbstd_cfg (dict, optional): Configs for minibatch-stddev layer.
Defaults to dict(group_size=4).
fused_convdown (bool, optional): Whether use fused downconv.
Defaults to True.
conv_module_cfg (dict, optional): Config for the convolution
module used in this generator. Defaults to None.
fused_convdown_cfg (dict, optional): Config for the fused downconv
module used in this discriminator. Defaults to None.
fromrgb_layer_cfg (dict, optional): Config for the fromrgb layer.
Defaults to None.
downsample_cfg (dict, optional): Config for the downsampling
operation. Defaults to None.
"""
_default_fromrgb_cfg
=
dict
(
conv_cfg
=
None
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
True
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
),
norm_cfg
=
None
,
order
=
(
'conv'
,
'act'
,
'norm'
))
_default_conv_module_cfg
=
dict
(
kernel_size
=
3
,
padding
=
1
,
stride
=
1
,
norm_cfg
=
None
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
))
_default_convdown_cfg
=
dict
(
kernel_size
=
3
,
padding
=
1
,
stride
=
2
,
norm_cfg
=
None
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
))
def
__init__
(
self
,
in_scale
,
label_size
=
0
,
base_channels
=
8192
,
max_channels
=
512
,
in_channels
=
3
,
channel_decay
=
1.0
,
mbstd_cfg
=
dict
(
group_size
=
4
),
fused_convdown
=
True
,
conv_module_cfg
=
None
,
fused_convdown_cfg
=
None
,
fromrgb_layer_cfg
=
None
,
downsample_cfg
=
None
):
super
().
__init__
()
self
.
in_scale
=
in_scale
self
.
in_log2_scale
=
int
(
np
.
log2
(
self
.
in_scale
))
self
.
label_size
=
label_size
self
.
base_channels
=
base_channels
self
.
max_channels
=
max_channels
self
.
in_channels
=
in_channels
self
.
channel_decay
=
channel_decay
self
.
with_mbstd
=
mbstd_cfg
is
not
None
self
.
fused_convdown
=
fused_convdown
self
.
conv_module_cfg
=
deepcopy
(
self
.
_default_conv_module_cfg
)
if
conv_module_cfg
is
not
None
:
self
.
conv_module_cfg
.
update
(
conv_module_cfg
)
if
self
.
fused_convdown
:
self
.
fused_convdown_cfg
=
deepcopy
(
self
.
_default_convdown_cfg
)
if
fused_convdown_cfg
is
not
None
:
self
.
fused_convdown_cfg
.
update
(
fused_convdown_cfg
)
self
.
fromrgb_layer_cfg
=
deepcopy
(
self
.
_default_fromrgb_cfg
)
if
fromrgb_layer_cfg
:
self
.
fromrgb_layer_cfg
.
update
(
fromrgb_layer_cfg
)
# setup conv blocks
self
.
conv_blocks
=
nn
.
ModuleList
()
self
.
fromrgb_layers
=
nn
.
ModuleList
()
for
s
in
range
(
2
,
self
.
in_log2_scale
+
1
):
self
.
fromrgb_layers
.
append
(
self
.
_get_fromrgb_layer
(
self
.
in_channels
,
s
))
self
.
conv_blocks
.
extend
(
self
.
_get_convdown_block
(
self
.
_num_out_channels
(
s
-
1
),
s
))
# setup downsample layer
self
.
downsample_cfg
=
deepcopy
(
downsample_cfg
)
if
self
.
downsample_cfg
is
None
or
self
.
downsample_cfg
.
get
(
'type'
,
None
)
==
'avgpool'
:
self
.
downsample
=
nn
.
AvgPool2d
(
kernel_size
=
2
,
stride
=
2
)
elif
self
.
downsample_cfg
.
get
(
'type'
,
None
)
in
[
'nearest'
,
'bilinear'
]:
self
.
downsample
=
partial
(
F
.
interpolate
,
mode
=
self
.
downsample_cfg
.
pop
(
'type'
),
**
self
.
downsample_cfg
)
else
:
raise
NotImplementedError
(
'We have not supported the downsampling with type'
f
'
{
downsample_cfg
}
.'
)
# setup minibatch stddev layer
if
self
.
with_mbstd
:
self
.
mbstd_layer
=
MiniBatchStddevLayer
(
**
mbstd_cfg
)
# minibatch stddev layer will concatenate an additional feature map
# in channel dimension.
decision_in_channels
=
self
.
_num_out_channels
(
1
)
*
16
+
16
else
:
decision_in_channels
=
self
.
_num_out_channels
(
1
)
*
16
# setup decision layer
self
.
decision
=
PGGANDecisionHead
(
decision_in_channels
,
self
.
_num_out_channels
(
0
),
1
+
self
.
label_size
)
def
_num_out_channels
(
self
,
log_scale
):
return
min
(
int
(
self
.
base_channels
/
(
2.0
**
(
log_scale
*
self
.
channel_decay
))),
self
.
max_channels
)
def
_get_fromrgb_layer
(
self
,
in_channels
,
log2_scale
):
return
EqualizedLRConvModule
(
in_channels
,
self
.
_num_out_channels
(
log2_scale
-
1
),
**
self
.
fromrgb_layer_cfg
)
def
_get_convdown_block
(
self
,
in_channels
,
log2_scale
):
modules
=
[]
if
log2_scale
==
2
:
modules
.
append
(
EqualizedLRConvModule
(
in_channels
,
self
.
_num_out_channels
(
log2_scale
-
1
),
**
self
.
conv_module_cfg
))
else
:
modules
.
append
(
EqualizedLRConvModule
(
in_channels
,
self
.
_num_out_channels
(
log2_scale
-
1
),
**
self
.
conv_module_cfg
))
if
self
.
fused_convdown
:
cfg_
=
dict
(
downsample
=
dict
(
type
=
'fused_pool'
))
cfg_
.
update
(
self
.
fused_convdown_cfg
)
else
:
cfg_
=
dict
(
downsample
=
self
.
downsample
)
cfg_
.
update
(
self
.
conv_module_cfg
)
modules
.
append
(
EqualizedLRConvDownModule
(
self
.
_num_out_channels
(
log2_scale
-
1
),
self
.
_num_out_channels
(
log2_scale
-
2
),
**
cfg_
))
return
modules
def
forward
(
self
,
x
,
transition_weight
=
1.
,
curr_scale
=-
1
):
"""Forward function.
Args:
x (torch.Tensor): Input image tensor.
transition_weight (float, optional): The weight used in resolution
transition. Defaults to 1.0.
curr_scale (int, optional): The scale for the current inference or
training. Defaults to -1.
Returns:
Tensor: Predict score for the input image.
"""
curr_log2_scale
=
self
.
in_log2_scale
if
curr_scale
<
4
else
int
(
np
.
log2
(
curr_scale
))
original_img
=
x
x
=
self
.
fromrgb_layers
[
curr_log2_scale
-
2
](
x
)
for
s
in
range
(
curr_log2_scale
,
2
,
-
1
):
x
=
self
.
conv_blocks
[
2
*
s
-
5
](
x
)
x
=
self
.
conv_blocks
[
2
*
s
-
4
](
x
)
if
s
==
curr_log2_scale
:
img_down
=
self
.
downsample
(
original_img
)
y
=
self
.
fromrgb_layers
[
curr_log2_scale
-
3
](
img_down
)
x
=
y
+
transition_weight
*
(
x
-
y
)
if
self
.
with_mbstd
:
x
=
self
.
mbstd_layer
(
x
)
x
=
self
.
decision
(
x
)
if
self
.
label_size
>
0
:
return
x
[:,
:
1
],
x
[:,
1
:]
return
x
build/lib/mmgen/models/architectures/pggan/modules.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn.bricks
import
(
NORM_LAYERS
,
PLUGIN_LAYERS
,
ConvModule
,
build_activation_layer
,
build_norm_layer
,
build_upsample_layer
)
from
mmcv.cnn.utils
import
normal_init
from
torch.nn.init
import
_calculate_correct_fan
from
mmgen.models.builder
import
MODULES
from
mmgen.models.common
import
AllGatherLayer
class
EqualizedLR
:
r
"""Equalized Learning Rate.
This trick is proposed in:
Progressive Growing of GANs for Improved Quality, Stability, and Variation
The general idea is to dynamically rescale the weight in training instead
of in initializing so that the variance of the responses in each layer is
guaranteed with some statistical properties.
Note that this function is always combined with a convolution module which
is initialized with :math:`\mathcal{N}(0, 1)`.
Args:
name (str | optional): The name of weights. Defaults to 'weight'.
mode (str, optional): The mode of computing ``fan`` which is the
same as ``kaiming_init`` in pytorch. You can choose one from
['fan_in', 'fan_out']. Defaults to 'fan_in'.
"""
def
__init__
(
self
,
name
=
'weight'
,
gain
=
2
**
0.5
,
mode
=
'fan_in'
,
lr_mul
=
1.0
):
self
.
name
=
name
self
.
mode
=
mode
self
.
gain
=
gain
self
.
lr_mul
=
lr_mul
def
compute_weight
(
self
,
module
):
"""Compute weight with equalized learning rate.
Args:
module (nn.Module): A module that is wrapped with equalized lr.
Returns:
torch.Tensor: Updated weight.
"""
weight
=
getattr
(
module
,
self
.
name
+
'_orig'
)
if
weight
.
ndim
==
5
:
# weight in shape of [b, out, in, k, k]
fan
=
_calculate_correct_fan
(
weight
[
0
],
self
.
mode
)
else
:
assert
weight
.
ndim
<=
4
fan
=
_calculate_correct_fan
(
weight
,
self
.
mode
)
weight
=
weight
*
torch
.
tensor
(
self
.
gain
,
device
=
weight
.
device
)
*
torch
.
sqrt
(
torch
.
tensor
(
1.
/
fan
,
device
=
weight
.
device
))
*
self
.
lr_mul
return
weight
def
__call__
(
self
,
module
,
inputs
):
"""Standard interface for forward pre hooks."""
setattr
(
module
,
self
.
name
,
self
.
compute_weight
(
module
))
@
staticmethod
def
apply
(
module
,
name
,
gain
=
2
**
0.5
,
mode
=
'fan_in'
,
lr_mul
=
1.
):
"""Apply function.
This function is to register an equalized learning rate hook in an
``nn.Module``.
Args:
module (nn.Module): Module to be wrapped.
name (str | optional): The name of weights. Defaults to 'weight'.
mode (str, optional): The mode of computing ``fan`` which is the
same as ``kaiming_init`` in pytorch. You can choose one from
['fan_in', 'fan_out']. Defaults to 'fan_in'.
Returns:
nn.Module: Module that is registered with equalized lr hook.
"""
# sanity check for duplicated hooks.
for
_
,
hook
in
module
.
_forward_pre_hooks
.
items
():
if
isinstance
(
hook
,
EqualizedLR
):
raise
RuntimeError
(
'Cannot register two equalized_lr hooks on the same '
f
'parameter
{
name
}
in
{
module
}
module.'
)
fn
=
EqualizedLR
(
name
,
gain
=
gain
,
mode
=
mode
,
lr_mul
=
lr_mul
)
weight
=
module
.
_parameters
[
name
]
delattr
(
module
,
name
)
module
.
register_parameter
(
name
+
'_orig'
,
weight
)
# We still need to assign weight back as fn.name because all sorts of
# things may assume that it exists, e.g., when initializing weights.
# However, we can't directly assign as it could be an nn.Parameter and
# gets added as a parameter. Instead, we register weight.data as a
# plain attribute.
setattr
(
module
,
name
,
weight
.
data
)
module
.
register_forward_pre_hook
(
fn
)
# TODO: register load state dict hook
return
fn
def
equalized_lr
(
module
,
name
=
'weight'
,
gain
=
2
**
0.5
,
mode
=
'fan_in'
,
lr_mul
=
1.
):
r
"""Equalized Learning Rate.
This trick is proposed in:
Progressive Growing of GANs for Improved Quality, Stability, and Variation
The general idea is to dynamically rescale the weight in training instead
of in initializing so that the variance of the responses in each layer is
guaranteed with some statistical properties.
Note that this function is always combined with a convolution module which
is initialized with :math:`\mathcal{N}(0, 1)`.
Args:
module (nn.Module): Module to be wrapped.
name (str | optional): The name of weights. Defaults to 'weight'.
mode (str, optional): The mode of computing ``fan`` which is the
same as ``kaiming_init`` in pytorch. You can choose one from
['fan_in', 'fan_out']. Defaults to 'fan_in'.
Returns:
nn.Module: Module that is registered with equalized lr hook.
"""
EqualizedLR
.
apply
(
module
,
name
,
gain
=
gain
,
mode
=
mode
,
lr_mul
=
lr_mul
)
return
module
def
pixel_norm
(
x
,
eps
=
1e-6
):
"""Pixel Normalization.
This normalization is proposed in:
Progressive Growing of GANs for Improved Quality, Stability, and Variation
Args:
x (torch.Tensor): Tensor to be normalized.
eps (float, optional): Epsilon to avoid dividing zero.
Defaults to 1e-6.
Returns:
torch.Tensor: Normalized tensor.
"""
if
torch
.
__version__
>=
'1.7.0'
:
norm
=
torch
.
linalg
.
norm
(
x
,
ord
=
2
,
dim
=
1
,
keepdim
=
True
)
# support older pytorch version
else
:
norm
=
torch
.
norm
(
x
,
p
=
2
,
dim
=
1
,
keepdim
=
True
)
norm
=
norm
/
torch
.
sqrt
(
torch
.
tensor
(
x
.
shape
[
1
]).
to
(
x
))
return
x
/
(
norm
+
eps
)
@
MODULES
.
register_module
()
@
NORM_LAYERS
.
register_module
()
class
PixelNorm
(
nn
.
Module
):
"""Pixel Normalization.
This module is proposed in:
Progressive Growing of GANs for Improved Quality, Stability, and Variation
Args:
eps (float, optional): Epsilon value. Defaults to 1e-6.
"""
_abbr_
=
'pn'
def
__init__
(
self
,
in_channels
=
None
,
eps
=
1e-6
):
super
().
__init__
()
self
.
eps
=
eps
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (torch.Tensor): Tensor to be normalized.
Returns:
torch.Tensor: Normalized tensor.
"""
return
pixel_norm
(
x
,
self
.
eps
)
@
PLUGIN_LAYERS
.
register_module
()
class
EqualizedLRConvModule
(
ConvModule
):
r
"""Equalized LR ConvModule.
In this module, we inherit default ``mmcv.cnn.ConvModule`` and adopt
equalized lr in convolution. The equalized learning rate is proposed in:
Progressive Growing of GANs for Improved Quality, Stability, and Variation
Note that, the initialization of ``self.conv`` will be overwritten as
:math:`\mathcal{N}(0, 1)`.
Args:
equalized_lr_cfg (dict | None, optional): Config for ``EqualizedLR``.
If ``None``, equalized learning rate is ignored. Defaults to
dict(mode='fan_in').
"""
def
__init__
(
self
,
*
args
,
equalized_lr_cfg
=
dict
(
mode
=
'fan_in'
),
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
with_equalized_lr
=
equalized_lr_cfg
is
not
None
if
self
.
with_equalized_lr
:
self
.
conv
=
equalized_lr
(
self
.
conv
,
**
equalized_lr_cfg
)
# initialize the conv weight with standard Gaussian noise.
self
.
_init_conv_weights
()
def
_init_conv_weights
(
self
):
"""Initialize conv weights as described in PGGAN."""
normal_init
(
self
.
conv
)
@
PLUGIN_LAYERS
.
register_module
()
class
EqualizedLRConvUpModule
(
EqualizedLRConvModule
):
r
"""Equalized LR (Upsample + Conv) Module.
In this module, we inherit ``EqualizedLRConvModule`` and adopt
upsampling before convolution. As for upsampling, in addition to the
sampling layer in MMCV, we also offer the "fused_nn" type. "fused_nn"
denotes fusing upsampling and convolution. The fusion is modified from
the official Tensorflow implementation in:
https://github.com/tkarras/progressive_growing_of_gans/blob/master/networks.py#L86
Args:
upsample (dict | None, optional): Config for upsampling operation. If
``None``, upsampling is ignored. If you need a faster fused version as
the official PGGAN in Tensorflow, you should set it as
``dict(type='fused_nn')``. Defaults to
``dict(type='nearest', scale_factor=2)``.
"""
def
__init__
(
self
,
*
args
,
upsample
=
dict
(
type
=
'nearest'
,
scale_factor
=
2
),
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
with_upsample
=
upsample
is
not
None
if
self
.
with_upsample
:
if
upsample
.
get
(
'type'
)
==
'fused_nn'
:
assert
isinstance
(
self
.
conv
,
nn
.
ConvTranspose2d
)
self
.
conv
.
register_forward_pre_hook
(
EqualizedLRConvUpModule
.
fused_nn_hook
)
else
:
self
.
upsample_layer
=
build_upsample_layer
(
upsample
)
def
forward
(
self
,
x
,
**
kwargs
):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if
hasattr
(
self
,
'upsample_layer'
):
x
=
self
.
upsample_layer
(
x
)
return
super
().
forward
(
x
,
**
kwargs
)
@
staticmethod
def
fused_nn_hook
(
module
,
inputs
):
"""Standard interface for forward pre hooks."""
weight
=
module
.
weight
# pad the last two dimensions
weight
=
F
.
pad
(
weight
,
(
1
,
1
,
1
,
1
))
weight
=
weight
[...,
1
:,
1
:]
+
weight
[...,
1
:,
:
-
1
]
+
weight
[
...,
:
-
1
,
1
:]
+
weight
[...,
:
-
1
,
:
-
1
]
module
.
weight
=
weight
@
PLUGIN_LAYERS
.
register_module
()
class
EqualizedLRConvDownModule
(
EqualizedLRConvModule
):
r
"""Equalized LR (Conv + Downsample) Module.
In this module, we inherit ``EqualizedLRConvModule`` and adopt
downsampling after convolution. As for downsampling, we provide two modes
of "avgpool" and "fused_pool". "avgpool" denotes the commonly used average
pooling operation, while "fused_pool" represents fusing downsampling and
convolution. The fusion is modified from the official Tensorflow
implementation in:
https://github.com/tkarras/progressive_growing_of_gans/blob/master/networks.py#L109
Args:
downsample (dict | None, optional): Config for downsampling operation.
If ``None``, downsampling is ignored. Currently, we support the
types of ["avgpool", "fused_pool"]. Defaults to
dict(type='fused_pool').
"""
def
__init__
(
self
,
*
args
,
downsample
=
dict
(
type
=
'fused_pool'
),
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
downsample_cfg
=
deepcopy
(
downsample
)
self
.
with_downsample
=
downsample
is
not
None
if
self
.
with_downsample
:
type_
=
downsample_cfg
.
pop
(
'type'
)
if
type_
==
'avgpool'
:
self
.
downsample
=
nn
.
AvgPool2d
(
2
,
2
)
elif
type_
==
'fused_pool'
:
self
.
conv
.
register_forward_pre_hook
(
EqualizedLRConvDownModule
.
fused_avgpool_hook
)
elif
callable
(
downsample
):
self
.
downsample
=
downsample
else
:
raise
NotImplementedError
(
'Currently, we only support ["avgpool", "fused_pool"] as '
f
'the type of downsample, but got
{
type_
}
instead.'
)
def
forward
(
self
,
x
,
**
kwargs
):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
torch.Tensor: Normalized tensor.
"""
x
=
super
().
forward
(
x
,
**
kwargs
)
if
hasattr
(
self
,
'downsample'
):
x
=
self
.
downsample
(
x
)
return
x
@
staticmethod
def
fused_avgpool_hook
(
module
,
inputs
):
"""Standard interface for forward pre hooks."""
weight
=
module
.
weight
# pad the last two dimensions
weight
=
F
.
pad
(
weight
,
(
1
,
1
,
1
,
1
))
weight
=
(
weight
[...,
1
:,
1
:]
+
weight
[...,
1
:,
:
-
1
]
+
weight
[...,
:
-
1
,
1
:]
+
weight
[...,
:
-
1
,
:
-
1
])
*
0.25
module
.
weight
=
weight
@
PLUGIN_LAYERS
.
register_module
()
class
EqualizedLRLinearModule
(
nn
.
Linear
):
r
"""Equalized LR LinearModule.
In this module, we adopt equalized lr in ``nn.Linear``. The equalized
learning rate is proposed in:
Progressive Growing of GANs for Improved Quality, Stability, and Variation
Note that, the initialization of ``self.weight`` will be overwritten as
:math:`\mathcal{N}(0, 1)`.
Args:
equalized_lr_cfg (dict | None, optional): Config for ``EqualizedLR``.
If ``None``, equalized learning rate is ignored. Defaults to
dict(mode='fan_in').
"""
def
__init__
(
self
,
*
args
,
equalized_lr_cfg
=
dict
(
mode
=
'fan_in'
),
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
with_equalized_lr
=
equalized_lr_cfg
is
not
None
if
self
.
with_equalized_lr
:
self
.
lr_mul
=
equalized_lr_cfg
.
get
(
'lr_mul'
,
1.
)
else
:
# In fact, lr_mul will only be used in EqualizedLR for
# initialization
self
.
lr_mul
=
1.
if
self
.
with_equalized_lr
:
equalized_lr
(
self
,
**
equalized_lr_cfg
)
self
.
_init_linear_weights
()
def
_init_linear_weights
(
self
):
"""Initialize linear weights as described in PGGAN."""
nn
.
init
.
normal_
(
self
.
weight
,
0
,
1.
/
self
.
lr_mul
)
if
self
.
bias
is
not
None
:
nn
.
init
.
constant_
(
self
.
bias
,
0.
)
@
MODULES
.
register_module
()
class
PGGANNoiseTo2DFeat
(
nn
.
Module
):
def
__init__
(
self
,
noise_size
,
out_channels
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
),
norm_cfg
=
dict
(
type
=
'PixelNorm'
),
normalize_latent
=
True
,
order
=
(
'linear'
,
'act'
,
'norm'
)):
super
().
__init__
()
self
.
noise_size
=
noise_size
self
.
out_channels
=
out_channels
self
.
normalize_latent
=
normalize_latent
self
.
with_activation
=
act_cfg
is
not
None
self
.
with_norm
=
norm_cfg
is
not
None
self
.
order
=
order
assert
len
(
order
)
==
3
and
set
(
order
)
==
set
([
'linear'
,
'act'
,
'norm'
])
# w/o bias, because the bias is added after reshaping the tensor to
# 2D feature
self
.
linear
=
EqualizedLRLinearModule
(
noise_size
,
out_channels
*
16
,
equalized_lr_cfg
=
dict
(
gain
=
np
.
sqrt
(
2
)
/
4
),
bias
=
False
)
if
self
.
with_activation
:
self
.
activation
=
build_activation_layer
(
act_cfg
)
# add bias for reshaped 2D feature.
self
.
register_parameter
(
'bias'
,
nn
.
Parameter
(
torch
.
zeros
(
1
,
out_channels
,
1
,
1
)))
if
self
.
with_norm
:
_
,
self
.
norm
=
build_norm_layer
(
norm_cfg
,
out_channels
)
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input noise tensor with shape (n, c).
Returns:
Tensor: Forward results with shape (n, c, 4, 4).
"""
assert
x
.
ndim
==
2
if
self
.
normalize_latent
:
x
=
pixel_norm
(
x
)
for
order
in
self
.
order
:
if
order
==
'linear'
:
x
=
self
.
linear
(
x
)
# [n, c, 4, 4]
x
=
torch
.
reshape
(
x
,
(
-
1
,
self
.
out_channels
,
4
,
4
))
x
=
x
+
self
.
bias
elif
order
==
'act'
and
self
.
with_activation
:
x
=
self
.
activation
(
x
)
elif
order
==
'norm'
and
self
.
with_norm
:
x
=
self
.
norm
(
x
)
return
x
class
PGGANDecisionHead
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
mid_channels
,
out_channels
,
bias
=
True
,
equalized_lr_cfg
=
dict
(
gain
=
1
),
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
),
out_act
=
None
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
mid_channels
=
mid_channels
self
.
out_channels
=
out_channels
self
.
with_activation
=
act_cfg
is
not
None
self
.
with_out_activation
=
out_act
is
not
None
# setup linear layers
# dirty code for supporting default mode in PGGAN
if
equalized_lr_cfg
:
equalized_lr_cfg_
=
dict
(
gain
=
2
**
0.5
)
else
:
equalized_lr_cfg_
=
None
self
.
linear0
=
EqualizedLRLinearModule
(
self
.
in_channels
,
self
.
mid_channels
,
bias
=
bias
,
equalized_lr_cfg
=
equalized_lr_cfg_
)
self
.
linear1
=
EqualizedLRLinearModule
(
self
.
mid_channels
,
self
.
out_channels
,
bias
=
bias
,
equalized_lr_cfg
=
equalized_lr_cfg
)
# setup activation layers
if
self
.
with_activation
:
self
.
activation
=
build_activation_layer
(
act_cfg
)
if
self
.
with_out_activation
:
self
.
out_activation
=
build_activation_layer
(
out_act
)
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if
x
.
ndim
>
2
:
x
=
torch
.
reshape
(
x
,
(
x
.
shape
[
0
],
-
1
))
x
=
self
.
linear0
(
x
)
if
self
.
with_activation
:
x
=
self
.
activation
(
x
)
x
=
self
.
linear1
(
x
)
if
self
.
with_out_activation
:
x
=
self
.
out_activation
(
x
)
return
x
@
MODULES
.
register_module
()
@
PLUGIN_LAYERS
.
register_module
()
class
MiniBatchStddevLayer
(
nn
.
Module
):
"""Minibatch standard deviation.
Args:
group_size (int, optional): The size of groups in batch dimension.
Defaults to 4.
eps (float, optional): Epsilon value to avoid computation error.
Defaults to 1e-8.
gather_all_batch (bool, optional): Whether gather batch from all GPUs.
Defaults to False.
"""
def
__init__
(
self
,
group_size
=
4
,
eps
=
1e-8
,
gather_all_batch
=
False
):
super
().
__init__
()
self
.
group_size
=
group_size
self
.
eps
=
eps
self
.
gather_all_batch
=
gather_all_batch
if
self
.
gather_all_batch
:
assert
torch
.
distributed
.
is_initialized
(
),
'Only in distributed training can the tensors be all gathered.'
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if
self
.
gather_all_batch
:
x
=
torch
.
cat
(
AllGatherLayer
.
apply
(
x
),
dim
=
0
)
# batch size should be smaller than or equal to group size. Otherwise,
# batch size should be divisible by the group size.
assert
x
.
shape
[
0
]
<=
self
.
group_size
or
x
.
shape
[
0
]
%
self
.
group_size
==
0
,
(
'Batch size be smaller than or equal '
'to group size. Otherwise,'
' batch size should be divisible by the group size.'
f
'But got batch size
{
x
.
shape
[
0
]
}
,'
f
' group size
{
self
.
group_size
}
'
)
n
,
c
,
h
,
w
=
x
.
shape
group_size
=
min
(
n
,
self
.
group_size
)
# [G, M, C, H, W]
y
=
torch
.
reshape
(
x
,
(
group_size
,
-
1
,
c
,
h
,
w
))
# [G, M, C, H, W]
y
=
y
-
y
.
mean
(
dim
=
0
,
keepdim
=
True
)
# In pt>=1.7, you can just use `.square()` function.
# [M, C, H, W]
y
=
y
.
pow
(
2
).
mean
(
dim
=
0
,
keepdim
=
False
)
y
=
torch
.
sqrt
(
y
+
self
.
eps
)
# [M, 1, 1, 1]
y
=
y
.
mean
(
dim
=
(
1
,
2
,
3
),
keepdim
=
True
)
y
=
y
.
repeat
(
group_size
,
1
,
h
,
w
)
return
torch
.
cat
([
x
,
y
],
dim
=
1
)
build/lib/mmgen/models/architectures/pix2pix/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.generator_discriminator
import
PatchDiscriminator
,
UnetGenerator
from
.modules
import
UnetSkipConnectionBlock
,
generation_init_weights
__all__
=
[
'PatchDiscriminator'
,
'UnetGenerator'
,
'UnetSkipConnectionBlock'
,
'generation_init_weights'
]
build/lib/mmgen/models/architectures/pix2pix/generator_discriminator.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
torch.nn
as
nn
from
mmcv.cnn
import
ConvModule
,
build_conv_layer
from
mmcv.runner
import
load_checkpoint
from
mmgen.models.builder
import
MODULES
from
mmgen.utils
import
get_root_logger
from
.modules
import
UnetSkipConnectionBlock
,
generation_init_weights
@
MODULES
.
register_module
()
class
UnetGenerator
(
nn
.
Module
):
"""Construct the Unet-based generator from the innermost layer to the
outermost layer, which is a recursive process.
Args:
in_channels (int): Number of channels in input images.
out_channels (int): Number of channels in output images.
num_down (int): Number of downsamplings in Unet. If `num_down` is 8,
the image with size 256x256 will become 1x1 at the bottleneck.
Default: 8.
base_channels (int): Number of channels at the last conv layer.
Default: 64.
norm_cfg (dict): Config dict to build norm layer. Default:
`dict(type='BN')`.
use_dropout (bool): Whether to use dropout layers. Default: False.
init_cfg (dict): Config dict for initialization.
`type`: The name of our initialization method. Default: 'normal'.
`gain`: Scaling factor for normal, xavier and orthogonal.
Default: 0.02.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
num_down
=
8
,
base_channels
=
64
,
norm_cfg
=
dict
(
type
=
'BN'
),
use_dropout
=
False
,
init_cfg
=
dict
(
type
=
'normal'
,
gain
=
0.02
)):
super
().
__init__
()
# We use norm layers in the unet generator.
assert
isinstance
(
norm_cfg
,
dict
),
(
"'norm_cfg' should be dict, but"
f
'got
{
type
(
norm_cfg
)
}
'
)
assert
'type'
in
norm_cfg
,
"'norm_cfg' must have key 'type'"
# add the innermost layer
unet_block
=
UnetSkipConnectionBlock
(
base_channels
*
8
,
base_channels
*
8
,
in_channels
=
None
,
submodule
=
None
,
norm_cfg
=
norm_cfg
,
is_innermost
=
True
)
# add intermediate layers with base_channels * 8 filters
for
_
in
range
(
num_down
-
5
):
unet_block
=
UnetSkipConnectionBlock
(
base_channels
*
8
,
base_channels
*
8
,
in_channels
=
None
,
submodule
=
unet_block
,
norm_cfg
=
norm_cfg
,
use_dropout
=
use_dropout
)
# gradually reduce the number of filters
# from base_channels * 8 to base_channels
unet_block
=
UnetSkipConnectionBlock
(
base_channels
*
4
,
base_channels
*
8
,
in_channels
=
None
,
submodule
=
unet_block
,
norm_cfg
=
norm_cfg
)
unet_block
=
UnetSkipConnectionBlock
(
base_channels
*
2
,
base_channels
*
4
,
in_channels
=
None
,
submodule
=
unet_block
,
norm_cfg
=
norm_cfg
)
unet_block
=
UnetSkipConnectionBlock
(
base_channels
,
base_channels
*
2
,
in_channels
=
None
,
submodule
=
unet_block
,
norm_cfg
=
norm_cfg
)
# add the outermost layer
self
.
model
=
UnetSkipConnectionBlock
(
out_channels
,
base_channels
,
in_channels
=
in_channels
,
submodule
=
unet_block
,
is_outermost
=
True
,
norm_cfg
=
norm_cfg
)
self
.
init_type
=
'normal'
if
init_cfg
is
None
else
init_cfg
.
get
(
'type'
,
'normal'
)
self
.
init_gain
=
0.02
if
init_cfg
is
None
else
init_cfg
.
get
(
'gain'
,
0.02
)
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
return
self
.
model
(
x
)
def
init_weights
(
self
,
pretrained
=
None
,
strict
=
True
):
"""Initialize weights for the model.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Default: None.
strict (bool, optional): Whether to allow different params for the
model and checkpoint. Default: True.
"""
if
isinstance
(
pretrained
,
str
):
logger
=
get_root_logger
()
load_checkpoint
(
self
,
pretrained
,
strict
=
strict
,
logger
=
logger
)
elif
pretrained
is
None
:
generation_init_weights
(
self
,
init_type
=
self
.
init_type
,
init_gain
=
self
.
init_gain
)
else
:
raise
TypeError
(
"'pretrained' must be a str or None. "
f
'But received
{
type
(
pretrained
)
}
.'
)
@
MODULES
.
register_module
()
class
PatchDiscriminator
(
nn
.
Module
):
"""A PatchGAN discriminator.
Args:
in_channels (int): Number of channels in input images.
base_channels (int): Number of channels at the first conv layer.
Default: 64.
num_conv (int): Number of stacked intermediate convs (excluding input
and output conv). Default: 3.
norm_cfg (dict): Config dict to build norm layer. Default:
`dict(type='BN')`.
init_cfg (dict): Config dict for initialization.
`type`: The name of our initialization method. Default: 'normal'.
`gain`: Scaling factor for normal, xavier and orthogonal.
Default: 0.02.
"""
def
__init__
(
self
,
in_channels
,
base_channels
=
64
,
num_conv
=
3
,
norm_cfg
=
dict
(
type
=
'BN'
),
init_cfg
=
dict
(
type
=
'normal'
,
gain
=
0.02
)):
super
().
__init__
()
assert
isinstance
(
norm_cfg
,
dict
),
(
"'norm_cfg' should be dict, but"
f
'got
{
type
(
norm_cfg
)
}
'
)
assert
'type'
in
norm_cfg
,
"'norm_cfg' must have key 'type'"
# We use norm layers in the patch discriminator.
# Only for IN, use bias since it does not have affine parameters.
use_bias
=
norm_cfg
[
'type'
]
==
'IN'
kernel_size
=
4
padding
=
1
# input layer
sequence
=
[
ConvModule
(
in_channels
=
in_channels
,
out_channels
=
base_channels
,
kernel_size
=
kernel_size
,
stride
=
2
,
padding
=
padding
,
bias
=
True
,
norm_cfg
=
None
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
))
]
# stacked intermediate layers,
# gradually increasing the number of filters
multiple_now
=
1
multiple_prev
=
1
for
n
in
range
(
1
,
num_conv
):
multiple_prev
=
multiple_now
multiple_now
=
min
(
2
**
n
,
8
)
sequence
+=
[
ConvModule
(
in_channels
=
base_channels
*
multiple_prev
,
out_channels
=
base_channels
*
multiple_now
,
kernel_size
=
kernel_size
,
stride
=
2
,
padding
=
padding
,
bias
=
use_bias
,
norm_cfg
=
norm_cfg
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
))
]
multiple_prev
=
multiple_now
multiple_now
=
min
(
2
**
num_conv
,
8
)
sequence
+=
[
ConvModule
(
in_channels
=
base_channels
*
multiple_prev
,
out_channels
=
base_channels
*
multiple_now
,
kernel_size
=
kernel_size
,
stride
=
1
,
padding
=
padding
,
bias
=
use_bias
,
norm_cfg
=
norm_cfg
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
))
]
# output one-channel prediction map
sequence
+=
[
build_conv_layer
(
dict
(
type
=
'Conv2d'
),
base_channels
*
multiple_now
,
1
,
kernel_size
=
kernel_size
,
stride
=
1
,
padding
=
padding
)
]
self
.
model
=
nn
.
Sequential
(
*
sequence
)
self
.
init_type
=
'normal'
if
init_cfg
is
None
else
init_cfg
.
get
(
'type'
,
'normal'
)
self
.
init_gain
=
0.02
if
init_cfg
is
None
else
init_cfg
.
get
(
'gain'
,
0.02
)
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
return
self
.
model
(
x
)
def
init_weights
(
self
,
pretrained
=
None
):
"""Initialize weights for the model.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Default: None.
"""
if
isinstance
(
pretrained
,
str
):
logger
=
get_root_logger
()
load_checkpoint
(
self
,
pretrained
,
strict
=
False
,
logger
=
logger
)
elif
pretrained
is
None
:
generation_init_weights
(
self
,
init_type
=
self
.
init_type
,
init_gain
=
self
.
init_gain
)
else
:
raise
TypeError
(
"'pretrained' must be a str or None. "
f
'But received
{
type
(
pretrained
)
}
.'
)
build/lib/mmgen/models/architectures/pix2pix/modules.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
ConvModule
,
kaiming_init
,
normal_init
,
xavier_init
from
torch.nn
import
init
def
generation_init_weights
(
module
,
init_type
=
'normal'
,
init_gain
=
0.02
):
"""Default initialization of network weights for image generation.
By default, we use normal init, but xavier and kaiming might work
better for some applications.
Args:
module (nn.Module): Module to be initialized.
init_type (str): The name of an initialization method:
normal | xavier | kaiming | orthogonal.
init_gain (float): Scaling factor for normal, xavier and
orthogonal.
"""
def
init_func
(
m
):
"""Initialization function.
Args:
m (nn.Module): Module to be initialized.
"""
classname
=
m
.
__class__
.
__name__
if
hasattr
(
m
,
'weight'
)
and
(
classname
.
find
(
'Conv'
)
!=
-
1
or
classname
.
find
(
'Linear'
)
!=
-
1
):
if
init_type
==
'normal'
:
normal_init
(
m
,
0.0
,
init_gain
)
elif
init_type
==
'xavier'
:
xavier_init
(
m
,
gain
=
init_gain
,
distribution
=
'normal'
)
elif
init_type
==
'kaiming'
:
kaiming_init
(
m
,
a
=
0
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
,
distribution
=
'normal'
)
elif
init_type
==
'orthogonal'
:
init
.
orthogonal_
(
m
.
weight
,
gain
=
init_gain
)
init
.
constant_
(
m
.
bias
.
data
,
0.0
)
else
:
raise
NotImplementedError
(
f
"Initialization method '
{
init_type
}
' is not implemented"
)
elif
classname
.
find
(
'BatchNorm2d'
)
!=
-
1
:
# BatchNorm Layer's weight is not a matrix;
# only normal distribution applies.
normal_init
(
m
,
1.0
,
init_gain
)
module
.
apply
(
init_func
)
class
UnetSkipConnectionBlock
(
nn
.
Module
):
"""Construct a Unet submodule with skip connections, with the following.
structure: downsampling - `submodule` - upsampling.
Args:
outer_channels (int): Number of channels at the outer conv layer.
inner_channels (int): Number of channels at the inner conv layer.
in_channels (int): Number of channels in input images/features. If is
None, equals to `outer_channels`. Default: None.
submodule (UnetSkipConnectionBlock): Previously constructed submodule.
Default: None.
is_outermost (bool): Whether this module is the outermost module.
Default: False.
is_innermost (bool): Whether this module is the innermost module.
Default: False.
norm_cfg (dict): Config dict to build norm layer. Default:
`dict(type='BN')`.
use_dropout (bool): Whether to use dropout layers. Default: False.
"""
def
__init__
(
self
,
outer_channels
,
inner_channels
,
in_channels
=
None
,
submodule
=
None
,
is_outermost
=
False
,
is_innermost
=
False
,
norm_cfg
=
dict
(
type
=
'BN'
),
use_dropout
=
False
):
super
().
__init__
()
# cannot be both outermost and innermost
assert
not
(
is_outermost
and
is_innermost
),
(
"'is_outermost' and 'is_innermost' cannot be True"
'at the same time.'
)
self
.
is_outermost
=
is_outermost
assert
isinstance
(
norm_cfg
,
dict
),
(
"'norm_cfg' should be dict, but"
f
'got
{
type
(
norm_cfg
)
}
'
)
assert
'type'
in
norm_cfg
,
"'norm_cfg' must have key 'type'"
# We use norm layers in the unet skip connection block.
# Only for IN, use bias since it does not have affine parameters.
use_bias
=
norm_cfg
[
'type'
]
==
'IN'
kernel_size
=
4
stride
=
2
padding
=
1
if
in_channels
is
None
:
in_channels
=
outer_channels
down_conv_cfg
=
dict
(
type
=
'Conv2d'
)
down_norm_cfg
=
norm_cfg
down_act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
)
up_conv_cfg
=
dict
(
type
=
'deconv'
)
up_norm_cfg
=
norm_cfg
up_act_cfg
=
dict
(
type
=
'ReLU'
)
up_in_channels
=
inner_channels
*
2
up_bias
=
use_bias
middle
=
[
submodule
]
upper
=
[]
if
is_outermost
:
down_act_cfg
=
None
down_norm_cfg
=
None
up_bias
=
True
up_norm_cfg
=
None
upper
=
[
nn
.
Tanh
()]
elif
is_innermost
:
down_norm_cfg
=
None
up_in_channels
=
inner_channels
middle
=
[]
else
:
upper
=
[
nn
.
Dropout
(
0.5
)]
if
use_dropout
else
[]
down
=
[
ConvModule
(
in_channels
=
in_channels
,
out_channels
=
inner_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias
=
use_bias
,
conv_cfg
=
down_conv_cfg
,
norm_cfg
=
down_norm_cfg
,
act_cfg
=
down_act_cfg
,
order
=
(
'act'
,
'conv'
,
'norm'
))
]
up
=
[
ConvModule
(
in_channels
=
up_in_channels
,
out_channels
=
outer_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias
=
up_bias
,
conv_cfg
=
up_conv_cfg
,
norm_cfg
=
up_norm_cfg
,
act_cfg
=
up_act_cfg
,
order
=
(
'act'
,
'conv'
,
'norm'
))
]
model
=
down
+
middle
+
up
+
upper
self
.
model
=
nn
.
Sequential
(
*
model
)
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if
self
.
is_outermost
:
return
self
.
model
(
x
)
# add skip connections
return
torch
.
cat
([
x
,
self
.
model
(
x
)],
1
)
build/lib/mmgen/models/architectures/positional_encoding.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmgen.models.builder
import
MODULES
@
MODULES
.
register_module
(
'SPE'
)
@
MODULES
.
register_module
(
'SPE2d'
)
class
SinusoidalPositionalEmbedding
(
nn
.
Module
):
"""Sinusoidal Positional Embedding 1D or 2D (SPE/SPE2d).
This module is a modified from:
https://github.com/pytorch/fairseq/blob/master/fairseq/modules/sinusoidal_positional_embedding.py # noqa
Based on the original SPE in single dimension, we implement a 2D sinusoidal
positional encodding (SPE2d), as introduced in Positional Encoding as
Spatial Inductive Bias in GANs, CVPR'2021.
Args:
embedding_dim (int): The number of dimensions for the positional
encoding.
padding_idx (int | list[int]): The index for the padding contents. The
padding positions will obtain an encoding vector filling in zeros.
init_size (int, optional): The initial size of the positional buffer.
Defaults to 1024.
div_half_dim (bool, optional): If true, the embedding will be divided
by :math:`d/2`. Otherwise, it will be divided by
:math:`(d/2 -1)`. Defaults to False.
center_shift (int | None, optional): Shift the center point to some
index. Defaults to None.
"""
def
__init__
(
self
,
embedding_dim
,
padding_idx
,
init_size
=
1024
,
div_half_dim
=
False
,
center_shift
=
None
):
super
().
__init__
()
self
.
embedding_dim
=
embedding_dim
self
.
padding_idx
=
padding_idx
self
.
div_half_dim
=
div_half_dim
self
.
center_shift
=
center_shift
self
.
weights
=
SinusoidalPositionalEmbedding
.
get_embedding
(
init_size
,
embedding_dim
,
padding_idx
,
self
.
div_half_dim
)
self
.
register_buffer
(
'_float_tensor'
,
torch
.
FloatTensor
(
1
))
self
.
max_positions
=
int
(
1e5
)
@
staticmethod
def
get_embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
None
,
div_half_dim
=
False
):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert
embedding_dim
%
2
==
0
,
(
'In this version, we request '
f
'embedding_dim divisible by 2 but got
{
embedding_dim
}
'
)
# there is a little difference from the original paper.
half_dim
=
embedding_dim
//
2
if
not
div_half_dim
:
emb
=
np
.
log
(
10000
)
/
(
half_dim
-
1
)
else
:
emb
=
np
.
log
(
1e4
)
/
half_dim
# compute exp(-log10000 / d * i)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float
)
*
-
emb
)
emb
=
torch
.
arange
(
num_embeddings
,
dtype
=
torch
.
float
).
unsqueeze
(
1
)
*
emb
.
unsqueeze
(
0
)
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
).
view
(
num_embeddings
,
-
1
)
if
padding_idx
is
not
None
:
emb
[
padding_idx
,
:]
=
0
return
emb
def
forward
(
self
,
input
,
**
kwargs
):
"""Input is expected to be of size [bsz x seqlen].
Returned tensor is expected to be of size [bsz x seq_len x emb_dim]
"""
assert
input
.
dim
()
==
2
or
input
.
dim
(
)
==
4
,
'Input dimension should be 2 (1D) or 4(2D)'
if
input
.
dim
()
==
4
:
return
self
.
make_grid2d_like
(
input
,
**
kwargs
)
b
,
seq_len
=
input
.
shape
max_pos
=
self
.
padding_idx
+
1
+
seq_len
if
self
.
weights
is
None
or
max_pos
>
self
.
weights
.
size
(
0
):
# recompute/expand embedding if needed
self
.
weights
=
SinusoidalPositionalEmbedding
.
get_embedding
(
max_pos
,
self
.
embedding_dim
,
self
.
padding_idx
)
self
.
weights
=
self
.
weights
.
to
(
self
.
_float_tensor
)
positions
=
self
.
make_positions
(
input
,
self
.
padding_idx
).
to
(
self
.
_float_tensor
.
device
)
return
self
.
weights
.
index_select
(
0
,
positions
.
view
(
-
1
)).
view
(
b
,
seq_len
,
self
.
embedding_dim
).
detach
()
def
make_positions
(
self
,
input
,
padding_idx
):
mask
=
input
.
ne
(
padding_idx
).
int
()
return
(
torch
.
cumsum
(
mask
,
dim
=
1
).
type_as
(
mask
)
*
mask
).
long
()
+
padding_idx
def
make_grid2d
(
self
,
height
,
width
,
num_batches
=
1
,
center_shift
=
None
):
h
,
w
=
height
,
width
# if `center_shift` is not given from the outside, use
# `self.center_shift`
if
center_shift
is
None
:
center_shift
=
self
.
center_shift
h_shift
=
0
w_shift
=
0
# center shift to the input grid
if
center_shift
is
not
None
:
# if h/w is even, the left center should be aligned with
# center shift
if
h
%
2
==
0
:
h_left_center
=
h
//
2
h_shift
=
center_shift
-
h_left_center
else
:
h_center
=
h
//
2
+
1
h_shift
=
center_shift
-
h_center
if
w
%
2
==
0
:
w_left_center
=
w
//
2
w_shift
=
center_shift
-
w_left_center
else
:
w_center
=
w
//
2
+
1
w_shift
=
center_shift
-
w_center
# Note that the index is started from 1 since zero will be padding idx.
# axis -- (b, h or w)
x_axis
=
torch
.
arange
(
1
,
w
+
1
).
unsqueeze
(
0
).
repeat
(
num_batches
,
1
)
+
w_shift
y_axis
=
torch
.
arange
(
1
,
h
+
1
).
unsqueeze
(
0
).
repeat
(
num_batches
,
1
)
+
h_shift
# emb -- (b, emb_dim, h or w)
x_emb
=
self
(
x_axis
).
transpose
(
1
,
2
)
y_emb
=
self
(
y_axis
).
transpose
(
1
,
2
)
# make grid for x/y axis
# Note that repeat will copy data. If use learned emb, expand may be
# better.
x_grid
=
x_emb
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
h
,
1
)
y_grid
=
y_emb
.
unsqueeze
(
3
).
repeat
(
1
,
1
,
1
,
w
)
# cat grid -- (b, 2 x emb_dim, h, w)
grid
=
torch
.
cat
([
x_grid
,
y_grid
],
dim
=
1
)
return
grid
.
detach
()
def
make_grid2d_like
(
self
,
x
,
center_shift
=
None
):
"""Input tensor with shape of (b, ..., h, w) Return tensor with shape
of (b, 2 x emb_dim, h, w)
Note that the positional embedding highly depends on the the function,
``make_positions``.
"""
h
,
w
=
x
.
shape
[
-
2
:]
grid
=
self
.
make_grid2d
(
h
,
w
,
x
.
size
(
0
),
center_shift
)
return
grid
.
to
(
x
)
@
MODULES
.
register_module
(
'CSG2d'
)
@
MODULES
.
register_module
(
'CSG'
)
@
MODULES
.
register_module
()
class
CatersianGrid
(
nn
.
Module
):
"""Catersian Grid for 2d tensor.
The Catersian Grid is a common-used positional encoding in deep learning.
In this implementation, we follow the convention of ``grid_sample`` in
PyTorch. In other words, ``[-1, -1]`` denotes the left-top corner while
``[1, 1]`` denotes the right-botton corner.
"""
def
forward
(
self
,
x
,
**
kwargs
):
assert
x
.
dim
()
==
4
return
self
.
make_grid2d_like
(
x
,
**
kwargs
)
def
make_grid2d
(
self
,
height
,
width
,
num_batches
=
1
,
requires_grad
=
False
):
h
,
w
=
height
,
width
grid_y
,
grid_x
=
torch
.
meshgrid
(
torch
.
arange
(
0
,
h
),
torch
.
arange
(
0
,
w
))
grid_x
=
2
*
grid_x
/
max
(
float
(
w
)
-
1.
,
1.
)
-
1.
grid_y
=
2
*
grid_y
/
max
(
float
(
h
)
-
1.
,
1.
)
-
1.
grid
=
torch
.
stack
((
grid_x
,
grid_y
),
0
)
grid
.
requires_grad
=
requires_grad
grid
=
torch
.
unsqueeze
(
grid
,
0
)
grid
=
grid
.
repeat
(
num_batches
,
1
,
1
,
1
)
return
grid
def
make_grid2d_like
(
self
,
x
,
requires_grad
=
False
):
h
,
w
=
x
.
shape
[
-
2
:]
grid
=
self
.
make_grid2d
(
h
,
w
,
x
.
size
(
0
),
requires_grad
=
requires_grad
)
return
grid
.
to
(
x
)
build/lib/mmgen/models/architectures/singan/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.generator_discriminator
import
(
SinGANMultiScaleDiscriminator
,
SinGANMultiScaleGenerator
)
from
.positional_encoding
import
SinGANMSGeneratorPE
__all__
=
[
'SinGANMultiScaleDiscriminator'
,
'SinGANMultiScaleGenerator'
,
'SinGANMSGeneratorPE'
]
build/lib/mmgen/models/architectures/singan/generator_discriminator.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
functools
import
partial
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.runner
import
load_state_dict
from
mmcv.utils
import
print_log
from
mmgen.models.builder
import
MODULES
from
mmgen.utils
import
get_root_logger
from
.modules
import
DiscriminatorBlock
,
GeneratorBlock
@
MODULES
.
register_module
()
class
SinGANMultiScaleGenerator
(
nn
.
Module
):
"""Multi-Scale Generator used in SinGAN.
More details can be found in: Singan: Learning a Generative Model from a
Single Natural Image, ICCV'19.
Notes:
- In this version, we adopt the interpolation function from the official
PyTorch APIs, which is different from the original implementation by the
authors. However, in our experiments, this influence can be ignored.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
num_scales (int): The number of scales/stages in generator. Note
that this number is counted from zero, which is the same as the
original paper.
kernel_size (int, optional): Kernel size, same as :obj:`nn.Conv2d`.
Defaults to 3.
padding (int, optional): Padding for the convolutional layer, same as
:obj:`nn.Conv2d`. Defaults to 0.
num_layers (int, optional): The number of convolutional layers in each
generator block. Defaults to 5.
base_channels (int, optional): The basic channels for convolutional
layers in the generator block. Defaults to 32.
min_feat_channels (int, optional): Minimum channels for the feature
maps in the generator block. Defaults to 32.
out_act_cfg (dict | None, optional): Configs for output activation
layer. Defaults to dict(type='Tanh').
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
num_scales
,
kernel_size
=
3
,
padding
=
0
,
num_layers
=
5
,
base_channels
=
32
,
min_feat_channels
=
32
,
out_act_cfg
=
dict
(
type
=
'Tanh'
),
**
kwargs
):
super
().
__init__
()
self
.
pad_head
=
int
((
kernel_size
-
1
)
/
2
*
num_layers
)
self
.
blocks
=
nn
.
ModuleList
()
self
.
upsample
=
partial
(
F
.
interpolate
,
mode
=
'bicubic'
,
align_corners
=
True
)
for
scale
in
range
(
num_scales
+
1
):
base_ch
=
min
(
base_channels
*
pow
(
2
,
int
(
np
.
floor
(
scale
/
4
))),
128
)
min_feat_ch
=
min
(
min_feat_channels
*
pow
(
2
,
int
(
np
.
floor
(
scale
/
4
))),
128
)
self
.
blocks
.
append
(
GeneratorBlock
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
padding
=
padding
,
num_layers
=
num_layers
,
base_channels
=
base_ch
,
min_feat_channels
=
min_feat_ch
,
out_act_cfg
=
out_act_cfg
,
**
kwargs
))
self
.
noise_padding_layer
=
nn
.
ZeroPad2d
(
self
.
pad_head
)
self
.
img_padding_layer
=
nn
.
ZeroPad2d
(
self
.
pad_head
)
def
forward
(
self
,
input_sample
,
fixed_noises
,
noise_weights
,
rand_mode
,
curr_scale
,
num_batches
=
1
,
get_prev_res
=
False
,
return_noise
=
False
):
"""Forward function.
Args:
input_sample (Tensor | None): The input for generator. In the
original implementation, a tensor filled with zeros is adopted.
If None is given, we will construct it from the first fixed
noises.
fixed_noises (list[Tensor]): List of the fixed noises in SinGAN.
noise_weights (list[float]): List of the weights for random noises.
rand_mode (str): Choices from ['rand', 'recon']. In ``rand`` mode,
it will sample from random noises. Otherwise, the
reconstruction for the single image will be returned.
curr_scale (int): The scale for the current inference or training.
num_batches (int, optional): The number of batches. Defaults to 1.
get_prev_res (bool, optional): Whether to return results from
previous stages. Defaults to False.
return_noise (bool, optional): Whether to return noises tensor.
Defaults to False.
Returns:
Tensor | dict: Generated image tensor or dictionary containing
\
more data.
"""
if
get_prev_res
or
return_noise
:
prev_res_list
=
[]
noise_list
=
[]
if
input_sample
is
None
:
input_sample
=
torch
.
zeros
(
(
num_batches
,
3
,
fixed_noises
[
0
].
shape
[
-
2
],
fixed_noises
[
0
].
shape
[
-
1
])).
to
(
fixed_noises
[
0
])
g_res
=
input_sample
for
stage
in
range
(
curr_scale
+
1
):
if
rand_mode
==
'recon'
:
noise_
=
fixed_noises
[
stage
]
else
:
noise_
=
torch
.
randn
(
num_batches
,
*
fixed_noises
[
stage
].
shape
[
1
:]).
to
(
g_res
)
if
return_noise
:
noise_list
.
append
(
noise_
)
# add padding at head
pad_
=
(
self
.
pad_head
,
)
*
4
noise_
=
F
.
pad
(
noise_
,
pad_
)
g_res_pad
=
F
.
pad
(
g_res
,
pad_
)
noise
=
noise_
*
noise_weights
[
stage
]
+
g_res_pad
g_res
=
self
.
blocks
[
stage
](
noise
.
detach
(),
g_res
)
if
get_prev_res
and
stage
!=
curr_scale
:
prev_res_list
.
append
(
g_res
)
# upsample, here we use interpolation from PyTorch
if
stage
!=
curr_scale
:
h_next
,
w_next
=
fixed_noises
[
stage
+
1
].
shape
[
-
2
:]
g_res
=
self
.
upsample
(
g_res
,
(
h_next
,
w_next
))
if
get_prev_res
or
return_noise
:
output_dict
=
dict
(
fake_img
=
g_res
,
prev_res_list
=
prev_res_list
,
noise_batch
=
noise_list
)
return
output_dict
return
g_res
def
check_and_load_prev_weight
(
self
,
curr_scale
):
if
curr_scale
==
0
:
return
prev_ch
=
self
.
blocks
[
curr_scale
-
1
].
base_channels
curr_ch
=
self
.
blocks
[
curr_scale
].
base_channels
prev_in_ch
=
self
.
blocks
[
curr_scale
-
1
].
in_channels
curr_in_ch
=
self
.
blocks
[
curr_scale
].
in_channels
if
prev_ch
==
curr_ch
and
prev_in_ch
==
curr_in_ch
:
load_state_dict
(
self
.
blocks
[
curr_scale
],
self
.
blocks
[
curr_scale
-
1
].
state_dict
(),
logger
=
get_root_logger
())
print_log
(
'Successfully load pretrianed model from last scale.'
)
else
:
print_log
(
'Cannot load pretrained model from last scale since'
f
' prev_ch(
{
prev_ch
}
) != curr_ch(
{
curr_ch
}
)'
f
' or prev_in_ch(
{
prev_in_ch
}
) != curr_in_ch(
{
curr_in_ch
}
)'
)
@
MODULES
.
register_module
()
class
SinGANMultiScaleDiscriminator
(
nn
.
Module
):
"""Multi-Scale Discriminator used in SinGAN.
More details can be found in: Singan: Learning a Generative Model from a
Single Natural Image, ICCV'19.
Args:
in_channels (int): Input channels.
num_scales (int): The number of scales/stages in generator. Note
that this number is counted from zero, which is the same as the
original paper.
kernel_size (int, optional): Kernel size, same as :obj:`nn.Conv2d`.
Defaults to 3.
padding (int, optional): Padding for the convolutional layer, same as
:obj:`nn.Conv2d`. Defaults to 0.
num_layers (int, optional): The number of convolutional layers in each
generator block. Defaults to 5.
base_channels (int, optional): The basic channels for convolutional
layers in the generator block. Defaults to 32.
min_feat_channels (int, optional): Minimum channels for the feature
maps in the generator block. Defaults to 32.
"""
def
__init__
(
self
,
in_channels
,
num_scales
,
kernel_size
=
3
,
padding
=
0
,
num_layers
=
5
,
base_channels
=
32
,
min_feat_channels
=
32
,
**
kwargs
):
super
().
__init__
()
self
.
blocks
=
nn
.
ModuleList
()
for
scale
in
range
(
num_scales
+
1
):
base_ch
=
min
(
base_channels
*
pow
(
2
,
int
(
np
.
floor
(
scale
/
4
))),
128
)
min_feat_ch
=
min
(
min_feat_channels
*
pow
(
2
,
int
(
np
.
floor
(
scale
/
4
))),
128
)
self
.
blocks
.
append
(
DiscriminatorBlock
(
in_channels
=
in_channels
,
kernel_size
=
kernel_size
,
padding
=
padding
,
num_layers
=
num_layers
,
base_channels
=
base_ch
,
min_feat_channels
=
min_feat_ch
,
**
kwargs
))
def
forward
(
self
,
x
,
curr_scale
):
"""Forward function.
Args:
x (Tensor): Input feature map.
curr_scale (int): Current scale for discriminator. If in testing,
you need to set it to the last scale.
Returns:
Tensor: Discriminative results.
"""
out
=
self
.
blocks
[
curr_scale
](
x
)
return
out
def
check_and_load_prev_weight
(
self
,
curr_scale
):
if
curr_scale
==
0
:
return
prev_ch
=
self
.
blocks
[
curr_scale
-
1
].
base_channels
curr_ch
=
self
.
blocks
[
curr_scale
].
base_channels
if
prev_ch
==
curr_ch
:
self
.
blocks
[
curr_scale
].
load_state_dict
(
self
.
blocks
[
curr_scale
-
1
].
state_dict
())
print_log
(
'Successfully load pretrianed model from last scale.'
)
else
:
print_log
(
'Cannot load pretrained model from last scale since'
f
' prev_ch(
{
prev_ch
}
) != curr_ch(
{
curr_ch
}
)'
)
build/lib/mmgen/models/architectures/singan/modules.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
torch.nn
as
nn
from
mmcv.cnn
import
ConvModule
,
constant_init
,
normal_init
from
mmcv.runner
import
load_checkpoint
from
mmcv.utils.parrots_wrapper
import
_BatchNorm
from
mmgen.utils
import
get_root_logger
class
GeneratorBlock
(
nn
.
Module
):
"""Generator block used in SinGAN.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
num_scales (int): The number of scales/stages in generator. Note
that this number is counted from zero, which is the same as the
original paper.
kernel_size (int, optional): Kernel size, same as :obj:`nn.Conv2d`.
Defaults to 3.
padding (int, optional): Padding for the convolutional layer, same as
:obj:`nn.Conv2d`. Defaults to 0.
num_layers (int, optional): The number of convolutional layers in each
generator block. Defaults to 5.
base_channels (int, optional): The basic channels for convolutional
layers in the generator block. Defaults to 32.
min_feat_channels (int, optional): Minimum channels for the feature
maps in the generator block. Defaults to 32.
out_act_cfg (dict | None, optional): Configs for output activation
layer. Defaults to dict(type='Tanh').
stride (int, optional): Same as :obj:`nn.Conv2d`. Defaults to 1.
allow_no_residual (bool, optional): Whether to allow no residual link
in this block. Defaults to False.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
padding
,
num_layers
,
base_channels
,
min_feat_channels
,
out_act_cfg
=
dict
(
type
=
'Tanh'
),
stride
=
1
,
allow_no_residual
=
False
,
**
kwargs
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
base_channels
=
base_channels
self
.
kernel_size
=
kernel_size
self
.
num_layers
=
num_layers
self
.
allow_no_residual
=
allow_no_residual
self
.
head
=
ConvModule
(
in_channels
=
in_channels
,
out_channels
=
base_channels
,
kernel_size
=
kernel_size
,
padding
=
padding
,
stride
=
1
,
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
),
**
kwargs
)
self
.
body
=
nn
.
Sequential
()
for
i
in
range
(
num_layers
-
2
):
feat_channels_
=
int
(
base_channels
/
pow
(
2
,
(
i
+
1
)))
block
=
ConvModule
(
max
(
2
*
feat_channels_
,
min_feat_channels
),
max
(
feat_channels_
,
min_feat_channels
),
kernel_size
=
kernel_size
,
padding
=
padding
,
stride
=
stride
,
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
),
**
kwargs
)
self
.
body
.
add_module
(
f
'block
{
i
+
1
}
'
,
block
)
self
.
tail
=
ConvModule
(
max
(
feat_channels_
,
min_feat_channels
),
out_channels
,
kernel_size
=
kernel_size
,
padding
=
padding
,
stride
=
1
,
norm_cfg
=
None
,
act_cfg
=
out_act_cfg
,
**
kwargs
)
self
.
init_weights
()
def
forward
(
self
,
x
,
prev
):
"""Forward function.
Args:
x (Tensor): Input feature map.
prev (Tensor): Previous feature map.
Returns:
Tensor: Output feature map with the shape of (N, C, H, W).
"""
x
=
self
.
head
(
x
)
x
=
self
.
body
(
x
)
x
=
self
.
tail
(
x
)
# if prev and x are not in the same shape at the channel dimension
if
self
.
allow_no_residual
and
x
.
shape
[
1
]
!=
prev
.
shape
[
1
]:
return
x
return
x
+
prev
def
init_weights
(
self
,
pretrained
=
None
):
if
isinstance
(
pretrained
,
str
):
logger
=
get_root_logger
()
load_checkpoint
(
self
,
pretrained
,
strict
=
False
,
logger
=
logger
)
elif
pretrained
is
None
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
normal_init
(
m
,
0
,
0.02
)
elif
isinstance
(
m
,
(
_BatchNorm
,
nn
.
InstanceNorm2d
)):
constant_init
(
m
,
1
)
else
:
raise
TypeError
(
'pretrained must be a str or None but'
f
' got
{
type
(
pretrained
)
}
instead.'
)
class
DiscriminatorBlock
(
nn
.
Module
):
"""Discriminator Block used in SinGAN.
Args:
in_channels (int): Input channels.
base_channels (int): Base channels for this block.
min_feat_channels (int): The minimum channels for feature map.
kernel_size (int): Size of convolutional kernel, same as
:obj:`nn.Conv2d`.
padding (int): Padding for convolutional layer, same as
:obj:`nn.Conv2d`.
num_layers (int): The number of convolutional layers in this block.
norm_cfg (dict | None, optional): Config for the normalization layer.
Defaults to dict(type='BN').
act_cfg (dict | None, optional): Config for the activation layer.
Defaults to dict(type='LeakyReLU', negative_slope=0.2).
stride (int, optional): The stride for the convolutional layer, same as
:obj:`nn.Conv2d`. Defaults to 1.
"""
def
__init__
(
self
,
in_channels
,
base_channels
,
min_feat_channels
,
kernel_size
,
padding
,
num_layers
,
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
),
stride
=
1
,
**
kwargs
):
super
().
__init__
()
self
.
base_channels
=
base_channels
self
.
stride
=
stride
self
.
head
=
ConvModule
(
in_channels
,
base_channels
,
kernel_size
=
kernel_size
,
padding
=
padding
,
stride
=
1
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
,
**
kwargs
)
self
.
body
=
nn
.
Sequential
()
for
i
in
range
(
num_layers
-
2
):
feat_channels_
=
int
(
base_channels
/
pow
(
2
,
(
i
+
1
)))
block
=
ConvModule
(
max
(
2
*
feat_channels_
,
min_feat_channels
),
max
(
feat_channels_
,
min_feat_channels
),
kernel_size
=
kernel_size
,
padding
=
padding
,
stride
=
stride
,
conv_cfg
=
None
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
,
**
kwargs
)
self
.
body
.
add_module
(
f
'block
{
i
+
1
}
'
,
block
)
self
.
tail
=
ConvModule
(
max
(
feat_channels_
,
min_feat_channels
),
1
,
kernel_size
=
kernel_size
,
padding
=
padding
,
stride
=
1
,
norm_cfg
=
None
,
act_cfg
=
None
,
**
kwargs
)
self
.
init_weights
()
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map.
"""
x
=
self
.
head
(
x
)
x
=
self
.
body
(
x
)
x
=
self
.
tail
(
x
)
return
x
# TODO: study the effects of init functions
def
init_weights
(
self
,
pretrained
=
None
):
if
isinstance
(
pretrained
,
str
):
logger
=
get_root_logger
()
load_checkpoint
(
self
,
pretrained
,
strict
=
False
,
logger
=
logger
)
elif
pretrained
is
None
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
normal_init
(
m
,
0
,
0.02
)
elif
isinstance
(
m
,
(
_BatchNorm
,
nn
.
InstanceNorm2d
)):
constant_init
(
m
,
1
)
else
:
raise
TypeError
(
'pretrained must be a str or None but'
f
' got
{
type
(
pretrained
)
}
instead.'
)
build/lib/mmgen/models/architectures/singan/positional_encoding.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
"""Implementation for Positional Encoding as Spatial Inductive Bias in GANs.
In this module, we provide necessary components to conduct experiments
mentioned in the paper: Positional Encoding as Spatial Inductive Bias in GANs.
More details can be found in: https://arxiv.org/pdf/2012.05217.pdf
"""
from
functools
import
partial
import
mmcv
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmgen.models.builder
import
MODULES
,
build_module
from
.generator_discriminator
import
SinGANMultiScaleGenerator
from
.modules
import
GeneratorBlock
@
MODULES
.
register_module
()
class
SinGANMSGeneratorPE
(
SinGANMultiScaleGenerator
):
"""Multi-Scale Generator used in SinGAN with positional encoding.
More details can be found in: Positional Encoding as Spatial Inductvie Bias
in GANs, CVPR'2021.
Notes:
- In this version, we adopt the interpolation function from the official
PyTorch APIs, which is different from the original implementation by the
authors. However, in our experiments, this influence can be ignored.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
num_scales (int): The number of scales/stages in generator. Note
that this number is counted from zero, which is the same as the
original paper.
kernel_size (int, optional): Kernel size, same as :obj:`nn.Conv2d`.
Defaults to 3.
padding (int, optional): Padding for the convolutional layer, same as
:obj:`nn.Conv2d`. Defaults to 0.
num_layers (int, optional): The number of convolutional layers in each
generator block. Defaults to 5.
base_channels (int, optional): The basic channels for convolutional
layers in the generator block. Defaults to 32.
min_feat_channels (int, optional): Minimum channels for the feature
maps in the generator block. Defaults to 32.
out_act_cfg (dict | None, optional): Configs for output activation
layer. Defaults to dict(type='Tanh').
padding_mode (str, optional): The mode of convolutional padding, same
as :obj:`nn.Conv2d`. Defaults to 'zero'.
pad_at_head (bool, optional): Whether to add padding at head.
Defaults to True.
interp_pad (bool, optional): The padding value of interpolating feature
maps. Defaults to False.
noise_with_pad (bool, optional): Whether the input fixed noises are
with explicit padding. Defaults to False.
positional_encoding (dict | None, optional): Configs for the positional
encoding. Defaults to None.
first_stage_in_channels (int | None, optional): The input channel of
the first generator block. If None, the first stage will adopt the
same input channels as other stages. Defaults to None.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
num_scales
,
kernel_size
=
3
,
padding
=
0
,
num_layers
=
5
,
base_channels
=
32
,
min_feat_channels
=
32
,
out_act_cfg
=
dict
(
type
=
'Tanh'
),
padding_mode
=
'zero'
,
pad_at_head
=
True
,
interp_pad
=
False
,
noise_with_pad
=
False
,
positional_encoding
=
None
,
first_stage_in_channels
=
None
,
**
kwargs
):
super
(
SinGANMultiScaleGenerator
,
self
).
__init__
()
self
.
pad_at_head
=
pad_at_head
self
.
interp_pad
=
interp_pad
self
.
noise_with_pad
=
noise_with_pad
self
.
with_positional_encode
=
positional_encoding
is
not
None
if
self
.
with_positional_encode
:
self
.
head_position_encode
=
build_module
(
positional_encoding
)
self
.
pad_head
=
int
((
kernel_size
-
1
)
/
2
*
num_layers
)
self
.
blocks
=
nn
.
ModuleList
()
self
.
upsample
=
partial
(
F
.
interpolate
,
mode
=
'bicubic'
,
align_corners
=
True
)
for
scale
in
range
(
num_scales
+
1
):
base_ch
=
min
(
base_channels
*
pow
(
2
,
int
(
np
.
floor
(
scale
/
4
))),
128
)
min_feat_ch
=
min
(
min_feat_channels
*
pow
(
2
,
int
(
np
.
floor
(
scale
/
4
))),
128
)
if
scale
==
0
:
in_ch
=
(
first_stage_in_channels
if
first_stage_in_channels
else
in_channels
)
else
:
in_ch
=
in_channels
self
.
blocks
.
append
(
GeneratorBlock
(
in_channels
=
in_ch
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
padding
=
padding
,
num_layers
=
num_layers
,
base_channels
=
base_ch
,
min_feat_channels
=
min_feat_ch
,
out_act_cfg
=
out_act_cfg
,
padding_mode
=
padding_mode
,
**
kwargs
))
if
padding_mode
==
'zero'
:
self
.
noise_padding_layer
=
nn
.
ZeroPad2d
(
self
.
pad_head
)
self
.
img_padding_layer
=
nn
.
ZeroPad2d
(
self
.
pad_head
)
self
.
mask_padding_layer
=
nn
.
ReflectionPad2d
(
self
.
pad_head
)
elif
padding_mode
==
'reflect'
:
self
.
noise_padding_layer
=
nn
.
ReflectionPad2d
(
self
.
pad_head
)
self
.
img_padding_layer
=
nn
.
ReflectionPad2d
(
self
.
pad_head
)
self
.
mask_padding_layer
=
nn
.
ReflectionPad2d
(
self
.
pad_head
)
mmcv
.
print_log
(
'Using Reflection padding'
,
'mmgen'
)
else
:
raise
NotImplementedError
(
f
'Padding mode
{
padding_mode
}
is not supported'
)
def
forward
(
self
,
input_sample
,
fixed_noises
,
noise_weights
,
rand_mode
,
curr_scale
,
num_batches
=
1
,
get_prev_res
=
False
,
return_noise
=
False
):
"""Forward function.
Args:
input_sample (Tensor | None): The input for generator. In the
original implementation, a tensor filled with zeros is adopted.
If None is given, we will construct it from the first fixed
noises.
fixed_noises (list[Tensor]): List of the fixed noises in SinGAN.
noise_weights (list[float]): List of the weights for random noises.
rand_mode (str): Choices from ['rand', 'recon']. In ``rand`` mode,
it will sample from random noises. Otherwise, the
reconstruction for the single image will be returned.
curr_scale (int): The scale for the current inference or training.
num_batches (int, optional): The number of batches. Defaults to 1.
get_prev_res (bool, optional): Whether to return results from
previous stages. Defaults to False.
return_noise (bool, optional): Whether to return noises tensor.
Defaults to False.
Returns:
Tensor | dict: Generated image tensor or dictionary containing
\
more data.
"""
if
get_prev_res
or
return_noise
:
prev_res_list
=
[]
noise_list
=
[]
if
input_sample
is
None
:
input_sample
=
torch
.
zeros
(
(
num_batches
,
3
,
fixed_noises
[
0
].
shape
[
-
2
],
fixed_noises
[
0
].
shape
[
-
1
])).
to
(
fixed_noises
[
0
])
g_res
=
input_sample
for
stage
in
range
(
curr_scale
+
1
):
if
rand_mode
==
'recon'
:
noise_
=
fixed_noises
[
stage
]
else
:
noise_
=
torch
.
randn
(
num_batches
,
*
fixed_noises
[
stage
].
shape
[
1
:]).
to
(
g_res
)
if
return_noise
:
noise_list
.
append
(
noise_
)
if
self
.
with_positional_encode
and
stage
==
0
:
head_grid
=
self
.
head_position_encode
(
fixed_noises
[
0
])
noise_
=
noise_
+
head_grid
# add padding at head
if
self
.
pad_at_head
:
if
self
.
interp_pad
:
if
self
.
noise_with_pad
:
size
=
noise_
.
shape
[
-
2
:]
else
:
size
=
(
noise_
.
size
(
2
)
+
2
*
self
.
pad_head
,
noise_
.
size
(
3
)
+
2
*
self
.
pad_head
)
noise_
=
self
.
upsample
(
noise_
,
size
)
g_res_pad
=
self
.
upsample
(
g_res
,
size
)
else
:
if
not
self
.
noise_with_pad
:
noise_
=
self
.
noise_padding_layer
(
noise_
)
g_res_pad
=
self
.
img_padding_layer
(
g_res
)
else
:
g_res_pad
=
g_res
if
stage
==
0
and
self
.
with_positional_encode
:
noise
=
noise_
*
noise_weights
[
stage
]
else
:
noise
=
noise_
*
noise_weights
[
stage
]
+
g_res_pad
g_res
=
self
.
blocks
[
stage
](
noise
.
detach
(),
g_res
)
if
get_prev_res
and
stage
!=
curr_scale
:
prev_res_list
.
append
(
g_res
)
# upsample, here we use interpolation from PyTorch
if
stage
!=
curr_scale
:
h_next
,
w_next
=
fixed_noises
[
stage
+
1
].
shape
[
-
2
:]
if
self
.
noise_with_pad
:
# remove the additional padding if noise with pad
h_next
-=
2
*
self
.
pad_head
w_next
-=
2
*
self
.
pad_head
g_res
=
self
.
upsample
(
g_res
,
(
h_next
,
w_next
))
if
get_prev_res
or
return_noise
:
output_dict
=
dict
(
fake_img
=
g_res
,
prev_res_list
=
prev_res_list
,
noise_batch
=
noise_list
)
return
output_dict
return
g_res
build/lib/mmgen/models/architectures/sngan_proj/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.generator_discriminator
import
ProjDiscriminator
,
SNGANGenerator
from
.modules
import
SNGANDiscHeadResBlock
,
SNGANDiscResBlock
,
SNGANGenResBlock
__all__
=
[
'ProjDiscriminator'
,
'SNGANGenerator'
,
'SNGANGenResBlock'
,
'SNGANDiscResBlock'
,
'SNGANDiscHeadResBlock'
]
build/lib/mmgen/models/architectures/sngan_proj/generator_discriminator.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
(
ConvModule
,
build_activation_layer
,
constant_init
,
xavier_init
)
from
mmcv.runner
import
load_checkpoint
from
mmcv.runner.checkpoint
import
_load_checkpoint_with_prefix
from
mmcv.utils
import
is_list_of
from
torch.nn.init
import
xavier_uniform_
from
torch.nn.utils
import
spectral_norm
from
mmgen.models.builder
import
MODULES
,
build_module
from
mmgen.utils
import
check_dist_init
from
mmgen.utils.logger
import
get_root_logger
from
..common
import
get_module_device
@
MODULES
.
register_module
(
'SAGANGenerator'
)
@
MODULES
.
register_module
()
class
SNGANGenerator
(
nn
.
Module
):
r
"""Generator for SNGAN / Proj-GAN. The implementation refers to
https://github.com/pfnet-research/sngan_projection/tree/master/gen_models
In our implementation, we have two notable design. Namely,
``channels_cfg`` and ``blocks_cfg``.
``channels_cfg``: In default config of SNGAN / Proj-GAN, the number of
ResBlocks and the channels of those blocks are corresponding to the
resolution of the output image. Therefore, we allow user to define
``channels_cfg`` to try their own models. We also provide a default
config to allow users to build the model only from the output
resolution.
``block_cfg``: In reference code, the generator consists of a group of
ResBlock. However, in our implementation, to make this model more
generalize, we support defining ``blocks_cfg`` by users and loading
the blocks by calling the build_module method.
Args:
output_scale (int): Output scale for the generated image.
num_classes (int, optional): The number classes you would like to
generate. This arguments would influence the structure of the
intermedia blocks and label sampling operation in ``forward``
(e.g. If num_classes=0, ConditionalNormalization layers would
degrade to unconditional ones.). This arguments would be passed
to intermedia blocks by overwrite their config. Defaults to 0.
base_channels (int, optional): The basic channel number of the
generator. The other layers contains channels based on this number.
Default to 64.
out_channels (int, optional): Channels of the output images.
Default to 3.
input_scale (int, optional): Input scale for the features.
Defaults to 4.
noise_size (int, optional): Size of the input noise vector.
Default to 128.
attention_cfg (dict, optional): Config for the self-attention block.
Default to ``dict(type='SelfAttentionBlock')``.
attention_after_nth_block (int | list[int], optional): Self attention
block would be added after which *ConvBlock*. If ``int`` is passed,
only one attention block would be added. If ``list`` is passed,
self-attention blocks would be added after multiple ConvBlocks.
To be noted that if the input is smaller than ``1``,
self-attention corresponding to this index would be ignored.
Default to 0.
channels_cfg (list | dict[list], optional): Config for input channels
of the intermedia blocks. If list is passed, each element of the
list means the input channels of current block is how many times
compared to the ``base_channels``. For block ``i``, the input and
output channels should be ``channels_cfg[i]`` and
``channels_cfg[i+1]`` If dict is provided, the key of the dict
should be the output scale and corresponding value should be a list
to define channels. Default: Please refer to
``_defualt_channels_cfg``.
blocks_cfg (dict, optional): Config for the intermedia blocks.
Defaults to ``dict(type='SNGANGenResBlock')``
act_cfg (dict, optional): Activation config for the final output
layer. Defaults to ``dict(type='ReLU')``.
use_cbn (bool, optional): Whether use conditional normalization. This
argument would pass to norm layers. Defaults to True.
auto_sync_bn (bool, optional): Whether convert Batch Norm to
Synchronized ones when Distributed training is on. Defaults to
True.
with_spectral_norm (bool, optional): Whether use spectral norm for
conv blocks or not. Default to False.
with_embedding_spectral_norm (bool, optional): Whether use spectral
norm for embedding layers in normalization blocks or not. If not
specified (set as ``None``), ``with_embedding_spectral_norm`` would
be set as the same value as ``with_spectral_norm``.
Defaults to None.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
norm_eps (float, optional): eps for Normalization layers (both
conditional and non-conditional ones). Default to `1e-4`.
sn_eps (float, optional): eps for spectral normalization operation.
Defaults to `1e-12`.
init_cfg (string, optional): Config for weight initialization.
Defaults to ``dict(type='BigGAN')``.
pretrained (str | dict, optional): Path for the pretrained model or
dict containing information for pretained models whose necessary
key is 'ckpt_path'. Besides, you can also provide 'prefix' to load
the generator part from the whole state dict. Defaults to None.
"""
# default channel factors
_default_channels_cfg
=
{
32
:
[
1
,
1
,
1
],
64
:
[
16
,
8
,
4
,
2
],
128
:
[
16
,
16
,
8
,
4
,
2
]
}
def
__init__
(
self
,
output_scale
,
num_classes
=
0
,
base_channels
=
64
,
out_channels
=
3
,
input_scale
=
4
,
noise_size
=
128
,
attention_cfg
=
dict
(
type
=
'SelfAttentionBlock'
),
attention_after_nth_block
=
0
,
channels_cfg
=
None
,
blocks_cfg
=
dict
(
type
=
'SNGANGenResBlock'
),
act_cfg
=
dict
(
type
=
'ReLU'
),
use_cbn
=
True
,
auto_sync_bn
=
True
,
with_spectral_norm
=
False
,
with_embedding_spectral_norm
=
None
,
sn_style
=
'torch'
,
norm_eps
=
1e-4
,
sn_eps
=
1e-12
,
init_cfg
=
dict
(
type
=
'BigGAN'
),
pretrained
=
None
):
super
().
__init__
()
self
.
input_scale
=
input_scale
self
.
output_scale
=
output_scale
self
.
noise_size
=
noise_size
self
.
num_classes
=
num_classes
self
.
init_type
=
init_cfg
.
get
(
'type'
,
None
)
self
.
blocks_cfg
=
deepcopy
(
blocks_cfg
)
self
.
blocks_cfg
.
setdefault
(
'num_classes'
,
num_classes
)
self
.
blocks_cfg
.
setdefault
(
'act_cfg'
,
act_cfg
)
self
.
blocks_cfg
.
setdefault
(
'use_cbn'
,
use_cbn
)
self
.
blocks_cfg
.
setdefault
(
'auto_sync_bn'
,
auto_sync_bn
)
self
.
blocks_cfg
.
setdefault
(
'with_spectral_norm'
,
with_spectral_norm
)
# set `norm_spectral_norm` as `with_spectral_norm` if not defined
with_embedding_spectral_norm
=
with_embedding_spectral_norm
\
if
with_embedding_spectral_norm
is
not
None
else
with_spectral_norm
self
.
blocks_cfg
.
setdefault
(
'with_embedding_spectral_norm'
,
with_embedding_spectral_norm
)
self
.
blocks_cfg
.
setdefault
(
'init_cfg'
,
init_cfg
)
self
.
blocks_cfg
.
setdefault
(
'sn_style'
,
sn_style
)
self
.
blocks_cfg
.
setdefault
(
'norm_eps'
,
norm_eps
)
self
.
blocks_cfg
.
setdefault
(
'sn_eps'
,
sn_eps
)
channels_cfg
=
deepcopy
(
self
.
_default_channels_cfg
)
\
if
channels_cfg
is
None
else
deepcopy
(
channels_cfg
)
if
isinstance
(
channels_cfg
,
dict
):
if
output_scale
not
in
channels_cfg
:
raise
KeyError
(
f
'`output_scale=
{
output_scale
}
is not found in '
'`channel_cfg`, only support configs for '
f
'
{
[
chn
for
chn
in
channels_cfg
.
keys
()]
}
'
)
self
.
channel_factor_list
=
channels_cfg
[
output_scale
]
elif
isinstance
(
channels_cfg
,
list
):
self
.
channel_factor_list
=
channels_cfg
else
:
raise
ValueError
(
'Only support list or dict for `channel_cfg`, '
f
'receive
{
type
(
channels_cfg
)
}
'
)
self
.
noise2feat
=
nn
.
Linear
(
noise_size
,
input_scale
**
2
*
base_channels
*
self
.
channel_factor_list
[
0
])
if
with_spectral_norm
:
self
.
noise2feat
=
spectral_norm
(
self
.
noise2feat
)
# check `attention_after_nth_block`
if
not
isinstance
(
attention_after_nth_block
,
list
):
attention_after_nth_block
=
[
attention_after_nth_block
]
if
not
is_list_of
(
attention_after_nth_block
,
int
):
raise
ValueError
(
'`attention_after_nth_block` only support int or '
'a list of int. Please check your input type.'
)
self
.
conv_blocks
=
nn
.
ModuleList
()
self
.
attention_block_idx
=
[]
for
idx
in
range
(
len
(
self
.
channel_factor_list
)):
factor_input
=
self
.
channel_factor_list
[
idx
]
factor_output
=
self
.
channel_factor_list
[
idx
+
1
]
\
if
idx
<
len
(
self
.
channel_factor_list
)
-
1
else
1
# get block-specific config
block_cfg_
=
deepcopy
(
self
.
blocks_cfg
)
block_cfg_
[
'in_channels'
]
=
factor_input
*
base_channels
block_cfg_
[
'out_channels'
]
=
factor_output
*
base_channels
self
.
conv_blocks
.
append
(
build_module
(
block_cfg_
))
# build self-attention block
# `idx` is start from 0, add 1 to get the index
if
idx
+
1
in
attention_after_nth_block
:
self
.
attention_block_idx
.
append
(
len
(
self
.
conv_blocks
))
attn_cfg_
=
deepcopy
(
attention_cfg
)
attn_cfg_
[
'in_channels'
]
=
factor_output
*
base_channels
attn_cfg_
[
'sn_style'
]
=
sn_style
self
.
conv_blocks
.
append
(
build_module
(
attn_cfg_
))
to_rgb_norm_cfg
=
dict
(
type
=
'BN'
,
eps
=
norm_eps
)
if
check_dist_init
()
and
auto_sync_bn
:
to_rgb_norm_cfg
[
'type'
]
=
'SyncBN'
self
.
to_rgb
=
ConvModule
(
factor_output
*
base_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
,
norm_cfg
=
to_rgb_norm_cfg
,
act_cfg
=
act_cfg
,
order
=
(
'norm'
,
'act'
,
'conv'
),
with_spectral_norm
=
with_spectral_norm
)
self
.
final_act
=
build_activation_layer
(
dict
(
type
=
'Tanh'
))
self
.
init_weights
(
pretrained
)
def
forward
(
self
,
noise
,
num_batches
=
0
,
label
=
None
,
return_noise
=
False
):
"""Forward function.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
label (torch.Tensor | callable | None): You can directly give a
batch of label through a ``torch.Tensor`` or offer a callable
function to sample a batch of label data. Otherwise, the
``None`` indicates to use the default label sampler.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
Returns:
torch.Tensor | dict: If not ``return_noise``, only the output
image will be returned. Otherwise, a dict contains
``fake_image``, ``noise_batch`` and ``label_batch``
would be returned.
"""
if
isinstance
(
noise
,
torch
.
Tensor
):
assert
noise
.
shape
[
1
]
==
self
.
noise_size
assert
noise
.
ndim
==
2
,
(
'The noise should be in shape of (n, c), '
f
'but got
{
noise
.
shape
}
'
)
noise_batch
=
noise
# receive a noise generator and sample noise.
elif
callable
(
noise
):
noise_generator
=
noise
assert
num_batches
>
0
noise_batch
=
noise_generator
((
num_batches
,
self
.
noise_size
))
# otherwise, we will adopt default noise sampler.
else
:
assert
num_batches
>
0
noise_batch
=
torch
.
randn
((
num_batches
,
self
.
noise_size
))
if
isinstance
(
label
,
torch
.
Tensor
):
assert
label
.
ndim
==
1
,
(
'The label shoube be in shape of (n, )'
f
'but got
{
label
.
shape
}
.'
)
label_batch
=
label
elif
callable
(
label
):
label_generator
=
label
assert
num_batches
>
0
label_batch
=
label_generator
(
num_batches
)
elif
self
.
num_classes
==
0
:
label_batch
=
None
else
:
assert
num_batches
>
0
label_batch
=
torch
.
randint
(
0
,
self
.
num_classes
,
(
num_batches
,
))
# dirty code for putting data on the right device
noise_batch
=
noise_batch
.
to
(
get_module_device
(
self
))
if
label_batch
is
not
None
:
label_batch
=
label_batch
.
to
(
get_module_device
(
self
))
x
=
self
.
noise2feat
(
noise_batch
)
x
=
x
.
reshape
(
x
.
size
(
0
),
-
1
,
self
.
input_scale
,
self
.
input_scale
)
for
idx
,
conv_block
in
enumerate
(
self
.
conv_blocks
):
if
idx
in
self
.
attention_block_idx
:
x
=
conv_block
(
x
)
else
:
x
=
conv_block
(
x
,
label_batch
)
out_feat
=
self
.
to_rgb
(
x
)
out_img
=
self
.
final_act
(
out_feat
)
if
return_noise
:
return
dict
(
fake_img
=
out_img
,
noise_batch
=
noise_batch
,
label
=
label_batch
)
return
out_img
def
init_weights
(
self
,
pretrained
=
None
,
strict
=
True
):
"""Init weights for SNGAN-Proj and SAGAN. If ``pretrained=None``,
weight initialization would follow the ``INIT_TYPE`` in
``init_cfg=dict(type=INIT_TYPE)``.
For SNGAN-Proj,
(``INIT_TYPE.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']``),
we follow the initialization method in the official Chainer's
implementation (https://github.com/pfnet-research/sngan_projection).
For SAGAN (``INIT_TYPE.upper() == 'SAGAN'``), we follow the
initialization method in official tensorflow's implementation
(https://github.com/brain-research/self-attention-gan).
Besides the reimplementation of the official code's initialization, we
provide BigGAN's and Pytorch-StudioGAN's style initialization
(``INIT_TYPE.upper() == BIGGAN`` and ``INIT_TYPE.upper() == STUDIO``).
Please refer to https://github.com/ajbrock/BigGAN-PyTorch and
https://github.com/POSTECH-CVLab/PyTorch-StudioGAN.
Args:
pretrained (str | dict, optional): Path for the pretrained model or
dict containing information for pretained models whose
necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
"""
if
isinstance
(
pretrained
,
str
):
logger
=
get_root_logger
()
load_checkpoint
(
self
,
pretrained
,
strict
=
strict
,
logger
=
logger
)
elif
isinstance
(
pretrained
,
dict
):
ckpt_path
=
pretrained
.
get
(
'ckpt_path'
,
None
)
assert
ckpt_path
is
not
None
prefix
=
pretrained
.
get
(
'prefix'
,
''
)
map_location
=
pretrained
.
get
(
'map_location'
,
'cpu'
)
strict
=
pretrained
.
get
(
'strict'
,
True
)
state_dict
=
_load_checkpoint_with_prefix
(
prefix
,
ckpt_path
,
map_location
)
self
.
load_state_dict
(
state_dict
,
strict
=
strict
)
elif
pretrained
is
None
:
if
self
.
init_type
.
upper
()
in
'STUDIO'
:
# initialization method from Pytorch-StudioGAN
# * weight: orthogonal_init gain=1
# * bias : 0
for
m
in
self
.
modules
():
if
isinstance
(
m
,
(
nn
.
Conv2d
,
nn
.
Linear
,
nn
.
Embedding
)):
nn
.
init
.
orthogonal_
(
m
.
weight
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
m
.
bias
.
data
.
fill_
(
0.
)
elif
self
.
init_type
.
upper
()
==
'BIGGAN'
:
# initialization method from BigGAN-pytorch
# * weight: xavier_init gain=1
# * bias : default
for
n
,
m
in
self
.
named_modules
():
if
isinstance
(
m
,
(
nn
.
Conv2d
,
nn
.
Linear
,
nn
.
Embedding
)):
xavier_uniform_
(
m
.
weight
,
gain
=
1
)
elif
self
.
init_type
.
upper
()
==
'SAGAN'
:
# initialization method from official tensorflow code
# * weight : xavier_init gain=1
# * bias : 0
# * weight_embedding: 1
# * bias_embedding : 0
for
n
,
m
in
self
.
named_modules
():
if
isinstance
(
m
,
(
nn
.
Conv2d
,
nn
.
Linear
)):
xavier_init
(
m
,
gain
=
1
,
distribution
=
'uniform'
)
if
isinstance
(
m
,
nn
.
Embedding
):
# To be noted that here we initialize the embedding
# layer in cBN with specific prefix. If you implement
# your own cBN and want to use this initialization
# method, please make sure the embedding layers in
# your implementation have the same prefix as ours.
if
'weight'
in
n
:
constant_init
(
m
,
1
)
if
'bias'
in
n
:
constant_init
(
m
,
0
)
elif
self
.
init_type
.
upper
()
in
[
'SNGAN'
,
'SNGAN-PROJ'
,
'GAN-PROJ'
]:
# initialization method from the official chainer code
# * conv.weight : xavier_init gain=sqrt(2)
# * shortcut.weight : xavier_init gain=1
# * bias : 0
# * weight_embedding: 1
# * bias_embedding : 0
for
n
,
m
in
self
.
named_modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
if
'shortcut'
in
n
or
'to_rgb'
in
n
:
xavier_init
(
m
,
gain
=
1
,
distribution
=
'uniform'
)
else
:
xavier_init
(
m
,
gain
=
np
.
sqrt
(
2
),
distribution
=
'uniform'
)
if
isinstance
(
m
,
nn
.
Linear
):
xavier_init
(
m
,
gain
=
1
,
distribution
=
'uniform'
)
if
isinstance
(
m
,
nn
.
Embedding
):
# To be noted that here we initialize the embedding
# layer in cBN with specific prefix. If you implement
# your own cBN and want to use this initialization
# method, please make sure the embedding layers in
# your implementation have the same prefix as ours.
if
'weight'
in
n
:
constant_init
(
m
,
1
)
if
'bias'
in
n
:
constant_init
(
m
,
0
)
else
:
raise
NotImplementedError
(
'Unknown initialization method: '
f
'
\'
{
self
.
init_type
}
\'
'
)
else
:
raise
TypeError
(
"'pretrined' must be a str or None. "
f
'But receive
{
type
(
pretrained
)
}
.'
)
@
MODULES
.
register_module
(
'SAGANDiscriminator'
)
@
MODULES
.
register_module
()
class
ProjDiscriminator
(
nn
.
Module
):
r
"""Discriminator for SNGAN / Proj-GAN. The implementation is refer to
https://github.com/pfnet-research/sngan_projection/tree/master/dis_models
The overall structure of the projection discriminator can be split into a
``from_rgb`` layer, a group of ResBlocks, a linear decision layer, and a
projection layer. To support defining custom layers, we introduce
``from_rgb_cfg`` and ``blocks_cfg``.
The design of the model structure is highly corresponding to the output
resolution. Therefore, we provide `channels_cfg` and `downsample_cfg` to
control the input channels and the downsample behavior of the intermedia
blocks.
``downsample_cfg``: In default config of SNGAN / Proj-GAN, whether to apply
downsample in each intermedia blocks is quite flexible and
corresponding to the resolution of the output image. Therefore, we
support user to define the ``downsample_cfg`` by themselves, and to
control the structure of the discriminator.
``channels_cfg``: In default config of SNGAN / Proj-GAN, the number of
ResBlocks and the channels of those blocks are corresponding to the
resolution of the output image. Therefore, we allow user to define
`channels_cfg` for try their own models. We also provide a default
config to allow users to build the model only from the output
resolution.
Args:
input_scale (int): The scale of the input image.
num_classes (int, optional): The number classes you would like to
generate. If num_classes=0, no label projection would be used.
Default to 0.
base_channels (int, optional): The basic channel number of the
discriminator. The other layers contains channels based on this
number. Defaults to 128.
input_channels (int, optional): Channels of the input image.
Defaults to 3.
attention_cfg (dict, optional): Config for the self-attention block.
Default to ``dict(type='SelfAttentionBlock')``.
attention_after_nth_block (int | list[int], optional): Self-attention
block would be added after which *ConvBlock* (including the head
block). If ``int`` is passed, only one attention block would be
added. If ``list`` is passed, self-attention blocks would be added
after multiple ConvBlocks. To be noted that if the input is
smaller than ``1``, self-attention corresponding to this index
would be ignored. Default to 0.
channels_cfg (list | dict[list], optional): Config for input channels
of the intermedia blocks. If list is passed, each element of the
list means the input channels of current block is how many times
compared to the ``base_channels``. For block ``i``, the input and
output channels should be ``channels_cfg[i]`` and
``channels_cfg[i+1]`` If dict is provided, the key of the dict
should be the output scale and corresponding value should be a list
to define channels. Default: Please refer to
``_defualt_channels_cfg``.
downsample_cfg (list[bool] | dict[list], optional): Config for
downsample behavior of the intermedia layers. If a list is passed,
``downsample_cfg[idx] == True`` means apply downsample in idx-th
block, and vice versa. If dict is provided, the key dict should
be the input scale of the image and corresponding value should be
a list ti define the downsample behavior. Default: Please refer
to ``_default_downsample_cfg``.
from_rgb_cfg (dict, optional): Config for the first layer to convert
rgb image to feature map. Defaults to
``dict(type='SNGANDiscHeadResBlock')``.
blocks_cfg (dict, optional): Config for the intermedia blocks.
Defaults to ``dict(type='SNGANDiscResBlock')``
act_cfg (dict, optional): Activation config for the final output
layer. Defaults to ``dict(type='ReLU')``.
with_spectral_norm (bool, optional): Whether use spectral norm for
all conv blocks or not. Default to True.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
sn_eps (float, optional): eps for spectral normalization operation.
Defaults to `1e-12`.
init_cfg (dict, optional): Config for weight initialization.
Default to ``dict(type='BigGAN')``.
pretrained (str | dict , optional): Path for the pretrained model or
dict containing information for pretained models whose necessary
key is 'ckpt_path'. Besides, you can also provide 'prefix' to load
the generator part from the whole state dict. Defaults to None.
"""
# default channel factors
_defualt_channels_cfg
=
{
32
:
[
1
,
1
,
1
],
64
:
[
2
,
4
,
8
,
16
],
128
:
[
2
,
4
,
8
,
16
,
16
],
}
# default downsample behavior
_defualt_downsample_cfg
=
{
32
:
[
True
,
False
,
False
],
64
:
[
True
,
True
,
True
,
True
],
128
:
[
True
,
True
,
True
,
True
,
False
]
}
def
__init__
(
self
,
input_scale
,
num_classes
=
0
,
base_channels
=
128
,
input_channels
=
3
,
attention_cfg
=
dict
(
type
=
'SelfAttentionBlock'
),
attention_after_nth_block
=-
1
,
channels_cfg
=
None
,
downsample_cfg
=
None
,
from_rgb_cfg
=
dict
(
type
=
'SNGANDiscHeadResBlock'
),
blocks_cfg
=
dict
(
type
=
'SNGANDiscResBlock'
),
act_cfg
=
dict
(
type
=
'ReLU'
),
with_spectral_norm
=
True
,
sn_style
=
'torch'
,
sn_eps
=
1e-12
,
init_cfg
=
dict
(
type
=
'BigGAN'
),
pretrained
=
None
):
super
().
__init__
()
self
.
init_type
=
init_cfg
.
get
(
'type'
,
None
)
# add SN options and activation function options to cfg
self
.
from_rgb_cfg
=
deepcopy
(
from_rgb_cfg
)
self
.
from_rgb_cfg
.
setdefault
(
'act_cfg'
,
act_cfg
)
self
.
from_rgb_cfg
.
setdefault
(
'with_spectral_norm'
,
with_spectral_norm
)
self
.
from_rgb_cfg
.
setdefault
(
'sn_style'
,
sn_style
)
self
.
from_rgb_cfg
.
setdefault
(
'init_cfg'
,
init_cfg
)
# add SN options and activation function options to cfg
self
.
blocks_cfg
=
deepcopy
(
blocks_cfg
)
self
.
blocks_cfg
.
setdefault
(
'act_cfg'
,
act_cfg
)
self
.
blocks_cfg
.
setdefault
(
'with_spectral_norm'
,
with_spectral_norm
)
self
.
blocks_cfg
.
setdefault
(
'sn_style'
,
sn_style
)
self
.
blocks_cfg
.
setdefault
(
'sn_eps'
,
sn_eps
)
self
.
blocks_cfg
.
setdefault
(
'init_cfg'
,
init_cfg
)
channels_cfg
=
deepcopy
(
self
.
_defualt_channels_cfg
)
\
if
channels_cfg
is
None
else
deepcopy
(
channels_cfg
)
if
isinstance
(
channels_cfg
,
dict
):
if
input_scale
not
in
channels_cfg
:
raise
KeyError
(
f
'`input_scale=
{
input_scale
}
is not found in '
'`channel_cfg`, only support configs for '
f
'
{
[
chn
for
chn
in
channels_cfg
.
keys
()]
}
'
)
self
.
channel_factor_list
=
channels_cfg
[
input_scale
]
elif
isinstance
(
channels_cfg
,
list
):
self
.
channel_factor_list
=
channels_cfg
else
:
raise
ValueError
(
'Only support list or dict for `channel_cfg`, '
f
'receive
{
type
(
channels_cfg
)
}
'
)
downsample_cfg
=
deepcopy
(
self
.
_defualt_downsample_cfg
)
\
if
downsample_cfg
is
None
else
deepcopy
(
downsample_cfg
)
if
isinstance
(
downsample_cfg
,
dict
):
if
input_scale
not
in
downsample_cfg
:
raise
KeyError
(
f
'`output_scale=
{
input_scale
}
is not found in '
'`downsample_cfg`, only support configs for '
f
'
{
[
chn
for
chn
in
downsample_cfg
.
keys
()]
}
'
)
self
.
downsample_list
=
downsample_cfg
[
input_scale
]
elif
isinstance
(
downsample_cfg
,
list
):
self
.
downsample_list
=
downsample_cfg
else
:
raise
ValueError
(
'Only support list or dict for `channel_cfg`, '
f
'receive
{
type
(
downsample_cfg
)
}
'
)
if
len
(
self
.
downsample_list
)
!=
len
(
self
.
channel_factor_list
):
raise
ValueError
(
'`downsample_cfg` should have same length with '
'`channels_cfg`, but receive '
f
'
{
len
(
self
.
downsample_list
)
}
and '
f
'
{
len
(
self
.
channel_factor_list
)
}
.'
)
# check `attention_after_nth_block`
if
not
isinstance
(
attention_after_nth_block
,
list
):
attention_after_nth_block
=
[
attention_after_nth_block
]
if
not
all
([
isinstance
(
idx
,
int
)
for
idx
in
attention_after_nth_block
]):
raise
ValueError
(
'`attention_after_nth_block` only support int or '
'a list of int. Please check your input type.'
)
self
.
from_rgb
=
build_module
(
self
.
from_rgb_cfg
,
dict
(
in_channels
=
input_channels
,
out_channels
=
base_channels
))
self
.
conv_blocks
=
nn
.
ModuleList
()
# add self-attention block after the first block
if
1
in
attention_after_nth_block
:
attn_cfg_
=
deepcopy
(
attention_cfg
)
attn_cfg_
[
'in_channels'
]
=
base_channels
attn_cfg_
[
'sn_style'
]
=
sn_style
self
.
conv_blocks
.
append
(
build_module
(
attn_cfg_
))
for
idx
in
range
(
len
(
self
.
downsample_list
)):
factor_input
=
1
if
idx
==
0
else
self
.
channel_factor_list
[
idx
-
1
]
factor_output
=
self
.
channel_factor_list
[
idx
]
# get block-specific config
block_cfg_
=
deepcopy
(
self
.
blocks_cfg
)
block_cfg_
[
'downsample'
]
=
self
.
downsample_list
[
idx
]
block_cfg_
[
'in_channels'
]
=
factor_input
*
base_channels
block_cfg_
[
'out_channels'
]
=
factor_output
*
base_channels
self
.
conv_blocks
.
append
(
build_module
(
block_cfg_
))
# build self-attention block
# the first ConvBlock is `from_rgb` block,
# add 2 to get the index of the ConvBlocks
if
idx
+
2
in
attention_after_nth_block
:
attn_cfg_
=
deepcopy
(
attention_cfg
)
attn_cfg_
[
'in_channels'
]
=
factor_output
*
base_channels
self
.
conv_blocks
.
append
(
build_module
(
attn_cfg_
))
self
.
decision
=
nn
.
Linear
(
factor_output
*
base_channels
,
1
)
if
with_spectral_norm
:
self
.
decision
=
spectral_norm
(
self
.
decision
)
self
.
num_classes
=
num_classes
# In this case, discriminator is designed for conditional synthesis.
if
num_classes
>
0
:
self
.
proj_y
=
nn
.
Embedding
(
num_classes
,
factor_output
*
base_channels
)
if
with_spectral_norm
:
self
.
proj_y
=
spectral_norm
(
self
.
proj_y
)
self
.
activate
=
build_activation_layer
(
act_cfg
)
self
.
init_weights
(
pretrained
)
def
forward
(
self
,
x
,
label
=
None
):
"""Forward function. If `self.num_classes` is larger than 0, label
projection would be used.
Args:
x (torch.Tensor): Fake or real image tensor.
label (torch.Tensor, options): Label correspond to the input image.
Noted that, if `self.num_classed` is larger than 0,
`label` should not be None. Default to None.
Returns:
torch.Tensor: Prediction for the reality of the input image.
"""
h
=
self
.
from_rgb
(
x
)
for
conv_block
in
self
.
conv_blocks
:
h
=
conv_block
(
h
)
h
=
self
.
activate
(
h
)
h
=
torch
.
sum
(
h
,
dim
=
[
2
,
3
])
out
=
self
.
decision
(
h
)
if
self
.
num_classes
>
0
:
w_y
=
self
.
proj_y
(
label
)
out
=
out
+
torch
.
sum
(
w_y
*
h
,
dim
=
1
,
keepdim
=
True
)
return
out
.
view
(
out
.
size
(
0
),
-
1
)
def
init_weights
(
self
,
pretrained
=
None
,
strict
=
True
):
"""Init weights for SNGAN-Proj and SAGAN. If ``pretrained=None`` and
weight initialization would follow the ``INIT_TYPE`` in
``init_cfg=dict(type=INIT_TYPE)``.
For SNGAN-Proj
(``INIT_TYPE.upper() in ['SNGAN', 'SNGAN-PROJ', 'GAN-PROJ']``),
we follow the initialization method in the official Chainer's
implementation (https://github.com/pfnet-research/sngan_projection).
For SAGAN (``INIT_TYPE.upper() == 'SAGAN'``), we follow the
initialization method in official tensorflow's implementation
(https://github.com/brain-research/self-attention-gan).
Besides the reimplementation of the official code's initialization, we
provide BigGAN's and Pytorch-StudioGAN's style initialization
(``INIT_TYPE.upper() == BIGGAN`` and ``INIT_TYPE.upper() == STUDIO``).
Please refer to https://github.com/ajbrock/BigGAN-PyTorch and
https://github.com/POSTECH-CVLab/PyTorch-StudioGAN.
Args:
pretrained (str | dict, optional): Path for the pretrained model or
dict containing information for pretained models whose
necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
"""
if
isinstance
(
pretrained
,
str
):
logger
=
get_root_logger
()
load_checkpoint
(
self
,
pretrained
,
strict
=
strict
,
logger
=
logger
)
elif
isinstance
(
pretrained
,
dict
):
ckpt_path
=
pretrained
.
get
(
'ckpt_path'
,
None
)
assert
ckpt_path
is
not
None
prefix
=
pretrained
.
get
(
'prefix'
,
''
)
map_location
=
pretrained
.
get
(
'map_location'
,
'cpu'
)
strict
=
pretrained
.
get
(
'strict'
,
True
)
state_dict
=
_load_checkpoint_with_prefix
(
prefix
,
ckpt_path
,
map_location
)
self
.
load_state_dict
(
state_dict
,
strict
=
strict
)
elif
pretrained
is
None
:
if
self
.
init_type
.
upper
()
==
'STUDIO'
:
# initialization method from Pytorch-StudioGAN
# * weight: orthogonal_init gain=1
# * bias : 0
for
m
in
self
.
modules
():
if
isinstance
(
m
,
(
nn
.
Conv2d
,
nn
.
Linear
,
nn
.
Embedding
)):
nn
.
init
.
orthogonal_
(
m
.
weight
,
gain
=
1
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
m
.
bias
.
data
.
fill_
(
0.
)
elif
self
.
init_type
.
upper
()
==
'BIGGAN'
:
# initialization method from BigGAN-pytorch
# * weight: xavier_init gain=1
# * bias : default
for
m
in
self
.
modules
():
if
isinstance
(
m
,
(
nn
.
Conv2d
,
nn
.
Linear
,
nn
.
Embedding
)):
xavier_uniform_
(
m
.
weight
,
gain
=
1
)
elif
self
.
init_type
.
upper
()
==
'SAGAN'
:
# initialization method from official tensorflow code
# * weight: xavier_init gain=1
# * bias : 0
for
m
in
self
.
modules
():
if
isinstance
(
m
,
(
nn
.
Conv2d
,
nn
.
Linear
,
nn
.
Embedding
)):
xavier_init
(
m
,
gain
=
1
,
distribution
=
'uniform'
)
elif
self
.
init_type
.
upper
()
in
[
'SNGAN'
,
'SNGAN-PROJ'
,
'GAN-PROJ'
]:
# initialization method from the official chainer code
# * embedding.weight: xavier_init gain=1
# * conv.weight : xavier_init gain=sqrt(2)
# * shortcut.weight : xavier_init gain=1
# * bias : 0
for
n
,
m
in
self
.
named_modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
if
'shortcut'
in
n
:
xavier_init
(
m
,
gain
=
1
,
distribution
=
'uniform'
)
else
:
xavier_init
(
m
,
gain
=
np
.
sqrt
(
2
),
distribution
=
'uniform'
)
if
isinstance
(
m
,
(
nn
.
Linear
,
nn
.
Embedding
)):
xavier_init
(
m
,
gain
=
1
,
distribution
=
'uniform'
)
else
:
raise
NotImplementedError
(
'Unknown initialization method: '
f
'
\'
{
self
.
init_type
}
\'
'
)
else
:
raise
TypeError
(
"'pretrained' must by a str or None. "
f
'But receive
{
type
(
pretrained
)
}
.'
)
build/lib/mmgen/models/architectures/sngan_proj/modules.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
import
numpy
as
np
import
torch.nn
as
nn
from
mmcv.cnn
import
(
build_activation_layer
,
build_norm_layer
,
build_upsample_layer
,
constant_init
,
xavier_init
)
from
torch.nn.init
import
xavier_uniform_
from
torch.nn.utils
import
spectral_norm
from
mmgen.models.architectures.biggan.biggan_snmodule
import
SNEmbedding
from
mmgen.models.architectures.biggan.modules
import
SNConvModule
from
mmgen.models.builder
import
MODULES
from
mmgen.utils
import
check_dist_init
@
MODULES
.
register_module
()
class
SNGANGenResBlock
(
nn
.
Module
):
"""ResBlock used in Generator of SNGAN / Proj-GAN.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
hidden_channels (int, optional): Input channels of the second Conv
layer of the block. If ``None`` is given, would be set as
``out_channels``. Default to None.
num_classes (int, optional): Number of classes would like to generate.
This argument would pass to norm layers and influence the structure
and behavior of the normalization process. Default to 0.
use_cbn (bool, optional): Whether use conditional normalization. This
argument would pass to norm layers. Default to True.
use_norm_affine (bool, optional): Whether use learnable affine
parameters in norm operation when cbn is off. Default False.
act_cfg (dict, optional): Config for activate function. Default
to ``dict(type='ReLU')``.
upsample_cfg (dict, optional): Config for the upsample method.
Default to ``dict(type='nearest', scale_factor=2)``.
upsample (bool, optional): Whether apply upsample operation in this
module. Default to True.
auto_sync_bn (bool, optional): Whether convert Batch Norm to
Synchronized ones when Distributed training is on. Default to True.
conv_cfg (dict | None): Config for conv blocks of this module. If pass
``None``, would use ``_default_conv_cfg``. Default to ``None``.
with_spectral_norm (bool, optional): Whether use spectral norm for
conv blocks and norm layers. Default to True.
with_embedding_spectral_norm (bool, optional): Whether use spectral
norm for embedding layers in normalization blocks or not. If not
specified (set as ``None``), ``with_embedding_spectral_norm`` would
be set as the same value as ``with_spectral_norm``.
Default to None.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
norm_eps (float, optional): eps for Normalization layers (both
conditional and non-conditional ones). Default to `1e-4`.
sn_eps (float, optional): eps for spectral normalization operation.
Default to `1e-12`.
init_cfg (dict, optional): Config for weight initialization.
Default to ``dict(type='BigGAN')``.
"""
_default_conv_cfg
=
dict
(
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
act_cfg
=
None
)
def
__init__
(
self
,
in_channels
,
out_channels
,
hidden_channels
=
None
,
num_classes
=
0
,
use_cbn
=
True
,
use_norm_affine
=
False
,
act_cfg
=
dict
(
type
=
'ReLU'
),
norm_cfg
=
dict
(
type
=
'BN'
),
upsample_cfg
=
dict
(
type
=
'nearest'
,
scale_factor
=
2
),
upsample
=
True
,
auto_sync_bn
=
True
,
conv_cfg
=
None
,
with_spectral_norm
=
False
,
with_embedding_spectral_norm
=
None
,
sn_style
=
'torch'
,
norm_eps
=
1e-4
,
sn_eps
=
1e-12
,
init_cfg
=
dict
(
type
=
'BigGAN'
)):
super
().
__init__
()
self
.
learnable_sc
=
in_channels
!=
out_channels
or
upsample
self
.
with_upsample
=
upsample
self
.
init_type
=
init_cfg
.
get
(
'type'
,
None
)
self
.
activate
=
build_activation_layer
(
act_cfg
)
hidden_channels
=
out_channels
if
hidden_channels
is
None
\
else
hidden_channels
if
self
.
with_upsample
:
self
.
upsample
=
build_upsample_layer
(
upsample_cfg
)
self
.
conv_cfg
=
deepcopy
(
self
.
_default_conv_cfg
)
if
conv_cfg
is
not
None
:
self
.
conv_cfg
.
update
(
conv_cfg
)
# set `norm_spectral_norm` as `with_spectral_norm` if not defined
with_embedding_spectral_norm
=
with_embedding_spectral_norm
\
if
with_embedding_spectral_norm
is
not
None
else
with_spectral_norm
sn_cfg
=
dict
(
eps
=
sn_eps
,
sn_style
=
sn_style
)
self
.
conv_1
=
SNConvModule
(
in_channels
,
hidden_channels
,
with_spectral_norm
=
with_spectral_norm
,
spectral_norm_cfg
=
sn_cfg
,
**
self
.
conv_cfg
)
self
.
conv_2
=
SNConvModule
(
hidden_channels
,
out_channels
,
with_spectral_norm
=
with_spectral_norm
,
spectral_norm_cfg
=
sn_cfg
,
**
self
.
conv_cfg
)
self
.
norm_1
=
SNConditionNorm
(
in_channels
,
num_classes
,
use_cbn
,
norm_cfg
,
use_norm_affine
,
auto_sync_bn
,
with_embedding_spectral_norm
,
sn_style
,
norm_eps
,
sn_eps
,
init_cfg
)
self
.
norm_2
=
SNConditionNorm
(
hidden_channels
,
num_classes
,
use_cbn
,
norm_cfg
,
use_norm_affine
,
auto_sync_bn
,
with_embedding_spectral_norm
,
sn_style
,
norm_eps
,
sn_eps
,
init_cfg
)
if
self
.
learnable_sc
:
# use hyperparameters-fixed shortcut here
self
.
shortcut
=
SNConvModule
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
act_cfg
=
None
,
with_spectral_norm
=
with_spectral_norm
,
spectral_norm_cfg
=
sn_cfg
)
self
.
init_weights
()
def
forward
(
self
,
x
,
y
=
None
):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
y (Tensor): Input label with shape (n, ).
Default None.
Returns:
Tensor: Forward results.
"""
out
=
self
.
norm_1
(
x
,
y
)
out
=
self
.
activate
(
out
)
if
self
.
with_upsample
:
out
=
self
.
upsample
(
out
)
out
=
self
.
conv_1
(
out
)
out
=
self
.
norm_2
(
out
,
y
)
out
=
self
.
activate
(
out
)
out
=
self
.
conv_2
(
out
)
shortcut
=
self
.
forward_shortcut
(
x
)
return
out
+
shortcut
def
forward_shortcut
(
self
,
x
):
out
=
x
if
self
.
learnable_sc
:
if
self
.
with_upsample
:
out
=
self
.
upsample
(
out
)
out
=
self
.
shortcut
(
out
)
return
out
def
init_weights
(
self
):
"""Initialize weights for the model."""
if
self
.
init_type
.
upper
()
==
'STUDIO'
:
nn
.
init
.
orthogonal_
(
self
.
conv_1
.
conv
.
weight
)
nn
.
init
.
orthogonal_
(
self
.
conv_2
.
conv
.
weight
)
self
.
conv_1
.
conv
.
bias
.
data
.
fill_
(
0.
)
self
.
conv_2
.
conv
.
bias
.
data
.
fill_
(
0.
)
if
self
.
learnable_sc
:
nn
.
init
.
orthogonal_
(
self
.
shortcut
.
conv
.
weight
)
self
.
shortcut
.
conv
.
bias
.
data
.
fill_
(
0.
)
elif
self
.
init_type
.
upper
()
==
'BIGGAN'
:
xavier_uniform_
(
self
.
conv_1
.
conv
.
weight
,
gain
=
1
)
xavier_uniform_
(
self
.
conv_2
.
conv
.
weight
,
gain
=
1
)
if
self
.
learnable_sc
:
xavier_uniform_
(
self
.
shortcut
.
conv
.
weight
,
gain
=
1
)
elif
self
.
init_type
.
upper
()
==
'SAGAN'
:
xavier_init
(
self
.
conv_1
,
gain
=
1
,
distribution
=
'uniform'
)
xavier_init
(
self
.
conv_2
,
gain
=
1
,
distribution
=
'uniform'
)
if
self
.
learnable_sc
:
xavier_init
(
self
.
shortcut
,
gain
=
1
,
distribution
=
'uniform'
)
elif
self
.
init_type
.
upper
()
in
[
'SNGAN'
,
'SNGAN-PROJ'
,
'GAN-PROJ'
]:
xavier_init
(
self
.
conv_1
,
gain
=
np
.
sqrt
(
2
),
distribution
=
'uniform'
)
xavier_init
(
self
.
conv_2
,
gain
=
np
.
sqrt
(
2
),
distribution
=
'uniform'
)
if
self
.
learnable_sc
:
xavier_init
(
self
.
shortcut
,
gain
=
1
,
distribution
=
'uniform'
)
else
:
raise
NotImplementedError
(
'Unknown initialization method: '
f
'
\'
{
self
.
init_type
}
\'
'
)
@
MODULES
.
register_module
()
class
SNGANDiscResBlock
(
nn
.
Module
):
"""resblock used in discriminator of sngan / proj-gan.
args:
in_channels (int): input channels.
out_channels (int): output channels.
hidden_channels (int, optional): input channels of the second conv
layer of the block. if ``none`` is given, would be set as
``out_channels``. Defaults to none.
downsample (bool, optional): whether apply downsample operation in this
module. Defaults to false.
act_cfg (dict, optional): config for activate function. default
to ``dict(type='relu')``.
conv_cfg (dict | none): config for conv blocks of this module. if pass
``none``, would use ``_default_conv_cfg``. default to ``none``.
with_spectral_norm (bool, optional): whether use spectral norm for
conv blocks and norm layers. Defaults to true.
sn_eps (float, optional): eps for spectral normalization operation.
Default to `1e-12`.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
init_cfg (dict, optional): Config for weight initialization.
Defaults to ``dict(type='BigGAN')``.
"""
_default_conv_cfg
=
dict
(
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
act_cfg
=
None
)
def
__init__
(
self
,
in_channels
,
out_channels
,
hidden_channels
=
None
,
downsample
=
False
,
act_cfg
=
dict
(
type
=
'ReLU'
),
conv_cfg
=
None
,
with_spectral_norm
=
True
,
sn_style
=
'torch'
,
sn_eps
=
1e-12
,
init_cfg
=
dict
(
type
=
'BigGAN'
)):
super
().
__init__
()
hidden_channels
=
out_channels
if
hidden_channels
is
None
\
else
hidden_channels
self
.
with_downsample
=
downsample
self
.
init_type
=
init_cfg
.
get
(
'type'
,
None
)
self
.
conv_cfg
=
deepcopy
(
self
.
_default_conv_cfg
)
if
conv_cfg
is
not
None
:
self
.
conv_cfg
.
update
(
conv_cfg
)
self
.
activate
=
build_activation_layer
(
act_cfg
)
sn_cfg
=
dict
(
eps
=
sn_eps
,
sn_style
=
sn_style
)
self
.
conv_1
=
SNConvModule
(
in_channels
,
hidden_channels
,
with_spectral_norm
=
with_spectral_norm
,
spectral_norm_cfg
=
sn_cfg
,
**
self
.
conv_cfg
)
self
.
conv_2
=
SNConvModule
(
hidden_channels
,
out_channels
,
with_spectral_norm
=
with_spectral_norm
,
spectral_norm_cfg
=
sn_cfg
,
**
self
.
conv_cfg
)
if
self
.
with_downsample
:
self
.
downsample
=
nn
.
AvgPool2d
(
2
,
2
)
self
.
learnable_sc
=
in_channels
!=
out_channels
or
downsample
if
self
.
learnable_sc
:
# use hyperparameters-fixed shortcut here
self
.
shortcut
=
SNConvModule
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
act_cfg
=
None
,
with_spectral_norm
=
with_spectral_norm
,
spectral_norm_cfg
=
sn_cfg
)
self
.
init_weights
()
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
out
=
self
.
activate
(
x
)
out
=
self
.
conv_1
(
out
)
out
=
self
.
activate
(
out
)
out
=
self
.
conv_2
(
out
)
if
self
.
with_downsample
:
out
=
self
.
downsample
(
out
)
shortcut
=
self
.
forward_shortcut
(
x
)
return
out
+
shortcut
def
forward_shortcut
(
self
,
x
):
out
=
x
if
self
.
learnable_sc
:
out
=
self
.
shortcut
(
out
)
if
self
.
with_downsample
:
out
=
self
.
downsample
(
out
)
return
out
def
init_weights
(
self
):
if
self
.
init_type
.
upper
()
==
'STUDIO'
:
nn
.
init
.
orthogonal_
(
self
.
conv_1
.
conv
.
weight
)
nn
.
init
.
orthogonal_
(
self
.
conv_2
.
conv
.
weight
)
self
.
conv_1
.
conv
.
bias
.
data
.
fill_
(
0.
)
self
.
conv_2
.
conv
.
bias
.
data
.
fill_
(
0.
)
if
self
.
learnable_sc
:
nn
.
init
.
orthogonal_
(
self
.
shortcut
.
conv
.
weight
)
self
.
shortcut
.
conv
.
bias
.
data
.
fill_
(
0.
)
elif
self
.
init_type
.
upper
()
==
'BIGGAN'
:
xavier_uniform_
(
self
.
conv_1
.
conv
.
weight
,
gain
=
1
)
xavier_uniform_
(
self
.
conv_2
.
conv
.
weight
,
gain
=
1
)
if
self
.
learnable_sc
:
xavier_uniform_
(
self
.
shortcut
.
conv
.
weight
,
gain
=
1
)
elif
self
.
init_type
.
upper
()
==
'SAGAN'
:
xavier_init
(
self
.
conv_1
,
gain
=
1
,
distribution
=
'uniform'
)
xavier_init
(
self
.
conv_2
,
gain
=
1
,
distribution
=
'uniform'
)
if
self
.
learnable_sc
:
xavier_init
(
self
.
shortcut
,
gain
=
1
,
distribution
=
'uniform'
)
elif
self
.
init_type
.
upper
()
in
[
'SNGAN'
,
'SNGAN-PROJ'
,
'GAN-PROJ'
]:
xavier_init
(
self
.
conv_1
,
gain
=
np
.
sqrt
(
2
),
distribution
=
'uniform'
)
xavier_init
(
self
.
conv_2
,
gain
=
np
.
sqrt
(
2
),
distribution
=
'uniform'
)
if
self
.
learnable_sc
:
xavier_init
(
self
.
shortcut
,
gain
=
1
,
distribution
=
'uniform'
)
else
:
raise
NotImplementedError
(
'Unknown initialization method: '
f
'
\'
{
self
.
init_type
}
\'
'
)
@
MODULES
.
register_module
()
class
SNGANDiscHeadResBlock
(
nn
.
Module
):
"""The first ResBlock used in discriminator of sngan / proj-gan. Compared
to ``SNGANDisResBlock``, this module has a different forward order.
args:
in_channels (int): Input channels.
out_channels (int): Output channels.
downsample (bool, optional): whether apply downsample operation in this
module. default to false.
conv_cfg (dict | none): config for conv blocks of this module. if pass
``none``, would use ``_default_conv_cfg``. default to ``none``.
act_cfg (dict, optional): config for activate function. default
to ``dict(type='relu')``.
with_spectral_norm (bool, optional): whether use spectral norm for
conv blocks and norm layers. default to true.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
sn_eps (float, optional): eps for spectral normalization operation.
Default to `1e-12`.
init_cfg (dict, optional): Config for weight initialization.
Default to ``dict(type='BigGAN')``.
"""
_default_conv_cfg
=
dict
(
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
act_cfg
=
None
)
def
__init__
(
self
,
in_channels
,
out_channels
,
conv_cfg
=
None
,
act_cfg
=
dict
(
type
=
'ReLU'
),
with_spectral_norm
=
True
,
sn_eps
=
1e-12
,
sn_style
=
'torch'
,
init_cfg
=
dict
(
type
=
'BigGAN'
)):
super
().
__init__
()
self
.
init_type
=
init_cfg
.
get
(
'type'
,
None
)
self
.
conv_cfg
=
deepcopy
(
self
.
_default_conv_cfg
)
if
conv_cfg
is
not
None
:
self
.
conv_cfg
.
update
(
conv_cfg
)
self
.
activate
=
build_activation_layer
(
act_cfg
)
sn_cfg
=
dict
(
eps
=
sn_eps
,
sn_style
=
sn_style
)
self
.
conv_1
=
SNConvModule
(
in_channels
,
out_channels
,
with_spectral_norm
=
with_spectral_norm
,
spectral_norm_cfg
=
sn_cfg
,
**
self
.
conv_cfg
)
self
.
conv_2
=
SNConvModule
(
out_channels
,
out_channels
,
with_spectral_norm
=
with_spectral_norm
,
spectral_norm_cfg
=
sn_cfg
,
**
self
.
conv_cfg
)
self
.
downsample
=
nn
.
AvgPool2d
(
2
,
2
)
# use hyperparameters-fixed shortcut here
self
.
shortcut
=
SNConvModule
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
act_cfg
=
None
,
with_spectral_norm
=
with_spectral_norm
,
spectral_norm_cfg
=
sn_cfg
)
self
.
init_weights
()
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
out
=
self
.
conv_1
(
x
)
out
=
self
.
activate
(
out
)
out
=
self
.
conv_2
(
out
)
out
=
self
.
downsample
(
out
)
shortcut
=
self
.
forward_shortcut
(
x
)
return
out
+
shortcut
def
forward_shortcut
(
self
,
x
):
out
=
self
.
downsample
(
x
)
out
=
self
.
shortcut
(
out
)
return
out
def
init_weights
(
self
):
if
self
.
init_type
.
upper
()
==
'STUDIO'
:
for
m
in
[
self
.
conv_1
,
self
.
conv_2
,
self
.
shortcut
]:
nn
.
init
.
orthogonal_
(
m
.
conv
.
weight
)
m
.
conv
.
bias
.
data
.
fill_
(
0.
)
elif
self
.
init_type
.
upper
()
==
'BIGGAN'
:
xavier_uniform_
(
self
.
conv_1
.
conv
.
weight
,
gain
=
1
)
xavier_uniform_
(
self
.
conv_2
.
conv
.
weight
,
gain
=
1
)
xavier_uniform_
(
self
.
shortcut
.
conv
.
weight
,
gain
=
1
)
elif
self
.
init_type
.
upper
()
==
'SAGAN'
:
xavier_init
(
self
.
conv_1
,
gain
=
1
,
distribution
=
'uniform'
)
xavier_init
(
self
.
conv_2
,
gain
=
1
,
distribution
=
'uniform'
)
xavier_init
(
self
.
shortcut
,
gain
=
1
,
distribution
=
'uniform'
)
elif
self
.
init_type
.
upper
()
in
[
'SNGAN'
,
'SNGAN-PROJ'
,
'GAN-PROJ'
]:
xavier_init
(
self
.
conv_1
,
gain
=
np
.
sqrt
(
2
),
distribution
=
'uniform'
)
xavier_init
(
self
.
conv_2
,
gain
=
np
.
sqrt
(
2
),
distribution
=
'uniform'
)
xavier_init
(
self
.
shortcut
,
gain
=
1
,
distribution
=
'uniform'
)
else
:
raise
NotImplementedError
(
'Unknown initialization method: '
f
'
\'
{
self
.
init_type
}
\'
'
)
@
MODULES
.
register_module
()
class
SNConditionNorm
(
nn
.
Module
):
"""Conditional Normalization for SNGAN / Proj-GAN. The implementation
refers to.
https://github.com/pfnet-research/sngan_projection/blob/master/source/links/conditional_batch_normalization.py # noda
and
https://github.com/POSTECH-CVLab/PyTorch-StudioGAN/blob/master/src/utils/model_ops.py # noqa
Args:
in_channels (int): Number of the channels of the input feature map.
num_classes (int): Number of the classes in the dataset. If ``use_cbn``
is True, ``num_classes`` must larger than 0.
use_cbn (bool, optional): Whether use conditional normalization. If
``use_cbn`` is True, two embedding layers would be used to mapping
label to weight and bias used in normalization process.
norm_cfg (dict, optional): Config for normalization method. Defaults
to ``dict(type='BN')``.
cbn_norm_affine (bool): Whether set ``affine=True`` when use conditional batch norm.
This argument only work when ``use_cbn`` is True. Defaults to False.
auto_sync_bn (bool, optional): Whether convert Batch Norm to
Synchronized ones when Distributed training is on. Defaults to True.
with_spectral_norm (bool, optional): whether use spectral norm for
conv blocks and norm layers. Defaults to true.
norm_eps (float, optional): eps for Normalization layers (both
conditional and non-conditional ones). Defaults to `1e-4`.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `torch`.
sn_eps (float, optional): eps for spectral normalization operation.
Defaults to `1e-12`.
init_cfg (dict, optional): Config for weight initialization.
Defaults to ``dict(type='BigGAN')``.
"""
def
__init__
(
self
,
in_channels
,
num_classes
,
use_cbn
=
True
,
norm_cfg
=
dict
(
type
=
'BN'
),
cbn_norm_affine
=
False
,
auto_sync_bn
=
True
,
with_spectral_norm
=
False
,
sn_style
=
'torch'
,
norm_eps
=
1e-4
,
sn_eps
=
1e-12
,
init_cfg
=
dict
(
type
=
'BigGAN'
)):
super
().
__init__
()
self
.
use_cbn
=
use_cbn
self
.
init_type
=
init_cfg
.
get
(
'type'
,
None
)
norm_cfg
=
deepcopy
(
norm_cfg
)
norm_type
=
norm_cfg
[
'type'
]
if
norm_type
not
in
[
'IN'
,
'BN'
,
'SyncBN'
]:
raise
ValueError
(
'Only support `IN` (InstanceNorm), '
'`BN` (BatcnNorm) and `SyncBN` for '
'Class-conditional bn. '
f
'Receive norm_type:
{
norm_type
}
'
)
if
self
.
use_cbn
:
norm_cfg
.
setdefault
(
'affine'
,
cbn_norm_affine
)
norm_cfg
.
setdefault
(
'eps'
,
norm_eps
)
if
check_dist_init
()
and
auto_sync_bn
and
norm_type
==
'BN'
:
norm_cfg
[
'type'
]
=
'SyncBN'
_
,
self
.
norm
=
build_norm_layer
(
norm_cfg
,
in_channels
)
if
self
.
use_cbn
:
if
num_classes
<=
0
:
raise
ValueError
(
'`num_classes` must be larger '
'than 0 with `use_cbn=True`'
)
self
.
reweight_embedding
=
(
self
.
init_type
.
upper
()
==
'BIGGAN'
or
self
.
init_type
.
upper
()
==
'STUDIO'
)
if
with_spectral_norm
:
if
sn_style
==
'torch'
:
self
.
weight_embedding
=
spectral_norm
(
nn
.
Embedding
(
num_classes
,
in_channels
),
eps
=
sn_eps
)
self
.
bias_embedding
=
spectral_norm
(
nn
.
Embedding
(
num_classes
,
in_channels
),
eps
=
sn_eps
)
elif
sn_style
==
'ajbrock'
:
self
.
weight_embedding
=
SNEmbedding
(
num_classes
,
in_channels
,
eps
=
sn_eps
)
self
.
bias_embedding
=
SNEmbedding
(
num_classes
,
in_channels
,
eps
=
sn_eps
)
else
:
raise
NotImplementedError
(
f
'
{
sn_style
}
style spectral Norm is not '
'supported yet'
)
else
:
self
.
weight_embedding
=
nn
.
Embedding
(
num_classes
,
in_channels
)
self
.
bias_embedding
=
nn
.
Embedding
(
num_classes
,
in_channels
)
self
.
init_weights
()
def
forward
(
self
,
x
,
y
=
None
):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
y (Tensor, optional): Input label with shape (n, ).
Default None.
Returns:
Tensor: Forward results.
"""
out
=
self
.
norm
(
x
)
if
self
.
use_cbn
:
weight
=
self
.
weight_embedding
(
y
)[:,
:,
None
,
None
]
bias
=
self
.
bias_embedding
(
y
)[:,
:,
None
,
None
]
if
self
.
reweight_embedding
:
# print('reweight_called --> correct')
weight
=
weight
+
1.
out
=
out
*
weight
+
bias
return
out
def
init_weights
(
self
):
if
self
.
use_cbn
:
if
self
.
init_type
.
upper
()
==
'STUDIO'
:
nn
.
init
.
orthogonal_
(
self
.
weight_embedding
.
weight
)
nn
.
init
.
orthogonal_
(
self
.
bias_embedding
.
weight
)
elif
self
.
init_type
.
upper
()
==
'BIGGAN'
:
xavier_uniform_
(
self
.
weight_embedding
.
weight
,
gain
=
1
)
xavier_uniform_
(
self
.
bias_embedding
.
weight
,
gain
=
1
)
elif
self
.
init_type
.
upper
()
in
[
'SNGAN'
,
'SNGAN-PROJ'
,
'GAN-PROJ'
,
'SAGAN'
]:
constant_init
(
self
.
weight_embedding
,
1
)
constant_init
(
self
.
bias_embedding
,
0
)
else
:
raise
NotImplementedError
(
'Unknown initialization method: '
f
'
\'
{
self
.
init_type
}
\'
'
)
build/lib/mmgen/models/architectures/stylegan/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.generator_discriminator_v1
import
(
StyleGAN1Discriminator
,
StyleGANv1Generator
)
from
.generator_discriminator_v2
import
(
StyleGAN2Discriminator
,
StyleGANv2Generator
)
from
.generator_discriminator_v3
import
StyleGANv3Generator
from
.mspie
import
MSStyleGAN2Discriminator
,
MSStyleGANv2Generator
__all__
=
[
'StyleGAN2Discriminator'
,
'StyleGANv2Generator'
,
'StyleGANv1Generator'
,
'StyleGAN1Discriminator'
,
'MSStyleGAN2Discriminator'
,
'MSStyleGANv2Generator'
,
'StyleGANv3Generator'
]
Prev
1
…
10
11
12
13
14
15
16
17
18
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment