Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
stylegan2_mmcv
Commits
1401de15
Commit
1401de15
authored
Jun 28, 2024
by
dongchy920
Browse files
stylegan2_mmcv
parents
Pipeline
#1274
canceled with stages
Changes
463
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
6019 additions
and
0 deletions
+6019
-0
build/lib/mmgen/models/architectures/stylegan/generator_discriminator_v1.py
...dels/architectures/stylegan/generator_discriminator_v1.py
+523
-0
build/lib/mmgen/models/architectures/stylegan/generator_discriminator_v2.py
...dels/architectures/stylegan/generator_discriminator_v2.py
+704
-0
build/lib/mmgen/models/architectures/stylegan/generator_discriminator_v3.py
...dels/architectures/stylegan/generator_discriminator_v3.py
+197
-0
build/lib/mmgen/models/architectures/stylegan/modules/__init__.py
...b/mmgen/models/architectures/stylegan/modules/__init__.py
+12
-0
build/lib/mmgen/models/architectures/stylegan/modules/styleganv1_modules.py
...dels/architectures/stylegan/modules/styleganv1_modules.py
+174
-0
build/lib/mmgen/models/architectures/stylegan/modules/styleganv2_modules.py
...dels/architectures/stylegan/modules/styleganv2_modules.py
+1168
-0
build/lib/mmgen/models/architectures/stylegan/modules/styleganv3_modules.py
...dels/architectures/stylegan/modules/styleganv3_modules.py
+693
-0
build/lib/mmgen/models/architectures/stylegan/mspie.py
build/lib/mmgen/models/architectures/stylegan/mspie.py
+556
-0
build/lib/mmgen/models/architectures/stylegan/utils.py
build/lib/mmgen/models/architectures/stylegan/utils.py
+81
-0
build/lib/mmgen/models/architectures/wgan_gp/__init__.py
build/lib/mmgen/models/architectures/wgan_gp/__init__.py
+4
-0
build/lib/mmgen/models/architectures/wgan_gp/generator_discriminator.py
...n/models/architectures/wgan_gp/generator_discriminator.py
+242
-0
build/lib/mmgen/models/architectures/wgan_gp/modules.py
build/lib/mmgen/models/architectures/wgan_gp/modules.py
+191
-0
build/lib/mmgen/models/builder.py
build/lib/mmgen/models/builder.py
+37
-0
build/lib/mmgen/models/common/__init__.py
build/lib/mmgen/models/common/__init__.py
+5
-0
build/lib/mmgen/models/common/dist_utils.py
build/lib/mmgen/models/common/dist_utils.py
+29
-0
build/lib/mmgen/models/common/model_utils.py
build/lib/mmgen/models/common/model_utils.py
+76
-0
build/lib/mmgen/models/diffusions/__init__.py
build/lib/mmgen/models/diffusions/__init__.py
+5
-0
build/lib/mmgen/models/diffusions/base_diffusion.py
build/lib/mmgen/models/diffusions/base_diffusion.py
+1017
-0
build/lib/mmgen/models/diffusions/sampler.py
build/lib/mmgen/models/diffusions/sampler.py
+37
-0
build/lib/mmgen/models/diffusions/utils.py
build/lib/mmgen/models/diffusions/utils.py
+268
-0
No files found.
Too many changes to show.
To preserve performance only
463 of 463+
files are displayed.
Plain diff
Email patch
build/lib/mmgen/models/architectures/stylegan/generator_discriminator_v1.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
random
import
mmcv
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmgen.models.architectures
import
PixelNorm
from
mmgen.models.architectures.common
import
get_module_device
from
mmgen.models.architectures.pggan
import
(
EqualizedLRConvDownModule
,
EqualizedLRConvModule
)
from
mmgen.models.architectures.stylegan.modules
import
Blur
from
mmgen.models.builder
import
MODULES
from
..
import
MiniBatchStddevLayer
from
.modules.styleganv1_modules
import
StyleConv
from
.modules.styleganv2_modules
import
EqualLinearActModule
from
.utils
import
get_mean_latent
,
style_mixing
@
MODULES
.
register_module
()
class
StyleGANv1Generator
(
nn
.
Module
):
"""StyleGAN1 Generator.
In StyleGAN1, we use a progressive growing architecture composing of a
style mapping module and number of convolutional style blocks. More details
can be found in: A Style-Based Generator Architecture for Generative
Adversarial Networks CVPR2019.
Args:
out_size (int): The output size of the StyleGAN1 generator.
style_channels (int): The number of channels for style code.
num_mlps (int, optional): The number of MLP layers. Defaults to 8.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 2, 1].
lr_mlp (float, optional): The learning rate for the style mapping
layer. Defaults to 0.01.
default_style_mode (str, optional): The default mode of style mixing.
In training, we defaultly adopt mixing style mode. However, in the
evaluation, we use 'single' style mode. `['mix', 'single']` are
currently supported. Defaults to 'mix'.
eval_style_mode (str, optional): The evaluation mode of style mixing.
Defaults to 'single'.
mix_prob (float, optional): Mixing probability. The value should be
in range of [0, 1]. Defaults to 0.9.
"""
def
__init__
(
self
,
out_size
,
style_channels
,
num_mlps
=
8
,
blur_kernel
=
[
1
,
2
,
1
],
lr_mlp
=
0.01
,
default_style_mode
=
'mix'
,
eval_style_mode
=
'single'
,
mix_prob
=
0.9
):
super
().
__init__
()
self
.
out_size
=
out_size
self
.
style_channels
=
style_channels
self
.
num_mlps
=
num_mlps
self
.
lr_mlp
=
lr_mlp
self
.
_default_style_mode
=
default_style_mode
self
.
default_style_mode
=
default_style_mode
self
.
eval_style_mode
=
eval_style_mode
self
.
mix_prob
=
mix_prob
# define style mapping layers
mapping_layers
=
[
PixelNorm
()]
for
_
in
range
(
num_mlps
):
mapping_layers
.
append
(
EqualLinearActModule
(
style_channels
,
style_channels
,
equalized_lr_cfg
=
dict
(
lr_mul
=
lr_mlp
,
gain
=
1.
),
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
)))
self
.
style_mapping
=
nn
.
Sequential
(
*
mapping_layers
)
self
.
channels
=
{
4
:
512
,
8
:
512
,
16
:
512
,
32
:
512
,
64
:
256
,
128
:
128
,
256
:
64
,
512
:
32
,
1024
:
16
,
}
# generator backbone (8x8 --> higher resolutions)
self
.
log_size
=
int
(
np
.
log2
(
self
.
out_size
))
self
.
convs
=
nn
.
ModuleList
()
self
.
to_rgbs
=
nn
.
ModuleList
()
in_channels_
=
self
.
channels
[
4
]
for
i
in
range
(
2
,
self
.
log_size
+
1
):
out_channels_
=
self
.
channels
[
2
**
i
]
self
.
convs
.
append
(
StyleConv
(
in_channels_
,
out_channels_
,
3
,
style_channels
,
initial
=
(
i
==
2
),
upsample
=
True
,
fused
=
True
))
self
.
to_rgbs
.
append
(
EqualizedLRConvModule
(
out_channels_
,
3
,
1
,
act_cfg
=
None
))
in_channels_
=
out_channels_
self
.
num_latents
=
self
.
log_size
*
2
-
2
self
.
num_injected_noises
=
self
.
num_latents
# register buffer for injected noises
for
layer_idx
in
range
(
self
.
num_injected_noises
):
res
=
(
layer_idx
+
4
)
//
2
shape
=
[
1
,
1
,
2
**
res
,
2
**
res
]
self
.
register_buffer
(
f
'injected_noise_
{
layer_idx
}
'
,
torch
.
randn
(
*
shape
))
def
train
(
self
,
mode
=
True
):
if
mode
:
if
self
.
default_style_mode
!=
self
.
_default_style_mode
:
mmcv
.
print_log
(
f
'Switch to train style mode:
{
self
.
_default_style_mode
}
'
,
'mmgen'
)
self
.
default_style_mode
=
self
.
_default_style_mode
else
:
if
self
.
default_style_mode
!=
self
.
eval_style_mode
:
mmcv
.
print_log
(
f
'Switch to evaluation style mode:
{
self
.
eval_style_mode
}
'
,
'mmgen'
)
self
.
default_style_mode
=
self
.
eval_style_mode
return
super
(
StyleGANv1Generator
,
self
).
train
(
mode
)
def
make_injected_noise
(
self
):
"""make noises that will be injected into feature maps.
Returns:
list[Tensor]: List of layer-wise noise tensor.
"""
device
=
get_module_device
(
self
)
# noises = [torch.randn(1, 1, 2**2, 2**2, device=device)]
noises
=
[]
for
i
in
range
(
2
,
self
.
log_size
+
1
):
for
_
in
range
(
2
):
noises
.
append
(
torch
.
randn
(
1
,
1
,
2
**
i
,
2
**
i
,
device
=
device
))
return
noises
def
get_mean_latent
(
self
,
num_samples
=
4096
,
**
kwargs
):
"""Get mean latent of W space in this generator.
Args:
num_samples (int, optional): Number of sample times. Defaults
to 4096.
Returns:
Tensor: Mean latent of this generator.
"""
return
get_mean_latent
(
self
,
num_samples
,
**
kwargs
)
def
style_mixing
(
self
,
n_source
,
n_target
,
inject_index
=
1
,
truncation_latent
=
None
,
truncation
=
0.7
,
curr_scale
=-
1
,
transition_weight
=
1
):
return
style_mixing
(
self
,
n_source
=
n_source
,
n_target
=
n_target
,
inject_index
=
inject_index
,
truncation
=
truncation
,
truncation_latent
=
truncation_latent
,
style_channels
=
self
.
style_channels
,
curr_scale
=
curr_scale
,
transition_weight
=
transition_weight
)
def
forward
(
self
,
styles
,
num_batches
=-
1
,
return_noise
=
False
,
return_latents
=
False
,
inject_index
=
None
,
truncation
=
1
,
truncation_latent
=
None
,
input_is_latent
=
False
,
injected_noise
=
None
,
randomize_noise
=
True
,
transition_weight
=
1.
,
curr_scale
=-
1
):
"""Forward function.
This function has been integrated with the truncation trick. Please
refer to the usage of `truncation` and `truncation_latent`.
Args:
styles (torch.Tensor | list[torch.Tensor] | callable | None): In
StyleGAN1, you can provide noise tensor or latent tensor. Given
a list containing more than one noise or latent tensors, style
mixing trick will be used in training. Of course, You can
directly give a batch of noise through a ``torch.Tensor`` or
offer a callable function to sample a batch of noise data.
Otherwise, the ``None`` indicates to use the default noise
sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
return_latents (bool, optional): If True, ``latent`` will be
returned in a dict with ``fake_img``. Defaults to False.
inject_index (int | None, optional): The index number for mixing
style codes. Defaults to None.
truncation (float, optional): Truncation factor. Give value less
than 1., the truncation trick will be adopted. Defaults to 1.
truncation_latent (torch.Tensor, optional): Mean truncation latent.
Defaults to None.
input_is_latent (bool, optional): If `True`, the input tensor is
the latent tensor. Defaults to False.
injected_noise (torch.Tensor | None, optional): Given a tensor, the
random noise will be fixed as this input injected noise.
Defaults to None.
randomize_noise (bool, optional): If `False`, images are sampled
with the buffered noise tensor injected to the style conv
block. Defaults to True.
transition_weight (float, optional): The weight used in resolution
transition. Defaults to 1..
curr_scale (int, optional): The resolution scale of generated image
tensor. -1 means the max resolution scale of the StyleGAN1.
Defaults to -1.
Returns:
torch.Tensor | dict: Generated image tensor or dictionary
\
containing more data.
"""
# receive noise and conduct sanity check.
if
isinstance
(
styles
,
torch
.
Tensor
):
assert
styles
.
shape
[
1
]
==
self
.
style_channels
styles
=
[
styles
]
elif
mmcv
.
is_seq_of
(
styles
,
torch
.
Tensor
):
for
t
in
styles
:
assert
t
.
shape
[
-
1
]
==
self
.
style_channels
# receive a noise generator and sample noise.
elif
callable
(
styles
):
device
=
get_module_device
(
self
)
noise_generator
=
styles
assert
num_batches
>
0
if
self
.
default_style_mode
==
'mix'
and
random
.
random
(
)
<
self
.
mix_prob
:
styles
=
[
noise_generator
((
num_batches
,
self
.
style_channels
))
for
_
in
range
(
2
)
]
else
:
styles
=
[
noise_generator
((
num_batches
,
self
.
style_channels
))]
styles
=
[
s
.
to
(
device
)
for
s
in
styles
]
# otherwise, we will adopt default noise sampler.
else
:
device
=
get_module_device
(
self
)
assert
num_batches
>
0
and
not
input_is_latent
if
self
.
default_style_mode
==
'mix'
and
random
.
random
(
)
<
self
.
mix_prob
:
styles
=
[
torch
.
randn
((
num_batches
,
self
.
style_channels
))
for
_
in
range
(
2
)
]
else
:
styles
=
[
torch
.
randn
((
num_batches
,
self
.
style_channels
))]
styles
=
[
s
.
to
(
device
)
for
s
in
styles
]
if
not
input_is_latent
:
noise_batch
=
styles
styles
=
[
self
.
style_mapping
(
s
)
for
s
in
styles
]
else
:
noise_batch
=
None
if
injected_noise
is
None
:
if
randomize_noise
:
injected_noise
=
[
None
]
*
self
.
num_injected_noises
else
:
injected_noise
=
[
getattr
(
self
,
f
'injected_noise_
{
i
}
'
)
for
i
in
range
(
self
.
num_injected_noises
)
]
# use truncation trick
if
truncation
<
1
:
style_t
=
[]
# calculate truncation latent on the fly
if
truncation_latent
is
None
and
not
hasattr
(
self
,
'truncation_latent'
):
self
.
truncation_latent
=
self
.
get_mean_latent
()
truncation_latent
=
self
.
truncation_latent
elif
truncation_latent
is
None
and
hasattr
(
self
,
'truncation_latent'
):
truncation_latent
=
self
.
truncation_latent
for
style
in
styles
:
style_t
.
append
(
truncation_latent
+
truncation
*
(
style
-
truncation_latent
))
styles
=
style_t
# no style mixing
if
len
(
styles
)
<
2
:
inject_index
=
self
.
num_latents
if
styles
[
0
].
ndim
<
3
:
latent
=
styles
[
0
].
unsqueeze
(
1
).
repeat
(
1
,
inject_index
,
1
)
else
:
latent
=
styles
[
0
]
# style mixing
else
:
if
inject_index
is
None
:
inject_index
=
random
.
randint
(
1
,
self
.
num_latents
-
1
)
latent
=
styles
[
0
].
unsqueeze
(
1
).
repeat
(
1
,
inject_index
,
1
)
latent2
=
styles
[
1
].
unsqueeze
(
1
).
repeat
(
1
,
self
.
num_latents
-
inject_index
,
1
)
latent
=
torch
.
cat
([
latent
,
latent2
],
1
)
curr_log_size
=
self
.
log_size
if
curr_scale
<
0
else
int
(
np
.
log2
(
curr_scale
))
step
=
curr_log_size
-
2
_index
=
0
out
=
latent
# 4x4 ---> higher resolutions
for
i
,
(
conv
,
to_rgb
)
in
enumerate
(
zip
(
self
.
convs
,
self
.
to_rgbs
)):
if
i
>
0
and
step
>
0
:
out_prev
=
out
out
=
conv
(
out
,
latent
[:,
_index
],
latent
[:,
_index
+
1
],
noise1
=
injected_noise
[
2
*
i
],
noise2
=
injected_noise
[
2
*
i
+
1
])
if
i
==
step
:
out
=
to_rgb
(
out
)
if
i
>
0
and
0
<=
transition_weight
<
1
:
skip_rgb
=
self
.
to_rgbs
[
i
-
1
](
out_prev
)
skip_rgb
=
F
.
interpolate
(
skip_rgb
,
scale_factor
=
2
,
mode
=
'nearest'
)
out
=
(
1
-
transition_weight
)
*
skip_rgb
+
transition_weight
*
out
break
_index
+=
2
img
=
out
if
return_latents
or
return_noise
:
output_dict
=
dict
(
fake_img
=
img
,
latent
=
latent
,
inject_index
=
inject_index
,
noise_batch
=
noise_batch
)
return
output_dict
return
img
@
MODULES
.
register_module
()
class
StyleGAN1Discriminator
(
nn
.
Module
):
"""StyleGAN1 Discriminator.
The architecture of this discriminator is proposed in StyleGAN1. More
details can be found in: A Style-Based Generator Architecture for
Generative Adversarial Networks CVPR2019.
Args:
in_size (int): The input size of images.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 2, 1].
mbstd_cfg (dict, optional): Configs for minibatch-stddev layer.
Defaults to dict(group_size=4).
"""
def
__init__
(
self
,
in_size
,
blur_kernel
=
[
1
,
2
,
1
],
mbstd_cfg
=
dict
(
group_size
=
4
)):
super
().
__init__
()
self
.
with_mbstd
=
mbstd_cfg
is
not
None
channels
=
{
4
:
512
,
8
:
512
,
16
:
512
,
32
:
512
,
64
:
256
,
128
:
128
,
256
:
64
,
512
:
32
,
1024
:
16
,
}
log_size
=
int
(
np
.
log2
(
in_size
))
self
.
log_size
=
log_size
in_channels
=
channels
[
in_size
]
self
.
convs
=
nn
.
ModuleList
()
self
.
from_rgb
=
nn
.
ModuleList
()
for
i
in
range
(
log_size
,
2
,
-
1
):
out_channel
=
channels
[
2
**
(
i
-
1
)]
self
.
from_rgb
.
append
(
EqualizedLRConvModule
(
3
,
in_channels
,
kernel_size
=
3
,
padding
=
1
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
)))
self
.
convs
.
append
(
nn
.
Sequential
(
EqualizedLRConvModule
(
in_channels
,
out_channel
,
kernel_size
=
3
,
padding
=
1
,
bias
=
True
,
norm_cfg
=
None
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
)),
Blur
(
blur_kernel
,
pad
=
(
1
,
1
)),
EqualizedLRConvDownModule
(
out_channel
,
out_channel
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
act_cfg
=
None
),
nn
.
LeakyReLU
(
negative_slope
=
0.2
,
inplace
=
True
)))
in_channels
=
out_channel
self
.
from_rgb
.
append
(
EqualizedLRConvModule
(
3
,
in_channels
,
kernel_size
=
3
,
padding
=
0
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
)))
self
.
convs
.
append
(
nn
.
Sequential
(
EqualizedLRConvModule
(
in_channels
+
1
,
512
,
kernel_size
=
3
,
padding
=
1
,
bias
=
True
,
norm_cfg
=
None
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
)),
EqualizedLRConvModule
(
512
,
512
,
kernel_size
=
4
,
padding
=
0
,
bias
=
True
,
norm_cfg
=
None
,
act_cfg
=
None
),
))
if
self
.
with_mbstd
:
self
.
mbstd_layer
=
MiniBatchStddevLayer
(
**
mbstd_cfg
)
self
.
final_linear
=
nn
.
Sequential
(
EqualLinearActModule
(
channels
[
4
],
1
))
self
.
n_layer
=
len
(
self
.
convs
)
def
forward
(
self
,
input
,
transition_weight
=
1.
,
curr_scale
=-
1
):
"""Forward function.
Args:
input (torch.Tensor): Input image tensor.
transition_weight (float, optional): The weight used in resolution
transition. Defaults to 1..
curr_scale (int, optional): The resolution scale of image tensor.
-1 means the max resolution scale of the StyleGAN1.
Defaults to -1.
Returns:
torch.Tensor: Predict score for the input image.
"""
curr_log_size
=
self
.
log_size
if
curr_scale
<
0
else
int
(
np
.
log2
(
curr_scale
))
step
=
curr_log_size
-
2
for
i
in
range
(
step
,
-
1
,
-
1
):
index
=
self
.
n_layer
-
i
-
1
if
i
==
step
:
out
=
self
.
from_rgb
[
index
](
input
)
# minibatch standard deviation
if
i
==
0
:
out
=
self
.
mbstd_layer
(
out
)
out
=
self
.
convs
[
index
](
out
)
if
i
>
0
:
if
i
==
step
and
0
<=
transition_weight
<
1
:
skip_rgb
=
F
.
avg_pool2d
(
input
,
2
)
skip_rgb
=
self
.
from_rgb
[
index
+
1
](
skip_rgb
)
out
=
(
1
-
transition_weight
)
*
skip_rgb
+
transition_weight
*
out
out
=
out
.
view
(
out
.
shape
[
0
],
-
1
)
out
=
self
.
final_linear
(
out
)
return
out
build/lib/mmgen/models/architectures/stylegan/generator_discriminator_v2.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
random
import
mmcv
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.runner.checkpoint
import
_load_checkpoint_with_prefix
from
mmgen.core.runners.fp16_utils
import
auto_fp16
from
mmgen.models.architectures
import
PixelNorm
from
mmgen.models.architectures.common
import
get_module_device
from
mmgen.models.builder
import
MODULES
,
build_module
from
.ada.augment
import
AugmentPipe
from
.ada.misc
import
constant
from
.modules.styleganv2_modules
import
(
ConstantInput
,
ConvDownLayer
,
EqualLinearActModule
,
ModMBStddevLayer
,
ModulatedStyleConv
,
ModulatedToRGB
,
ResBlock
)
from
.utils
import
get_mean_latent
,
style_mixing
@
MODULES
.
register_module
()
class
StyleGANv2Generator
(
nn
.
Module
):
r
"""StyleGAN2 Generator.
In StyleGAN2, we use a static architecture composing of a style mapping
module and number of convolutional style blocks. More details can be found
in: Analyzing and Improving the Image Quality of StyleGAN CVPR2020.
You can load pretrained model through passing information into
``pretrained`` argument. We have already offered official weights as
follows:
- stylegan2-ffhq-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-ffhq-config-f-official_20210327_171224-bce9310c.pth # noqa
- stylegan2-horse-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-horse-config-f-official_20210327_173203-ef3e69ca.pth # noqa
- stylegan2-car-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-car-config-f-official_20210327_172340-8cfe053c.pth # noqa
- stylegan2-cat-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-cat-config-f-official_20210327_172444-15bc485b.pth # noqa
- stylegan2-church-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-church-config-f-official_20210327_172657-1d42b7d1.pth # noqa
If you want to load the ema model, you can just use following codes:
.. code-block:: python
# ckpt_http is one of the valid path from http source
generator = StyleGANv2Generator(1024, 512,
pretrained=dict(
ckpt_path=ckpt_http,
prefix='generator_ema'))
Of course, you can also download the checkpoint in advance and set
``ckpt_path`` with local path. If you just want to load the original
generator (not the ema model), please set the prefix with 'generator'.
Note that our implementation allows to generate BGR image, while the
original StyleGAN2 outputs RGB images by default. Thus, we provide
``bgr2rgb`` argument to convert the image space.
Args:
out_size (int): The output size of the StyleGAN2 generator.
style_channels (int): The number of channels for style code.
num_mlps (int, optional): The number of MLP layers. Defaults to 8.
channel_multiplier (int, optional): The multiplier factor for the
channel number. Defaults to 2.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 3, 3, 1].
lr_mlp (float, optional): The learning rate for the style mapping
layer. Defaults to 0.01.
default_style_mode (str, optional): The default mode of style mixing.
In training, we defaultly adopt mixing style mode. However, in the
evaluation, we use 'single' style mode. `['mix', 'single']` are
currently supported. Defaults to 'mix'.
eval_style_mode (str, optional): The evaluation mode of style mixing.
Defaults to 'single'.
mix_prob (float, optional): Mixing probability. The value should be
in range of [0, 1]. Defaults to ``0.9``.
num_fp16_scales (int, optional): The number of resolutions to use auto
fp16 training. Different from ``fp16_enabled``, this argument
allows users to adopt FP16 training only in several blocks.
This behaviour is much more similar to the official implementation
by Tero. Defaults to 0.
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. If this flag is `True`, the whole module will be wrapped
with ``auto_fp16``. Defaults to False.
pretrained (dict | None, optional): Information for pretained models.
The necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
"""
def
__init__
(
self
,
out_size
,
style_channels
,
num_mlps
=
8
,
channel_multiplier
=
2
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
lr_mlp
=
0.01
,
default_style_mode
=
'mix'
,
eval_style_mode
=
'single'
,
mix_prob
=
0.9
,
num_fp16_scales
=
0
,
fp16_enabled
=
False
,
pretrained
=
None
):
super
().
__init__
()
self
.
out_size
=
out_size
self
.
style_channels
=
style_channels
self
.
num_mlps
=
num_mlps
self
.
channel_multiplier
=
channel_multiplier
self
.
lr_mlp
=
lr_mlp
self
.
_default_style_mode
=
default_style_mode
self
.
default_style_mode
=
default_style_mode
self
.
eval_style_mode
=
eval_style_mode
self
.
mix_prob
=
mix_prob
self
.
num_fp16_scales
=
num_fp16_scales
self
.
fp16_enabled
=
fp16_enabled
# define style mapping layers
mapping_layers
=
[
PixelNorm
()]
for
_
in
range
(
num_mlps
):
mapping_layers
.
append
(
EqualLinearActModule
(
style_channels
,
style_channels
,
equalized_lr_cfg
=
dict
(
lr_mul
=
lr_mlp
,
gain
=
1.
),
act_cfg
=
dict
(
type
=
'fused_bias'
)))
self
.
style_mapping
=
nn
.
Sequential
(
*
mapping_layers
)
self
.
channels
=
{
4
:
512
,
8
:
512
,
16
:
512
,
32
:
512
,
64
:
256
*
channel_multiplier
,
128
:
128
*
channel_multiplier
,
256
:
64
*
channel_multiplier
,
512
:
32
*
channel_multiplier
,
1024
:
16
*
channel_multiplier
,
}
# constant input layer
self
.
constant_input
=
ConstantInput
(
self
.
channels
[
4
])
# 4x4 stage
self
.
conv1
=
ModulatedStyleConv
(
self
.
channels
[
4
],
self
.
channels
[
4
],
kernel_size
=
3
,
style_channels
=
style_channels
,
blur_kernel
=
blur_kernel
)
self
.
to_rgb1
=
ModulatedToRGB
(
self
.
channels
[
4
],
style_channels
,
upsample
=
False
,
fp16_enabled
=
fp16_enabled
)
# generator backbone (8x8 --> higher resolutions)
self
.
log_size
=
int
(
np
.
log2
(
self
.
out_size
))
self
.
convs
=
nn
.
ModuleList
()
self
.
upsamples
=
nn
.
ModuleList
()
self
.
to_rgbs
=
nn
.
ModuleList
()
in_channels_
=
self
.
channels
[
4
]
for
i
in
range
(
3
,
self
.
log_size
+
1
):
out_channels_
=
self
.
channels
[
2
**
i
]
# If `fp16_enabled` is True, all of layers will be run in auto
# FP16. In the case of `num_fp16_sacles` > 0, only partial
# layers will be run in fp16.
_use_fp16
=
(
self
.
log_size
-
i
)
<
num_fp16_scales
or
fp16_enabled
self
.
convs
.
append
(
ModulatedStyleConv
(
in_channels_
,
out_channels_
,
3
,
style_channels
,
upsample
=
True
,
blur_kernel
=
blur_kernel
,
fp16_enabled
=
_use_fp16
))
self
.
convs
.
append
(
ModulatedStyleConv
(
out_channels_
,
out_channels_
,
3
,
style_channels
,
upsample
=
False
,
blur_kernel
=
blur_kernel
,
fp16_enabled
=
_use_fp16
))
self
.
to_rgbs
.
append
(
ModulatedToRGB
(
out_channels_
,
style_channels
,
upsample
=
True
,
fp16_enabled
=
_use_fp16
))
# set to global fp16
in_channels_
=
out_channels_
self
.
num_latents
=
self
.
log_size
*
2
-
2
self
.
num_injected_noises
=
self
.
num_latents
-
1
# register buffer for injected noises
for
layer_idx
in
range
(
self
.
num_injected_noises
):
res
=
(
layer_idx
+
5
)
//
2
shape
=
[
1
,
1
,
2
**
res
,
2
**
res
]
self
.
register_buffer
(
f
'injected_noise_
{
layer_idx
}
'
,
torch
.
randn
(
*
shape
))
if
pretrained
is
not
None
:
self
.
_load_pretrained_model
(
**
pretrained
)
def
_load_pretrained_model
(
self
,
ckpt_path
,
prefix
=
''
,
map_location
=
'cpu'
,
strict
=
True
):
state_dict
=
_load_checkpoint_with_prefix
(
prefix
,
ckpt_path
,
map_location
)
self
.
load_state_dict
(
state_dict
,
strict
=
strict
)
mmcv
.
print_log
(
f
'Load pretrained model from
{
ckpt_path
}
'
,
'mmgen'
)
def
train
(
self
,
mode
=
True
):
if
mode
:
if
self
.
default_style_mode
!=
self
.
_default_style_mode
:
mmcv
.
print_log
(
f
'Switch to train style mode:
{
self
.
_default_style_mode
}
'
,
'mmgen'
)
self
.
default_style_mode
=
self
.
_default_style_mode
else
:
if
self
.
default_style_mode
!=
self
.
eval_style_mode
:
mmcv
.
print_log
(
f
'Switch to evaluation style mode:
{
self
.
eval_style_mode
}
'
,
'mmgen'
)
self
.
default_style_mode
=
self
.
eval_style_mode
return
super
(
StyleGANv2Generator
,
self
).
train
(
mode
)
def
make_injected_noise
(
self
):
"""make noises that will be injected into feature maps.
Returns:
list[Tensor]: List of layer-wise noise tensor.
"""
device
=
get_module_device
(
self
)
noises
=
[
torch
.
randn
(
1
,
1
,
2
**
2
,
2
**
2
,
device
=
device
)]
for
i
in
range
(
3
,
self
.
log_size
+
1
):
for
_
in
range
(
2
):
noises
.
append
(
torch
.
randn
(
1
,
1
,
2
**
i
,
2
**
i
,
device
=
device
))
return
noises
def
get_mean_latent
(
self
,
num_samples
=
4096
,
**
kwargs
):
"""Get mean latent of W space in this generator.
Args:
num_samples (int, optional): Number of sample times. Defaults
to 4096.
Returns:
Tensor: Mean latent of this generator.
"""
return
get_mean_latent
(
self
,
num_samples
,
**
kwargs
)
def
style_mixing
(
self
,
n_source
,
n_target
,
inject_index
=
1
,
truncation_latent
=
None
,
truncation
=
0.7
):
return
style_mixing
(
self
,
n_source
=
n_source
,
n_target
=
n_target
,
inject_index
=
inject_index
,
truncation
=
truncation
,
truncation_latent
=
truncation_latent
,
style_channels
=
self
.
style_channels
)
@
auto_fp16
()
def
forward
(
self
,
styles
,
num_batches
=-
1
,
return_noise
=
False
,
return_latents
=
False
,
inject_index
=
None
,
truncation
=
1
,
truncation_latent
=
None
,
input_is_latent
=
False
,
injected_noise
=
None
,
randomize_noise
=
True
):
"""Forward function.
This function has been integrated with the truncation trick. Please
refer to the usage of `truncation` and `truncation_latent`.
Args:
styles (torch.Tensor | list[torch.Tensor] | callable | None): In
StyleGAN2, you can provide noise tensor or latent tensor. Given
a list containing more than one noise or latent tensors, style
mixing trick will be used in training. Of course, You can
directly give a batch of noise through a ``torch.Tensor`` or
offer a callable function to sample a batch of noise data.
Otherwise, the ``None`` indicates to use the default noise
sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
return_latents (bool, optional): If True, ``latent`` will be
returned in a dict with ``fake_img``. Defaults to False.
inject_index (int | None, optional): The index number for mixing
style codes. Defaults to None.
truncation (float, optional): Truncation factor. Give value less
than 1., the truncation trick will be adopted. Defaults to 1.
truncation_latent (torch.Tensor, optional): Mean truncation latent.
Defaults to None.
input_is_latent (bool, optional): If `True`, the input tensor is
the latent tensor. Defaults to False.
injected_noise (torch.Tensor | None, optional): Given a tensor, the
random noise will be fixed as this input injected noise.
Defaults to None.
randomize_noise (bool, optional): If `False`, images are sampled
with the buffered noise tensor injected to the style conv
block. Defaults to True.
Returns:
torch.Tensor | dict: Generated image tensor or dictionary
\
containing more data.
"""
# receive noise and conduct sanity check.
if
isinstance
(
styles
,
torch
.
Tensor
):
assert
styles
.
shape
[
1
]
==
self
.
style_channels
styles
=
[
styles
]
elif
mmcv
.
is_seq_of
(
styles
,
torch
.
Tensor
):
for
t
in
styles
:
assert
t
.
shape
[
-
1
]
==
self
.
style_channels
# receive a noise generator and sample noise.
elif
callable
(
styles
):
device
=
get_module_device
(
self
)
noise_generator
=
styles
assert
num_batches
>
0
if
self
.
default_style_mode
==
'mix'
and
random
.
random
(
)
<
self
.
mix_prob
:
styles
=
[
noise_generator
((
num_batches
,
self
.
style_channels
))
for
_
in
range
(
2
)
]
else
:
styles
=
[
noise_generator
((
num_batches
,
self
.
style_channels
))]
styles
=
[
s
.
to
(
device
)
for
s
in
styles
]
# otherwise, we will adopt default noise sampler.
else
:
device
=
get_module_device
(
self
)
assert
num_batches
>
0
and
not
input_is_latent
if
self
.
default_style_mode
==
'mix'
and
random
.
random
(
)
<
self
.
mix_prob
:
styles
=
[
torch
.
randn
((
num_batches
,
self
.
style_channels
))
for
_
in
range
(
2
)
]
else
:
styles
=
[
torch
.
randn
((
num_batches
,
self
.
style_channels
))]
styles
=
[
s
.
to
(
device
)
for
s
in
styles
]
if
not
input_is_latent
:
noise_batch
=
styles
styles
=
[
self
.
style_mapping
(
s
)
for
s
in
styles
]
else
:
noise_batch
=
None
if
injected_noise
is
None
:
if
randomize_noise
:
injected_noise
=
[
None
]
*
self
.
num_injected_noises
else
:
injected_noise
=
[
getattr
(
self
,
f
'injected_noise_
{
i
}
'
)
for
i
in
range
(
self
.
num_injected_noises
)
]
# use truncation trick
if
truncation
<
1
:
style_t
=
[]
# calculate truncation latent on the fly
if
truncation_latent
is
None
and
not
hasattr
(
self
,
'truncation_latent'
):
self
.
truncation_latent
=
self
.
get_mean_latent
()
truncation_latent
=
self
.
truncation_latent
elif
truncation_latent
is
None
and
hasattr
(
self
,
'truncation_latent'
):
truncation_latent
=
self
.
truncation_latent
for
style
in
styles
:
style_t
.
append
(
truncation_latent
+
truncation
*
(
style
-
truncation_latent
))
styles
=
style_t
# no style mixing
if
len
(
styles
)
<
2
:
inject_index
=
self
.
num_latents
if
styles
[
0
].
ndim
<
3
:
latent
=
styles
[
0
].
unsqueeze
(
1
).
repeat
(
1
,
inject_index
,
1
)
else
:
latent
=
styles
[
0
]
# style mixing
else
:
if
inject_index
is
None
:
inject_index
=
random
.
randint
(
1
,
self
.
num_latents
-
1
)
latent
=
styles
[
0
].
unsqueeze
(
1
).
repeat
(
1
,
inject_index
,
1
)
latent2
=
styles
[
1
].
unsqueeze
(
1
).
repeat
(
1
,
self
.
num_latents
-
inject_index
,
1
)
latent
=
torch
.
cat
([
latent
,
latent2
],
1
)
# 4x4 stage
out
=
self
.
constant_input
(
latent
)
out
=
self
.
conv1
(
out
,
latent
[:,
0
],
noise
=
injected_noise
[
0
])
skip
=
self
.
to_rgb1
(
out
,
latent
[:,
1
])
_index
=
1
# 8x8 ---> higher resolutions
for
up_conv
,
conv
,
noise1
,
noise2
,
to_rgb
in
zip
(
self
.
convs
[::
2
],
self
.
convs
[
1
::
2
],
injected_noise
[
1
::
2
],
injected_noise
[
2
::
2
],
self
.
to_rgbs
):
out
=
up_conv
(
out
,
latent
[:,
_index
],
noise
=
noise1
)
out
=
conv
(
out
,
latent
[:,
_index
+
1
],
noise
=
noise2
)
skip
=
to_rgb
(
out
,
latent
[:,
_index
+
2
],
skip
)
_index
+=
2
# make sure the output image is torch.float32 to avoid RunTime Error
# in other modules
img
=
skip
.
to
(
torch
.
float32
)
if
return_latents
or
return_noise
:
output_dict
=
dict
(
fake_img
=
img
,
latent
=
latent
,
inject_index
=
inject_index
,
noise_batch
=
noise_batch
)
return
output_dict
return
img
@
MODULES
.
register_module
()
class
StyleGAN2Discriminator
(
nn
.
Module
):
"""StyleGAN2 Discriminator.
The architecture of this discriminator is proposed in StyleGAN2. More
details can be found in: Analyzing and Improving the Image Quality of
StyleGAN CVPR2020.
You can load pretrained model through passing information into
``pretrained`` argument. We have already offered official weights as
follows:
- stylegan2-ffhq-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-ffhq-config-f-official_20210327_171224-bce9310c.pth # noqa
- stylegan2-horse-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-horse-config-f-official_20210327_173203-ef3e69ca.pth # noqa
- stylegan2-car-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-car-config-f-official_20210327_172340-8cfe053c.pth # noqa
- stylegan2-cat-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-cat-config-f-official_20210327_172444-15bc485b.pth # noqa
- stylegan2-church-config-f: https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-church-config-f-official_20210327_172657-1d42b7d1.pth # noqa
If you want to load the ema model, you can just use following codes:
.. code-block:: python
# ckpt_http is one of the valid path from http source
discriminator = StyleGAN2Discriminator(1024, 512,
pretrained=dict(
ckpt_path=ckpt_http,
prefix='discriminator'))
Of course, you can also download the checkpoint in advance and set
``ckpt_path`` with local path.
Note that our implementation adopts BGR image as input, while the
original StyleGAN2 provides RGB images to the discriminator. Thus, we
provide ``bgr2rgb`` argument to convert the image space. If your images
follow the RGB order, please set it to ``True`` accordingly.
Args:
in_size (int): The input size of images.
channel_multiplier (int, optional): The multiplier factor for the
channel number. Defaults to 2.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 3, 3, 1].
mbstd_cfg (dict, optional): Configs for minibatch-stddev layer.
Defaults to dict(group_size=4, channel_groups=1).
num_fp16_scales (int, optional): The number of resolutions to use auto
fp16 training. Defaults to 0.
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. Defaults to False.
out_fp32 (bool, optional): Whether to convert the output feature map to
`torch.float32`. Defaults to `True`.
convert_input_fp32 (bool, optional): Whether to convert input type to
fp32 if not `fp16_enabled`. This argument is designed to deal with
the cases where some modules are run in FP16 and others in FP32.
Defaults to True.
input_bgr2rgb (bool, optional): Whether to reformat the input channels
with order `rgb`. Since we provide several converted weights,
whose input order is `rgb`. You can set this argument to True if
you want to finetune on converted weights. Defaults to False.
pretrained (dict | None, optional): Information for pretained models.
The necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
"""
def
__init__
(
self
,
in_size
,
channel_multiplier
=
2
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
mbstd_cfg
=
dict
(
group_size
=
4
,
channel_groups
=
1
),
num_fp16_scales
=
0
,
fp16_enabled
=
False
,
out_fp32
=
True
,
convert_input_fp32
=
True
,
input_bgr2rgb
=
False
,
pretrained
=
None
):
super
().
__init__
()
self
.
num_fp16_scale
=
num_fp16_scales
self
.
fp16_enabled
=
fp16_enabled
self
.
convert_input_fp32
=
convert_input_fp32
self
.
out_fp32
=
out_fp32
channels
=
{
4
:
512
,
8
:
512
,
16
:
512
,
32
:
512
,
64
:
256
*
channel_multiplier
,
128
:
128
*
channel_multiplier
,
256
:
64
*
channel_multiplier
,
512
:
32
*
channel_multiplier
,
1024
:
16
*
channel_multiplier
,
}
log_size
=
int
(
np
.
log2
(
in_size
))
in_channels
=
channels
[
in_size
]
_use_fp16
=
num_fp16_scales
>
0
convs
=
[
ConvDownLayer
(
3
,
channels
[
in_size
],
1
,
fp16_enabled
=
_use_fp16
)
]
for
i
in
range
(
log_size
,
2
,
-
1
):
out_channel
=
channels
[
2
**
(
i
-
1
)]
# add fp16 training for higher resolutions
_use_fp16
=
(
log_size
-
i
)
<
num_fp16_scales
or
fp16_enabled
convs
.
append
(
ResBlock
(
in_channels
,
out_channel
,
blur_kernel
,
fp16_enabled
=
_use_fp16
,
convert_input_fp32
=
convert_input_fp32
))
in_channels
=
out_channel
self
.
convs
=
nn
.
Sequential
(
*
convs
)
self
.
mbstd_layer
=
ModMBStddevLayer
(
**
mbstd_cfg
)
self
.
final_conv
=
ConvDownLayer
(
in_channels
+
1
,
channels
[
4
],
3
)
self
.
final_linear
=
nn
.
Sequential
(
EqualLinearActModule
(
channels
[
4
]
*
4
*
4
,
channels
[
4
],
act_cfg
=
dict
(
type
=
'fused_bias'
)),
EqualLinearActModule
(
channels
[
4
],
1
),
)
self
.
input_bgr2rgb
=
input_bgr2rgb
if
pretrained
is
not
None
:
self
.
_load_pretrained_model
(
**
pretrained
)
def
_load_pretrained_model
(
self
,
ckpt_path
,
prefix
=
''
,
map_location
=
'cpu'
,
strict
=
True
):
state_dict
=
_load_checkpoint_with_prefix
(
prefix
,
ckpt_path
,
map_location
)
self
.
load_state_dict
(
state_dict
,
strict
=
strict
)
mmcv
.
print_log
(
f
'Load pretrained model from
{
ckpt_path
}
'
,
'mmgen'
)
@
auto_fp16
()
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (torch.Tensor): Input image tensor.
Returns:
torch.Tensor: Predict score for the input image.
"""
# This setting was used to finetune on converted weights
if
self
.
input_bgr2rgb
:
x
=
x
[:,
[
2
,
1
,
0
],
...]
x
=
self
.
convs
(
x
)
x
=
self
.
mbstd_layer
(
x
)
if
not
self
.
final_conv
.
fp16_enabled
and
self
.
convert_input_fp32
:
x
=
x
.
to
(
torch
.
float32
)
x
=
self
.
final_conv
(
x
)
x
=
x
.
view
(
x
.
shape
[
0
],
-
1
)
x
=
self
.
final_linear
(
x
)
return
x
@
MODULES
.
register_module
()
class
ADAStyleGAN2Discriminator
(
StyleGAN2Discriminator
):
def
__init__
(
self
,
in_size
,
*
args
,
data_aug
=
None
,
**
kwargs
):
"""StyleGANv2 Discriminator with adaptive augmentation.
Args:
in_size (int): The input size of images.
data_aug (dict, optional): Config for data
augmentation. Defaults to None.
"""
super
().
__init__
(
in_size
,
*
args
,
**
kwargs
)
self
.
with_ada
=
data_aug
is
not
None
if
self
.
with_ada
:
self
.
ada_aug
=
build_module
(
data_aug
)
self
.
ada_aug
.
requires_grad
=
False
self
.
log_size
=
int
(
np
.
log2
(
in_size
))
def
forward
(
self
,
x
):
"""Forward function."""
if
self
.
with_ada
:
x
=
self
.
ada_aug
.
aug_pipeline
(
x
)
return
super
().
forward
(
x
)
@
MODULES
.
register_module
()
class
ADAAug
(
nn
.
Module
):
"""Data Augmentation Module for Adaptive Discriminator augmentation.
Args:
aug_pipeline (dict, optional): Config for augmentation pipeline.
Defaults to None.
update_interval (int, optional): Interval for updating
augmentation probability. Defaults to 4.
augment_initial_p (float, optional): Initial augmentation
probability. Defaults to 0..
ada_target (float, optional): ADA target. Defaults to 0.6.
ada_kimg (int, optional): ADA training duration. Defaults to 500.
"""
def
__init__
(
self
,
aug_pipeline
=
None
,
update_interval
=
4
,
augment_initial_p
=
0.
,
ada_target
=
0.6
,
ada_kimg
=
500
):
super
().
__init__
()
self
.
aug_pipeline
=
AugmentPipe
(
**
aug_pipeline
)
self
.
update_interval
=
update_interval
self
.
ada_kimg
=
ada_kimg
self
.
ada_target
=
ada_target
self
.
aug_pipeline
.
p
.
copy_
(
torch
.
tensor
(
augment_initial_p
))
# this log buffer stores two numbers: num_scalars, sum_scalars.
self
.
register_buffer
(
'log_buffer'
,
torch
.
zeros
((
2
,
)))
def
update
(
self
,
iteration
=
0
,
num_batches
=
0
):
"""Update Augment probability.
Args:
iteration (int, optional): Training iteration.
Defaults to 0.
num_batches (int, optional): The number of reals batches.
Defaults to 0.
"""
if
(
iteration
+
1
)
%
self
.
update_interval
==
0
:
adjust_step
=
float
(
num_batches
*
self
.
update_interval
)
/
float
(
self
.
ada_kimg
*
1000.
)
# get the mean value as the ada heuristic
ada_heuristic
=
self
.
log_buffer
[
1
]
/
self
.
log_buffer
[
0
]
adjust
=
np
.
sign
(
ada_heuristic
.
item
()
-
self
.
ada_target
)
*
adjust_step
# update the augment p
# Note that p may be bigger than 1.0
self
.
aug_pipeline
.
p
.
copy_
((
self
.
aug_pipeline
.
p
+
adjust
).
max
(
constant
(
0
,
device
=
self
.
log_buffer
.
device
)))
self
.
log_buffer
=
self
.
log_buffer
*
0.
build/lib/mmgen/models/architectures/stylegan/generator_discriminator_v3.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
import
mmcv
import
torch
import
torch.nn
as
nn
from
mmcv.runner.checkpoint
import
_load_checkpoint_with_prefix
from
mmgen.models.architectures.common
import
get_module_device
from
mmgen.models.builder
import
MODULES
,
build_module
from
.utils
import
get_mean_latent
@
MODULES
.
register_module
()
class
StyleGANv3Generator
(
nn
.
Module
):
"""StyleGAN3 Generator.
In StyleGAN3, we make several changes to StyleGANv2's generator which
include transformed fourier features, filtered nonlinearities and
non-critical sampling, etc. More details can be found in: Alias-Free
Generative Adversarial Networks NeurIPS'2021.
Ref: https://github.com/NVlabs/stylegan3
Args:
out_size (int): The output size of the StyleGAN3 generator.
style_channels (int): The number of channels for style code.
img_channels (int): The number of output's channels.
noise_size (int, optional): Size of the input noise vector.
Defaults to 512.
rgb2bgr (bool, optional): Whether to reformat the output channels
with order `bgr`. We provide several pre-trained StyleGAN3
weights whose output channels order is `rgb`. You can set
this argument to True to use the weights.
pretrained (str | dict, optional): Path for the pretrained model or
dict containing information for pretained models whose necessary
key is 'ckpt_path'. Besides, you can also provide 'prefix' to load
the generator part from the whole state dict. Defaults to None.
synthesis_cfg (dict, optional): Config for synthesis network. Defaults
to dict(type='SynthesisNetwork').
mapping_cfg (dict, optional): Config for mapping network. Defaults to
dict(type='MappingNetwork').
"""
def
__init__
(
self
,
out_size
,
style_channels
,
img_channels
,
noise_size
=
512
,
rgb2bgr
=
False
,
pretrained
=
None
,
synthesis_cfg
=
dict
(
type
=
'SynthesisNetwork'
),
mapping_cfg
=
dict
(
type
=
'MappingNetwork'
)):
super
().
__init__
()
self
.
noise_size
=
noise_size
self
.
style_channels
=
style_channels
self
.
out_size
=
out_size
self
.
img_channels
=
img_channels
self
.
rgb2bgr
=
rgb2bgr
self
.
_synthesis_cfg
=
deepcopy
(
synthesis_cfg
)
self
.
_synthesis_cfg
.
setdefault
(
'style_channels'
,
style_channels
)
self
.
_synthesis_cfg
.
setdefault
(
'out_size'
,
out_size
)
self
.
_synthesis_cfg
.
setdefault
(
'img_channels'
,
img_channels
)
self
.
synthesis
=
build_module
(
self
.
_synthesis_cfg
)
self
.
num_ws
=
self
.
synthesis
.
num_ws
self
.
_mapping_cfg
=
deepcopy
(
mapping_cfg
)
self
.
_mapping_cfg
.
setdefault
(
'noise_size'
,
noise_size
)
self
.
_mapping_cfg
.
setdefault
(
'style_channels'
,
style_channels
)
self
.
_mapping_cfg
.
setdefault
(
'num_ws'
,
self
.
num_ws
)
self
.
style_mapping
=
build_module
(
self
.
_mapping_cfg
)
if
pretrained
is
not
None
:
self
.
_load_pretrained_model
(
**
pretrained
)
def
_load_pretrained_model
(
self
,
ckpt_path
,
prefix
=
''
,
map_location
=
'cpu'
,
strict
=
True
):
state_dict
=
_load_checkpoint_with_prefix
(
prefix
,
ckpt_path
,
map_location
)
self
.
load_state_dict
(
state_dict
,
strict
=
strict
)
mmcv
.
print_log
(
f
'Load pretrained model from
{
ckpt_path
}
'
,
'mmgen'
)
def
forward
(
self
,
noise
,
num_batches
=
0
,
input_is_latent
=
False
,
truncation
=
1
,
num_truncation_layer
=
None
,
update_emas
=
False
,
force_fp32
=
True
,
return_noise
=
False
,
return_latents
=
False
):
"""Forward Function for stylegan3.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
input_is_latent (bool, optional): If `True`, the input tensor is
the latent tensor. Defaults to False.
truncation (float, optional): Truncation factor. Give value less
than 1., the truncation trick will be adopted. Defaults to 1.
num_truncation_layer (int, optional): Number of layers use
truncated latent. Defaults to None.
update_emas (bool, optional): Whether update moving average of
mean latent. Defaults to False.
force_fp32 (bool, optional): Force fp32 ignore the weights.
Defaults to True.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
return_latents (bool, optional): If True, ``latent`` will be
returned in a dict with ``fake_img``. Defaults to False.
Returns:
torch.Tensor | dict: Generated image tensor or dictionary
\
containing more data.
"""
# if input is latent, set noise size as the style_channels
noise_size
=
(
self
.
style_channels
if
input_is_latent
else
self
.
noise_size
)
if
isinstance
(
noise
,
torch
.
Tensor
):
assert
noise
.
shape
[
1
]
==
noise_size
assert
noise
.
ndim
==
2
,
(
'The noise should be in shape of (n, c), '
f
'but got
{
noise
.
shape
}
'
)
noise_batch
=
noise
# receive a noise generator and sample noise.
elif
callable
(
noise
):
noise_generator
=
noise
assert
num_batches
>
0
noise_batch
=
noise_generator
((
num_batches
,
noise_size
))
# otherwise, we will adopt default noise sampler.
else
:
assert
num_batches
>
0
noise_batch
=
torch
.
randn
((
num_batches
,
noise_size
))
device
=
get_module_device
(
self
)
noise_batch
=
noise_batch
.
to
(
device
)
if
input_is_latent
:
ws
=
noise_batch
.
unsqueeze
(
1
).
repeat
([
1
,
self
.
num_ws
,
1
])
else
:
ws
=
self
.
style_mapping
(
noise_batch
,
truncation
=
truncation
,
num_truncation_layer
=
num_truncation_layer
,
update_emas
=
update_emas
)
out_img
=
self
.
synthesis
(
ws
,
update_emas
=
update_emas
,
force_fp32
=
force_fp32
)
if
self
.
rgb2bgr
:
out_img
=
out_img
[:,
[
2
,
1
,
0
],
...]
if
return_noise
or
return_latents
:
output
=
dict
(
fake_img
=
out_img
,
noise_batch
=
noise_batch
,
latent
=
ws
)
return
output
return
out_img
def
get_mean_latent
(
self
,
num_samples
=
4096
,
**
kwargs
):
"""Get mean latent of W space in this generator.
Args:
num_samples (int, optional): Number of sample times. Defaults
to 4096.
Returns:
Tensor: Mean latent of this generator.
"""
if
hasattr
(
self
.
style_mapping
,
'w_avg'
):
return
self
.
style_mapping
.
w_avg
return
get_mean_latent
(
self
,
num_samples
,
**
kwargs
)
def
get_training_kwargs
(
self
,
phase
):
"""Get training kwargs. In StyleGANv3, we enable fp16, and update
mangitude ema during training of discriminator. This function is used
to pass related arguments.
Args:
phase (str): Current training phase.
Returns:
dict: Training kwargs.
"""
if
phase
==
'disc'
:
return
dict
(
update_emas
=
True
,
force_fp32
=
False
)
if
phase
==
'gen'
:
return
dict
(
force_fp32
=
False
)
return
{}
build/lib/mmgen/models/architectures/stylegan/modules/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.styleganv2_modules
import
(
Blur
,
ConstantInput
,
ModulatedConv2d
,
ModulatedStyleConv
,
ModulatedToRGB
,
NoiseInjection
)
from
.styleganv3_modules
import
(
MappingNetwork
,
SynthesisInput
,
SynthesisLayer
,
SynthesisNetwork
)
__all__
=
[
'Blur'
,
'ModulatedStyleConv'
,
'ModulatedToRGB'
,
'NoiseInjection'
,
'ConstantInput'
,
'MappingNetwork'
,
'SynthesisInput'
,
'SynthesisLayer'
,
'SynthesisNetwork'
,
'ModulatedConv2d'
]
build/lib/mmgen/models/architectures/stylegan/modules/styleganv1_modules.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
torch.nn
as
nn
from
mmgen.models.architectures.pggan
import
(
EqualizedLRConvModule
,
EqualizedLRConvUpModule
,
EqualizedLRLinearModule
)
from
mmgen.models.architectures.stylegan.modules
import
(
Blur
,
ConstantInput
,
NoiseInjection
)
class
AdaptiveInstanceNorm
(
nn
.
Module
):
r
"""Adaptive Instance Normalization Module.
Ref: https://github.com/rosinality/style-based-gan-pytorch/blob/master/model.py # noqa
Args:
in_channel (int): The number of input's channel.
style_dim (int): Style latent dimension.
"""
def
__init__
(
self
,
in_channel
,
style_dim
):
super
().
__init__
()
self
.
norm
=
nn
.
InstanceNorm2d
(
in_channel
)
self
.
affine
=
EqualizedLRLinearModule
(
style_dim
,
in_channel
*
2
)
self
.
affine
.
bias
.
data
[:
in_channel
]
=
1
self
.
affine
.
bias
.
data
[
in_channel
:]
=
0
def
forward
(
self
,
input
,
style
):
"""Forward function.
Args:
input (Tensor): Input tensor with shape (n, c, h, w).
style (Tensor): Input style tensor with shape (n, c).
Returns:
Tensor: Forward results.
"""
style
=
self
.
affine
(
style
).
unsqueeze
(
2
).
unsqueeze
(
3
)
gamma
,
beta
=
style
.
chunk
(
2
,
1
)
out
=
self
.
norm
(
input
)
out
=
gamma
*
out
+
beta
return
out
class
StyleConv
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
style_channels
,
padding
=
1
,
initial
=
False
,
blur_kernel
=
[
1
,
2
,
1
],
upsample
=
False
,
fused
=
False
):
"""Convolutional style blocks composing of noise injector, AdaIN module
and convolution layers.
Args:
in_channels (int): The channel number of the input tensor.
out_channels (itn): The channel number of the output tensor.
kernel_size (int): The kernel size of convolution layers.
style_channels (int): The number of channels for style code.
padding (int, optional): Padding of convolution layers.
Defaults to 1.
initial (bool, optional): Whether this is the first StyleConv of
StyleGAN's generator. Defaults to False.
blur_kernel (list, optional): The blurry kernel.
Defaults to [1, 2, 1].
upsample (bool, optional): Whether perform upsampling.
Defaults to False.
fused (bool, optional): Whether use fused upconv.
Defaults to False.
"""
super
().
__init__
()
if
initial
:
self
.
conv1
=
ConstantInput
(
in_channels
)
else
:
if
upsample
:
if
fused
:
self
.
conv1
=
nn
.
Sequential
(
EqualizedLRConvUpModule
(
in_channels
,
out_channels
,
kernel_size
,
padding
=
padding
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
)),
Blur
(
blur_kernel
,
pad
=
(
1
,
1
)),
)
else
:
self
.
conv1
=
nn
.
Sequential
(
nn
.
Upsample
(
scale_factor
=
2
,
mode
=
'nearest'
),
EqualizedLRConvModule
(
in_channels
,
out_channels
,
kernel_size
,
padding
=
padding
,
act_cfg
=
None
),
Blur
(
blur_kernel
,
pad
=
(
1
,
1
)))
else
:
self
.
conv1
=
EqualizedLRConvModule
(
in_channels
,
out_channels
,
kernel_size
,
padding
=
padding
,
act_cfg
=
None
)
self
.
noise_injector1
=
NoiseInjection
()
self
.
activate1
=
nn
.
LeakyReLU
(
0.2
)
self
.
adain1
=
AdaptiveInstanceNorm
(
out_channels
,
style_channels
)
self
.
conv2
=
EqualizedLRConvModule
(
out_channels
,
out_channels
,
kernel_size
,
padding
=
padding
,
act_cfg
=
None
)
self
.
noise_injector2
=
NoiseInjection
()
self
.
activate2
=
nn
.
LeakyReLU
(
0.2
)
self
.
adain2
=
AdaptiveInstanceNorm
(
out_channels
,
style_channels
)
def
forward
(
self
,
x
,
style1
,
style2
,
noise1
=
None
,
noise2
=
None
,
return_noise
=
False
):
"""Forward function.
Args:
x (Tensor): Input tensor.
style1 (Tensor): Input style tensor with shape (n, c).
style2 (Tensor): Input style tensor with shape (n, c).
noise1 (Tensor, optional): Noise tensor with shape (n, c, h, w).
Defaults to None.
noise2 (Tensor, optional): Noise tensor with shape (n, c, h, w).
Defaults to None.
return_noise (bool, optional): If True, ``noise1`` and ``noise2``
will be returned with ``out``. Defaults to False.
Returns:
Tensor | tuple[Tensor]: Forward results.
"""
out
=
self
.
conv1
(
x
)
if
return_noise
:
out
,
noise1
=
self
.
noise_injector1
(
out
,
noise
=
noise1
,
return_noise
=
return_noise
)
else
:
out
=
self
.
noise_injector1
(
out
,
noise
=
noise1
,
return_noise
=
return_noise
)
out
=
self
.
activate1
(
out
)
out
=
self
.
adain1
(
out
,
style1
)
out
=
self
.
conv2
(
out
)
if
return_noise
:
out
,
noise2
=
self
.
noise_injector2
(
out
,
noise
=
noise2
,
return_noise
=
return_noise
)
else
:
out
=
self
.
noise_injector2
(
out
,
noise
=
noise2
,
return_noise
=
return_noise
)
out
=
self
.
activate2
(
out
)
out
=
self
.
adain2
(
out
,
style2
)
if
return_noise
:
return
out
,
noise1
,
noise2
return
out
build/lib/mmgen/models/architectures/stylegan/modules/styleganv2_modules.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
from
functools
import
partial
import
mmcv
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn.bricks.activation
import
build_activation_layer
from
mmcv.ops.fused_bias_leakyrelu
import
(
FusedBiasLeakyReLU
,
fused_bias_leakyrelu
)
from
mmcv.ops.upfirdn2d
import
upfirdn2d
from
mmcv.runner.dist_utils
import
get_dist_info
from
mmgen.core.runners.fp16_utils
import
auto_fp16
from
mmgen.models.architectures.pggan
import
(
EqualizedLRConvModule
,
EqualizedLRLinearModule
,
equalized_lr
)
from
mmgen.models.common
import
AllGatherLayer
from
mmgen.ops
import
conv2d
,
conv_transpose2d
class
_FusedBiasLeakyReLU
(
FusedBiasLeakyReLU
):
"""Wrap FusedBiasLeakyReLU to support FP16 training."""
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, ...).
Returns:
Tensor: Output feature map.
"""
return
fused_bias_leakyrelu
(
x
,
self
.
bias
.
to
(
x
.
dtype
),
self
.
negative_slope
,
self
.
scale
)
class
EqualLinearActModule
(
nn
.
Module
):
"""Equalized LR Linear Module with Activation Layer.
This module is modified from ``EqualizedLRLinearModule`` defined in PGGAN.
The major features updated in this module is adding support for activation
layers used in StyleGAN2.
Args:
equalized_lr_cfg (dict | None, optional): Config for equalized lr.
Defaults to dict(gain=1., lr_mul=1.).
bias (bool, optional): Whether to use bias item. Defaults to True.
bias_init (float, optional): The value for bias initialization.
Defaults to ``0.``.
act_cfg (dict | None, optional): Config for activation layer.
Defaults to None.
"""
def
__init__
(
self
,
*
args
,
equalized_lr_cfg
=
dict
(
gain
=
1.
,
lr_mul
=
1.
),
bias
=
True
,
bias_init
=
0.
,
act_cfg
=
None
,
**
kwargs
):
super
().
__init__
()
self
.
with_activation
=
act_cfg
is
not
None
# w/o bias in linear layer
self
.
linear
=
EqualizedLRLinearModule
(
*
args
,
bias
=
False
,
equalized_lr_cfg
=
equalized_lr_cfg
,
**
kwargs
)
if
equalized_lr_cfg
is
not
None
:
self
.
lr_mul
=
equalized_lr_cfg
.
get
(
'lr_mul'
,
1.
)
else
:
self
.
lr_mul
=
1.
# define bias outside linear layer
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
linear
.
out_features
).
fill_
(
bias_init
))
else
:
self
.
bias
=
None
if
self
.
with_activation
:
act_cfg
=
deepcopy
(
act_cfg
)
if
act_cfg
[
'type'
]
==
'fused_bias'
:
self
.
act_type
=
act_cfg
.
pop
(
'type'
)
assert
self
.
bias
is
not
None
self
.
activate
=
partial
(
fused_bias_leakyrelu
,
**
act_cfg
)
else
:
self
.
act_type
=
'normal'
self
.
activate
=
build_activation_layer
(
act_cfg
)
else
:
self
.
act_type
=
None
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, ...).
Returns:
Tensor: Output feature map.
"""
if
x
.
ndim
>=
3
:
x
=
x
.
reshape
(
x
.
size
(
0
),
-
1
)
x
=
self
.
linear
(
x
)
if
self
.
with_activation
and
self
.
act_type
==
'fused_bias'
:
x
=
self
.
activate
(
x
,
self
.
bias
*
self
.
lr_mul
)
elif
self
.
bias
is
not
None
and
self
.
with_activation
:
x
=
self
.
activate
(
x
+
self
.
bias
*
self
.
lr_mul
)
elif
self
.
bias
is
not
None
:
x
=
x
+
self
.
bias
*
self
.
lr_mul
elif
self
.
with_activation
:
x
=
self
.
activate
(
x
)
return
x
def
_make_kernel
(
k
):
k
=
torch
.
tensor
(
k
,
dtype
=
torch
.
float32
)
if
k
.
ndim
==
1
:
k
=
k
[
None
,
:]
*
k
[:,
None
]
k
/=
k
.
sum
()
return
k
class
UpsampleUpFIRDn
(
nn
.
Module
):
"""UpFIRDn for Upsampling.
This module is used in the ``to_rgb`` layers in StyleGAN2 for upsampling
the images.
Args:
kernel (Array): Blur kernel/filter used in UpFIRDn.
factor (int, optional): Upsampling factor. Defaults to 2.
"""
def
__init__
(
self
,
kernel
,
factor
=
2
):
super
().
__init__
()
self
.
factor
=
factor
kernel
=
_make_kernel
(
kernel
)
*
(
factor
**
2
)
self
.
register_buffer
(
'kernel'
,
kernel
)
p
=
kernel
.
shape
[
0
]
-
factor
pad0
=
(
p
+
1
)
//
2
+
factor
-
1
pad1
=
p
//
2
self
.
pad
=
(
pad0
,
pad1
)
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map.
"""
out
=
upfirdn2d
(
x
,
self
.
kernel
.
to
(
x
.
dtype
),
up
=
self
.
factor
,
down
=
1
,
pad
=
self
.
pad
)
return
out
class
DownsampleUpFIRDn
(
nn
.
Module
):
"""UpFIRDn for Downsampling.
This module is mentioned in StyleGAN2 for dowampling the feature maps.
Args:
kernel (Array): Blur kernel/filter used in UpFIRDn.
factor (int, optional): Downsampling factor. Defaults to 2.
"""
def
__init__
(
self
,
kernel
,
factor
=
2
):
super
().
__init__
()
self
.
factor
=
factor
kernel
=
_make_kernel
(
kernel
)
self
.
register_buffer
(
'kernel'
,
kernel
)
p
=
kernel
.
shape
[
0
]
-
factor
pad0
=
(
p
+
1
)
//
2
pad1
=
p
//
2
self
.
pad
=
(
pad0
,
pad1
)
def
forward
(
self
,
input
):
"""Forward function.
Args:
input (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map.
"""
out
=
upfirdn2d
(
input
,
self
.
kernel
.
to
(
input
.
dtype
),
up
=
1
,
down
=
self
.
factor
,
pad
=
self
.
pad
)
return
out
class
Blur
(
nn
.
Module
):
"""Blur module.
This module is adopted rightly after upsampling operation in StyleGAN2.
Args:
kernel (Array): Blur kernel/filter used in UpFIRDn.
pad (list[int]): Padding for features.
upsample_factor (int, optional): Upsampling factor. Defaults to 1.
"""
def
__init__
(
self
,
kernel
,
pad
,
upsample_factor
=
1
):
super
().
__init__
()
kernel
=
_make_kernel
(
kernel
)
if
upsample_factor
>
1
:
kernel
=
kernel
*
(
upsample_factor
**
2
)
self
.
register_buffer
(
'kernel'
,
kernel
)
self
.
pad
=
pad
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map.
"""
# In Tero's implementation, he uses fp32
return
upfirdn2d
(
x
,
self
.
kernel
.
to
(
x
.
dtype
),
pad
=
self
.
pad
)
class
ModulatedConv2d
(
nn
.
Module
):
r
"""Modulated Conv2d in StyleGANv2.
This module implements the modulated convolution layers proposed in
StyleGAN2. Details can be found in Analyzing and Improving the Image
Quality of StyleGAN, CVPR2020.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
style_channels (int): Channels for the style codes.
demodulate (bool, optional): Whether to adopt demodulation.
Defaults to True.
upsample (bool, optional): Whether to adopt upsampling in features.
Defaults to False.
downsample (bool, optional): Whether to adopt downsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
equalized_lr_cfg (dict | None, optional): Configs for equalized lr.
Defaults to dict(mode='fan_in', lr_mul=1., gain=1.).
style_mod_cfg (dict, optional): Configs for style modulation module.
Defaults to dict(bias_init=1.).
style_bias (float, optional): Bias value for style code.
Defaults to 0..
eps (float, optional): Epsilon value to avoid computation error.
Defaults to 1e-8.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
style_channels
,
demodulate
=
True
,
upsample
=
False
,
downsample
=
False
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
equalized_lr_cfg
=
dict
(
mode
=
'fan_in'
,
lr_mul
=
1.
,
gain
=
1.
),
style_mod_cfg
=
dict
(
bias_init
=
1.
),
style_bias
=
0.
,
padding
=
None
,
# self define padding
eps
=
1e-8
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
kernel_size
self
.
style_channels
=
style_channels
self
.
demodulate
=
demodulate
# sanity check for kernel size
assert
isinstance
(
self
.
kernel_size
,
int
)
and
(
self
.
kernel_size
>=
1
and
self
.
kernel_size
%
2
==
1
)
self
.
upsample
=
upsample
self
.
downsample
=
downsample
self
.
style_bias
=
style_bias
self
.
eps
=
eps
# build style modulation module
style_mod_cfg
=
dict
()
if
style_mod_cfg
is
None
else
style_mod_cfg
self
.
style_modulation
=
EqualLinearActModule
(
style_channels
,
in_channels
,
**
style_mod_cfg
)
# set lr_mul for conv weight
lr_mul_
=
1.
if
equalized_lr_cfg
is
not
None
:
lr_mul_
=
equalized_lr_cfg
.
get
(
'lr_mul'
,
1.
)
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
1
,
out_channels
,
in_channels
,
kernel_size
,
kernel_size
).
div_
(
lr_mul_
))
# build blurry layer for upsampling
if
upsample
:
factor
=
2
p
=
(
len
(
blur_kernel
)
-
factor
)
-
(
kernel_size
-
1
)
pad0
=
(
p
+
1
)
//
2
+
factor
-
1
pad1
=
p
//
2
+
1
self
.
blur
=
Blur
(
blur_kernel
,
(
pad0
,
pad1
),
upsample_factor
=
factor
)
# build blurry layer for downsampling
if
downsample
:
factor
=
2
p
=
(
len
(
blur_kernel
)
-
factor
)
+
(
kernel_size
-
1
)
pad0
=
(
p
+
1
)
//
2
pad1
=
p
//
2
self
.
blur
=
Blur
(
blur_kernel
,
pad
=
(
pad0
,
pad1
))
# add equalized_lr hook for conv weight
if
equalized_lr_cfg
is
not
None
:
equalized_lr
(
self
,
**
equalized_lr_cfg
)
self
.
padding
=
padding
if
padding
else
(
kernel_size
//
2
)
def
forward
(
self
,
x
,
style
,
input_gain
=
None
):
n
,
c
,
h
,
w
=
x
.
shape
weight
=
self
.
weight
# Pre-normalize inputs to avoid FP16 overflow.
if
x
.
dtype
==
torch
.
float16
and
self
.
demodulate
:
weight
=
weight
*
(
1
/
np
.
sqrt
(
self
.
in_channels
*
self
.
kernel_size
*
self
.
kernel_size
)
/
weight
.
norm
(
float
(
'inf'
),
dim
=
[
1
,
2
,
3
],
keepdim
=
True
)
)
# max_Ikk
style
=
style
/
style
.
norm
(
float
(
'inf'
),
dim
=
1
,
keepdim
=
True
)
# max_I
# process style code
style
=
self
.
style_modulation
(
style
).
view
(
n
,
1
,
c
,
1
,
1
)
+
self
.
style_bias
# combine weight and style
weight
=
weight
*
style
if
self
.
demodulate
:
demod
=
torch
.
rsqrt
(
weight
.
pow
(
2
).
sum
([
2
,
3
,
4
])
+
self
.
eps
)
weight
=
weight
*
demod
.
view
(
n
,
self
.
out_channels
,
1
,
1
,
1
)
if
input_gain
is
not
None
:
# input_gain shape [batch, in_ch]
input_gain
=
input_gain
.
expand
(
n
,
self
.
in_channels
)
# weight shape [batch, out_ch, in_ch, kernel_size, kernel_size]
weight
=
weight
*
input_gain
.
unsqueeze
(
1
).
unsqueeze
(
3
).
unsqueeze
(
4
)
weight
=
weight
.
view
(
n
*
self
.
out_channels
,
c
,
self
.
kernel_size
,
self
.
kernel_size
)
weight
=
weight
.
to
(
x
.
dtype
)
if
self
.
upsample
:
x
=
x
.
reshape
(
1
,
n
*
c
,
h
,
w
)
weight
=
weight
.
view
(
n
,
self
.
out_channels
,
c
,
self
.
kernel_size
,
self
.
kernel_size
)
weight
=
weight
.
transpose
(
1
,
2
).
reshape
(
n
*
c
,
self
.
out_channels
,
self
.
kernel_size
,
self
.
kernel_size
)
x
=
conv_transpose2d
(
x
,
weight
,
padding
=
0
,
stride
=
2
,
groups
=
n
)
x
=
x
.
reshape
(
n
,
self
.
out_channels
,
*
x
.
shape
[
-
2
:])
x
=
self
.
blur
(
x
)
elif
self
.
downsample
:
x
=
self
.
blur
(
x
)
x
=
x
.
view
(
1
,
n
*
self
.
in_channels
,
*
x
.
shape
[
-
2
:])
x
=
conv2d
(
x
,
weight
,
stride
=
2
,
padding
=
0
,
groups
=
n
)
x
=
x
.
view
(
n
,
self
.
out_channels
,
*
x
.
shape
[
-
2
:])
else
:
x
=
x
.
reshape
(
1
,
n
*
c
,
h
,
w
)
x
=
conv2d
(
x
,
weight
,
stride
=
1
,
padding
=
self
.
padding
,
groups
=
n
)
x
=
x
.
view
(
n
,
self
.
out_channels
,
*
x
.
shape
[
-
2
:])
return
x
class
NoiseInjection
(
nn
.
Module
):
"""Noise Injection Module.
In StyleGAN2, they adopt this module to inject spatial random noise map in
the generators.
Args:
noise_weight_init (float, optional): Initialization weight for noise
injection. Defaults to ``0.``.
"""
def
__init__
(
self
,
noise_weight_init
=
0.
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
1
).
fill_
(
noise_weight_init
))
def
forward
(
self
,
image
,
noise
=
None
,
return_noise
=
False
):
"""Forward Function.
Args:
image (Tensor): Spatial features with a shape of (N, C, H, W).
noise (Tensor, optional): Noises from the outside.
Defaults to None.
return_noise (bool, optional): Whether to return noise tensor.
Defaults to False.
Returns:
Tensor: Output features.
"""
if
noise
is
None
:
batch
,
_
,
height
,
width
=
image
.
shape
noise
=
image
.
new_empty
(
batch
,
1
,
height
,
width
).
normal_
()
noise
=
noise
.
to
(
image
.
dtype
)
if
return_noise
:
return
image
+
self
.
weight
.
to
(
image
.
dtype
)
*
noise
,
noise
return
image
+
self
.
weight
.
to
(
image
.
dtype
)
*
noise
class
ConstantInput
(
nn
.
Module
):
"""Constant Input.
In StyleGAN2, they substitute the original head noise input with such a
constant input module.
Args:
channel (int): Channels for the constant input tensor.
size (int, optional): Spatial size for the constant input.
Defaults to 4.
"""
def
__init__
(
self
,
channel
,
size
=
4
):
super
().
__init__
()
if
isinstance
(
size
,
int
):
size
=
[
size
,
size
]
elif
mmcv
.
is_seq_of
(
size
,
int
):
assert
len
(
size
)
==
2
,
f
'The length of size should be 2 but got
{
len
(
size
)
}
'
else
:
raise
ValueError
(
f
'Got invalid value in size,
{
size
}
'
)
self
.
input
=
nn
.
Parameter
(
torch
.
randn
(
1
,
channel
,
*
size
))
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, ...).
Returns:
Tensor: Output feature map.
"""
batch
=
x
.
shape
[
0
]
out
=
self
.
input
.
repeat
(
batch
,
1
,
1
,
1
)
return
out
class
ModulatedPEConv2d
(
nn
.
Module
):
r
"""Modulated Conv2d in StyleGANv2 with Positional Encoding (PE).
This module is modified from the ``ModulatedConv2d`` in StyleGAN2 to
support the experiments in: Positional Encoding as Spatial Inductive Bias
in GANs, CVPR'2021.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
style_channels (int): Channels for the style codes.
demodulate (bool, optional): Whether to adopt demodulation.
Defaults to True.
upsample (bool, optional): Whether to adopt upsampling in features.
Defaults to False.
downsample (bool, optional): Whether to adopt downsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
equalized_lr_cfg (dict | None, optional): Configs for equalized lr.
Defaults to dict(mode='fan_in', lr_mul=1., gain=1.).
style_mod_cfg (dict, optional): Configs for style modulation module.
Defaults to dict(bias_init=1.).
style_bias (float, optional): Bias value for style code.
Defaults to 0..
eps (float, optional): Epsilon value to avoid computation error.
Defaults to 1e-8.
no_pad (bool, optional): Whether to removing the padding in
convolution. Defaults to False.
deconv2conv (bool, optional): Whether to substitute the transposed conv
with (conv2d, upsampling). Defaults to False.
interp_pad (int | None, optional): The padding number of interpolation
pad. Defaults to None.
up_config (dict, optional): Upsampling config.
Defaults to dict(scale_factor=2, mode='nearest').
up_after_conv (bool, optional): Whether to adopt upsampling after
convolution. Defaults to False.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
style_channels
,
demodulate
=
True
,
upsample
=
False
,
downsample
=
False
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
equalized_lr_cfg
=
dict
(
mode
=
'fan_in'
,
lr_mul
=
1.
,
gain
=
1.
),
style_mod_cfg
=
dict
(
bias_init
=
1.
),
style_bias
=
0.
,
eps
=
1e-8
,
no_pad
=
False
,
deconv2conv
=
False
,
interp_pad
=
None
,
up_config
=
dict
(
scale_factor
=
2
,
mode
=
'nearest'
),
up_after_conv
=
False
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
kernel_size
self
.
style_channels
=
style_channels
self
.
demodulate
=
demodulate
# sanity check for kernel size
assert
isinstance
(
self
.
kernel_size
,
int
)
and
(
self
.
kernel_size
>=
1
and
self
.
kernel_size
%
2
==
1
)
self
.
upsample
=
upsample
self
.
downsample
=
downsample
self
.
style_bias
=
style_bias
self
.
eps
=
eps
self
.
no_pad
=
no_pad
self
.
deconv2conv
=
deconv2conv
self
.
interp_pad
=
interp_pad
self
.
with_interp_pad
=
interp_pad
is
not
None
self
.
up_config
=
deepcopy
(
up_config
)
self
.
up_after_conv
=
up_after_conv
# build style modulation module
style_mod_cfg
=
dict
()
if
style_mod_cfg
is
None
else
style_mod_cfg
self
.
style_modulation
=
EqualLinearActModule
(
style_channels
,
in_channels
,
**
style_mod_cfg
)
# set lr_mul for conv weight
lr_mul_
=
1.
if
equalized_lr_cfg
is
not
None
:
lr_mul_
=
equalized_lr_cfg
.
get
(
'lr_mul'
,
1.
)
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
1
,
out_channels
,
in_channels
,
kernel_size
,
kernel_size
).
div_
(
lr_mul_
))
# build blurry layer for upsampling
if
upsample
and
not
self
.
deconv2conv
:
factor
=
2
p
=
(
len
(
blur_kernel
)
-
factor
)
-
(
kernel_size
-
1
)
pad0
=
(
p
+
1
)
//
2
+
factor
-
1
pad1
=
p
//
2
+
1
self
.
blur
=
Blur
(
blur_kernel
,
(
pad0
,
pad1
),
upsample_factor
=
factor
)
# build blurry layer for downsampling
if
downsample
:
factor
=
2
p
=
(
len
(
blur_kernel
)
-
factor
)
+
(
kernel_size
-
1
)
pad0
=
(
p
+
1
)
//
2
pad1
=
p
//
2
self
.
blur
=
Blur
(
blur_kernel
,
pad
=
(
pad0
,
pad1
))
# add equalized_lr hook for conv weight
if
equalized_lr_cfg
is
not
None
:
equalized_lr
(
self
,
**
equalized_lr_cfg
)
# if `no_pad`, remove all of the padding in conv
self
.
padding
=
kernel_size
//
2
if
not
no_pad
else
0
def
forward
(
self
,
x
,
style
):
"""Forward function.
Args:
x ([Tensor): Input features with shape of (N, C, H, W).
style (Tensor): Style latent with shape of (N, C).
Returns:
Tensor: Output feature with shape of (N, C, H, W).
"""
n
,
c
,
h
,
w
=
x
.
shape
# process style code
style
=
self
.
style_modulation
(
style
).
view
(
n
,
1
,
c
,
1
,
1
)
+
self
.
style_bias
# combine weight and style
weight
=
self
.
weight
*
style
if
self
.
demodulate
:
demod
=
torch
.
rsqrt
(
weight
.
pow
(
2
).
sum
([
2
,
3
,
4
])
+
self
.
eps
)
weight
=
weight
*
demod
.
view
(
n
,
self
.
out_channels
,
1
,
1
,
1
)
weight
=
weight
.
view
(
n
*
self
.
out_channels
,
c
,
self
.
kernel_size
,
self
.
kernel_size
)
if
self
.
upsample
and
not
self
.
deconv2conv
:
x
=
x
.
reshape
(
1
,
n
*
c
,
h
,
w
)
weight
=
weight
.
view
(
n
,
self
.
out_channels
,
c
,
self
.
kernel_size
,
self
.
kernel_size
)
weight
=
weight
.
transpose
(
1
,
2
).
reshape
(
n
*
c
,
self
.
out_channels
,
self
.
kernel_size
,
self
.
kernel_size
)
x
=
conv_transpose2d
(
x
,
weight
,
padding
=
0
,
stride
=
2
,
groups
=
n
)
x
=
x
.
reshape
(
n
,
self
.
out_channels
,
*
x
.
shape
[
-
2
:])
x
=
self
.
blur
(
x
)
elif
self
.
upsample
and
self
.
deconv2conv
:
if
self
.
up_after_conv
:
x
=
x
.
reshape
(
1
,
n
*
c
,
h
,
w
)
x
=
conv2d
(
x
,
weight
,
padding
=
self
.
padding
,
groups
=
n
)
x
=
x
.
view
(
n
,
self
.
out_channels
,
*
x
.
shape
[
2
:
4
])
if
self
.
with_interp_pad
:
h_
,
w_
=
x
.
shape
[
-
2
:]
up_cfg_
=
deepcopy
(
self
.
up_config
)
up_scale
=
up_cfg_
.
pop
(
'scale_factor'
)
size_
=
(
h_
*
up_scale
+
self
.
interp_pad
,
w_
*
up_scale
+
self
.
interp_pad
)
x
=
F
.
interpolate
(
x
,
size
=
size_
,
**
up_cfg_
)
else
:
x
=
F
.
interpolate
(
x
,
**
self
.
up_config
)
if
not
self
.
up_after_conv
:
h_
,
w_
=
x
.
shape
[
-
2
:]
x
=
x
.
view
(
1
,
n
*
c
,
h_
,
w_
)
x
=
conv2d
(
x
,
weight
,
padding
=
self
.
padding
,
groups
=
n
)
x
=
x
.
view
(
n
,
self
.
out_channels
,
*
x
.
shape
[
2
:
4
])
elif
self
.
downsample
:
x
=
self
.
blur
(
x
)
x
=
x
.
view
(
1
,
n
*
self
.
in_channels
,
*
x
.
shape
[
-
2
:])
x
=
conv2d
(
x
,
weight
,
stride
=
2
,
padding
=
0
,
groups
=
n
)
x
=
x
.
view
(
n
,
self
.
out_channels
,
*
x
.
shape
[
-
2
:])
else
:
x
=
x
.
view
(
1
,
n
*
c
,
h
,
w
)
x
=
conv2d
(
x
,
weight
,
stride
=
1
,
padding
=
self
.
padding
,
groups
=
n
)
x
=
x
.
view
(
n
,
self
.
out_channels
,
*
x
.
shape
[
-
2
:])
return
x
class
ModulatedStyleConv
(
nn
.
Module
):
"""Modulated Style Convolution.
In this module, we integrate the modulated conv2d, noise injector and
activation layers into together.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
style_channels (int): Channels for the style codes.
demodulate (bool, optional): Whether to adopt demodulation.
Defaults to True.
upsample (bool, optional): Whether to adopt upsampling in features.
Defaults to False.
downsample (bool, optional): Whether to adopt downsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
equalized_lr_cfg (dict | None, optional): Configs for equalized lr.
Defaults to dict(mode='fan_in', lr_mul=1., gain=1.).
style_mod_cfg (dict, optional): Configs for style modulation module.
Defaults to dict(bias_init=1.).
style_bias (float, optional): Bias value for style code.
Defaults to ``0.``.
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. Defaults to False.
conv_clamp (float, optional): Clamp the convolutional layer results to
avoid gradient overflow. Defaults to `256.0`.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
style_channels
,
upsample
=
False
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
demodulate
=
True
,
style_mod_cfg
=
dict
(
bias_init
=
1.
),
style_bias
=
0.
,
fp16_enabled
=
False
,
conv_clamp
=
256
):
super
().
__init__
()
# add support for fp16
self
.
fp16_enabled
=
fp16_enabled
self
.
conv_clamp
=
float
(
conv_clamp
)
self
.
conv
=
ModulatedConv2d
(
in_channels
,
out_channels
,
kernel_size
,
style_channels
,
demodulate
=
demodulate
,
upsample
=
upsample
,
blur_kernel
=
blur_kernel
,
style_mod_cfg
=
style_mod_cfg
,
style_bias
=
style_bias
)
self
.
noise_injector
=
NoiseInjection
()
self
.
activate
=
_FusedBiasLeakyReLU
(
out_channels
)
# if self.fp16_enabled:
# self.half()
@
auto_fp16
(
apply_to
=
(
'x'
,
'noise'
))
def
forward
(
self
,
x
,
style
,
noise
=
None
,
return_noise
=
False
):
"""Forward Function.
Args:
x ([Tensor): Input features with shape of (N, C, H, W).
style (Tensor): Style latent with shape of (N, C).
noise (Tensor, optional): Noise for injection. Defaults to None.
return_noise (bool, optional): Whether to return noise tensors.
Defaults to False.
Returns:
Tensor: Output features with shape of (N, C, H, W)
"""
out
=
self
.
conv
(
x
,
style
)
if
return_noise
:
out
,
noise
=
self
.
noise_injector
(
out
,
noise
=
noise
,
return_noise
=
return_noise
)
else
:
out
=
self
.
noise_injector
(
out
,
noise
=
noise
,
return_noise
=
return_noise
)
# TODO: FP16 in activate layers
out
=
self
.
activate
(
out
)
if
self
.
fp16_enabled
:
out
=
torch
.
clamp
(
out
,
min
=-
self
.
conv_clamp
,
max
=
self
.
conv_clamp
)
if
return_noise
:
return
out
,
noise
return
out
class
ModulatedPEStyleConv
(
nn
.
Module
):
"""Modulated Style Convolution with Positional Encoding.
This module is modified from the ``ModulatedStyleConv`` in StyleGAN2 to
support the experiments in: Positional Encoding as Spatial Inductive Bias
in GANs, CVPR'2021.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
style_channels (int): Channels for the style codes.
demodulate (bool, optional): Whether to adopt demodulation.
Defaults to True.
upsample (bool, optional): Whether to adopt upsampling in features.
Defaults to False.
downsample (bool, optional): Whether to adopt downsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
equalized_lr_cfg (dict | None, optional): Configs for equalized lr.
Defaults to dict(mode='fan_in', lr_mul=1., gain=1.).
style_mod_cfg (dict, optional): Configs for style modulation module.
Defaults to dict(bias_init=1.).
style_bias (float, optional): Bias value for style code.
Defaults to 0..
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
style_channels
,
upsample
=
False
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
demodulate
=
True
,
style_mod_cfg
=
dict
(
bias_init
=
1.
),
style_bias
=
0.
,
**
kwargs
):
super
().
__init__
()
self
.
conv
=
ModulatedPEConv2d
(
in_channels
,
out_channels
,
kernel_size
,
style_channels
,
demodulate
=
demodulate
,
upsample
=
upsample
,
blur_kernel
=
blur_kernel
,
style_mod_cfg
=
style_mod_cfg
,
style_bias
=
style_bias
,
**
kwargs
)
self
.
noise_injector
=
NoiseInjection
()
self
.
activate
=
_FusedBiasLeakyReLU
(
out_channels
)
def
forward
(
self
,
x
,
style
,
noise
=
None
,
return_noise
=
False
):
"""Forward Function.
Args:
x ([Tensor): Input features with shape of (N, C, H, W).
style (Tensor): Style latent with shape of (N, C).
noise (Tensor, optional): Noise for injection. Defaults to None.
return_noise (bool, optional): Whether to return noise tensors.
Defaults to False.
Returns:
Tensor: Output features with shape of (N, C, H, W)
"""
out
=
self
.
conv
(
x
,
style
)
if
return_noise
:
out
,
noise
=
self
.
noise_injector
(
out
,
noise
=
noise
,
return_noise
=
return_noise
)
else
:
out
=
self
.
noise_injector
(
out
,
noise
=
noise
,
return_noise
=
return_noise
)
out
=
self
.
activate
(
out
)
if
return_noise
:
return
out
,
noise
return
out
class
ModulatedToRGB
(
nn
.
Module
):
"""To RGB layer.
This module is designed to output image tensor in StyleGAN2.
Args:
in_channels (int): Input channels.
style_channels (int): Channels for the style codes.
out_channels (int, optional): Output channels. Defaults to 3.
upsample (bool, optional): Whether to adopt upsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
style_mod_cfg (dict, optional): Configs for style modulation module.
Defaults to dict(bias_init=1.).
style_bias (float, optional): Bias value for style code.
Defaults to 0..
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. Defaults to False.
conv_clamp (float, optional): Clamp the convolutional layer results to
avoid gradient overflow. Defaults to `256.0`.
out_fp32 (bool, optional): Whether to convert the output feature map to
`torch.float32`. Defaults to `True`.
"""
def
__init__
(
self
,
in_channels
,
style_channels
,
out_channels
=
3
,
upsample
=
True
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
style_mod_cfg
=
dict
(
bias_init
=
1.
),
style_bias
=
0.
,
fp16_enabled
=
False
,
conv_clamp
=
256
,
out_fp32
=
True
):
super
().
__init__
()
if
upsample
:
self
.
upsample
=
UpsampleUpFIRDn
(
blur_kernel
)
# add support for fp16
self
.
fp16_enabled
=
fp16_enabled
self
.
conv_clamp
=
float
(
conv_clamp
)
self
.
conv
=
ModulatedConv2d
(
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
style_channels
=
style_channels
,
demodulate
=
False
,
style_mod_cfg
=
style_mod_cfg
,
style_bias
=
style_bias
)
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
3
,
1
,
1
))
# enforece the output to be fp32 (follow Tero's implementation)
self
.
out_fp32
=
out_fp32
@
auto_fp16
(
apply_to
=
(
'x'
,
'style'
))
def
forward
(
self
,
x
,
style
,
skip
=
None
):
"""Forward Function.
Args:
x ([Tensor): Input features with shape of (N, C, H, W).
style (Tensor): Style latent with shape of (N, C).
skip (Tensor, optional): Tensor for skip link. Defaults to None.
Returns:
Tensor: Output features with shape of (N, C, H, W)
"""
out
=
self
.
conv
(
x
,
style
)
out
=
out
+
self
.
bias
.
to
(
x
.
dtype
)
if
self
.
fp16_enabled
:
out
=
torch
.
clamp
(
out
,
min
=-
self
.
conv_clamp
,
max
=
self
.
conv_clamp
)
# Here, Tero adopts FP16 at `skip`.
if
skip
is
not
None
:
skip
=
self
.
upsample
(
skip
)
out
=
out
+
skip
return
out
class
ConvDownLayer
(
nn
.
Sequential
):
"""Convolution and Downsampling layer.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
downsample (bool, optional): Whether to adopt downsampling in features.
Defaults to False.
blur_kernel (list[int], optional): Blurry kernel.
Defaults to [1, 3, 3, 1].
bias (bool, optional): Whether to use bias parameter. Defaults to True.
act_cfg (dict, optional): Activation configs.
Defaults to dict(type='fused_bias').
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. Defaults to False.
conv_clamp (float, optional): Clamp the convolutional layer results to
avoid gradient overflow. Defaults to `256.0`.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
downsample
=
False
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
bias
=
True
,
act_cfg
=
dict
(
type
=
'fused_bias'
),
fp16_enabled
=
False
,
conv_clamp
=
256.
):
self
.
fp16_enabled
=
fp16_enabled
self
.
conv_clamp
=
float
(
conv_clamp
)
layers
=
[]
if
downsample
:
factor
=
2
p
=
(
len
(
blur_kernel
)
-
factor
)
+
(
kernel_size
-
1
)
pad0
=
(
p
+
1
)
//
2
pad1
=
p
//
2
layers
.
append
(
Blur
(
blur_kernel
,
pad
=
(
pad0
,
pad1
)))
stride
=
2
self
.
padding
=
0
else
:
stride
=
1
self
.
padding
=
kernel_size
//
2
self
.
with_fused_bias
=
act_cfg
is
not
None
and
act_cfg
.
get
(
'type'
)
==
'fused_bias'
if
self
.
with_fused_bias
:
conv_act_cfg
=
None
else
:
conv_act_cfg
=
act_cfg
layers
.
append
(
EqualizedLRConvModule
(
in_channels
,
out_channels
,
kernel_size
,
padding
=
self
.
padding
,
stride
=
stride
,
bias
=
bias
and
not
self
.
with_fused_bias
,
norm_cfg
=
None
,
act_cfg
=
conv_act_cfg
,
equalized_lr_cfg
=
dict
(
mode
=
'fan_in'
,
gain
=
1.
)))
if
self
.
with_fused_bias
:
layers
.
append
(
_FusedBiasLeakyReLU
(
out_channels
))
super
(
ConvDownLayer
,
self
).
__init__
(
*
layers
)
@
auto_fp16
(
apply_to
=
(
'x'
,
))
def
forward
(
self
,
x
):
x
=
super
().
forward
(
x
)
if
self
.
fp16_enabled
:
x
=
torch
.
clamp
(
x
,
min
=-
self
.
conv_clamp
,
max
=
self
.
conv_clamp
)
return
x
class
ResBlock
(
nn
.
Module
):
"""Residual block used in the discriminator of StyleGAN2.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
kernel_size (int): Kernel size, same as :obj:`nn.Con2d`.
fp16_enabled (bool, optional): Whether to use fp16 training in this
module. Defaults to False.
convert_input_fp32 (bool, optional): Whether to convert input type to
fp32 if not `fp16_enabled`. This argument is designed to deal with
the cases where some modules are run in FP16 and others in FP32.
Defaults to True.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
fp16_enabled
=
False
,
convert_input_fp32
=
True
):
super
().
__init__
()
self
.
fp16_enabled
=
fp16_enabled
self
.
convert_input_fp32
=
convert_input_fp32
self
.
conv1
=
ConvDownLayer
(
in_channels
,
in_channels
,
3
,
blur_kernel
=
blur_kernel
)
self
.
conv2
=
ConvDownLayer
(
in_channels
,
out_channels
,
3
,
downsample
=
True
,
blur_kernel
=
blur_kernel
)
self
.
skip
=
ConvDownLayer
(
in_channels
,
out_channels
,
1
,
downsample
=
True
,
act_cfg
=
None
,
bias
=
False
,
blur_kernel
=
blur_kernel
)
@
auto_fp16
()
def
forward
(
self
,
input
):
"""Forward function.
Args:
input (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map.
"""
# TODO: study whether this explicit datatype transfer will harm the
# apex training speed
if
not
self
.
fp16_enabled
and
self
.
convert_input_fp32
:
input
=
input
.
to
(
torch
.
float32
)
out
=
self
.
conv1
(
input
)
out
=
self
.
conv2
(
out
)
skip
=
self
.
skip
(
input
)
out
=
(
out
+
skip
)
/
np
.
sqrt
(
2
)
return
out
class
ModMBStddevLayer
(
nn
.
Module
):
"""Modified MiniBatch Stddev Layer.
This layer is modified from ``MiniBatchStddevLayer`` used in PGGAN. In
StyleGAN2, the authors add a new feature, `channel_groups`, into this
layer.
Note that to accelerate the training procedure, we also add a new feature
of ``sync_std`` to achieve multi-nodes/machine training. This feature is
still in beta version and we have tested it on 256 scales.
Args:
group_size (int, optional): The size of groups in batch dimension.
Defaults to 4.
channel_groups (int, optional): The size of groups in channel
dimension. Defaults to 1.
sync_std (bool, optional): Whether to use synchronized std feature.
Defaults to False.
sync_groups (int | None, optional): The size of groups in node
dimension. Defaults to None.
eps (float, optional): Epsilon value to avoid computation error.
Defaults to 1e-8.
"""
def
__init__
(
self
,
group_size
=
4
,
channel_groups
=
1
,
sync_std
=
False
,
sync_groups
=
None
,
eps
=
1e-8
):
super
().
__init__
()
self
.
group_size
=
group_size
self
.
eps
=
eps
self
.
channel_groups
=
channel_groups
self
.
sync_std
=
sync_std
self
.
sync_groups
=
group_size
if
sync_groups
is
None
else
sync_groups
if
self
.
sync_std
:
assert
torch
.
distributed
.
is_initialized
(
),
'Only in distributed training can the sync_std be activated.'
mmcv
.
print_log
(
'Adopt synced minibatch stddev layer'
,
'mmgen'
)
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input feature map with shape of (N, C, H, W).
Returns:
Tensor: Output feature map with shape of (N, C+1, H, W).
"""
if
self
.
sync_std
:
# concatenate all features
all_features
=
torch
.
cat
(
AllGatherLayer
.
apply
(
x
),
dim
=
0
)
# get the exact features we need in calculating std-dev
rank
,
ws
=
get_dist_info
()
local_bs
=
all_features
.
shape
[
0
]
//
ws
start_idx
=
local_bs
*
rank
# avoid the case where start idx near the tail of features
if
start_idx
+
self
.
sync_groups
>
all_features
.
shape
[
0
]:
start_idx
=
all_features
.
shape
[
0
]
-
self
.
sync_groups
end_idx
=
min
(
local_bs
*
rank
+
self
.
sync_groups
,
all_features
.
shape
[
0
])
x
=
all_features
[
start_idx
:
end_idx
]
# batch size should be smaller than or equal to group size. Otherwise,
# batch size should be divisible by the group size.
assert
x
.
shape
[
0
]
<=
self
.
group_size
or
x
.
shape
[
0
]
%
self
.
group_size
==
0
,
(
'Batch size be smaller than or equal '
'to group size. Otherwise,'
' batch size should be divisible by the group size.'
f
'But got batch size
{
x
.
shape
[
0
]
}
,'
f
' group size
{
self
.
group_size
}
'
)
assert
x
.
shape
[
1
]
%
self
.
channel_groups
==
0
,
(
'"channel_groups" must be divided by the feature channels. '
f
'channel_groups:
{
self
.
channel_groups
}
, '
f
'feature channels:
{
x
.
shape
[
1
]
}
'
)
n
,
c
,
h
,
w
=
x
.
shape
group_size
=
min
(
n
,
self
.
group_size
)
# [G, M, Gc, C', H, W]
y
=
torch
.
reshape
(
x
,
(
group_size
,
-
1
,
self
.
channel_groups
,
c
//
self
.
channel_groups
,
h
,
w
))
y
=
torch
.
var
(
y
,
dim
=
0
,
unbiased
=
False
)
y
=
torch
.
sqrt
(
y
+
self
.
eps
)
# [M, 1, 1, 1]
y
=
y
.
mean
(
dim
=
(
2
,
3
,
4
),
keepdim
=
True
).
squeeze
(
2
)
y
=
y
.
repeat
(
group_size
,
1
,
h
,
w
)
return
torch
.
cat
([
x
,
y
],
dim
=
1
)
build/lib/mmgen/models/architectures/stylegan/modules/styleganv3_modules.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
scipy
import
torch
import
torch.nn
as
nn
from
mmgen.models.builder
import
MODULES
from
mmgen.ops
import
bias_act
,
conv2d_gradfix
,
filtered_lrelu
def
modulated_conv2d
(
x
,
w
,
s
,
demodulate
=
True
,
padding
=
0
,
input_gain
=
None
,
):
"""Modulated Conv2d in StyleGANv3.
Args:
x (torch.Tensor): Input tensor with shape (batch_size, in_channels,
height, width).
w (torch.Tensor): Weight of modulated convolution with shape
(out_channels, in_channels, kernel_height, kernel_width).
s (torch.Tensor): Style tensor with shape (batch_size, in_channels).
demodulate (bool): Whether apply weight demodulation. Defaults to True.
padding (int or list[int]): Convolution padding. Defaults to 0.
input_gain (list[int]): Scaling factors for input. Defaults to None.
Returns:
torch.Tensor: Convolution Output.
"""
batch_size
=
int
(
x
.
shape
[
0
])
_
,
in_channels
,
kh
,
kw
=
w
.
shape
# Pre-normalize inputs.
if
demodulate
:
w
=
w
*
w
.
square
().
mean
([
1
,
2
,
3
],
keepdim
=
True
).
rsqrt
()
s
=
s
*
s
.
square
().
mean
().
rsqrt
()
# Modulate weights.
w
=
w
.
unsqueeze
(
0
)
# [NOIkk]
w
=
w
*
s
.
unsqueeze
(
1
).
unsqueeze
(
3
).
unsqueeze
(
4
)
# [NOIkk]
# Demodulate weights.
if
demodulate
:
dcoefs
=
(
w
.
square
().
sum
(
dim
=
[
2
,
3
,
4
])
+
1e-8
).
rsqrt
()
# [NO]
w
=
w
*
dcoefs
.
unsqueeze
(
2
).
unsqueeze
(
3
).
unsqueeze
(
4
)
# [NOIkk]
# Apply input scaling.
if
input_gain
is
not
None
:
input_gain
=
input_gain
.
expand
(
batch_size
,
in_channels
)
# [NI]
w
=
w
*
input_gain
.
unsqueeze
(
1
).
unsqueeze
(
3
).
unsqueeze
(
4
)
# [NOIkk]
# Execute as one fused op using grouped convolution.
x
=
x
.
reshape
(
1
,
-
1
,
*
x
.
shape
[
2
:])
w
=
w
.
reshape
(
-
1
,
in_channels
,
kh
,
kw
)
x
=
conv2d_gradfix
.
conv2d
(
input
=
x
,
weight
=
w
.
to
(
x
.
dtype
),
padding
=
padding
,
groups
=
batch_size
)
x
=
x
.
reshape
(
batch_size
,
-
1
,
*
x
.
shape
[
2
:])
return
x
class
FullyConnectedLayer
(
nn
.
Module
):
"""Fully connected layer used in StyleGANv3.
Args:
in_features (int): Number of channels in the input feature.
out_features (int): Number of channels in the out feature.
activation (str, optional): Activation function with choices 'relu',
'lrelu', 'linear'. 'linear' means no extra activation.
Defaults to 'linear'.
bias (bool, optional): Whether to use additive bias. Defaults to True.
lr_multiplier (float, optional): Equalized learning rate multiplier.
Defaults to 1..
weight_init (float, optional): Weight multiplier for initialization.
Defaults to 1..
bias_init (float, optional): Initial bias. Defaults to 0..
"""
def
__init__
(
self
,
in_features
,
out_features
,
activation
=
'linear'
,
bias
=
True
,
lr_multiplier
=
1.
,
weight_init
=
1.
,
bias_init
=
0.
):
super
().
__init__
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
activation
=
activation
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
randn
([
out_features
,
in_features
])
*
(
weight_init
/
lr_multiplier
))
bias_init
=
np
.
broadcast_to
(
np
.
asarray
(
bias_init
,
dtype
=
np
.
float32
),
[
out_features
])
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
from_numpy
(
bias_init
/
lr_multiplier
))
if
bias
else
None
self
.
weight_gain
=
lr_multiplier
/
np
.
sqrt
(
in_features
)
self
.
bias_gain
=
lr_multiplier
def
forward
(
self
,
x
):
"""Forward function."""
w
=
self
.
weight
.
to
(
x
.
dtype
)
*
self
.
weight_gain
b
=
self
.
bias
if
b
is
not
None
:
b
=
b
.
to
(
x
.
dtype
)
if
self
.
bias_gain
!=
1
:
b
=
b
*
self
.
bias_gain
if
self
.
activation
==
'linear'
and
b
is
not
None
:
x
=
torch
.
addmm
(
b
.
unsqueeze
(
0
),
x
,
w
.
t
())
else
:
x
=
x
.
matmul
(
w
.
t
())
x
=
bias_act
.
bias_act
(
x
,
b
,
act
=
self
.
activation
)
return
x
@
MODULES
.
register_module
()
class
MappingNetwork
(
nn
.
Module
):
"""Style mapping network used in StyleGAN3. The main difference between it
and styleganv1,v2 is that mean latent is registered as a buffer and dynamic
updated during training.
Args:
noise_size (int, optional): Size of the input noise vector.
c_dim (int, optional): Size of the input noise vector.
style_channels (int): The number of channels for style code.
num_ws (int): The repeat times of w latent.
num_layers (int, optional): The number of layers of mapping network.
Defaults to 2.
lr_multiplier (float, optional): Equalized learning rate multiplier.
Defaults to 0.01.
w_avg_beta (float, optional): The value used for update `w_avg`.
Defaults to 0.998.
"""
def
__init__
(
self
,
noise_size
,
style_channels
,
num_ws
,
c_dim
=
0
,
num_layers
=
2
,
lr_multiplier
=
0.01
,
w_avg_beta
=
0.998
):
super
().
__init__
()
self
.
noise_size
=
noise_size
self
.
c_dim
=
c_dim
self
.
style_channels
=
style_channels
self
.
num_ws
=
num_ws
self
.
num_layers
=
num_layers
self
.
w_avg_beta
=
w_avg_beta
# Construct layers.
self
.
embed
=
FullyConnectedLayer
(
self
.
c_dim
,
self
.
style_channels
)
if
self
.
c_dim
>
0
else
None
features
=
[
self
.
noise_size
+
(
self
.
style_channels
if
self
.
c_dim
>
0
else
0
)
]
+
[
self
.
style_channels
]
*
self
.
num_layers
for
idx
,
in_features
,
out_features
in
zip
(
range
(
num_layers
),
features
[:
-
1
],
features
[
1
:]):
layer
=
FullyConnectedLayer
(
in_features
,
out_features
,
activation
=
'lrelu'
,
lr_multiplier
=
lr_multiplier
)
setattr
(
self
,
f
'fc
{
idx
}
'
,
layer
)
self
.
register_buffer
(
'w_avg'
,
torch
.
zeros
([
style_channels
]))
def
forward
(
self
,
z
,
c
=
None
,
truncation
=
1
,
num_truncation_layer
=
None
,
update_emas
=
False
):
"""Style mapping function.
Args:
z (torch.Tensor): Input noise tensor.
c (torch.Tensor, optional): Input label tensor. Defaults to None.
truncation (float, optional): Truncation factor. Give value less
than 1., the truncation trick will be adopted. Defaults to 1.
num_truncation_layer (int, optional): Number of layers use
truncated latent. Defaults to None.
update_emas (bool, optional): Whether update moving average of
mean latent. Defaults to False.
Returns:
torch.Tensor: W-plus latent.
"""
if
num_truncation_layer
is
None
:
num_truncation_layer
=
self
.
num_ws
# Embed, normalize, and concatenate inputs.
x
=
z
.
to
(
torch
.
float32
)
x
=
x
*
(
x
.
square
().
mean
(
1
,
keepdim
=
True
)
+
1e-8
).
rsqrt
()
if
self
.
c_dim
>
0
:
y
=
self
.
embed
(
c
.
to
(
torch
.
float32
))
y
=
y
*
(
y
.
square
().
mean
(
1
,
keepdim
=
True
)
+
1e-8
).
rsqrt
()
x
=
torch
.
cat
([
x
,
y
],
dim
=
1
)
if
x
is
not
None
else
y
# Execute layers.
for
idx
in
range
(
self
.
num_layers
):
x
=
getattr
(
self
,
f
'fc
{
idx
}
'
)(
x
)
# Update moving average of W.
if
update_emas
:
self
.
w_avg
.
copy_
(
x
.
detach
().
mean
(
dim
=
0
).
lerp
(
self
.
w_avg
,
self
.
w_avg_beta
))
# Broadcast and apply truncation.
x
=
x
.
unsqueeze
(
1
).
repeat
([
1
,
self
.
num_ws
,
1
])
if
truncation
!=
1
:
x
[:,
:
num_truncation_layer
]
=
self
.
w_avg
.
lerp
(
x
[:,
:
num_truncation_layer
],
truncation
)
return
x
class
SynthesisInput
(
nn
.
Module
):
"""Module which generate input for synthesis layer.
Args:
style_channels (int): The number of channels for style code.
channels (int): The number of output channel.
size (int): The size of sampling grid.
sampling_rate (int): Sampling rate for construct sampling grid.
bandwidth (float): Bandwidth of random frequencies.
"""
def
__init__
(
self
,
style_channels
,
channels
,
size
,
sampling_rate
,
bandwidth
):
super
().
__init__
()
self
.
style_channels
=
style_channels
self
.
channels
=
channels
self
.
size
=
np
.
broadcast_to
(
np
.
asarray
(
size
),
[
2
])
self
.
sampling_rate
=
sampling_rate
self
.
bandwidth
=
bandwidth
# Draw random frequencies from uniform 2D disc.
freqs
=
torch
.
randn
([
self
.
channels
,
2
])
radii
=
freqs
.
square
().
sum
(
dim
=
1
,
keepdim
=
True
).
sqrt
()
freqs
/=
radii
*
radii
.
square
().
exp
().
pow
(
0.25
)
freqs
*=
bandwidth
phases
=
torch
.
rand
([
self
.
channels
])
-
0.5
# Setup parameters and buffers.
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
randn
([
self
.
channels
,
self
.
channels
]))
self
.
affine
=
FullyConnectedLayer
(
style_channels
,
4
,
weight_init
=
0
,
bias_init
=
[
1
,
0
,
0
,
0
])
self
.
register_buffer
(
'transform'
,
torch
.
eye
(
3
,
3
))
# User-specified inverse transform wrt. resulting image.
self
.
register_buffer
(
'freqs'
,
freqs
)
self
.
register_buffer
(
'phases'
,
phases
)
def
forward
(
self
,
w
):
"""Forward function."""
# Introduce batch dimension.
transforms
=
self
.
transform
.
unsqueeze
(
0
)
# [batch, row, col]
freqs
=
self
.
freqs
.
unsqueeze
(
0
)
# [batch, channel, xy]
phases
=
self
.
phases
.
unsqueeze
(
0
)
# [batch, channel]
# Apply learned transformation.
t
=
self
.
affine
(
w
)
# t = (r_c, r_s, t_x, t_y)
t
=
t
/
t
[:,
:
2
].
norm
(
dim
=
1
,
keepdim
=
True
)
# t' = (r'_c, r'_s, t'_x, t'_y)
m_r
=
torch
.
eye
(
3
,
device
=
w
.
device
).
unsqueeze
(
0
).
repeat
(
[
w
.
shape
[
0
],
1
,
1
])
# Inverse rotation wrt. resulting image.
m_r
[:,
0
,
0
]
=
t
[:,
0
]
# r'_c
m_r
[:,
0
,
1
]
=
-
t
[:,
1
]
# r'_s
m_r
[:,
1
,
0
]
=
t
[:,
1
]
# r'_s
m_r
[:,
1
,
1
]
=
t
[:,
0
]
# r'_c
m_t
=
torch
.
eye
(
3
,
device
=
w
.
device
).
unsqueeze
(
0
).
repeat
(
[
w
.
shape
[
0
],
1
,
1
])
# Inverse translation wrt. resulting image.
m_t
[:,
0
,
2
]
=
-
t
[:,
2
]
# t'_x
m_t
[:,
1
,
2
]
=
-
t
[:,
3
]
# t'_y
# First rotate resulting image, then translate
# and finally apply user-specified transform.
transforms
=
m_r
@
m_t
@
transforms
# Transform frequencies.
phases
=
phases
+
(
freqs
@
transforms
[:,
:
2
,
2
:]).
squeeze
(
2
)
freqs
=
freqs
@
transforms
[:,
:
2
,
:
2
]
# Dampen out-of-band frequencies
# that may occur due to the user-specified transform.
amplitudes
=
(
1
-
(
freqs
.
norm
(
dim
=
2
)
-
self
.
bandwidth
)
/
(
self
.
sampling_rate
/
2
-
self
.
bandwidth
)).
clamp
(
0
,
1
)
# Construct sampling grid.
theta
=
torch
.
eye
(
2
,
3
,
device
=
w
.
device
)
theta
[
0
,
0
]
=
0.5
*
self
.
size
[
0
]
/
self
.
sampling_rate
theta
[
1
,
1
]
=
0.5
*
self
.
size
[
1
]
/
self
.
sampling_rate
grids
=
torch
.
nn
.
functional
.
affine_grid
(
theta
.
unsqueeze
(
0
),
[
1
,
1
,
self
.
size
[
1
],
self
.
size
[
0
]],
align_corners
=
False
)
# Compute Fourier features.
x
=
(
grids
.
unsqueeze
(
3
)
@
freqs
.
permute
(
0
,
2
,
1
).
unsqueeze
(
1
).
unsqueeze
(
2
)).
squeeze
(
3
)
# [batch, height, width, channel]
x
=
x
+
phases
.
unsqueeze
(
1
).
unsqueeze
(
2
)
x
=
torch
.
sin
(
x
*
(
np
.
pi
*
2
))
x
=
x
*
amplitudes
.
unsqueeze
(
1
).
unsqueeze
(
2
)
# Apply trainable mapping.
weight
=
self
.
weight
/
np
.
sqrt
(
self
.
channels
)
x
=
x
@
weight
.
t
()
# Ensure correct shape.
x
=
x
.
permute
(
0
,
3
,
1
,
2
)
# [batch, channel, height, width]
return
x
class
SynthesisLayer
(
nn
.
Module
):
"""Layer of Synthesis network for stylegan3.
Args:
style_channels (int): The number of channels for style code.
is_torgb (bool): Whether output of this layer is transformed to
rgb image.
is_critically_sampled (bool): Whether filter cutoff is set exactly
at the bandlimit.
use_fp16 (bool, optional): Whether to use fp16 training in this
module. If this flag is `True`, the whole module will be wrapped
with ``auto_fp16``.
in_channels (int): The channel number of the input feature map.
out_channels (int): The channel number of the output feature map.
in_size (int): The input size of feature map.
out_size (int): The output size of feature map.
in_sampling_rate (int): Sampling rate for upsampling filter.
out_sampling_rate (int): Sampling rate for downsampling filter.
in_cutoff (float): Cutoff frequency for upsampling filter.
out_cutoff (float): Cutoff frequency for downsampling filter.
in_half_width (float): The approximate width of the transition region
for upsampling filter.
out_half_width (float): The approximate width of the transition region
for downsampling filter.
conv_kernel (int, optional): The kernel of modulated convolution.
Defaults to 3.
filter_size (int, optional): Base filter size. Defaults to 6.
lrelu_upsampling (int, optional): Upsamling rate for `filtered_lrelu`.
Defaults to 2.
use_radial_filters (bool, optional): Whether use radially symmetric
jinc-based filter in downsamping filter. Defaults to False.
conv_clamp (int, optional): Clamp bound for convolution.
Defaults to 256.
magnitude_ema_beta (float, optional): Beta coefficient for calculating
input magnitude ema. Defaults to 0.999.
"""
def
__init__
(
self
,
style_channels
,
is_torgb
,
is_critically_sampled
,
use_fp16
,
in_channels
,
out_channels
,
in_size
,
out_size
,
in_sampling_rate
,
out_sampling_rate
,
in_cutoff
,
out_cutoff
,
in_half_width
,
out_half_width
,
conv_kernel
=
3
,
filter_size
=
6
,
lrelu_upsampling
=
2
,
use_radial_filters
=
False
,
conv_clamp
=
256
,
magnitude_ema_beta
=
0.999
,
):
super
().
__init__
()
self
.
style_channels
=
style_channels
self
.
is_torgb
=
is_torgb
self
.
is_critically_sampled
=
is_critically_sampled
self
.
use_fp16
=
use_fp16
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
in_size
=
np
.
broadcast_to
(
np
.
asarray
(
in_size
),
[
2
])
self
.
out_size
=
np
.
broadcast_to
(
np
.
asarray
(
out_size
),
[
2
])
self
.
in_sampling_rate
=
in_sampling_rate
self
.
out_sampling_rate
=
out_sampling_rate
self
.
tmp_sampling_rate
=
max
(
in_sampling_rate
,
out_sampling_rate
)
*
(
1
if
is_torgb
else
lrelu_upsampling
)
self
.
in_cutoff
=
in_cutoff
self
.
out_cutoff
=
out_cutoff
self
.
in_half_width
=
in_half_width
self
.
out_half_width
=
out_half_width
self
.
conv_kernel
=
1
if
is_torgb
else
conv_kernel
self
.
conv_clamp
=
conv_clamp
self
.
magnitude_ema_beta
=
magnitude_ema_beta
# Setup parameters and buffers.
self
.
affine
=
FullyConnectedLayer
(
self
.
style_channels
,
self
.
in_channels
,
bias_init
=
1
)
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
randn
([
self
.
out_channels
,
self
.
in_channels
,
self
.
conv_kernel
,
self
.
conv_kernel
]))
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
([
self
.
out_channels
]))
self
.
register_buffer
(
'magnitude_ema'
,
torch
.
ones
([]))
# Design upsampling filter.
self
.
up_factor
=
int
(
np
.
rint
(
self
.
tmp_sampling_rate
/
self
.
in_sampling_rate
))
assert
self
.
in_sampling_rate
*
self
.
up_factor
==
self
.
tmp_sampling_rate
self
.
up_taps
=
(
filter_size
*
self
.
up_factor
if
self
.
up_factor
>
1
and
not
self
.
is_torgb
else
1
)
self
.
register_buffer
(
'up_filter'
,
self
.
design_lowpass_filter
(
numtaps
=
self
.
up_taps
,
cutoff
=
self
.
in_cutoff
,
width
=
self
.
in_half_width
*
2
,
fs
=
self
.
tmp_sampling_rate
))
# Design downsampling filter.
self
.
down_factor
=
int
(
np
.
rint
(
self
.
tmp_sampling_rate
/
self
.
out_sampling_rate
))
assert
(
self
.
out_sampling_rate
*
self
.
down_factor
==
self
.
tmp_sampling_rate
)
self
.
down_taps
=
(
filter_size
*
self
.
down_factor
if
self
.
down_factor
>
1
and
not
self
.
is_torgb
else
1
)
self
.
down_radial
=
(
use_radial_filters
and
not
self
.
is_critically_sampled
)
self
.
register_buffer
(
'down_filter'
,
self
.
design_lowpass_filter
(
numtaps
=
self
.
down_taps
,
cutoff
=
self
.
out_cutoff
,
width
=
self
.
out_half_width
*
2
,
fs
=
self
.
tmp_sampling_rate
,
radial
=
self
.
down_radial
))
# Compute padding.
pad_total
=
(
self
.
out_size
-
1
)
*
self
.
down_factor
+
1
# Desired output size before downsampling.
pad_total
-=
(
self
.
in_size
+
self
.
conv_kernel
-
1
)
*
self
.
up_factor
# Input size after upsampling.
pad_total
+=
self
.
up_taps
+
self
.
down_taps
-
2
pad_lo
=
(
pad_total
+
self
.
up_factor
)
//
2
pad_hi
=
pad_total
-
pad_lo
self
.
padding
=
[
int
(
pad_lo
[
0
]),
int
(
pad_hi
[
0
]),
int
(
pad_lo
[
1
]),
int
(
pad_hi
[
1
])
]
def
forward
(
self
,
x
,
w
,
force_fp32
=
False
,
update_emas
=
False
):
"""Forward function for synthesis layer.
Args:
x (torch.Tensor): Input feature map tensor.
w (torch.Tensor): Input style tensor.
force_fp32 (bool, optional): Force fp32 ignore the weights.
Defaults to True.
update_emas (bool, optional): Whether update moving average of
input magnitude. Defaults to False.
Returns:
torch.Tensor: Output feature map tensor.
"""
# Track input magnitude.
if
update_emas
:
with
torch
.
autograd
.
profiler
.
record_function
(
'update_magnitude_ema'
):
magnitude_cur
=
x
.
detach
().
to
(
torch
.
float32
).
square
().
mean
()
self
.
magnitude_ema
.
copy_
(
magnitude_cur
.
lerp
(
self
.
magnitude_ema
,
self
.
magnitude_ema_beta
))
input_gain
=
self
.
magnitude_ema
.
rsqrt
()
# Execute affine layer.
styles
=
self
.
affine
(
w
)
if
self
.
is_torgb
:
weight_gain
=
1
/
np
.
sqrt
(
self
.
in_channels
*
(
self
.
conv_kernel
**
2
))
styles
=
styles
*
weight_gain
# Execute modulated conv2d.
dtype
=
torch
.
float16
if
(
self
.
use_fp16
and
not
force_fp32
and
x
.
device
.
type
==
'cuda'
)
else
torch
.
float32
x
=
modulated_conv2d
(
x
=
x
.
to
(
dtype
),
w
=
self
.
weight
,
s
=
styles
,
padding
=
self
.
conv_kernel
-
1
,
demodulate
=
(
not
self
.
is_torgb
),
input_gain
=
input_gain
)
# Execute bias, filtered leaky ReLU, and clamping.
gain
=
1
if
self
.
is_torgb
else
np
.
sqrt
(
2
)
slope
=
1
if
self
.
is_torgb
else
0.2
x
=
filtered_lrelu
.
filtered_lrelu
(
x
=
x
,
fu
=
self
.
up_filter
,
fd
=
self
.
down_filter
,
b
=
self
.
bias
.
to
(
x
.
dtype
),
up
=
self
.
up_factor
,
down
=
self
.
down_factor
,
padding
=
self
.
padding
,
gain
=
gain
,
slope
=
slope
,
clamp
=
self
.
conv_clamp
)
# Ensure correct shape and dtype.
assert
x
.
dtype
==
dtype
return
x
@
staticmethod
def
design_lowpass_filter
(
numtaps
,
cutoff
,
width
,
fs
,
radial
=
False
):
"""Design lowpass filter giving related arguments.
Args:
numtaps (int): Length of the filter. `numtaps` must be odd if a
passband includes the Nyquist frequency.
cutoff (float): Cutoff frequency of filter
width (float): The approximate width of the transition region.
fs (float): The sampling frequency of the signal.
radial (bool, optional): Whether use radially symmetric jinc-based
filter. Defaults to False.
Returns:
torch.Tensor: Kernel of lowpass filter.
"""
assert
numtaps
>=
1
# Identity filter.
if
numtaps
==
1
:
return
None
# Separable Kaiser low-pass filter.
if
not
radial
:
f
=
scipy
.
signal
.
firwin
(
numtaps
=
numtaps
,
cutoff
=
cutoff
,
width
=
width
,
fs
=
fs
)
return
torch
.
as_tensor
(
f
,
dtype
=
torch
.
float32
)
# Radially symmetric jinc-based filter.
x
=
(
np
.
arange
(
numtaps
)
-
(
numtaps
-
1
)
/
2
)
/
fs
r
=
np
.
hypot
(
*
np
.
meshgrid
(
x
,
x
))
f
=
scipy
.
special
.
j1
(
2
*
cutoff
*
(
np
.
pi
*
r
))
/
(
np
.
pi
*
r
)
beta
=
scipy
.
signal
.
kaiser_beta
(
scipy
.
signal
.
kaiser_atten
(
numtaps
,
width
/
(
fs
/
2
)))
w
=
np
.
kaiser
(
numtaps
,
beta
)
f
*=
np
.
outer
(
w
,
w
)
f
/=
np
.
sum
(
f
)
return
torch
.
as_tensor
(
f
,
dtype
=
torch
.
float32
)
@
MODULES
.
register_module
()
class
SynthesisNetwork
(
nn
.
Module
):
"""Synthesis network for stylegan3.
Args:
style_channels (int): The number of channels for style code.
out_size (int): The resolution of output image.
img_channels (int): The number of channels for output image.
channel_base (int, optional): Overall multiplier for the number of
channels. Defaults to 32768.
channel_max (int, optional): Maximum number of channels in any layer.
Defaults to 512.
num_layers (int, optional): Total number of layers, excluding Fourier
features and ToRGB. Defaults to 14.
num_critical (int, optional): Number of critically sampled layers at
the end. Defaults to 2.
first_cutoff (int, optional): Cutoff frequency of the first layer.
Defaults to 2.
first_stopband (int, optional): Minimum stopband of the first layer.
Defaults to 2**2.1.
last_stopband_rel (float, optional): Minimum stopband of the last
layer, expressed relative to the cutoff. Defaults to 2**0.3.
margin_size (int, optional): Number of additional pixels outside the
image. Defaults to 10.
output_scale (float, optional): Scale factor for output value.
Defaults to 0.25.
num_fp16_res (int, optional): Number of first few layers use fp16.
Defaults to 4.
"""
def
__init__
(
self
,
style_channels
,
out_size
,
img_channels
,
channel_base
=
32768
,
channel_max
=
512
,
num_layers
=
14
,
num_critical
=
2
,
first_cutoff
=
2
,
first_stopband
=
2
**
2.1
,
last_stopband_rel
=
2
**
0.3
,
margin_size
=
10
,
output_scale
=
0.25
,
num_fp16_res
=
4
,
**
layer_kwargs
,
):
super
().
__init__
()
self
.
style_channels
=
style_channels
self
.
num_ws
=
num_layers
+
2
self
.
out_size
=
out_size
self
.
img_channels
=
img_channels
self
.
num_layers
=
num_layers
self
.
num_critical
=
num_critical
self
.
margin_size
=
margin_size
self
.
output_scale
=
output_scale
self
.
num_fp16_res
=
num_fp16_res
# Geometric progression of layer cutoffs and min. stopbands.
last_cutoff
=
self
.
out_size
/
2
# f_{c,N}
last_stopband
=
last_cutoff
*
last_stopband_rel
# f_{t,N}
exponents
=
np
.
minimum
(
np
.
arange
(
self
.
num_layers
+
1
)
/
(
self
.
num_layers
-
self
.
num_critical
),
1
)
cutoffs
=
first_cutoff
*
(
last_cutoff
/
first_cutoff
)
**
exponents
# f_c[i]
stopbands
=
first_stopband
*
(
last_stopband
/
first_stopband
)
**
exponents
# f_t[i]
# Compute remaining layer parameters.
sampling_rates
=
np
.
exp2
(
np
.
ceil
(
np
.
log2
(
np
.
minimum
(
stopbands
*
2
,
self
.
out_size
))))
# s[i]
half_widths
=
np
.
maximum
(
stopbands
,
sampling_rates
/
2
)
-
cutoffs
# f_h[i]
sizes
=
sampling_rates
+
self
.
margin_size
*
2
sizes
[
-
2
:]
=
self
.
out_size
channels
=
np
.
rint
(
np
.
minimum
((
channel_base
/
2
)
/
cutoffs
,
channel_max
))
channels
[
-
1
]
=
self
.
img_channels
# Construct layers.
self
.
input
=
SynthesisInput
(
style_channels
=
self
.
style_channels
,
channels
=
int
(
channels
[
0
]),
size
=
int
(
sizes
[
0
]),
sampling_rate
=
sampling_rates
[
0
],
bandwidth
=
cutoffs
[
0
])
self
.
layer_names
=
[]
for
idx
in
range
(
self
.
num_layers
+
1
):
prev
=
max
(
idx
-
1
,
0
)
is_torgb
=
(
idx
==
self
.
num_layers
)
is_critically_sampled
=
(
idx
>=
self
.
num_layers
-
self
.
num_critical
)
use_fp16
=
(
sampling_rates
[
idx
]
*
(
2
**
self
.
num_fp16_res
)
>
self
.
out_size
)
layer
=
SynthesisLayer
(
style_channels
=
self
.
style_channels
,
is_torgb
=
is_torgb
,
is_critically_sampled
=
is_critically_sampled
,
use_fp16
=
use_fp16
,
in_channels
=
int
(
channels
[
prev
]),
out_channels
=
int
(
channels
[
idx
]),
in_size
=
int
(
sizes
[
prev
]),
out_size
=
int
(
sizes
[
idx
]),
in_sampling_rate
=
int
(
sampling_rates
[
prev
]),
out_sampling_rate
=
int
(
sampling_rates
[
idx
]),
in_cutoff
=
cutoffs
[
prev
],
out_cutoff
=
cutoffs
[
idx
],
in_half_width
=
half_widths
[
prev
],
out_half_width
=
half_widths
[
idx
],
**
layer_kwargs
)
name
=
f
'L
{
idx
}
_
{
layer
.
out_size
[
0
]
}
_
{
layer
.
out_channels
}
'
setattr
(
self
,
name
,
layer
)
self
.
layer_names
.
append
(
name
)
def
forward
(
self
,
ws
,
**
layer_kwargs
):
"""Forward function."""
ws
=
ws
.
to
(
torch
.
float32
).
unbind
(
dim
=
1
)
# Execute layers.
x
=
self
.
input
(
ws
[
0
])
for
name
,
w
in
zip
(
self
.
layer_names
,
ws
[
1
:]):
x
=
getattr
(
self
,
name
)(
x
,
w
,
**
layer_kwargs
)
if
self
.
output_scale
!=
1
:
x
=
x
*
self
.
output_scale
# Ensure correct shape and dtype.
x
=
x
.
to
(
torch
.
float32
)
return
x
build/lib/mmgen/models/architectures/stylegan/mspie.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
random
from
copy
import
deepcopy
import
mmcv
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmgen.models.architectures
import
PixelNorm
from
mmgen.models.architectures.common
import
get_module_device
from
mmgen.models.builder
import
MODULES
,
build_module
from
.modules.styleganv2_modules
import
(
ConstantInput
,
ConvDownLayer
,
EqualLinearActModule
,
ModMBStddevLayer
,
ModulatedPEStyleConv
,
ModulatedToRGB
,
ResBlock
)
from
.utils
import
get_mean_latent
,
style_mixing
@
MODULES
.
register_module
()
class
MSStyleGANv2Generator
(
nn
.
Module
):
"""StyleGAN2 Generator.
In StyleGAN2, we use a static architecture composing of a style mapping
module and number of convolutional style blocks. More details can be found
in: Analyzing and Improving the Image Quality of StyleGAN CVPR2020.
Args:
out_size (int): The output size of the StyleGAN2 generator.
style_channels (int): The number of channels for style code.
num_mlps (int, optional): The number of MLP layers. Defaults to 8.
channel_multiplier (int, optional): The multiplier factor for the
channel number. Defaults to 2.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 3, 3, 1].
lr_mlp (float, optional): The learning rate for the style mapping
layer. Defaults to 0.01.
default_style_mode (str, optional): The default mode of style mixing.
In training, we defaultly adopt mixing style mode. However, in the
evaluation, we use 'single' style mode. `['mix', 'single']` are
currently supported. Defaults to 'mix'.
eval_style_mode (str, optional): The evaluation mode of style mixing.
Defaults to 'single'.
mix_prob (float, optional): Mixing probability. The value should be
in range of [0, 1]. Defaults to 0.9.
"""
def
__init__
(
self
,
out_size
,
style_channels
,
num_mlps
=
8
,
channel_multiplier
=
2
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
lr_mlp
=
0.01
,
default_style_mode
=
'mix'
,
eval_style_mode
=
'single'
,
mix_prob
=
0.9
,
no_pad
=
False
,
deconv2conv
=
False
,
interp_pad
=
None
,
up_config
=
dict
(
scale_factor
=
2
,
mode
=
'nearest'
),
up_after_conv
=
False
,
head_pos_encoding
=
None
,
head_pos_size
=
(
4
,
4
),
interp_head
=
False
):
super
().
__init__
()
self
.
out_size
=
out_size
self
.
style_channels
=
style_channels
self
.
num_mlps
=
num_mlps
self
.
channel_multiplier
=
channel_multiplier
self
.
lr_mlp
=
lr_mlp
self
.
_default_style_mode
=
default_style_mode
self
.
default_style_mode
=
default_style_mode
self
.
eval_style_mode
=
eval_style_mode
self
.
mix_prob
=
mix_prob
self
.
no_pad
=
no_pad
self
.
deconv2conv
=
deconv2conv
self
.
interp_pad
=
interp_pad
self
.
with_interp_pad
=
interp_pad
is
not
None
self
.
up_config
=
deepcopy
(
up_config
)
self
.
up_after_conv
=
up_after_conv
self
.
head_pos_encoding
=
head_pos_encoding
self
.
head_pos_size
=
head_pos_size
self
.
interp_head
=
interp_head
# define style mapping layers
mapping_layers
=
[
PixelNorm
()]
for
_
in
range
(
num_mlps
):
mapping_layers
.
append
(
EqualLinearActModule
(
style_channels
,
style_channels
,
equalized_lr_cfg
=
dict
(
lr_mul
=
lr_mlp
,
gain
=
1.
),
act_cfg
=
dict
(
type
=
'fused_bias'
)))
self
.
style_mapping
=
nn
.
Sequential
(
*
mapping_layers
)
self
.
channels
=
{
4
:
512
,
8
:
512
,
16
:
512
,
32
:
512
,
64
:
256
*
channel_multiplier
,
128
:
128
*
channel_multiplier
,
256
:
64
*
channel_multiplier
,
512
:
32
*
channel_multiplier
,
1024
:
16
*
channel_multiplier
,
}
in_ch
=
self
.
channels
[
4
]
# constant input layer
if
self
.
head_pos_encoding
:
if
self
.
head_pos_encoding
[
'type'
]
in
[
'CatersianGrid'
,
'CSG'
,
'CSG2d'
]:
in_ch
=
2
self
.
head_pos_enc
=
build_module
(
self
.
head_pos_encoding
)
else
:
size_
=
4
if
self
.
no_pad
:
size_
+=
2
self
.
constant_input
=
ConstantInput
(
self
.
channels
[
4
],
size
=
size_
)
# 4x4 stage
self
.
conv1
=
ModulatedPEStyleConv
(
in_ch
,
self
.
channels
[
4
],
kernel_size
=
3
,
style_channels
=
style_channels
,
blur_kernel
=
blur_kernel
,
deconv2conv
=
self
.
deconv2conv
,
no_pad
=
self
.
no_pad
,
up_config
=
self
.
up_config
,
interp_pad
=
self
.
interp_pad
)
self
.
to_rgb1
=
ModulatedToRGB
(
self
.
channels
[
4
],
style_channels
,
upsample
=
False
)
# generator backbone (8x8 --> higher resolutions)
self
.
log_size
=
int
(
np
.
log2
(
self
.
out_size
))
self
.
convs
=
nn
.
ModuleList
()
self
.
upsamples
=
nn
.
ModuleList
()
self
.
to_rgbs
=
nn
.
ModuleList
()
in_channels_
=
self
.
channels
[
4
]
for
i
in
range
(
3
,
self
.
log_size
+
1
):
out_channels_
=
self
.
channels
[
2
**
i
]
self
.
convs
.
append
(
ModulatedPEStyleConv
(
in_channels_
,
out_channels_
,
3
,
style_channels
,
upsample
=
True
,
blur_kernel
=
blur_kernel
,
deconv2conv
=
self
.
deconv2conv
,
no_pad
=
self
.
no_pad
,
up_config
=
self
.
up_config
,
interp_pad
=
self
.
interp_pad
,
up_after_conv
=
self
.
up_after_conv
))
self
.
convs
.
append
(
ModulatedPEStyleConv
(
out_channels_
,
out_channels_
,
3
,
style_channels
,
upsample
=
False
,
blur_kernel
=
blur_kernel
,
deconv2conv
=
self
.
deconv2conv
,
no_pad
=
self
.
no_pad
,
up_config
=
self
.
up_config
,
interp_pad
=
self
.
interp_pad
,
up_after_conv
=
self
.
up_after_conv
))
self
.
to_rgbs
.
append
(
ModulatedToRGB
(
out_channels_
,
style_channels
,
upsample
=
True
))
in_channels_
=
out_channels_
self
.
num_latents
=
self
.
log_size
*
2
-
2
self
.
num_injected_noises
=
self
.
num_latents
-
1
# register buffer for injected noises
noises
=
self
.
make_injected_noise
()
for
layer_idx
in
range
(
self
.
num_injected_noises
):
self
.
register_buffer
(
f
'injected_noise_
{
layer_idx
}
'
,
noises
[
layer_idx
])
def
train
(
self
,
mode
=
True
):
if
mode
:
if
self
.
default_style_mode
!=
self
.
_default_style_mode
:
mmcv
.
print_log
(
f
'Switch to train style mode:
{
self
.
_default_style_mode
}
'
,
'mmgen'
)
self
.
default_style_mode
=
self
.
_default_style_mode
else
:
if
self
.
default_style_mode
!=
self
.
eval_style_mode
:
mmcv
.
print_log
(
f
'Switch to evaluation style mode:
{
self
.
eval_style_mode
}
'
,
'mmgen'
)
self
.
default_style_mode
=
self
.
eval_style_mode
return
super
(
MSStyleGANv2Generator
,
self
).
train
(
mode
)
def
make_injected_noise
(
self
,
chosen_scale
=
0
):
device
=
get_module_device
(
self
)
base_scale
=
2
**
2
+
chosen_scale
noises
=
[
torch
.
randn
(
1
,
1
,
base_scale
,
base_scale
,
device
=
device
)]
for
i
in
range
(
3
,
self
.
log_size
+
1
):
for
n
in
range
(
2
):
_pad
=
0
if
self
.
no_pad
and
not
self
.
up_after_conv
and
n
==
0
:
_pad
=
2
noises
.
append
(
torch
.
randn
(
1
,
1
,
base_scale
*
2
**
(
i
-
2
)
+
_pad
,
base_scale
*
2
**
(
i
-
2
)
+
_pad
,
device
=
device
))
return
noises
def
get_mean_latent
(
self
,
num_samples
=
4096
,
**
kwargs
):
"""Get mean latent of W space in this generator.
Args:
num_samples (int, optional): Number of sample times. Defaults
to 4096.
Returns:
Tensor: Mean latent of this generator.
"""
return
get_mean_latent
(
self
,
num_samples
,
**
kwargs
)
def
style_mixing
(
self
,
n_source
,
n_target
,
inject_index
=
1
,
truncation_latent
=
None
,
truncation
=
0.7
,
chosen_scale
=
0
):
return
style_mixing
(
self
,
n_source
=
n_source
,
n_target
=
n_target
,
inject_index
=
inject_index
,
truncation_latent
=
truncation_latent
,
truncation
=
truncation
,
style_channels
=
self
.
style_channels
,
chosen_scale
=
chosen_scale
)
def
forward
(
self
,
styles
,
num_batches
=-
1
,
return_noise
=
False
,
return_latents
=
False
,
inject_index
=
None
,
truncation
=
1
,
truncation_latent
=
None
,
input_is_latent
=
False
,
injected_noise
=
None
,
randomize_noise
=
True
,
chosen_scale
=
0
):
"""Forward function.
This function has been integrated with the truncation trick. Please
refer to the usage of `truncation` and `truncation_latent`.
Args:
styles (torch.Tensor | list[torch.Tensor] | callable | None): In
StyleGAN2, you can provide noise tensor or latent tensor. Given
a list containing more than one noise or latent tensors, style
mixing trick will be used in training. Of course, You can
directly give a batch of noise through a ``torch.Tensor`` or
offer a callable function to sample a batch of noise data.
Otherwise, the ``None`` indicates to use the default noise
sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
return_latents (bool, optional): If True, ``latent`` will be
returned in a dict with ``fake_img``. Defaults to False.
inject_index (int | None, optional): The index number for mixing
style codes. Defaults to None.
truncation (float, optional): Truncation factor. Give value less
than 1., the truncation trick will be adopted. Defaults to 1.
truncation_latent (torch.Tensor, optional): Mean truncation latent.
Defaults to None.
input_is_latent (bool, optional): If `True`, the input tensor is
the latent tensor. Defaults to False.
injected_noise (torch.Tensor | None, optional): Given a tensor, the
random noise will be fixed as this input injected noise.
Defaults to None.
randomize_noise (bool, optional): If `False`, images are sampled
with the buffered noise tensor injected to the style conv
block. Defaults to True.
Returns:
torch.Tensor | dict: Generated image tensor or dictionary
\
containing more data.
"""
# receive noise and conduct sanity check.
if
isinstance
(
styles
,
torch
.
Tensor
):
assert
styles
.
shape
[
1
]
==
self
.
style_channels
styles
=
[
styles
]
elif
mmcv
.
is_seq_of
(
styles
,
torch
.
Tensor
):
for
t
in
styles
:
assert
t
.
shape
[
-
1
]
==
self
.
style_channels
# receive a noise generator and sample noise.
elif
callable
(
styles
):
device
=
get_module_device
(
self
)
noise_generator
=
styles
assert
num_batches
>
0
if
self
.
default_style_mode
==
'mix'
and
random
.
random
(
)
<
self
.
mix_prob
:
styles
=
[
noise_generator
((
num_batches
,
self
.
style_channels
))
for
_
in
range
(
2
)
]
else
:
styles
=
[
noise_generator
((
num_batches
,
self
.
style_channels
))]
styles
=
[
s
.
to
(
device
)
for
s
in
styles
]
# otherwise, we will adopt default noise sampler.
else
:
device
=
get_module_device
(
self
)
assert
num_batches
>
0
and
not
input_is_latent
if
self
.
default_style_mode
==
'mix'
and
random
.
random
(
)
<
self
.
mix_prob
:
styles
=
[
torch
.
randn
((
num_batches
,
self
.
style_channels
))
for
_
in
range
(
2
)
]
else
:
styles
=
[
torch
.
randn
((
num_batches
,
self
.
style_channels
))]
styles
=
[
s
.
to
(
device
)
for
s
in
styles
]
if
not
input_is_latent
:
noise_batch
=
styles
styles
=
[
self
.
style_mapping
(
s
)
for
s
in
styles
]
else
:
noise_batch
=
None
if
injected_noise
is
None
:
if
randomize_noise
:
injected_noise
=
[
None
]
*
self
.
num_injected_noises
elif
chosen_scale
>
0
:
if
not
hasattr
(
self
,
f
'injected_noise_
{
chosen_scale
}
_0'
):
noises_
=
self
.
make_injected_noise
(
chosen_scale
)
for
i
in
range
(
self
.
num_injected_noises
):
setattr
(
self
,
f
'injected_noise_
{
chosen_scale
}
_
{
i
}
'
,
noises_
[
i
])
injected_noise
=
[
getattr
(
self
,
f
'injected_noise_
{
chosen_scale
}
_
{
i
}
'
)
for
i
in
range
(
self
.
num_injected_noises
)
]
else
:
injected_noise
=
[
getattr
(
self
,
f
'injected_noise_
{
i
}
'
)
for
i
in
range
(
self
.
num_injected_noises
)
]
# use truncation trick
if
truncation
<
1
:
style_t
=
[]
# calculate truncation latent on the fly
if
truncation_latent
is
None
and
not
hasattr
(
self
,
'truncation_latent'
):
self
.
truncation_latent
=
self
.
get_mean_latent
()
truncation_latent
=
self
.
truncation_latent
elif
truncation_latent
is
None
and
hasattr
(
self
,
'truncation_latent'
):
truncation_latent
=
self
.
truncation_latent
for
style
in
styles
:
style_t
.
append
(
truncation_latent
+
truncation
*
(
style
-
truncation_latent
))
styles
=
style_t
# no style mixing
if
len
(
styles
)
<
2
:
inject_index
=
self
.
num_latents
if
styles
[
0
].
ndim
<
3
:
latent
=
styles
[
0
].
unsqueeze
(
1
).
repeat
(
1
,
inject_index
,
1
)
else
:
latent
=
styles
[
0
]
# style mixing
else
:
if
inject_index
is
None
:
inject_index
=
random
.
randint
(
1
,
self
.
num_latents
-
1
)
latent
=
styles
[
0
].
unsqueeze
(
1
).
repeat
(
1
,
inject_index
,
1
)
latent2
=
styles
[
1
].
unsqueeze
(
1
).
repeat
(
1
,
self
.
num_latents
-
inject_index
,
1
)
latent
=
torch
.
cat
([
latent
,
latent2
],
1
)
if
isinstance
(
chosen_scale
,
int
):
chosen_scale
=
(
chosen_scale
,
chosen_scale
)
# 4x4 stage
if
self
.
head_pos_encoding
:
if
self
.
interp_head
:
out
=
self
.
head_pos_enc
.
make_grid2d
(
self
.
head_pos_size
[
0
],
self
.
head_pos_size
[
1
],
latent
.
size
(
0
))
h_in
=
self
.
head_pos_size
[
0
]
+
chosen_scale
[
0
]
w_in
=
self
.
head_pos_size
[
1
]
+
chosen_scale
[
1
]
out
=
F
.
interpolate
(
out
,
size
=
(
h_in
,
w_in
),
mode
=
'bilinear'
,
align_corners
=
True
)
else
:
out
=
self
.
head_pos_enc
.
make_grid2d
(
self
.
head_pos_size
[
0
]
+
chosen_scale
[
0
],
self
.
head_pos_size
[
1
]
+
chosen_scale
[
1
],
latent
.
size
(
0
))
out
=
out
.
to
(
latent
)
else
:
out
=
self
.
constant_input
(
latent
)
if
chosen_scale
[
0
]
!=
0
or
chosen_scale
[
1
]
!=
0
:
out
=
F
.
interpolate
(
out
,
size
=
(
out
.
shape
[
2
]
+
chosen_scale
[
0
],
out
.
shape
[
3
]
+
chosen_scale
[
1
]),
mode
=
'bilinear'
,
align_corners
=
True
)
out
=
self
.
conv1
(
out
,
latent
[:,
0
],
noise
=
injected_noise
[
0
])
skip
=
self
.
to_rgb1
(
out
,
latent
[:,
1
])
_index
=
1
# 8x8 ---> higher resolutions
for
up_conv
,
conv
,
noise1
,
noise2
,
to_rgb
in
zip
(
self
.
convs
[::
2
],
self
.
convs
[
1
::
2
],
injected_noise
[
1
::
2
],
injected_noise
[
2
::
2
],
self
.
to_rgbs
):
out
=
up_conv
(
out
,
latent
[:,
_index
],
noise
=
noise1
)
out
=
conv
(
out
,
latent
[:,
_index
+
1
],
noise
=
noise2
)
skip
=
to_rgb
(
out
,
latent
[:,
_index
+
2
],
skip
)
_index
+=
2
img
=
skip
if
return_latents
or
return_noise
:
output_dict
=
dict
(
fake_img
=
img
,
latent
=
latent
,
inject_index
=
inject_index
,
noise_batch
=
noise_batch
,
injected_noise
=
injected_noise
)
return
output_dict
return
img
@
MODULES
.
register_module
()
class
MSStyleGAN2Discriminator
(
nn
.
Module
):
"""StyleGAN2 Discriminator.
The architecture of this discriminator is proposed in StyleGAN2. More
details can be found in: Analyzing and Improving the Image Quality of
StyleGAN CVPR2020.
Args:
in_size (int): The input size of images.
channel_multiplier (int, optional): The multiplier factor for the
channel number. Defaults to 2.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 3, 3, 1].
mbstd_cfg (dict, optional): Configs for minibatch-stddev layer.
Defaults to dict(group_size=4, channel_groups=1).
"""
def
__init__
(
self
,
in_size
,
channel_multiplier
=
2
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
mbstd_cfg
=
dict
(
group_size
=
4
,
channel_groups
=
1
),
with_adaptive_pool
=
False
,
pool_size
=
(
2
,
2
)):
super
().
__init__
()
self
.
with_adaptive_pool
=
with_adaptive_pool
self
.
pool_size
=
pool_size
channels
=
{
4
:
512
,
8
:
512
,
16
:
512
,
32
:
512
,
64
:
256
*
channel_multiplier
,
128
:
128
*
channel_multiplier
,
256
:
64
*
channel_multiplier
,
512
:
32
*
channel_multiplier
,
1024
:
16
*
channel_multiplier
,
}
log_size
=
int
(
np
.
log2
(
in_size
))
in_channels
=
channels
[
in_size
]
convs
=
[
ConvDownLayer
(
3
,
channels
[
in_size
],
1
)]
for
i
in
range
(
log_size
,
2
,
-
1
):
out_channel
=
channels
[
2
**
(
i
-
1
)]
convs
.
append
(
ResBlock
(
in_channels
,
out_channel
,
blur_kernel
))
in_channels
=
out_channel
self
.
convs
=
nn
.
Sequential
(
*
convs
)
self
.
mbstd_layer
=
ModMBStddevLayer
(
**
mbstd_cfg
)
self
.
final_conv
=
ConvDownLayer
(
in_channels
+
1
,
channels
[
4
],
3
)
if
self
.
with_adaptive_pool
:
self
.
adaptive_pool
=
nn
.
AdaptiveAvgPool2d
(
pool_size
)
linear_in_channels
=
channels
[
4
]
*
pool_size
[
0
]
*
pool_size
[
1
]
else
:
linear_in_channels
=
channels
[
4
]
*
4
*
4
self
.
final_linear
=
nn
.
Sequential
(
EqualLinearActModule
(
linear_in_channels
,
channels
[
4
],
act_cfg
=
dict
(
type
=
'fused_bias'
)),
EqualLinearActModule
(
channels
[
4
],
1
),
)
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (torch.Tensor): Input image tensor.
Returns:
torch.Tensor: Predict score for the input image.
"""
x
=
self
.
convs
(
x
)
x
=
self
.
mbstd_layer
(
x
)
x
=
self
.
final_conv
(
x
)
if
self
.
with_adaptive_pool
:
x
=
self
.
adaptive_pool
(
x
)
x
=
x
.
view
(
x
.
shape
[
0
],
-
1
)
x
=
self
.
final_linear
(
x
)
return
x
build/lib/mmgen/models/architectures/stylegan/utils.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
..common
import
get_module_device
@
torch
.
no_grad
()
def
get_mean_latent
(
generator
,
num_samples
=
4096
,
bs_per_repeat
=
1024
):
"""Get mean latent of W space in Style-based GANs.
Args:
generator (nn.Module): Generator of a Style-based GAN.
num_samples (int, optional): Number of sample times. Defaults to 4096.
bs_per_repeat (int, optional): Batch size of noises per sample.
Defaults to 1024.
Returns:
Tensor: Mean latent of this generator.
"""
device
=
get_module_device
(
generator
)
mean_style
=
None
n_repeat
=
num_samples
//
bs_per_repeat
assert
n_repeat
*
bs_per_repeat
==
num_samples
for
_
in
range
(
n_repeat
):
style
=
generator
.
style_mapping
(
torch
.
randn
(
bs_per_repeat
,
generator
.
style_channels
).
to
(
device
)).
mean
(
0
,
keepdim
=
True
)
if
mean_style
is
None
:
mean_style
=
style
else
:
mean_style
+=
style
mean_style
/=
float
(
n_repeat
)
return
mean_style
@
torch
.
no_grad
()
def
style_mixing
(
generator
,
n_source
,
n_target
,
inject_index
=
1
,
truncation_latent
=
None
,
truncation
=
0.7
,
style_channels
=
512
,
**
kwargs
):
device
=
get_module_device
(
generator
)
source_code
=
torch
.
randn
(
n_source
,
style_channels
).
to
(
device
)
target_code
=
torch
.
randn
(
n_target
,
style_channels
).
to
(
device
)
source_image
=
generator
(
source_code
,
truncation_latent
=
truncation_latent
,
truncation
=
truncation
,
**
kwargs
)
h
,
w
=
source_image
.
shape
[
-
2
:]
images
=
[
torch
.
ones
(
1
,
3
,
h
,
w
).
to
(
device
)
*
-
1
]
target_image
=
generator
(
target_code
,
truncation_latent
=
truncation_latent
,
truncation
=
truncation
,
**
kwargs
)
images
.
append
(
source_image
)
for
i
in
range
(
n_target
):
image
=
generator
(
[
target_code
[
i
].
unsqueeze
(
0
).
repeat
(
n_source
,
1
),
source_code
],
truncation_latent
=
truncation_latent
,
truncation
=
truncation
,
inject_index
=
inject_index
,
**
kwargs
)
images
.
append
(
target_image
[
i
].
unsqueeze
(
0
))
images
.
append
(
image
)
images
=
torch
.
cat
(
images
,
0
)
return
images
build/lib/mmgen/models/architectures/wgan_gp/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.generator_discriminator
import
WGANGPDiscriminator
,
WGANGPGenerator
__all__
=
[
'WGANGPDiscriminator'
,
'WGANGPGenerator'
]
build/lib/mmgen/models/architectures/wgan_gp/generator_discriminator.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
ConvModule
from
mmcv.cnn.bricks.upsample
import
build_upsample_layer
from
mmgen.models.builder
import
MODULES
from
..common
import
get_module_device
from
.modules
import
ConvLNModule
,
WGANDecisionHead
,
WGANNoiseTo2DFeat
@
MODULES
.
register_module
()
class
WGANGPGenerator
(
nn
.
Module
):
r
"""Generator for WGANGP.
Implementation Details for WGANGP generator the same as training
configuration (a) described in PGGAN paper:
PROGRESSIVE GROWING OF GANS FOR IMPROVED QUALITY, STABILITY, AND VARIATION
https://research.nvidia.com/sites/default/files/pubs/2017-10_Progressive-Growing-of/karras2018iclr-paper.pdf # noqa
#. Adopt convolution architecture specified in appendix A.2;
#. Use batchnorm in the generator except for the final output layer;
#. Use ReLU in the generator except for the final output layer;
#. Use Tanh in the last layer;
#. Initialize all weights using He’s initializer.
Args:
noise_size (int): Size of the input noise vector.
out_scale (int): Output scale for the generated image.
conv_module_cfg (dict, optional): Config for the convolution
module used in this generator. Defaults to None.
upsample_cfg (dict, optional): Config for the upsampling operation.
Defaults to None.
"""
_default_channels_per_scale
=
{
'4'
:
512
,
'8'
:
512
,
'16'
:
256
,
'32'
:
128
,
'64'
:
64
,
'128'
:
32
}
_default_conv_module_cfg
=
dict
(
conv_cfg
=
None
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
,
act_cfg
=
dict
(
type
=
'ReLU'
),
norm_cfg
=
dict
(
type
=
'BN'
),
order
=
(
'conv'
,
'norm'
,
'act'
))
_default_upsample_cfg
=
dict
(
type
=
'nearest'
,
scale_factor
=
2
)
def
__init__
(
self
,
noise_size
,
out_scale
,
conv_module_cfg
=
None
,
upsample_cfg
=
None
):
super
().
__init__
()
# set initial params
self
.
noise_size
=
noise_size
self
.
out_scale
=
out_scale
self
.
conv_module_cfg
=
deepcopy
(
self
.
_default_conv_module_cfg
)
if
conv_module_cfg
is
not
None
:
self
.
conv_module_cfg
.
update
(
conv_module_cfg
)
self
.
upsample_cfg
=
upsample_cfg
if
upsample_cfg
else
deepcopy
(
self
.
_default_upsample_cfg
)
# set noise2feat head
self
.
noise2feat
=
WGANNoiseTo2DFeat
(
self
.
noise_size
,
self
.
_default_channels_per_scale
[
'4'
])
# set conv_blocks
self
.
conv_blocks
=
nn
.
ModuleList
()
self
.
conv_blocks
.
append
(
ConvModule
(
512
,
512
,
**
self
.
conv_module_cfg
))
log2scale
=
int
(
np
.
log2
(
self
.
out_scale
))
for
i
in
range
(
3
,
log2scale
+
1
):
self
.
conv_blocks
.
append
(
build_upsample_layer
(
self
.
_default_upsample_cfg
))
self
.
conv_blocks
.
append
(
ConvModule
(
self
.
_default_channels_per_scale
[
str
(
2
**
(
i
-
1
))],
self
.
_default_channels_per_scale
[
str
(
2
**
i
)],
**
self
.
conv_module_cfg
))
self
.
conv_blocks
.
append
(
ConvModule
(
self
.
_default_channels_per_scale
[
str
(
2
**
i
)],
self
.
_default_channels_per_scale
[
str
(
2
**
i
)],
**
self
.
conv_module_cfg
))
self
.
to_rgb
=
ConvModule
(
self
.
_default_channels_per_scale
[
str
(
self
.
out_scale
)],
kernel_size
=
1
,
out_channels
=
3
,
act_cfg
=
dict
(
type
=
'Tanh'
))
def
forward
(
self
,
noise
,
num_batches
=
0
,
return_noise
=
False
):
"""Forward function.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size. Defaults to
0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
Returns:
torch.Tensor | dict: If not ``return_noise``, only the output image
will be returned. Otherwise, a dict contains ``fake_img`` and
``noise_batch`` will be returned.
"""
# receive noise and conduct sanity check.
if
isinstance
(
noise
,
torch
.
Tensor
):
assert
noise
.
shape
[
1
]
==
self
.
noise_size
assert
noise
.
ndim
==
2
,
(
'The noise should be in shape of (n, c), '
f
'but got
{
noise
.
shape
}
'
)
noise_batch
=
noise
# receive a noise generator and sample noise.
elif
callable
(
noise
):
noise_generator
=
noise
assert
num_batches
>
0
noise_batch
=
noise_generator
((
num_batches
,
self
.
noise_size
))
# otherwise, we will adopt default noise sampler.
else
:
assert
num_batches
>
0
noise_batch
=
torch
.
randn
((
num_batches
,
self
.
noise_size
))
# dirty code for putting data on the right device
noise_batch
=
noise_batch
.
to
(
get_module_device
(
self
))
# noise vector to 2D feature
x
=
self
.
noise2feat
(
noise_batch
)
for
conv
in
self
.
conv_blocks
:
x
=
conv
(
x
)
out_img
=
self
.
to_rgb
(
x
)
if
return_noise
:
output
=
dict
(
fake_img
=
out_img
,
noise_batch
=
noise_batch
)
return
output
return
out_img
@
MODULES
.
register_module
()
class
WGANGPDiscriminator
(
nn
.
Module
):
r
"""Discriminator for WGANGP.
Implementation Details for WGANGP discriminator the same as training
configuration (a) described in PGGAN paper:
PROGRESSIVE GROWING OF GANS FOR IMPROVED QUALITY, STABILITY, AND VARIATION
https://research.nvidia.com/sites/default/files/pubs/2017-10_Progressive-Growing-of/karras2018iclr-paper.pdf # noqa
#. Adopt convolution architecture specified in appendix A.2;
#. Add layer normalization to all conv3x3 and conv4x4 layers;
#. Use LeakyReLU in the discriminator except for the final output layer;
#. Initialize all weights using He’s initializer.
Args:
in_channel (int): The channel number of the input image.
in_scale (int): The scale of the input image.
conv_module_cfg (dict, optional): Config for the convolution module
used in this discriminator. Defaults to None.
"""
_default_channels_per_scale
=
{
'4'
:
512
,
'8'
:
512
,
'16'
:
256
,
'32'
:
128
,
'64'
:
64
,
'128'
:
32
}
_default_conv_module_cfg
=
dict
(
conv_cfg
=
None
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
),
norm_cfg
=
dict
(
type
=
'LN2d'
),
order
=
(
'conv'
,
'norm'
,
'act'
))
_default_upsample_cfg
=
dict
(
type
=
'nearest'
,
scale_factor
=
2
)
def
__init__
(
self
,
in_channel
,
in_scale
,
conv_module_cfg
=
None
):
super
().
__init__
()
# set initial params
self
.
in_channel
=
in_channel
self
.
in_scale
=
in_scale
self
.
conv_module_cfg
=
deepcopy
(
self
.
_default_conv_module_cfg
)
if
conv_module_cfg
is
not
None
:
self
.
conv_module_cfg
.
update
(
conv_module_cfg
)
# set from_rgb head
self
.
from_rgb
=
ConvModule
(
3
,
kernel_size
=
1
,
out_channels
=
self
.
_default_channels_per_scale
[
str
(
self
.
in_scale
)],
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
))
# set conv_blocks
self
.
conv_blocks
=
nn
.
ModuleList
()
log2scale
=
int
(
np
.
log2
(
self
.
in_scale
))
for
i
in
range
(
log2scale
,
2
,
-
1
):
self
.
conv_blocks
.
append
(
ConvLNModule
(
self
.
_default_channels_per_scale
[
str
(
2
**
i
)],
self
.
_default_channels_per_scale
[
str
(
2
**
i
)],
feature_shape
=
(
self
.
_default_channels_per_scale
[
str
(
2
**
i
)],
2
**
i
,
2
**
i
),
**
self
.
conv_module_cfg
))
self
.
conv_blocks
.
append
(
ConvLNModule
(
self
.
_default_channels_per_scale
[
str
(
2
**
i
)],
self
.
_default_channels_per_scale
[
str
(
2
**
(
i
-
1
))],
feature_shape
=
(
self
.
_default_channels_per_scale
[
str
(
2
**
(
i
-
1
))],
2
**
i
,
2
**
i
),
**
self
.
conv_module_cfg
))
self
.
conv_blocks
.
append
(
nn
.
AvgPool2d
(
kernel_size
=
2
,
stride
=
2
))
self
.
decision
=
WGANDecisionHead
(
self
.
_default_channels_per_scale
[
'4'
],
self
.
_default_channels_per_scale
[
'4'
],
1
,
act_cfg
=
dict
(
type
=
'LeakyReLU'
,
negative_slope
=
0.2
),
norm_cfg
=
self
.
conv_module_cfg
[
'norm_cfg'
])
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (torch.Tensor): Fake or real image tensor.
Returns:
torch.Tensor: Prediction for the reality of the input image.
"""
# noise vector to 2D feature
x
=
self
.
from_rgb
(
x
)
for
conv
in
self
.
conv_blocks
:
x
=
conv
(
x
)
x
=
self
.
decision
(
x
)
return
x
build/lib/mmgen/models/architectures/wgan_gp/modules.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
(
PLUGIN_LAYERS
,
ConvModule
,
build_activation_layer
,
build_norm_layer
,
constant_init
)
from
mmgen.models.builder
import
MODULES
@
MODULES
.
register_module
()
class
WGANNoiseTo2DFeat
(
nn
.
Module
):
"""Module used in WGAN-GP to transform 1D noise tensor in order [N, C] to
2D shape feature tensor in order [N, C, H, W].
Args:
noise_size (int): Size of the input noise vector.
out_channels (int): The channel number of the output feature.
act_cfg (dict, optional): Config for the activation layer. Defaults to
dict(type='ReLU').
norm_cfg (dict, optional): Config dict to build norm layer. Defaults to
dict(type='BN').
order (tuple, optional): The order of conv/norm/activation layers. It
is a sequence of "conv", "norm" and "act". Common examples are
("conv", "norm", "act") and ("act", "conv", "norm"). Defaults to
('linear', 'act', 'norm').
"""
def
__init__
(
self
,
noise_size
,
out_channels
,
act_cfg
=
dict
(
type
=
'ReLU'
),
norm_cfg
=
dict
(
type
=
'BN'
),
order
=
(
'linear'
,
'act'
,
'norm'
)):
super
().
__init__
()
self
.
noise_size
=
noise_size
self
.
out_channels
=
out_channels
self
.
with_activation
=
act_cfg
is
not
None
self
.
with_norm
=
norm_cfg
is
not
None
self
.
order
=
order
assert
len
(
order
)
==
3
and
set
(
order
)
==
set
([
'linear'
,
'act'
,
'norm'
])
# w/o bias, because the bias is added after reshaping the tensor to
# 2D feature
self
.
linear
=
nn
.
Linear
(
noise_size
,
out_channels
*
16
,
bias
=
False
)
if
self
.
with_activation
:
self
.
activation
=
build_activation_layer
(
act_cfg
)
# add bias for reshaped 2D feature.
self
.
register_parameter
(
'bias'
,
nn
.
Parameter
(
torch
.
zeros
(
1
,
out_channels
,
1
,
1
)))
if
self
.
with_norm
:
_
,
self
.
norm
=
build_norm_layer
(
norm_cfg
,
out_channels
)
self
.
_init_weight
()
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input noise tensor with shape (n, c).
Returns:
Tensor: Forward results with shape (n, c, 4, 4).
"""
assert
x
.
ndim
==
2
for
order
in
self
.
order
:
if
order
==
'linear'
:
x
=
self
.
linear
(
x
)
# [n, c, 4, 4]
x
=
torch
.
reshape
(
x
,
(
-
1
,
self
.
out_channels
,
4
,
4
))
x
=
x
+
self
.
bias
elif
order
==
'act'
and
self
.
with_activation
:
x
=
self
.
activation
(
x
)
elif
order
==
'norm'
and
self
.
with_norm
:
x
=
self
.
norm
(
x
)
return
x
def
_init_weight
(
self
):
"""Initialize weights for the model."""
nn
.
init
.
normal_
(
self
.
linear
.
weight
,
0.
,
1.
)
if
self
.
bias
is
not
None
:
nn
.
init
.
constant_
(
self
.
bias
,
0.
)
if
self
.
with_norm
:
constant_init
(
self
.
norm
,
1
,
bias
=
0
)
class
WGANDecisionHead
(
nn
.
Module
):
"""Module used in WGAN-GP to get the final prediction result with 4x4
resolution input tensor in the bottom of the discriminator.
Args:
in_channels (int): Number of channels in input feature map.
mid_channels (int): Number of channels in feature map after
convolution.
out_channels (int): The channel number of the final output layer.
bias (bool, optional): Whether to use bias parameter. Defaults to True.
act_cfg (dict, optional): Config for the activation layer. Defaults to
dict(type='ReLU').
out_act (dict, optional): Config for the activation layer of output
layer. Defaults to None.
norm_cfg (dict, optional): Config dict to build norm layer. Defaults to
dict(type='LN2d').
"""
def
__init__
(
self
,
in_channels
,
mid_channels
,
out_channels
,
bias
=
True
,
act_cfg
=
dict
(
type
=
'ReLU'
),
out_act
=
None
,
norm_cfg
=
dict
(
type
=
'LN2d'
)):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
mid_channels
=
mid_channels
self
.
out_channels
=
out_channels
self
.
with_out_activation
=
out_act
is
not
None
# setup conv layer
self
.
conv
=
ConvLNModule
(
in_channels
,
feature_shape
=
(
mid_channels
,
1
,
1
),
kernel_size
=
4
,
out_channels
=
mid_channels
,
act_cfg
=
act_cfg
,
norm_cfg
=
norm_cfg
,
order
=
(
'conv'
,
'norm'
,
'act'
))
# setup linear layer
self
.
linear
=
nn
.
Linear
(
self
.
mid_channels
,
self
.
out_channels
,
bias
=
bias
)
if
self
.
with_out_activation
:
self
.
out_activation
=
build_activation_layer
(
out_act
)
self
.
_init_weight
()
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
x
=
self
.
conv
(
x
)
x
=
torch
.
reshape
(
x
,
(
x
.
shape
[
0
],
-
1
))
x
=
self
.
linear
(
x
)
if
self
.
with_out_activation
:
x
=
self
.
out_activation
(
x
)
return
x
def
_init_weight
(
self
):
"""Initialize weights for the model."""
nn
.
init
.
normal_
(
self
.
linear
.
weight
,
0.
,
1.
)
nn
.
init
.
constant_
(
self
.
linear
.
bias
,
0.
)
@
PLUGIN_LAYERS
.
register_module
()
class
ConvLNModule
(
ConvModule
):
r
"""ConvModule with Layer Normalization.
In this module, we inherit default ``mmcv.cnn.ConvModule`` and deal with
the situation that 'norm_cfg' is 'LN2d' or 'GN'. We adopt 'GN' as a
replacement for layer normalization referring to:
https://github.com/LynnHo/DCGAN-LSGAN-WGAN-GP-DRAGAN-Pytorch/blob/master/module.py # noqa
Args:
feature_shape (tuple): The shape of feature map that will be.
"""
def
__init__
(
self
,
*
args
,
feature_shape
=
None
,
**
kwargs
):
if
'norm_cfg'
in
kwargs
and
kwargs
[
'norm_cfg'
]
is
not
None
and
kwargs
[
'norm_cfg'
][
'type'
]
in
[
'LN2d'
,
'GN'
]:
nkwargs
=
deepcopy
(
kwargs
)
nkwargs
[
'norm_cfg'
]
=
None
super
().
__init__
(
*
args
,
**
nkwargs
)
self
.
with_norm
=
True
self
.
norm_name
=
kwargs
[
'norm_cfg'
][
'type'
]
if
self
.
norm_name
==
'LN2d'
:
norm
=
nn
.
LayerNorm
(
feature_shape
)
self
.
add_module
(
self
.
norm_name
,
norm
)
else
:
norm
=
nn
.
GroupNorm
(
1
,
feature_shape
[
0
])
self
.
add_module
(
self
.
norm_name
,
norm
)
else
:
super
().
__init__
(
*
args
,
**
kwargs
)
build/lib/mmgen/models/builder.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
torch.nn
as
nn
from
mmcv.utils
import
Registry
,
build_from_cfg
MODELS
=
Registry
(
'model'
)
MODULES
=
Registry
(
'module'
)
def
build
(
cfg
,
registry
,
default_args
=
None
):
"""Build a module.
Args:
cfg (dict, list[dict]): The config of modules, is is either a dict
or a list of configs.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if
isinstance
(
cfg
,
list
):
modules
=
[
build_from_cfg
(
cfg_
,
registry
,
default_args
)
for
cfg_
in
cfg
]
return
nn
.
ModuleList
(
modules
)
return
build_from_cfg
(
cfg
,
registry
,
default_args
)
def
build_model
(
cfg
,
train_cfg
=
None
,
test_cfg
=
None
):
"""Build model (GAN)."""
return
build
(
cfg
,
MODELS
,
dict
(
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
))
def
build_module
(
cfg
,
default_args
=
None
):
"""Build a module or modules from a list."""
return
build
(
cfg
,
MODULES
,
default_args
)
build/lib/mmgen/models/common/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.dist_utils
import
AllGatherLayer
from
.model_utils
import
GANImageBuffer
,
set_requires_grad
__all__
=
[
'set_requires_grad'
,
'AllGatherLayer'
,
'GANImageBuffer'
]
build/lib/mmgen/models/common/dist_utils.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.autograd
as
autograd
import
torch.distributed
as
dist
class
AllGatherLayer
(
autograd
.
Function
):
"""All gather layer with backward propagation path.
Indeed, this module is to make ``dist.all_gather()`` in the backward graph.
Such kind of operation has been widely used in Moco and other contrastive
learning algorithms.
"""
@
staticmethod
def
forward
(
ctx
,
x
):
"""Forward function."""
ctx
.
save_for_backward
(
x
)
output
=
[
torch
.
zeros_like
(
x
)
for
_
in
range
(
dist
.
get_world_size
())]
dist
.
all_gather
(
output
,
x
)
return
tuple
(
output
)
@
staticmethod
def
backward
(
ctx
,
*
grad_outputs
):
"""Backward function."""
x
,
=
ctx
.
saved_tensors
grad_out
=
torch
.
zeros_like
(
x
)
grad_out
=
grad_outputs
[
dist
.
get_rank
()]
return
grad_out
build/lib/mmgen/models/common/model_utils.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
torch
def
set_requires_grad
(
nets
,
requires_grad
=
False
):
"""Set requires_grad for all the networks.
Args:
nets (nn.Module | list[nn.Module]): A list of networks or a single
network.
requires_grad (bool): Whether the networks require gradients or not
"""
if
not
isinstance
(
nets
,
list
):
nets
=
[
nets
]
for
net
in
nets
:
if
net
is
not
None
:
for
param
in
net
.
parameters
():
param
.
requires_grad
=
requires_grad
class
GANImageBuffer
:
"""This class implements an image buffer that stores previously generated
images.
This buffer allows us to update the discriminator using a history of
generated images rather than the ones produced by the latest generator
to reduce model oscillation.
Args:
buffer_size (int): The size of image buffer. If buffer_size = 0,
no buffer will be created.
buffer_ratio (float): The chance / possibility to use the images
previously stored in the buffer.
"""
def
__init__
(
self
,
buffer_size
,
buffer_ratio
=
0.5
):
self
.
buffer_size
=
buffer_size
# create an empty buffer
if
self
.
buffer_size
>
0
:
self
.
img_num
=
0
self
.
image_buffer
=
[]
self
.
buffer_ratio
=
buffer_ratio
def
query
(
self
,
images
):
"""Query current image batch using a history of generated images.
Args:
images (Tensor): Current image batch without history information.
"""
if
self
.
buffer_size
==
0
:
# if the buffer size is 0, do nothing
return
images
return_images
=
[]
for
image
in
images
:
image
=
torch
.
unsqueeze
(
image
.
data
,
0
)
# if the buffer is not full, keep inserting current images
if
self
.
img_num
<
self
.
buffer_size
:
self
.
img_num
=
self
.
img_num
+
1
self
.
image_buffer
.
append
(
image
)
return_images
.
append
(
image
)
else
:
use_buffer
=
np
.
random
.
random
()
<
self
.
buffer_ratio
# by self.buffer_ratio, the buffer will return a previously
# stored image, and insert the current image into the buffer
if
use_buffer
:
random_id
=
np
.
random
.
randint
(
0
,
self
.
buffer_size
)
image_tmp
=
self
.
image_buffer
[
random_id
].
clone
()
self
.
image_buffer
[
random_id
]
=
image
return_images
.
append
(
image_tmp
)
# by (1 - self.buffer_ratio), the buffer will return the
# current image
else
:
return_images
.
append
(
image
)
# collect all the images and return
return_images
=
torch
.
cat
(
return_images
,
0
)
return
return_images
build/lib/mmgen/models/diffusions/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.base_diffusion
import
BasicGaussianDiffusion
from
.sampler
import
UniformTimeStepSampler
__all__
=
[
'BasicGaussianDiffusion'
,
'UniformTimeStepSampler'
]
build/lib/mmgen/models/diffusions/base_diffusion.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
sys
from
abc
import
ABCMeta
from
collections
import
OrderedDict
,
defaultdict
from
copy
import
deepcopy
from
functools
import
partial
import
mmcv
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
torch.nn.parallel.distributed
import
_find_tensors
from
..architectures.common
import
get_module_device
from
..builder
import
MODELS
,
build_module
from
.utils
import
_get_label_batch
,
_get_noise_batch
,
var_to_tensor
@
MODELS
.
register_module
()
class
BasicGaussianDiffusion
(
nn
.
Module
,
metaclass
=
ABCMeta
):
"""Basic module for gaussian Diffusion Denoising Probabilistic Models. A
diffusion probabilistic model (which we will call a 'diffusion model' for
brevity) is a parameterized Markov chain trained using variational
inference to produce samples matching the data after finite time.
The design of this module implements DDPM and improve-DDPM according to
"Denoising Diffusion Probabilistic Models" (2020) and "Improved Denoising
Diffusion Probabilistic Models" (2021).
Args:
denoising (dict): Config for denoising model.
ddpm_loss (dict): Config for losses of DDPM.
betas_cfg (dict): Config for betas in diffusion process.
num_timesteps (int, optional): The number of timesteps of the diffusion
process. Defaults to 1000.
num_classes (int | None, optional): The number of conditional classes.
Defaults to None.
sample_method (string, optional): Sample method for the denoising
process. Support 'DDPM' and 'DDIM'. Defaults to 'DDPM'.
timesteps_sampler (string, optional): How to sample timesteps in
training process. Defaults to `UniformTimeStepSampler`.
train_cfg (dict | None, optional): Config for training schedule.
Defaults to None.
test_cfg (dict | None, optional): Config for testing schedule. Defaults
to None.
"""
def
__init__
(
self
,
denoising
,
ddpm_loss
,
betas_cfg
,
num_timesteps
=
1000
,
num_classes
=
0
,
sample_method
=
'DDPM'
,
timestep_sampler
=
'UniformTimeStepSampler'
,
train_cfg
=
None
,
test_cfg
=
None
):
super
().
__init__
()
self
.
fp16_enable
=
False
# build denoising module in this function
self
.
num_classes
=
num_classes
self
.
num_timesteps
=
num_timesteps
self
.
sample_method
=
sample_method
self
.
_denoising_cfg
=
deepcopy
(
denoising
)
self
.
denoising
=
build_module
(
denoising
,
default_args
=
dict
(
num_classes
=
num_classes
,
num_timesteps
=
num_timesteps
))
# get output-related configs from denoising
self
.
denoising_var_mode
=
self
.
denoising
.
var_mode
self
.
denoising_mean_mode
=
self
.
denoising
.
mean_mode
# output_channels in denoising may be double, therefore we
# get number of channels from config
image_channels
=
self
.
_denoising_cfg
[
'in_channels'
]
# image_size should be the attribute of denoising network
image_size
=
self
.
denoising
.
image_size
image_shape
=
torch
.
Size
([
image_channels
,
image_size
,
image_size
])
self
.
image_shape
=
image_shape
self
.
get_noise
=
partial
(
_get_noise_batch
,
image_shape
=
image_shape
,
num_timesteps
=
self
.
num_timesteps
)
self
.
get_label
=
partial
(
_get_label_batch
,
num_timesteps
=
self
.
num_timesteps
)
# build sampler
if
timestep_sampler
is
not
None
:
self
.
sampler
=
build_module
(
timestep_sampler
,
default_args
=
dict
(
num_timesteps
=
num_timesteps
))
else
:
self
.
sampler
=
None
# build losses
if
ddpm_loss
is
not
None
:
self
.
ddpm_loss
=
build_module
(
ddpm_loss
,
default_args
=
dict
(
sampler
=
self
.
sampler
))
if
not
isinstance
(
self
.
ddpm_loss
,
nn
.
ModuleList
):
self
.
ddpm_loss
=
nn
.
ModuleList
([
self
.
ddpm_loss
])
else
:
self
.
ddpm_loss
=
None
self
.
betas_cfg
=
deepcopy
(
betas_cfg
)
self
.
train_cfg
=
deepcopy
(
train_cfg
)
if
train_cfg
else
None
self
.
test_cfg
=
deepcopy
(
test_cfg
)
if
test_cfg
else
None
self
.
_parse_train_cfg
()
if
test_cfg
is
not
None
:
self
.
_parse_test_cfg
()
self
.
prepare_diffusion_vars
()
def
_parse_train_cfg
(
self
):
"""Parsing train config and set some attributes for training."""
if
self
.
train_cfg
is
None
:
self
.
train_cfg
=
dict
()
self
.
use_ema
=
self
.
train_cfg
.
get
(
'use_ema'
,
False
)
if
self
.
use_ema
:
self
.
denoising_ema
=
deepcopy
(
self
.
denoising
)
self
.
real_img_key
=
self
.
train_cfg
.
get
(
'real_img_key'
,
'real_img'
)
def
_parse_test_cfg
(
self
):
"""Parsing test config and set some attributes for testing."""
if
self
.
test_cfg
is
None
:
self
.
test_cfg
=
dict
()
# whether to use exponential moving average for testing
self
.
use_ema
=
self
.
test_cfg
.
get
(
'use_ema'
,
False
)
if
self
.
use_ema
:
self
.
denoising_ema
=
deepcopy
(
self
.
denoising
)
def
_get_loss
(
self
,
outputs_dict
):
losses_dict
=
{}
# forward losses
for
loss_fn
in
self
.
ddpm_loss
:
losses_dict
[
loss_fn
.
loss_name
()]
=
loss_fn
(
outputs_dict
)
loss
,
log_vars
=
self
.
_parse_losses
(
losses_dict
)
# update collected log_var from loss_fn
for
loss_fn
in
self
.
ddpm_loss
:
if
hasattr
(
loss_fn
,
'log_vars'
):
log_vars
.
update
(
loss_fn
.
log_vars
)
return
loss
,
log_vars
def
_parse_losses
(
self
,
losses
):
"""Parse the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary information.
Returns:
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
\
which may be a weighted sum of all losses, log_vars contains
\
all the variables to be sent to the logger.
"""
log_vars
=
OrderedDict
()
for
loss_name
,
loss_value
in
losses
.
items
():
if
isinstance
(
loss_value
,
torch
.
Tensor
):
log_vars
[
loss_name
]
=
loss_value
.
mean
()
elif
isinstance
(
loss_value
,
list
):
log_vars
[
loss_name
]
=
sum
(
_loss
.
mean
()
for
_loss
in
loss_value
)
else
:
raise
TypeError
(
f
'
{
loss_name
}
is not a tensor or list of tensor'
)
loss
=
sum
(
_value
for
_key
,
_value
in
log_vars
.
items
()
if
'loss'
in
_key
)
log_vars
[
'loss'
]
=
loss
for
loss_name
,
loss_value
in
log_vars
.
items
():
if
dist
.
is_available
()
and
dist
.
is_initialized
():
loss_value
=
loss_value
.
data
.
clone
()
dist
.
all_reduce
(
loss_value
.
div_
(
dist
.
get_world_size
()))
log_vars
[
loss_name
]
=
loss_value
.
item
()
return
loss
,
log_vars
def
train_step
(
self
,
data
,
optimizer
,
ddp_reducer
=
None
,
loss_scaler
=
None
,
use_apex_amp
=
False
,
running_status
=
None
):
"""The iteration step during training.
This method defines an iteration step during training. Different from
other repo in **MM** series, we allow the back propagation and
optimizer updating to directly follow the iterative training schedule
of DDPMs.
Of course, we will show that you can also move the back
propagation outside of this method, and then optimize the parameters
in the optimizer hook. But this will cause extra GPU memory cost as a
result of retaining computational graph. Otherwise, the training
schedule should be modified in the detailed implementation.
Args:
optimizer (dict): Dict contains optimizer for denoising network.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
"""
# get running status
if
running_status
is
not
None
:
curr_iter
=
running_status
[
'iteration'
]
else
:
# dirty walkround for not providing running status
if
not
hasattr
(
self
,
'iteration'
):
self
.
iteration
=
0
curr_iter
=
self
.
iteration
real_imgs
=
data
[
self
.
real_img_key
]
# denoising training
optimizer
[
'denoising'
].
zero_grad
()
denoising_dict_
=
self
.
reconstruction_step
(
data
,
timesteps
=
self
.
sampler
,
sample_model
=
'orig'
,
return_noise
=
True
)
denoising_dict_
[
'iteration'
]
=
curr_iter
denoising_dict_
[
'real_imgs'
]
=
real_imgs
denoising_dict_
[
'loss_scaler'
]
=
loss_scaler
loss
,
log_vars
=
self
.
_get_loss
(
denoising_dict_
)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if
ddp_reducer
is
not
None
:
ddp_reducer
.
prepare_for_backward
(
_find_tensors
(
loss
))
if
loss_scaler
:
# add support for fp16
loss_scaler
.
scale
(
loss
).
backward
()
elif
use_apex_amp
:
from
apex
import
amp
with
amp
.
scale_loss
(
loss
,
optimizer
[
'denoising'
],
loss_id
=
0
)
as
scaled_loss_disc
:
scaled_loss_disc
.
backward
()
else
:
loss
.
backward
()
if
loss_scaler
:
loss_scaler
.
unscale_
(
optimizer
[
'denoising'
])
# note that we do not contain clip_grad procedure
loss_scaler
.
step
(
optimizer
[
'denoising'
])
# loss_scaler.update will be called in runner.train()
else
:
optimizer
[
'denoising'
].
step
()
# image used for vislization
results
=
dict
(
real_imgs
=
real_imgs
,
x_0_pred
=
denoising_dict_
[
'x_0_pred'
],
x_t
=
denoising_dict_
[
'diffusion_batches'
],
x_t_1
=
denoising_dict_
[
'fake_img'
])
outputs
=
dict
(
log_vars
=
log_vars
,
num_samples
=
real_imgs
.
shape
[
0
],
results
=
results
)
if
hasattr
(
self
,
'iteration'
):
self
.
iteration
+=
1
return
outputs
def
reconstruction_step
(
self
,
data_batch
,
noise
=
None
,
label
=
None
,
timesteps
=
None
,
sample_model
=
'orig'
,
return_noise
=
False
,
**
kwargs
):
"""Reconstruction step at corresponding `timestep`. To be noted that,
denoisint target ``x_t`` for each timestep are all generated from real
images, but not the denoising result from denoising network.
``sample_from_noise`` focus on generate samples start from **random
(or given) noise**. Therefore, we design this function to realize a
reconstruction process for the given images.
If `timestep` is None, automatically perform reconstruction at all
timesteps.
Args:
data_batch (dict): Input data from dataloader.
noise (torch.Tensor | callable | None): Noise used in diffusion
process. You can directly give a batch of noise through a
``torch.Tensor`` or offer a callable function to sample a
batch of noise data. Otherwise, the ``None`` indicates to use
the default noise sampler. Defaults to None.
label (torch.Tensor | None , optional): The conditional label of
the input image. Defaults to None.
timestep (int | list | torch.Tensor | callable | None): Target
timestep to perform reconstruction.
sampel_model (str, optional): Use which model to sample fake
images. Defaults to `'orig'`.
return_noise (bool, optional): If True,``noise_batch``, ``label``
and all other intermedia variables will be returned together
with ``fake_img`` in a dict. Defaults to False.
Returns:
torch.Tensor | dict: The output may be the direct synthesized
images in ``torch.Tensor``. Otherwise, a dict with required
data , including generated images, will be returned.
"""
assert
sample_model
in
[
'orig'
,
'ema'
],
(
'We only support
\'
orig
\'
and
\'
ema
\'
for '
f
'
\'
reconstruction_step
\'
, but receive
\'
{
sample_model
}
\'
.'
)
denoising_model
=
self
.
denoising
if
sample_model
==
'orig'
\
else
self
.
denoising_ema
# 0. prepare for timestep, noise and label
device
=
get_module_device
(
self
)
real_imgs
=
data_batch
[
self
.
real_img_key
]
num_batches
=
real_imgs
.
shape
[
0
]
if
timesteps
is
None
:
# default to performing the whole reconstruction process
timesteps
=
torch
.
LongTensor
([
t
for
t
in
range
(
self
.
num_timesteps
)
]).
view
(
self
.
num_timesteps
,
1
)
timesteps
=
timesteps
.
repeat
([
1
,
num_batches
])
if
isinstance
(
timesteps
,
(
int
,
list
)):
timesteps
=
torch
.
LongTensor
(
timesteps
)
elif
callable
(
timesteps
):
timestep_generator
=
timesteps
timesteps
=
timestep_generator
(
num_batches
)
else
:
assert
isinstance
(
timesteps
,
torch
.
Tensor
),
(
'we only support int list tensor or a callable function'
)
if
timesteps
.
ndim
==
1
:
timesteps
=
timesteps
.
unsqueeze
(
0
)
timesteps
=
timesteps
.
to
(
get_module_device
(
self
))
if
noise
is
not
None
:
assert
'noise'
not
in
data_batch
,
(
'Receive
\'
noise
\'
in both data_batch and passed arguments.'
)
if
noise
is
None
:
noise
=
data_batch
[
'noise'
]
if
'noise'
in
data_batch
else
None
if
self
.
num_classes
>
0
:
if
label
is
not
None
:
assert
'label'
not
in
data_batch
,
(
'Receive
\'
label
\'
in both data_batch '
'and passed arguments.'
)
if
label
is
None
:
label
=
data_batch
[
'label'
]
if
'label'
in
data_batch
else
None
label_batches
=
self
.
get_label
(
label
,
num_batches
=
num_batches
).
to
(
device
)
else
:
label_batches
=
None
output_dict
=
defaultdict
(
list
)
# loop all timesteps
for
timestep
in
timesteps
:
# 1. get diffusion results and parameters
noise_batches
=
self
.
get_noise
(
noise
,
num_batches
=
num_batches
).
to
(
device
)
diffusion_batches
=
self
.
q_sample
(
real_imgs
,
timestep
,
noise_batches
)
# 2. get denoising results.
denoising_batches
=
self
.
denoising_step
(
denoising_model
,
diffusion_batches
,
timestep
,
label
=
label_batches
,
return_noise
=
return_noise
,
clip_denoised
=
not
self
.
training
)
# 3. get ground truth by q_posterior
target_batches
=
self
.
q_posterior_mean_variance
(
real_imgs
,
diffusion_batches
,
timestep
,
logvar
=
True
)
if
return_noise
:
output_dict_
=
dict
(
timesteps
=
timestep
,
noise
=
noise_batches
,
diffusion_batches
=
diffusion_batches
)
if
self
.
num_classes
>
0
:
output_dict_
[
'label'
]
=
label_batches
output_dict_
.
update
(
denoising_batches
)
output_dict_
.
update
(
target_batches
)
else
:
output_dict_
=
dict
(
fake_img
=
denoising_batches
)
# update output of `timestep` to output_dict
for
k
,
v
in
output_dict_
.
items
():
if
k
in
output_dict
:
output_dict
[
k
].
append
(
v
)
else
:
output_dict
[
k
]
=
[
v
]
# 4. concentrate list to tensor
for
k
,
v
in
output_dict
.
items
():
output_dict
[
k
]
=
torch
.
cat
(
v
,
dim
=
0
)
# 5. return results
if
return_noise
:
return
output_dict
return
output_dict
[
'fake_img'
]
def
sample_from_noise
(
self
,
noise
,
num_batches
=
0
,
sample_model
=
'ema/orig'
,
label
=
None
,
**
kwargs
):
"""Sample images from noises by using Denoising model.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
sample_model (str, optional): The model to sample. If ``ema/orig``
is passed, this method will try to sample from ema (if
``self.use_ema == True``) and orig model. Defaults to
'ema/orig'.
label (torch.Tensor | None , optional): The conditional label.
Defaults to None.
Returns:
torch.Tensor | dict: The output may be the direct synthesized
images in ``torch.Tensor``. Otherwise, a dict with queried
data, including generated images, will be returned.
"""
# get sample function by name
sample_fn_name
=
f
'
{
self
.
sample_method
.
upper
()
}
_sample'
if
not
hasattr
(
self
,
sample_fn_name
):
raise
AttributeError
(
f
'Cannot find sample method [
{
sample_fn_name
}
] correspond '
f
'to [
{
self
.
sample_method
}
].'
)
sample_fn
=
getattr
(
self
,
sample_fn_name
)
if
sample_model
==
'ema'
:
assert
self
.
use_ema
_model
=
self
.
denoising_ema
elif
sample_model
==
'ema/orig'
and
self
.
use_ema
:
_model
=
self
.
denoising_ema
else
:
_model
=
self
.
denoising
outputs
=
sample_fn
(
_model
,
noise
=
noise
,
num_batches
=
num_batches
,
label
=
label
,
**
kwargs
)
if
isinstance
(
outputs
,
dict
)
and
'noise_batch'
in
outputs
:
# return_noise is True
noise
=
outputs
[
'x_t'
]
label
=
outputs
[
'label'
]
kwargs
[
'timesteps_noise'
]
=
outputs
[
'noise_batch'
]
fake_img
=
outputs
[
'fake_img'
]
else
:
fake_img
=
outputs
if
sample_model
==
'ema/orig'
and
self
.
use_ema
:
_model
=
self
.
denoising
outputs_
=
sample_fn
(
_model
,
noise
=
noise
,
num_batches
=
num_batches
,
**
kwargs
)
if
isinstance
(
outputs_
,
dict
)
and
'noise_batch'
in
outputs_
:
# return_noise is True
fake_img_
=
outputs_
[
'fake_img'
]
else
:
fake_img_
=
outputs_
if
isinstance
(
fake_img
,
dict
):
# save_intermedia is True
fake_img
=
{
k
:
torch
.
cat
([
fake_img
[
k
],
fake_img_
[
k
]],
dim
=
0
)
for
k
in
fake_img
.
keys
()
}
else
:
fake_img
=
torch
.
cat
([
fake_img
,
fake_img_
],
dim
=
0
)
return
fake_img
@
torch
.
no_grad
()
def
DDPM_sample
(
self
,
model
,
noise
=
None
,
num_batches
=
0
,
label
=
None
,
save_intermedia
=
False
,
timesteps_noise
=
None
,
return_noise
=
False
,
show_pbar
=
False
,
**
kwargs
):
"""DDPM sample from random noise.
Args:
model (torch.nn.Module): Denoising model used to sample images.
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
label (torch.Tensor | None , optional): The conditional label.
Defaults to None.
save_intermedia (bool, optional): Whether to save denoising result
of intermedia timesteps. If set as True, will return a dict
which key and value are denoising timestep and denoising
result. Otherwise, only the final denoising result will be
returned. Defaults to False.
timesteps_noise (torch.Tensor, optional): Noise term used in each
denoising timestep. If given, the input noise will be shaped to
[num_timesteps, b, c, h, w]. If set as None, noise of each
denoising timestep will be randomly sampled. Default as None.
return_noise (bool, optional): If True, a dict contains
``noise_batch``, ``x_t`` and ``label`` will be returned
together with the denoising results, and the key of denoising
results is ``fake_img``. To be noted that ``noise_batches``
will shape as [num_timesteps, b, c, h, w]. Defaults to False.
show_pbar (bool, optional): If True, a progress bar will be
displayed. Defaults to False.
Returns:
torch.Tensor | dict: If ``save_intermedia``, a dict contains
denoising results of each timestep will be returned.
Otherwise, only the final denoising result will be returned.
"""
device
=
get_module_device
(
self
)
noise
=
self
.
get_noise
(
noise
,
num_batches
=
num_batches
).
to
(
device
)
x_t
=
noise
.
clone
()
if
save_intermedia
:
# save input
intermedia
=
{
self
.
num_timesteps
:
x_t
.
clone
()}
# use timesteps noise if defined
if
timesteps_noise
is
not
None
:
timesteps_noise
=
self
.
get_noise
(
timesteps_noise
,
num_batches
=
num_batches
,
timesteps_noise
=
True
).
to
(
device
)
batched_timesteps
=
torch
.
arange
(
self
.
num_timesteps
-
1
,
-
1
,
-
1
).
long
().
to
(
device
)
if
show_pbar
:
pbar
=
mmcv
.
ProgressBar
(
self
.
num_timesteps
)
for
t
in
batched_timesteps
:
batched_t
=
t
.
expand
(
x_t
.
shape
[
0
])
step_noise
=
timesteps_noise
[
t
,
...]
\
if
timesteps_noise
is
not
None
else
None
x_t
=
self
.
denoising_step
(
model
,
x_t
,
batched_t
,
noise
=
step_noise
,
label
=
label
,
**
kwargs
)
if
save_intermedia
:
intermedia
[
int
(
t
)]
=
x_t
.
cpu
().
clone
()
if
show_pbar
:
pbar
.
update
()
denoising_results
=
intermedia
if
save_intermedia
else
x_t
if
show_pbar
:
sys
.
stdout
.
write
(
'
\n
'
)
if
return_noise
:
return
dict
(
noise_batch
=
timesteps_noise
,
x_t
=
noise
,
label
=
label
,
fake_img
=
denoising_results
)
return
denoising_results
def
prepare_diffusion_vars
(
self
):
"""Prepare for variables used in the diffusion process."""
self
.
betas
=
self
.
get_betas
()
self
.
alphas
=
1.0
-
self
.
betas
self
.
alphas_bar
=
np
.
cumproduct
(
self
.
alphas
,
axis
=
0
)
self
.
alphas_bar_prev
=
np
.
append
(
1.0
,
self
.
alphas_bar
[:
-
1
])
self
.
alphas_bar_next
=
np
.
append
(
self
.
alphas_bar
[
1
:],
0.0
)
# calculations for diffusion q(x_t | x_0) and others
self
.
sqrt_alphas_bar
=
np
.
sqrt
(
self
.
alphas_bar
)
self
.
sqrt_one_minus_alphas_bar
=
np
.
sqrt
(
1.0
-
self
.
alphas_bar
)
self
.
log_one_minus_alphas_bar
=
np
.
log
(
1.0
-
self
.
alphas_bar
)
self
.
sqrt_recip_alplas_bar
=
np
.
sqrt
(
1.0
/
self
.
alphas_bar
)
self
.
sqrt_recipm1_alphas_bar
=
np
.
sqrt
(
1.0
/
self
.
alphas_bar
-
1
)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self
.
tilde_betas_t
=
self
.
betas
*
(
1
-
self
.
alphas_bar_prev
)
/
(
1
-
self
.
alphas_bar
)
# clip log var for tilde_betas_0 = 0
self
.
log_tilde_betas_t_clipped
=
np
.
log
(
np
.
append
(
self
.
tilde_betas_t
[
1
],
self
.
tilde_betas_t
[
1
:]))
self
.
tilde_mu_t_coef1
=
np
.
sqrt
(
self
.
alphas_bar_prev
)
/
(
1
-
self
.
alphas_bar
)
*
self
.
betas
self
.
tilde_mu_t_coef2
=
np
.
sqrt
(
self
.
alphas
)
*
(
1
-
self
.
alphas_bar_prev
)
/
(
1
-
self
.
alphas_bar
)
def
get_betas
(
self
):
"""Get betas by defined schedule method in diffusion process."""
self
.
betas_schedule
=
self
.
betas_cfg
.
pop
(
'type'
)
if
self
.
betas_schedule
==
'linear'
:
return
self
.
linear_beta_schedule
(
self
.
num_timesteps
,
**
self
.
betas_cfg
)
elif
self
.
betas_schedule
==
'cosine'
:
return
self
.
cosine_beta_schedule
(
self
.
num_timesteps
,
**
self
.
betas_cfg
)
else
:
raise
AttributeError
(
f
'Unknown method name
{
self
.
beta_schedule
}
'
'for beta schedule.'
)
@
staticmethod
def
linear_beta_schedule
(
diffusion_timesteps
,
beta_0
=
1e-4
,
beta_T
=
2e-2
):
r
"""Linear schedule from Ho et al, extended to work for any number of
diffusion steps.
Args:
diffusion_timesteps (int): The number of betas to produce.
beta_0 (float, optional): `\beta` at timestep 0. Defaults to 1e-4.
beta_T (float, optional): `\beta` at timestep `T` (the final
diffusion timestep). Defaults to 2e-2.
Returns:
np.ndarray: Betas used in diffusion process.
"""
scale
=
1000
/
diffusion_timesteps
beta_0
=
scale
*
beta_0
beta_T
=
scale
*
beta_T
return
np
.
linspace
(
beta_0
,
beta_T
,
diffusion_timesteps
,
dtype
=
np
.
float64
)
@
staticmethod
def
cosine_beta_schedule
(
diffusion_timesteps
,
max_beta
=
0.999
,
s
=
0.008
):
r
"""Create a beta schedule that discretizes the given alpha_t_bar
function, which defines the cumulative product of `(1-\beta)` over time
from `t = [0, 1]`.
Args:
diffusion_timesteps (int): The number of betas to produce.
max_beta (float, optional): The maximum beta to use; use values
lower than 1 to prevent singularities. Defaults to 0.999.
s (float, optional): Small offset to prevent `\beta` from being too
small near `t = 0` Defaults to 0.008.
Returns:
np.ndarray: Betas used in diffusion process.
"""
def
f
(
t
,
T
,
s
):
return
np
.
cos
((
t
/
T
+
s
)
/
(
1
+
s
)
*
np
.
pi
/
2
)
**
2
betas
=
[]
for
t
in
range
(
diffusion_timesteps
):
alpha_bar_t
=
f
(
t
+
1
,
diffusion_timesteps
,
s
)
alpha_bar_t_1
=
f
(
t
,
diffusion_timesteps
,
s
)
betas_t
=
1
-
alpha_bar_t
/
alpha_bar_t_1
betas
.
append
(
min
(
betas_t
,
max_beta
))
return
np
.
array
(
betas
)
def
q_sample
(
self
,
x_0
,
t
,
noise
=
None
):
r
"""Get diffusion result at timestep `t` by `q(x_t | x_0)`.
Args:
x_0 (torch.Tensor): Original image without diffusion.
t (torch.Tensor): Target diffusion timestep.
noise (torch.Tensor, optional): Noise used in reparameteration
trick. Default to None.
Returns:
torch.tensor: Diffused image `x_t`.
"""
device
=
get_module_device
(
self
)
num_batches
=
x_0
.
shape
[
0
]
tar_shape
=
x_0
.
shape
noise
=
self
.
get_noise
(
noise
,
num_batches
=
num_batches
)
mean
=
var_to_tensor
(
self
.
sqrt_alphas_bar
,
t
,
tar_shape
,
device
)
std
=
var_to_tensor
(
self
.
sqrt_one_minus_alphas_bar
,
t
,
tar_shape
,
device
)
return
x_0
*
mean
+
noise
*
std
def
q_mean_log_variance
(
self
,
x_0
,
t
):
r
"""Get mean and log_variance of diffusion process `q(x_t | x_0)`.
Args:
x_0 (torch.tensor): The original image before diffusion, shape as
[bz, ch, H, W].
t (torch.tensor): Target timestep, shape as [bz, ].
Returns:
Tuple(torch.tensor): Tuple contains mean and log variance.
"""
device
=
get_module_device
(
self
)
tar_shape
=
x_0
.
shape
mean
=
var_to_tensor
(
self
.
sqrt_alphas_bar
,
t
,
tar_shape
,
device
)
*
x_0
logvar
=
var_to_tensor
(
self
.
log_one_minus_alphas_bar
,
t
,
tar_shape
,
device
)
return
mean
,
logvar
def
q_posterior_mean_variance
(
self
,
x_0
,
x_t
,
t
,
need_var
=
True
,
logvar
=
False
):
r
"""Get mean and variance of diffusion posterior
`q(x_{t-1} | x_t, x_0)`.
Args:
x_0 (torch.tensor): The original image before diffusion, shape as
[bz, ch, H, W].
t (torch.tensor): Target timestep, shape as [bz, ].
need_var (bool, optional): If set as ``True``, this function will
return a dict contains ``var``. Otherwise, only mean will be
returned, ``logvar`` will be ignored. Defaults to True.
logvar (bool, optional): If set as ``True``, the returned dict
will additionally contain ``logvar``. This argument will be
considered only if ``var == True``. Defaults to False.
Returns:
torch.Tensor | dict: If ``var``, will return a dict contains
``mean`` and ``var``. Otherwise, only mean will be returned.
If ``var`` and ``logvar`` set at as True simultaneously, the
returned dict will additional contain ``logvar``.
"""
device
=
get_module_device
(
self
)
tar_shape
=
x_0
.
shape
tilde_mu_t_coef1
=
var_to_tensor
(
self
.
tilde_mu_t_coef1
,
t
,
tar_shape
,
device
)
tilde_mu_t_coef2
=
var_to_tensor
(
self
.
tilde_mu_t_coef2
,
t
,
tar_shape
,
device
)
posterior_mean
=
tilde_mu_t_coef1
*
x_0
+
tilde_mu_t_coef2
*
x_t
# do not need variance, just return mean
if
not
need_var
:
return
posterior_mean
posterior_var
=
var_to_tensor
(
self
.
tilde_betas_t
,
t
,
tar_shape
,
device
)
out_dict
=
dict
(
mean_posterior
=
posterior_mean
,
var_posterior
=
posterior_var
)
if
logvar
:
posterior_logvar
=
var_to_tensor
(
self
.
log_tilde_betas_t_clipped
,
t
,
tar_shape
,
device
)
out_dict
[
'logvar_posterior'
]
=
posterior_logvar
return
out_dict
def
p_mean_variance
(
self
,
denoising_output
,
x_t
,
t
,
clip_denoised
=
True
,
denoised_fn
=
None
):
r
"""Get mean, variance, log variance of denoising process
`p(x_{t-1} | x_{t})` and predicted `x_0`.
Args:
denoising_output (dict[torch.Tensor]): The output from denoising
model.
x_t (torch.Tensor): Diffused image at timestep `t` to denoising.
t (torch.Tensor): Current timestep.
clip_denoised (bool, optional): Whether cliped sample results into
[-1, 1]. Defaults to True.
denoised_fn (callable, optional): If not None, a function which
applies to the predicted ``x_0`` before it is passed to the
following sampling procedure. Noted that this function will be
applies before ``clip_denoised``. Defaults to None.
Returns:
dict: A dict contains ``var_pred``, ``logvar_pred``, ``mean_pred``
and ``x_0_pred``.
"""
target_shape
=
x_t
.
shape
device
=
get_module_device
(
self
)
# prepare for var and logvar
if
self
.
denoising_var_mode
.
upper
()
==
'LEARNED'
:
# NOTE: the output actually LEARNED_LOG_VAR
logvar_pred
=
denoising_output
[
'logvar'
]
varpred
=
torch
.
exp
(
logvar_pred
)
elif
self
.
denoising_var_mode
.
upper
()
==
'LEARNED_RANGE'
:
# NOTE: the output actually LEARNED_FACTOR
var_factor
=
denoising_output
[
'factor'
]
lower_bound_logvar
=
var_to_tensor
(
self
.
log_tilde_betas_t_clipped
,
t
,
target_shape
,
device
)
upper_bound_logvar
=
var_to_tensor
(
np
.
log
(
self
.
betas
),
t
,
target_shape
,
device
)
logvar_pred
=
var_factor
*
upper_bound_logvar
+
(
1
-
var_factor
)
*
lower_bound_logvar
varpred
=
torch
.
exp
(
logvar_pred
)
elif
self
.
denoising_var_mode
.
upper
()
==
'FIXED_LARGE'
:
# use betas as var
varpred
=
var_to_tensor
(
np
.
append
(
self
.
tilde_betas_t
[
1
],
self
.
betas
),
t
,
target_shape
,
device
)
logvar_pred
=
torch
.
log
(
varpred
)
elif
self
.
denoising_var_mode
.
upper
()
==
'FIXED_SMALL'
:
# use posterior (tilde_betas) as var
varpred
=
var_to_tensor
(
self
.
tilde_betas_t
,
t
,
target_shape
,
device
)
logvar_pred
=
var_to_tensor
(
self
.
log_tilde_betas_t_clipped
,
t
,
target_shape
,
device
)
else
:
raise
AttributeError
(
'Unknown denoising var output type '
f
'[
{
self
.
denoising_var_mode
}
].'
)
def
process_x_0
(
x
):
if
denoised_fn
is
not
None
and
callable
(
denoised_fn
):
x
=
denoised_fn
(
x
)
return
x
.
clamp
(
-
1
,
1
)
if
clip_denoised
else
x
# prepare for mean and x_0
if
self
.
denoising_mean_mode
.
upper
()
==
'EPS'
:
eps_pred
=
denoising_output
[
'eps_t_pred'
]
# We can get x_{t-1} with eps in two following approaches:
# 1. eps --(Eq 15)--> \hat{x_0} --(Eq 7)--> \tilde_mu --> x_{t-1}
# 2. eps --(Eq 11)--> \mu_{\theta} --(Eq 7)--> x_{t-1}
# We can verify \tilde_mu in method 1 and \mu_{\theta} in method 2
# are almost same (error of 1e-4) with the same eps input.
# In our implementation, we use method (1) to consistent with
# the official ones.
# If you want to calculate \mu_{\theta} with method 2, you can
# use the following code:
# coef1 = var_to_tensor(
# np.sqrt(1.0 / self.alphas), t, tar_shape)
# coef2 = var_to_tensor(
# self.betas / self.sqrt_one_minus_alphas_bar, t, tar_shape)
# mu_theta = coef1 * (x_t - coef2 * eps)
x_0_pred
=
process_x_0
(
self
.
pred_x_0_from_eps
(
eps_pred
,
x_t
,
t
))
mean_pred
=
self
.
q_posterior_mean_variance
(
x_0_pred
,
x_t
,
t
,
need_var
=
False
)
elif
self
.
denoising_mean_mode
.
upper
()
==
'START_X'
:
x_0_pred
=
process_x_0
(
denoising_output
[
'x_0_pred'
])
mean_pred
=
self
.
q_posterior_mean_variance
(
x_0_pred
,
x_t
,
t
,
need_var
=
False
)
elif
self
.
denoising_mean_mode
.
upper
()
==
'PREVIOUS_X'
:
# NOTE: the output actually PREVIOUS_X_MEAN (MU_THETA)
# because this actually predict \mu_{\theta}
mean_pred
=
denoising_output
[
'x_tm1_pred'
]
x_0_pred
=
process_x_0
(
self
.
pred_x_0_from_x_tm1
(
mean_pred
,
x_t
,
t
))
else
:
raise
AttributeError
(
'Unknown denoising mean output type '
f
'[
{
self
.
denoising_mean_mode
}
].'
)
output_dict
=
dict
(
var_pred
=
varpred
,
logvar_pred
=
logvar_pred
,
mean_pred
=
mean_pred
,
x_0_pred
=
x_0_pred
)
# avoid return duplicate variables
return
{
k
:
output_dict
[
k
]
for
k
in
output_dict
.
keys
()
if
k
not
in
denoising_output
}
def
denoising_step
(
self
,
model
,
x_t
,
t
,
noise
=
None
,
label
=
None
,
clip_denoised
=
True
,
denoised_fn
=
None
,
model_kwargs
=
None
,
return_noise
=
False
):
"""Single denoising step. Get `x_{t-1}` from ``x_t`` and ``t``.
Args:
model (torch.nn.Module): Denoising model used to sample images.
x_t (torch.Tensor): Input diffused image.
t (torch.Tensor): Current timestep.
noise (torch.Tensor | callable | None): Noise for
reparameterization trick. You can directly give a batch of
noise through a ``torch.Tensor`` or offer a callable function
to sample a batch of noise data. Otherwise, the ``None``
indicates to use the default noise sampler.
label (torch.Tensor | callable | None): You can directly give a
batch of label through a ``torch.Tensor`` or offer a callable
function to sample a batch of label data. Otherwise, the
``None`` indicates to use the default label sampler.
clip_denoised (bool, optional): Whether to clip sample results into
[-1, 1]. Defaults to False.
denoised_fn (callable, optional): If not None, a function which
applies to the predicted ``x_0`` prediction before it is used
to sample. Applies before ``clip_denoised``. Defaults to None.
model_kwargs (dict, optional): Arguments passed to denoising model.
Defaults to None.
return_noise (bool, optional): If True, ``noise_batch``, outputs
from denoising model and ``p_mean_variance`` will be returned
in a dict with ``fake_img``. Defaults to False.
Return:
torch.Tensor | dict: If not ``return_noise``, only the denoising
image will be returned. Otherwise, the dict contains
``fake_image``, ``noise_batch`` and outputs from denoising
model and ``p_mean_variance`` will be returned.
"""
# init model_kwargs as dict if not passed
if
model_kwargs
is
None
:
model_kwargs
=
dict
()
model_kwargs
.
update
(
dict
(
return_noise
=
return_noise
))
denoising_output
=
model
(
x_t
,
t
,
label
=
label
,
**
model_kwargs
)
p_output
=
self
.
p_mean_variance
(
denoising_output
,
x_t
,
t
,
clip_denoised
,
denoised_fn
)
mean_pred
=
p_output
[
'mean_pred'
]
var_pred
=
p_output
[
'var_pred'
]
num_batches
=
x_t
.
shape
[
0
]
device
=
get_module_device
(
self
)
# get noise for reparameterization
noise
=
self
.
get_noise
(
noise
,
num_batches
=
num_batches
).
to
(
device
)
nonzero_mask
=
((
t
!=
0
).
float
().
view
(
-
1
,
*
([
1
]
*
(
len
(
x_t
.
shape
)
-
1
))))
# Here we directly use var_pred instead logvar_pred,
# only error of 1e-12.
# logvar_pred = p_output['logvar_pred']
# sample = mean_pred + \
# nonzero_mask * torch.exp(0.5 * logvar_pred) * noise
sample
=
mean_pred
+
nonzero_mask
*
torch
.
sqrt
(
var_pred
)
*
noise
if
return_noise
:
return
dict
(
fake_img
=
sample
,
noise_repar
=
noise
,
**
denoising_output
,
**
p_output
)
return
sample
def
pred_x_0_from_eps
(
self
,
eps
,
x_t
,
t
):
r
"""Predict x_0 from eps by Equ 15 in DDPM paper:
.. math::
x_0 = \frac{(x_t - \sqrt{(1-\bar{\alpha}_t)} * eps)}
{\sqrt{\bar{\alpha}_t}}
Args:
eps (torch.Tensor)
x_t (torch.Tensor)
t (torch.Tensor)
Returns:
torch.tensor: Predicted ``x_0``.
"""
device
=
get_module_device
(
self
)
tar_shape
=
x_t
.
shape
coef1
=
var_to_tensor
(
self
.
sqrt_recip_alplas_bar
,
t
,
tar_shape
,
device
)
coef2
=
var_to_tensor
(
self
.
sqrt_recipm1_alphas_bar
,
t
,
tar_shape
,
device
)
return
x_t
*
coef1
-
eps
*
coef2
def
pred_x_0_from_x_tm1
(
self
,
x_tm1
,
x_t
,
t
):
r
"""
Predict `x_0` from `x_{t-1}`. (actually from `\mu_{\theta}`).
`(\mu_{\theta} - coef2 * x_t) / coef1`, where `coef1` and `coef2`
are from Eq 6 of the DDPM paper.
NOTE: This function actually predict ``x_0`` from ``mu_theta`` (mean
of ``x_{t-1}``).
Args:
x_tm1 (torch.Tensor): `x_{t-1}` used to predict `x_0`.
x_t (torch.Tensor): `x_{t}` used to predict `x_0`.
t (torch.Tensor): Current timestep.
Returns:
torch.Tensor: Predicted `x_0`.
"""
device
=
get_module_device
(
self
)
tar_shape
=
x_t
.
shape
coef1
=
var_to_tensor
(
self
.
tilde_mu_t_coef1
,
t
,
tar_shape
,
device
)
coef2
=
var_to_tensor
(
self
.
tilde_mu_t_coef2
,
t
,
tar_shape
,
device
)
x_0
=
(
x_tm1
-
coef2
*
x_t
)
/
coef1
return
x_0
def
forward_train
(
self
,
data
,
**
kwargs
):
"""Deprecated forward function in training."""
raise
NotImplementedError
(
'In MMGeneration, we do NOT recommend users to call'
'this function, because the train_step function is designed for '
'the training process.'
)
def
forward_test
(
self
,
data
,
**
kwargs
):
"""Testing function for Diffusion Denosing Probability Models.
Args:
data (torch.Tensor | dict | None): Input data. This data will be
passed to different methods.
"""
mode
=
kwargs
.
pop
(
'mode'
,
'sampling'
)
if
mode
==
'sampling'
:
return
self
.
sample_from_noise
(
data
,
**
kwargs
)
elif
mode
==
'reconstruction'
:
# this mode is design for evaluation likelood metrics
return
self
.
reconstruction_step
(
data
,
**
kwargs
)
raise
NotImplementedError
(
'Other specific testing functions should'
' be implemented by the sub-classes.'
)
def
forward
(
self
,
data
,
return_loss
=
False
,
**
kwargs
):
"""Forward function.
Args:
data (dict | torch.Tensor): Input data dictionary.
return_loss (bool, optional): Whether in training or testing.
Defaults to False.
Returns:
dict: Output dictionary.
"""
if
return_loss
:
return
self
.
forward_train
(
data
,
**
kwargs
)
return
self
.
forward_test
(
data
,
**
kwargs
)
build/lib/mmgen/models/diffusions/sampler.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
torch
from
..builder
import
MODULES
@
MODULES
.
register_module
()
class
UniformTimeStepSampler
:
"""Timestep sampler for DDPM-based models. This sampler sample all
timesteps with the same probabilistic.
Args:
num_timesteps (int): Total timesteps of the diffusion process.
"""
def
__init__
(
self
,
num_timesteps
):
self
.
num_timesteps
=
num_timesteps
self
.
prob
=
[
1
/
self
.
num_timesteps
for
_
in
range
(
self
.
num_timesteps
)]
def
sample
(
self
,
batch_size
):
"""Sample timesteps.
Args:
batch_size (int): The desired batch size of the sampled timesteps.
Returns:
torch.Tensor: Sampled timesteps.
"""
# use numpy to make sure our implementation is consistent with the
# official ones.
return
torch
.
from_numpy
(
np
.
random
.
choice
(
self
.
num_timesteps
,
size
=
(
batch_size
,
),
p
=
self
.
prob
)).
long
()
def
__call__
(
self
,
batch_size
):
"""Return sampled results."""
return
self
.
sample
(
batch_size
)
build/lib/mmgen/models/diffusions/utils.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
def
_get_noise_batch
(
noise
,
image_shape
,
num_timesteps
=
0
,
num_batches
=
0
,
timesteps_noise
=
False
):
"""Get noise batch. Support get sequeue of noise along timesteps.
We support the following use cases ('bz' denotes ```num_batches`` and 'n'
denotes ``num_timesteps``):
If timesteps_noise is True, we output noise which dimension is 5.
- Input is [bz, c, h, w]: Expand to [n, bz, c, h, w]
- Input is [n, c, h, w]: Expand to [n, bz, c, h, w]
- Input is [n*bz, c, h, w]: View to [n, bz, c, h, w]
- Dim of the input is 5: Return the input, ignore ``num_batches`` and
``num_timesteps``
- Callable or None: Generate noise shape as [n, bz, c, h, w]
- Otherwise: Raise error
If timestep_noise is False, we output noise which dimension is 4 and
ignore ``num_timesteps``.
- Dim of the input is 3: Unsqueeze to [1, c, h, w], ignore ``num_batches``
- Dim of the input is 4: Return input, ignore ``num_batches``
- Callable or None: Generate noise shape as [bz, c, h, w]
- Otherwise: Raise error
It's to be noted that, we do not move the generated label to target device
in this function because we can not get which device the noise should move
to.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
image_shape (torch.Size): Size of images in the diffusion process.
num_timesteps (int, optional): Total timestpes of the diffusion and
denoising process. Defaults to 0.
num_batches (int, optional): The number of batch size. To be noted that
this argument only work when the input ``noise`` is callable or
``None``. Defaults to 0.
timesteps_noise (bool, optional): If True, returned noise will shape
as [n, bz, c, h, w], otherwise shape as [bz, c, h, w].
Defaults to False.
device (str, optional): If not ``None``, move the generated noise to
corresponding device.
Returns:
torch.Tensor: Generated noise with desired shape.
"""
if
isinstance
(
noise
,
torch
.
Tensor
):
# conduct sanity check for the last three dimension
assert
noise
.
shape
[
-
3
:]
==
image_shape
if
timesteps_noise
:
if
noise
.
ndim
==
4
:
assert
num_batches
>
0
and
num_timesteps
>
0
# noise shape as [n, c, h, w], expand to [n, bz, c, h, w]
if
noise
.
shape
[
0
]
==
num_timesteps
:
noise_batch
=
noise
.
view
(
num_timesteps
,
1
,
*
image_shape
)
noise_batch
=
noise_batch
.
expand
(
-
1
,
num_batches
,
-
1
,
-
1
,
-
1
)
# noise shape as [bz, c, h, w], expand to [n, bz, c, h, w]
elif
noise
.
shape
[
0
]
==
num_batches
:
noise_batch
=
noise
.
view
(
1
,
num_batches
,
*
image_shape
)
noise_batch
=
noise_batch
.
expand
(
num_timesteps
,
-
1
,
-
1
,
-
1
,
-
1
)
# noise shape as [n*bz, c, h, w], reshape to [b, bz, c, h, w]
elif
noise
.
shape
[
0
]
==
num_timesteps
*
num_batches
:
noise_batch
=
noise
.
view
(
num_timesteps
,
-
1
,
*
image_shape
)
else
:
raise
ValueError
(
'The timesteps noise should be in shape of '
'(n, c, h, w), (bz, c, h, w), (n*bz, c, h, w) or '
f
'(n, bz, c, h, w). But receive
{
noise
.
shape
}
.'
)
elif
noise
.
ndim
==
5
:
# direct return noise
noise_batch
=
noise
else
:
raise
ValueError
(
'The timesteps noise should be in shape of '
'(n, c, h, w), (bz, c, h, w), (n*bz, c, h, w) or '
f
'(n, bz, c, h, w). But receive
{
noise
.
shape
}
.'
)
else
:
if
noise
.
ndim
==
3
:
# reshape noise to [1, c, h, w]
noise_batch
=
noise
[
None
,
...]
elif
noise
.
ndim
==
4
:
# do nothing
noise_batch
=
noise
else
:
raise
ValueError
(
'The noise should be in shape of (n, c, h, w) or'
f
'(c, h, w), but got
{
noise
.
shape
}
'
)
# receive a noise generator and sample noise.
elif
callable
(
noise
):
assert
num_batches
>
0
noise_generator
=
noise
if
timesteps_noise
:
assert
num_timesteps
>
0
# generate noise shape as [n, bz, c, h, w]
noise_batch
=
noise_generator
(
(
num_timesteps
,
num_batches
,
*
image_shape
))
else
:
# generate noise shape as [bz, c, h, w]
noise_batch
=
noise_generator
((
num_batches
,
*
image_shape
))
# otherwise, we will adopt default noise sampler.
else
:
assert
num_batches
>
0
if
timesteps_noise
:
assert
num_timesteps
>
0
# generate noise shape as [n, bz, c, h, w]
noise_batch
=
torch
.
randn
(
(
num_timesteps
,
num_batches
,
*
image_shape
))
else
:
# generate noise shape as [bz, c, h, w]
noise_batch
=
torch
.
randn
((
num_batches
,
*
image_shape
))
return
noise_batch
def
_get_label_batch
(
label
,
num_timesteps
=
0
,
num_classes
=
0
,
num_batches
=
0
,
timesteps_noise
=
False
):
"""Get label batch. Support get sequeue of label along timesteps.
We support the following use cases ('bz' denotes ```num_batches`` and 'n'
denotes ``num_timesteps``):
If num_classes <= 0, return None.
If timesteps_noise is True, we output label which dimension is 2.
- Input is [bz, ]: Expand to [n, bz]
- Input is [n, ]: Expand to [n, bz]
- Input is [n*bz, ]: View to [n, bz]
- Dim of the input is 2: Return the input, ignore ``num_batches`` and
``num_timesteps``
- Callable or None: Generate label shape as [n, bz]
- Otherwise: Raise error
If timesteps_noise is False, we output label which dimension is 1 and
ignore ``num_timesteps``.
- Dim of the input is 1: Unsqueeze to [1, ], ignore ``num_batches``
- Dim of the input is 2: Return the input. ignore ``num_batches``
- Callable or None: Generate label shape as [bz, ]
- Otherwise: Raise error
It's to be noted that, we do not move the generated label to target device
in this function because we can not get which device the noise should move
to.
Args:
label (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_timesteps (int, optional): Total timestpes of the diffusion and
denoising process. Defaults to 0.
num_batches (int, optional): The number of batch size. To be noted that
this argument only work when the input ``noise`` is callable or
``None``. Defaults to 0.
timesteps_noise (bool, optional): If True, returned noise will shape
as [n, bz, c, h, w], otherwise shape as [bz, c, h, w].
Defaults to False.
Returns:
torch.Tensor: Generated label with desired shape.
"""
# no labels output if num_classes is 0
if
num_classes
==
0
:
assert
label
is
None
,
(
'
\'
label
\'
should be None '
'if
\'
num_classes == 0
\'
.'
)
return
None
# receive label and conduct sanity check.
if
isinstance
(
label
,
torch
.
Tensor
):
if
timesteps_noise
:
if
label
.
ndim
==
1
:
assert
num_batches
>
0
and
num_timesteps
>
0
# [n, ] to [n, bz]
if
label
.
shape
[
0
]
==
num_timesteps
:
label_batch
=
label
.
view
(
num_timesteps
,
1
)
label_batch
=
label_batch
.
expand
(
-
1
,
num_batches
)
# [bz, ] to [n, bz]
elif
label
.
shape
[
0
]
==
num_batches
:
label_batch
=
label
.
view
(
1
,
num_batches
)
label_batch
=
label_batch
.
expand
(
num_timesteps
,
-
1
)
# [n*bz, ] to [n, bz]
elif
label
.
shape
[
0
]
==
num_timesteps
*
num_batches
:
label_batch
=
label
.
view
(
num_timesteps
,
-
1
)
else
:
raise
ValueError
(
'The timesteps label should be in shape of '
'(n, ), (bz,), (n*bz, ) or (n, bz, ). But receive '
f
'
{
label
.
shape
}
.'
)
elif
label
.
ndim
==
2
:
# dimension is 2, direct return
label_batch
=
label
else
:
raise
ValueError
(
'The timesteps label should be in shape of '
'(n, ), (bz,), (n*bz, ) or (n, bz, ). But receive '
f
'
{
label
.
shape
}
.'
)
else
:
# dimension is 0, expand to [1, ]
if
label
.
ndim
==
0
:
label_batch
=
label
[
None
,
...]
# dimension is 1, do nothing
elif
label
.
ndim
==
1
:
label_batch
=
label
else
:
raise
ValueError
(
'The label should be in shape of (bz, ) or'
f
'zero-dimension tensor, but got
{
label
.
shape
}
'
)
# receive a noise generator and sample noise.
elif
callable
(
label
):
assert
num_batches
>
0
label_generator
=
label
if
timesteps_noise
:
assert
num_timesteps
>
0
# generate label shape as [n, bz]
label_batch
=
label_generator
((
num_timesteps
,
num_batches
))
else
:
# generate label shape as [bz, ]
label_batch
=
label_generator
((
num_batches
,
))
# otherwise, we will adopt default label sampler.
else
:
assert
num_batches
>
0
if
timesteps_noise
:
assert
num_timesteps
>
0
# generate label shape as [n, bz]
label_batch
=
torch
.
randint
(
0
,
num_classes
,
(
num_timesteps
,
num_batches
))
else
:
# generate label shape as [bz, ]
label_batch
=
torch
.
randint
(
0
,
num_classes
,
(
num_batches
,
))
return
label_batch
def
var_to_tensor
(
var
,
index
,
target_shape
=
None
,
device
=
None
):
"""Function used to extract variables by given index, and convert into
tensor as given shape.
Args:
var (np.array): Variables to be extracted.
index (torch.Tensor): Target index to extract.
target_shape (torch.Size, optional): If given, the indexed variable
will expand to the given shape. Defaults to None.
device (str): If given, the indexed variable will move to the target
device. Otherwise, indexed variable will on cpu. Defaults to None.
Returns:
torch.Tensor: Converted variable.
"""
# we must move var to cuda for it's ndarray in current design
var_indexed
=
torch
.
from_numpy
(
var
)[
index
.
cpu
()].
float
()
if
device
is
not
None
:
var_indexed
=
var_indexed
.
to
(
device
)
while
len
(
var_indexed
.
shape
)
<
len
(
target_shape
):
var_indexed
=
var_indexed
[...,
None
]
return
var_indexed
Prev
1
…
11
12
13
14
15
16
17
18
19
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment