Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
stylegan2_mmcv
Commits
1401de15
Commit
1401de15
authored
Jun 28, 2024
by
dongchy920
Browse files
stylegan2_mmcv
parents
Pipeline
#1274
canceled with stages
Changes
463
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5825 additions
and
0 deletions
+5825
-0
build/lib/mmgen/models/gans/__init__.py
build/lib/mmgen/models/gans/__init__.py
+12
-0
build/lib/mmgen/models/gans/base_gan.py
build/lib/mmgen/models/gans/base_gan.py
+261
-0
build/lib/mmgen/models/gans/basic_conditional_gan.py
build/lib/mmgen/models/gans/basic_conditional_gan.py
+368
-0
build/lib/mmgen/models/gans/mspie_stylegan2.py
build/lib/mmgen/models/gans/mspie_stylegan2.py
+210
-0
build/lib/mmgen/models/gans/progressive_growing_unconditional_gan.py
...mgen/models/gans/progressive_growing_unconditional_gan.py
+488
-0
build/lib/mmgen/models/gans/singan.py
build/lib/mmgen/models/gans/singan.py
+461
-0
build/lib/mmgen/models/gans/static_unconditional_gan.py
build/lib/mmgen/models/gans/static_unconditional_gan.py
+310
-0
build/lib/mmgen/models/losses/__init__.py
build/lib/mmgen/models/losses/__init__.py
+22
-0
build/lib/mmgen/models/losses/ddpm_loss.py
build/lib/mmgen/models/losses/ddpm_loss.py
+568
-0
build/lib/mmgen/models/losses/disc_auxiliary_loss.py
build/lib/mmgen/models/losses/disc_auxiliary_loss.py
+552
-0
build/lib/mmgen/models/losses/gan_loss.py
build/lib/mmgen/models/losses/gan_loss.py
+116
-0
build/lib/mmgen/models/losses/gen_auxiliary_loss.py
build/lib/mmgen/models/losses/gen_auxiliary_loss.py
+867
-0
build/lib/mmgen/models/losses/pixelwise_loss.py
build/lib/mmgen/models/losses/pixelwise_loss.py
+718
-0
build/lib/mmgen/models/losses/utils.py
build/lib/mmgen/models/losses/utils.py
+114
-0
build/lib/mmgen/models/misc.py
build/lib/mmgen/models/misc.py
+72
-0
build/lib/mmgen/models/translation_models/__init__.py
build/lib/mmgen/models/translation_models/__init__.py
+9
-0
build/lib/mmgen/models/translation_models/base_translation_model.py
...mmgen/models/translation_models/base_translation_model.py
+139
-0
build/lib/mmgen/models/translation_models/cyclegan.py
build/lib/mmgen/models/translation_models/cyclegan.py
+211
-0
build/lib/mmgen/models/translation_models/pix2pix.py
build/lib/mmgen/models/translation_models/pix2pix.py
+184
-0
build/lib/mmgen/models/translation_models/static_translation_gan.py
...mmgen/models/translation_models/static_translation_gan.py
+143
-0
No files found.
Too many changes to show.
To preserve performance only
463 of 463+
files are displayed.
Plain diff
Email patch
build/lib/mmgen/models/gans/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.base_gan
import
BaseGAN
from
.basic_conditional_gan
import
BasicConditionalGAN
from
.mspie_stylegan2
import
MSPIEStyleGAN2
from
.progressive_growing_unconditional_gan
import
ProgressiveGrowingGAN
from
.singan
import
PESinGAN
,
SinGAN
from
.static_unconditional_gan
import
StaticUnconditionalGAN
__all__
=
[
'BaseGAN'
,
'StaticUnconditionalGAN'
,
'ProgressiveGrowingGAN'
,
'SinGAN'
,
'MSPIEStyleGAN2'
,
'PESinGAN'
,
'BasicConditionalGAN'
]
build/lib/mmgen/models/gans/base_gan.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
abc
import
ABCMeta
,
abstractmethod
from
collections
import
OrderedDict
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
class
BaseGAN
(
nn
.
Module
,
metaclass
=
ABCMeta
):
"""BaseGAN Module."""
def
__init__
(
self
):
super
().
__init__
()
self
.
fp16_enabled
=
False
@
property
def
with_disc
(
self
):
"""Whether with dicriminator."""
return
hasattr
(
self
,
'discriminator'
)
and
self
.
discriminator
is
not
None
@
property
def
with_ema_gen
(
self
):
"""bool: whether the GAN adopts exponential moving average."""
return
hasattr
(
self
,
'gen_ema'
)
and
self
.
gen_ema
is
not
None
@
property
def
with_gen_auxiliary_loss
(
self
):
"""bool: whether the GAN adopts auxiliary loss in the generator."""
return
hasattr
(
self
,
'gen_auxiliary_losses'
)
and
(
self
.
gen_auxiliary_losses
is
not
None
)
@
property
def
with_disc_auxiliary_loss
(
self
):
"""bool: whether the GAN adopts auxiliary loss in the discriminator."""
return
(
hasattr
(
self
,
'disc_auxiliary_losses'
)
)
and
self
.
disc_auxiliary_losses
is
not
None
def
_get_disc_loss
(
self
,
outputs_dict
):
# Construct losses dict. If you hope some items to be included in the
# computational graph, you have to add 'loss' in its name. Otherwise,
# items without 'loss' in their name will just be used to print
# information.
losses_dict
=
{}
# gan loss
losses_dict
[
'loss_disc_fake'
]
=
self
.
gan_loss
(
outputs_dict
[
'disc_pred_fake'
],
target_is_real
=
False
,
is_disc
=
True
)
losses_dict
[
'loss_disc_real'
]
=
self
.
gan_loss
(
outputs_dict
[
'disc_pred_real'
],
target_is_real
=
True
,
is_disc
=
True
)
# disc auxiliary loss
if
self
.
with_disc_auxiliary_loss
:
for
loss_module
in
self
.
disc_auxiliary_losses
:
loss_
=
loss_module
(
outputs_dict
)
if
loss_
is
None
:
continue
# the `loss_name()` function return name as 'loss_xxx'
if
loss_module
.
loss_name
()
in
losses_dict
:
losses_dict
[
loss_module
.
loss_name
(
)]
=
losses_dict
[
loss_module
.
loss_name
()]
+
loss_
else
:
losses_dict
[
loss_module
.
loss_name
()]
=
loss_
loss
,
log_var
=
self
.
_parse_losses
(
losses_dict
)
return
loss
,
log_var
def
_get_gen_loss
(
self
,
outputs_dict
):
# Construct losses dict. If you hope some items to be included in the
# computational graph, you have to add 'loss' in its name. Otherwise,
# items without 'loss' in their name will just be used to print
# information.
losses_dict
=
{}
# gan loss
losses_dict
[
'loss_disc_fake_g'
]
=
self
.
gan_loss
(
outputs_dict
[
'disc_pred_fake_g'
],
target_is_real
=
True
,
is_disc
=
False
)
# gen auxiliary loss
if
self
.
with_gen_auxiliary_loss
:
for
loss_module
in
self
.
gen_auxiliary_losses
:
loss_
=
loss_module
(
outputs_dict
)
if
loss_
is
None
:
continue
# the `loss_name()` function return name as 'loss_xxx'
if
loss_module
.
loss_name
()
in
losses_dict
:
losses_dict
[
loss_module
.
loss_name
(
)]
=
losses_dict
[
loss_module
.
loss_name
()]
+
loss_
else
:
losses_dict
[
loss_module
.
loss_name
()]
=
loss_
loss
,
log_var
=
self
.
_parse_losses
(
losses_dict
)
return
loss
,
log_var
@
abstractmethod
def
train_step
(
self
,
data
,
optimizer
,
ddp_reducer
=
None
):
"""The iteration step during training.
This method defines an iteration step during training. Different from
other repo in **MM** series, we allow the back propagation and
optimizer updating to directly follow the iterative training schedule
of GAN. Of course, we will show that you can also move the back
propagation outside of this method, and then optimize the parameters
in the optimizer hook. But this will cause extra GPU memory cost as a
result of retaining computational graph. Otherwise, the training
schedule should be modified in the detailed implementation.
TODO: Give an example of removing bp outside ``train_step``.
TODO: Try the synchronized back propagation.
Args:
data (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
ddp_reducer (:obj:`Reducer` | None, optional): This reducer is used
to dynamically collect used parameters in the distributed
training. If given an initialized ``Reducer``, we will call its
``prepare_for_backward()`` function just before calling
``.backward()``.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
\
``num_samples``.
- ``loss`` is a tensor for back propagation, which can be a
\
weighted sum of multiple losses.
- ``log_vars`` contains all the variables to be sent to the
logger.
- ``num_samples`` indicates the batch size (when the model is
\
DDP, it means the batch size on each GPU), which is used for
\
averaging the logs.
"""
def
sample_from_noise
(
self
,
noise
,
num_batches
=
0
,
sample_model
=
'ema/orig'
,
**
kwargs
):
"""Sample images from noises by using the generator.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
Returns:
torch.Tensor | dict: The output may be the direct synthesized
images in ``torch.Tensor``. Otherwise, a dict with queried
data, including generated images, will be returned.
"""
if
sample_model
==
'ema'
:
assert
self
.
use_ema
_model
=
self
.
generator_ema
elif
sample_model
==
'ema/orig'
and
self
.
use_ema
:
_model
=
self
.
generator_ema
else
:
_model
=
self
.
generator
outputs
=
_model
(
noise
,
num_batches
=
num_batches
,
**
kwargs
)
if
isinstance
(
outputs
,
dict
)
and
'noise_batch'
in
outputs
:
noise
=
outputs
[
'noise_batch'
]
if
sample_model
==
'ema/orig'
and
self
.
use_ema
:
_model
=
self
.
generator
outputs_
=
_model
(
noise
,
num_batches
=
num_batches
,
**
kwargs
)
if
isinstance
(
outputs_
,
dict
):
outputs
[
'fake_img'
]
=
torch
.
cat
(
[
outputs
[
'fake_img'
],
outputs_
[
'fake_img'
]],
dim
=
0
)
else
:
outputs
=
torch
.
cat
([
outputs
,
outputs_
],
dim
=
0
)
return
outputs
def
forward_train
(
self
,
data
,
**
kwargs
):
"""Deprecated forward function in training."""
raise
NotImplementedError
(
'In MMGeneration, we do NOT recommend users to call'
'this function, because the train_step function is designed for '
'the training process.'
)
def
forward_test
(
self
,
data
,
**
kwargs
):
"""Testing function for GANs.
Args:
data (torch.Tensor | dict | None): Input data. This data will be
passed to different methods.
"""
if
kwargs
.
pop
(
'mode'
,
'sampling'
)
==
'sampling'
:
return
self
.
sample_from_noise
(
data
,
**
kwargs
)
raise
NotImplementedError
(
'Other specific testing functions should'
' be implemented by the sub-classes.'
)
def
forward
(
self
,
data
,
return_loss
=
False
,
**
kwargs
):
"""Forward function.
Args:
data (dict | torch.Tensor): Input data dictionary.
return_loss (bool, optional): Whether in training or testing.
Defaults to False.
Returns:
dict: Output dictionary.
"""
if
return_loss
:
return
self
.
forward_train
(
data
,
**
kwargs
)
return
self
.
forward_test
(
data
,
**
kwargs
)
def
_parse_losses
(
self
,
losses
):
"""Parse the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary information.
Returns:
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
\
which may be a weighted sum of all losses, log_vars contains
\
all the variables to be sent to the logger.
"""
log_vars
=
OrderedDict
()
for
loss_name
,
loss_value
in
losses
.
items
():
if
isinstance
(
loss_value
,
torch
.
Tensor
):
log_vars
[
loss_name
]
=
loss_value
.
mean
()
elif
isinstance
(
loss_value
,
list
):
log_vars
[
loss_name
]
=
sum
(
_loss
.
mean
()
for
_loss
in
loss_value
)
# Allow setting None for some loss item.
# This is to support dynamic loss module, where the loss is
# calculated with a fixed frequency.
elif
loss_value
is
None
:
continue
else
:
raise
TypeError
(
f
'
{
loss_name
}
is not a tensor or list of tensors'
)
# Note that you have to add 'loss' in name of the items that will be
# included in back propagation.
loss
=
sum
(
_value
for
_key
,
_value
in
log_vars
.
items
()
if
'loss'
in
_key
)
log_vars
[
'loss'
]
=
loss
for
loss_name
,
loss_value
in
log_vars
.
items
():
# reduce loss when distributed training
if
dist
.
is_available
()
and
dist
.
is_initialized
():
loss_value
=
loss_value
.
data
.
clone
()
dist
.
all_reduce
(
loss_value
.
div_
(
dist
.
get_world_size
()))
log_vars
[
loss_name
]
=
loss_value
.
item
()
return
loss
,
log_vars
build/lib/mmgen/models/gans/basic_conditional_gan.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
import
torch
import
torch.nn
as
nn
from
torch.nn.parallel.distributed
import
_find_tensors
from
..builder
import
MODELS
,
build_module
from
..common
import
set_requires_grad
from
.base_gan
import
BaseGAN
@
MODELS
.
register_module
(
'BasiccGAN'
)
@
MODELS
.
register_module
()
class
BasicConditionalGAN
(
BaseGAN
):
"""Basic conditional GANs.
This is the conditional GAN model containing standard adversarial training
schedule. To fulfill the requirements of various GAN algorithms,
``disc_auxiliary_loss`` and ``gen_auxiliary_loss`` are provided to
customize auxiliary losses, e.g., gradient penalty loss, and discriminator
shift loss. In addition, ``train_cfg`` and ``test_cfg`` aims at setuping
training schedule.
Args:
generator (dict): Config for generator.
discriminator (dict): Config for discriminator.
gan_loss (dict): Config for generative adversarial loss.
disc_auxiliary_loss (dict): Config for auxiliary loss to
discriminator.
gen_auxiliary_loss (dict | None, optional): Config for auxiliary loss
to generator. Defaults to None.
train_cfg (dict | None, optional): Config for training schedule.
Defaults to None.
test_cfg (dict | None, optional): Config for testing schedule. Defaults
to None.
num_classes (int | None, optional): The number of conditional classes.
Defaults to None.
"""
def
__init__
(
self
,
generator
,
discriminator
,
gan_loss
,
disc_auxiliary_loss
=
None
,
gen_auxiliary_loss
=
None
,
train_cfg
=
None
,
test_cfg
=
None
,
num_classes
=
None
):
super
().
__init__
()
self
.
num_classes
=
num_classes
self
.
_gen_cfg
=
deepcopy
(
generator
)
self
.
generator
=
build_module
(
generator
,
default_args
=
dict
(
num_classes
=
num_classes
))
# support no discriminator in testing
if
discriminator
is
not
None
:
self
.
discriminator
=
build_module
(
discriminator
,
default_args
=
dict
(
num_classes
=
num_classes
))
else
:
self
.
discriminator
=
None
# support no gan_loss in testing
if
gan_loss
is
not
None
:
self
.
gan_loss
=
build_module
(
gan_loss
)
else
:
self
.
gan_loss
=
None
if
disc_auxiliary_loss
:
self
.
disc_auxiliary_losses
=
build_module
(
disc_auxiliary_loss
)
if
not
isinstance
(
self
.
disc_auxiliary_losses
,
nn
.
ModuleList
):
self
.
disc_auxiliary_losses
=
nn
.
ModuleList
(
[
self
.
disc_auxiliary_losses
])
else
:
self
.
disc_auxiliary_loss
=
None
if
gen_auxiliary_loss
:
self
.
gen_auxiliary_losses
=
build_module
(
gen_auxiliary_loss
)
if
not
isinstance
(
self
.
gen_auxiliary_losses
,
nn
.
ModuleList
):
self
.
gen_auxiliary_losses
=
nn
.
ModuleList
(
[
self
.
gen_auxiliary_losses
])
else
:
self
.
gen_auxiliary_losses
=
None
self
.
train_cfg
=
deepcopy
(
train_cfg
)
if
train_cfg
else
None
self
.
test_cfg
=
deepcopy
(
test_cfg
)
if
test_cfg
else
None
self
.
_parse_train_cfg
()
if
test_cfg
is
not
None
:
self
.
_parse_test_cfg
()
def
_parse_train_cfg
(
self
):
"""Parsing train config and set some attributes for training."""
if
self
.
train_cfg
is
None
:
self
.
train_cfg
=
dict
()
# control the work flow in train step
self
.
disc_steps
=
self
.
train_cfg
.
get
(
'disc_steps'
,
1
)
self
.
gen_steps
=
self
.
train_cfg
.
get
(
'gen_steps'
,
1
)
# add support for accumulating gradients within multiple steps. This
# feature aims to simulate large `batch_sizes` (but may have some
# detailed differences in BN). Note that `self.disc_steps` should be
# set according to the batch accumulation strategy.
# In addition, in the detailed implementation, there is a difference
# between the batch accumulation in the generator and discriminator.
self
.
batch_accumulation_steps
=
self
.
train_cfg
.
get
(
'batch_accumulation_steps'
,
1
)
# whether to use exponential moving average for training
self
.
use_ema
=
self
.
train_cfg
.
get
(
'use_ema'
,
False
)
if
self
.
use_ema
:
# use deepcopy to guarantee the consistency
self
.
generator_ema
=
deepcopy
(
self
.
generator
)
def
_parse_test_cfg
(
self
):
"""Parsing test config and set some attributes for testing."""
if
self
.
test_cfg
is
None
:
self
.
test_cfg
=
dict
()
# basic testing information
self
.
batch_size
=
self
.
test_cfg
.
get
(
'batch_size'
,
1
)
# whether to use exponential moving average for testing
self
.
use_ema
=
self
.
test_cfg
.
get
(
'use_ema'
,
False
)
# TODO: finish ema part
def
train_step
(
self
,
data_batch
,
optimizer
,
ddp_reducer
=
None
,
loss_scaler
=
None
,
use_apex_amp
=
False
,
running_status
=
None
):
"""Train step function.
This function implements the standard training iteration for
asynchronous adversarial training. Namely, in each iteration, we first
update discriminator and then compute loss for generator with the newly
updated discriminator.
As for distributed training, we use the ``reducer`` from ddp to
synchronize the necessary params in current computational graph.
Args:
data_batch (dict): Input data from dataloader.
optimizer (dict): Dict contains optimizer for generator and
discriminator.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
loss_scaler (:obj:`torch.cuda.amp.GradScaler` | None, optional):
The loss/gradient scaler used for auto mixed-precision
training. Defaults to ``None``.
use_apex_amp (bool, optional). Whether to use apex.amp. Defaults to
``False``.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Contains 'log_vars', 'num_samples', and 'results'.
"""
# get data from data_batch
real_imgs
=
data_batch
[
'img'
]
# get the ground-truth label, torch.Tensor (N, )
gt_label
=
data_batch
[
'gt_label'
]
# If you adopt ddp, this batch size is local batch size for each GPU.
# If you adopt dp, this batch size is the global batch size as usual.
batch_size
=
real_imgs
.
shape
[
0
]
# get running status
if
running_status
is
not
None
:
curr_iter
=
running_status
[
'iteration'
]
else
:
# dirty walkround for not providing running status
if
not
hasattr
(
self
,
'iteration'
):
self
.
iteration
=
0
curr_iter
=
self
.
iteration
# disc training
set_requires_grad
(
self
.
discriminator
,
True
)
# do not `zero_grad` during batch accumulation
if
curr_iter
%
self
.
batch_accumulation_steps
==
0
:
optimizer
[
'discriminator'
].
zero_grad
()
# TODO: add noise sampler to customize noise sampling
with
torch
.
no_grad
():
fake_data
=
self
.
generator
(
None
,
num_batches
=
batch_size
,
label
=
None
,
return_noise
=
True
)
# fake_label should be in the same data type with the gt_label
fake_imgs
,
fake_label
=
fake_data
[
'fake_img'
],
fake_data
[
'label'
]
# disc pred for fake imgs and real_imgs
disc_pred_fake
=
self
.
discriminator
(
fake_imgs
,
label
=
fake_label
)
disc_pred_real
=
self
.
discriminator
(
real_imgs
,
label
=
gt_label
)
# get data dict to compute losses for disc
data_dict_
=
dict
(
gen
=
self
.
generator
,
disc
=
self
.
discriminator
,
disc_pred_fake
=
disc_pred_fake
,
disc_pred_real
=
disc_pred_real
,
fake_imgs
=
fake_imgs
,
real_imgs
=
real_imgs
,
iteration
=
curr_iter
,
batch_size
=
batch_size
,
gt_label
=
gt_label
,
fake_label
=
fake_label
,
loss_scaler
=
loss_scaler
)
loss_disc
,
log_vars_disc
=
self
.
_get_disc_loss
(
data_dict_
)
loss_disc
=
loss_disc
/
float
(
self
.
batch_accumulation_steps
)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if
ddp_reducer
is
not
None
:
ddp_reducer
.
prepare_for_backward
(
_find_tensors
(
loss_disc
))
if
loss_scaler
:
# add support for fp16
loss_scaler
.
scale
(
loss_disc
).
backward
()
elif
use_apex_amp
:
from
apex
import
amp
with
amp
.
scale_loss
(
loss_disc
,
optimizer
[
'discriminator'
],
loss_id
=
0
)
as
scaled_loss_disc
:
scaled_loss_disc
.
backward
()
else
:
loss_disc
.
backward
()
if
(
curr_iter
+
1
)
%
self
.
batch_accumulation_steps
==
0
:
if
loss_scaler
:
loss_scaler
.
unscale_
(
optimizer
[
'discriminator'
])
# note that we do not contain clip_grad procedure
loss_scaler
.
step
(
optimizer
[
'discriminator'
])
# loss_scaler.update will be called in runner.train()
else
:
optimizer
[
'discriminator'
].
step
()
# skip generator training if only train discriminator for current
# iteration
if
(
curr_iter
+
1
)
%
self
.
disc_steps
!=
0
:
results
=
dict
(
fake_imgs
=
fake_imgs
.
cpu
(),
real_imgs
=
real_imgs
.
cpu
())
outputs
=
dict
(
log_vars
=
log_vars_disc
,
num_samples
=
batch_size
,
results
=
results
)
if
hasattr
(
self
,
'iteration'
):
self
.
iteration
+=
1
return
outputs
# generator training
set_requires_grad
(
self
.
discriminator
,
False
)
# allow for training the generator with multiple steps
for
_
in
range
(
self
.
gen_steps
):
optimizer
[
'generator'
].
zero_grad
()
for
_
in
range
(
self
.
batch_accumulation_steps
):
# TODO: add noise sampler to customize noise sampling
fake_data
=
self
.
generator
(
None
,
num_batches
=
batch_size
,
return_noise
=
True
)
# fake_label should be in the same data type with the gt_label
fake_imgs
,
fake_label
=
fake_data
[
'fake_img'
],
fake_data
[
'label'
]
disc_pred_fake_g
=
self
.
discriminator
(
fake_imgs
,
label
=
fake_label
)
data_dict_
=
dict
(
gen
=
self
.
generator
,
disc
=
self
.
discriminator
,
fake_imgs
=
fake_imgs
,
disc_pred_fake_g
=
disc_pred_fake_g
,
iteration
=
curr_iter
,
batch_size
=
batch_size
,
fake_label
=
fake_label
,
loss_scaler
=
loss_scaler
)
loss_gen
,
log_vars_g
=
self
.
_get_gen_loss
(
data_dict_
)
loss_gen
=
loss_gen
/
float
(
self
.
batch_accumulation_steps
)
# prepare for backward in ddp. If you do not call this function
# before back propagation, the ddp will not dynamically find
# the used params in current computation.
if
ddp_reducer
is
not
None
:
ddp_reducer
.
prepare_for_backward
(
_find_tensors
(
loss_gen
))
if
loss_scaler
:
loss_scaler
.
scale
(
loss_gen
).
backward
()
elif
use_apex_amp
:
from
apex
import
amp
with
amp
.
scale_loss
(
loss_gen
,
optimizer
[
'generator'
],
loss_id
=
1
)
as
scaled_loss_disc
:
scaled_loss_disc
.
backward
()
else
:
loss_gen
.
backward
()
if
loss_scaler
:
loss_scaler
.
unscale_
(
optimizer
[
'generator'
])
# note that we do not contain clip_grad procedure
loss_scaler
.
step
(
optimizer
[
'generator'
])
# loss_scaler.update will be called in runner.train()
else
:
optimizer
[
'generator'
].
step
()
log_vars
=
{}
log_vars
.
update
(
log_vars_g
)
log_vars
.
update
(
log_vars_disc
)
results
=
dict
(
fake_imgs
=
fake_imgs
.
cpu
(),
real_imgs
=
real_imgs
.
cpu
())
outputs
=
dict
(
log_vars
=
log_vars
,
num_samples
=
batch_size
,
results
=
results
)
if
hasattr
(
self
,
'iteration'
):
self
.
iteration
+=
1
return
outputs
def
sample_from_noise
(
self
,
noise
,
num_batches
=
0
,
sample_model
=
'ema/orig'
,
label
=
None
,
**
kwargs
):
"""Sample images from noises by using the generator.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
sampel_model (str, optional): Use which model to sample fake
images. Defaults to `'ema/orig'`.
label (torch.Tensor | None , optional): The conditional label.
Defaults to None.
Returns:
torch.Tensor | dict: The output may be the direct synthesized
images in ``torch.Tensor``. Otherwise, a dict with queried
data, including generated images, will be returned.
"""
if
sample_model
==
'ema'
:
assert
self
.
use_ema
_model
=
self
.
generator_ema
elif
sample_model
==
'ema/orig'
and
self
.
use_ema
:
_model
=
self
.
generator_ema
else
:
_model
=
self
.
generator
outputs
=
_model
(
noise
,
num_batches
=
num_batches
,
label
=
label
,
**
kwargs
)
if
isinstance
(
outputs
,
dict
)
and
'noise_batch'
in
outputs
:
noise
=
outputs
[
'noise_batch'
]
label
=
outputs
[
'label'
]
if
sample_model
==
'ema/orig'
and
self
.
use_ema
:
_model
=
self
.
generator
outputs_
=
_model
(
noise
,
num_batches
=
num_batches
,
label
=
label
,
**
kwargs
)
if
isinstance
(
outputs_
,
dict
):
outputs
[
'fake_img'
]
=
torch
.
cat
(
[
outputs
[
'fake_img'
],
outputs_
[
'fake_img'
]],
dim
=
0
)
else
:
outputs
=
torch
.
cat
([
outputs
,
outputs_
],
dim
=
0
)
return
outputs
build/lib/mmgen/models/gans/mspie_stylegan2.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
logging
from
functools
import
partial
import
mmcv
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
torch.nn.parallel.distributed
import
_find_tensors
from
..builder
import
MODELS
from
..common
import
set_requires_grad
from
.static_unconditional_gan
import
StaticUnconditionalGAN
@
MODELS
.
register_module
()
class
MSPIEStyleGAN2
(
StaticUnconditionalGAN
):
"""MS-PIE StyleGAN2.
In this GAN, we adopt the MS-PIE training schedule so that multi-scale
images can be generated with a single generator. Details can be found in:
Positional Encoding as Spatial Inductive Bias in GANs, CVPR2021.
Args:
generator (dict): Config for generator.
discriminator (dict): Config for discriminator.
gan_loss (dict): Config for generative adversarial loss.
disc_auxiliary_loss (dict): Config for auxiliary loss to
discriminator.
gen_auxiliary_loss (dict | None, optional): Config for auxiliary loss
to generator. Defaults to None.
train_cfg (dict | None, optional): Config for training schedule.
Defaults to None.
test_cfg (dict | None, optional): Config for testing schedule. Defaults
to None.
"""
def
_parse_train_cfg
(
self
):
super
(
MSPIEStyleGAN2
,
self
).
_parse_train_cfg
()
# set the number of upsampling blocks. This value will be used to
# calculate the current result size according to the size of the input
# feature map, e.g., positional encoding map
self
.
num_upblocks
=
self
.
train_cfg
.
get
(
'num_upblocks'
,
6
)
# multiple input scales (a list of int) that will be added to the
# original starting scale.
self
.
multi_input_scales
=
self
.
train_cfg
.
get
(
'multi_input_scales'
)
self
.
multi_scale_probability
=
self
.
train_cfg
.
get
(
'multi_scale_probability'
)
def
train_step
(
self
,
data_batch
,
optimizer
,
ddp_reducer
=
None
,
running_status
=
None
):
"""Train step function.
This function implements the standard training iteration for
asynchronous adversarial training. Namely, in each iteration, we first
update discriminator and then compute loss for generator with the newly
updated discriminator.
As for distributed training, we use the ``reducer`` from ddp to
synchronize the necessary params in current computational graph.
Args:
data_batch (dict): Input data from dataloader.
optimizer (dict): Dict contains optimizer for generator and
discriminator.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Contains 'log_vars', 'num_samples', and 'results'.
"""
# get data from data_batch
real_imgs
=
data_batch
[
'real_img'
]
# If you adopt ddp, this batch size is local batch size for each GPU.
# If you adopt dp, this batch size is the global batch size as usual.
batch_size
=
real_imgs
.
shape
[
0
]
# get running status
if
running_status
is
not
None
:
curr_iter
=
running_status
[
'iteration'
]
else
:
# dirty walkround for not providing running status
if
not
hasattr
(
self
,
'iteration'
):
self
.
iteration
=
0
curr_iter
=
self
.
iteration
if
dist
.
is_initialized
():
# randomly sample a scale for current training iteration
chosen_scale
=
np
.
random
.
choice
(
self
.
multi_input_scales
,
1
,
self
.
multi_scale_probability
)[
0
]
chosen_scale
=
torch
.
tensor
(
chosen_scale
,
dtype
=
torch
.
int
).
cuda
()
dist
.
broadcast
(
chosen_scale
,
0
)
chosen_scale
=
int
(
chosen_scale
.
item
())
else
:
mmcv
.
print_log
(
'Distributed training has not been initialized. Degrade to '
'the standard stylegan2'
,
logger
=
'mmgen'
,
level
=
logging
.
WARN
)
chosen_scale
=
0
curr_size
=
(
4
+
chosen_scale
)
*
(
2
**
self
.
num_upblocks
)
# adjust the shape of images
if
real_imgs
.
shape
[
-
2
:]
!=
(
curr_size
,
curr_size
):
real_imgs
=
F
.
interpolate
(
real_imgs
,
size
=
(
curr_size
,
curr_size
),
mode
=
'bilinear'
,
align_corners
=
True
)
# disc training
set_requires_grad
(
self
.
discriminator
,
True
)
optimizer
[
'discriminator'
].
zero_grad
()
# TODO: add noise sampler to customize noise sampling
with
torch
.
no_grad
():
fake_imgs
=
self
.
generator
(
None
,
num_batches
=
batch_size
,
chosen_scale
=
chosen_scale
)
# disc pred for fake imgs and real_imgs
disc_pred_fake
=
self
.
discriminator
(
fake_imgs
)
disc_pred_real
=
self
.
discriminator
(
real_imgs
)
# get data dict to compute losses for disc
data_dict_
=
dict
(
gen
=
self
.
generator
,
disc
=
self
.
discriminator
,
disc_pred_fake
=
disc_pred_fake
,
disc_pred_real
=
disc_pred_real
,
fake_imgs
=
fake_imgs
,
real_imgs
=
real_imgs
,
iteration
=
curr_iter
,
batch_size
=
batch_size
,
gen_partial
=
partial
(
self
.
generator
,
chosen_scale
=
chosen_scale
))
loss_disc
,
log_vars_disc
=
self
.
_get_disc_loss
(
data_dict_
)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if
ddp_reducer
is
not
None
:
ddp_reducer
.
prepare_for_backward
(
_find_tensors
(
loss_disc
))
loss_disc
.
backward
()
optimizer
[
'discriminator'
].
step
()
# skip generator training if only train discriminator for current
# iteration
if
(
curr_iter
+
1
)
%
self
.
disc_steps
!=
0
:
results
=
dict
(
fake_imgs
=
fake_imgs
.
cpu
(),
real_imgs
=
real_imgs
.
cpu
())
log_vars_disc
[
'curr_size'
]
=
curr_size
outputs
=
dict
(
log_vars
=
log_vars_disc
,
num_samples
=
batch_size
,
results
=
results
)
if
hasattr
(
self
,
'iteration'
):
self
.
iteration
+=
1
return
outputs
# generator training
set_requires_grad
(
self
.
discriminator
,
False
)
optimizer
[
'generator'
].
zero_grad
()
# TODO: add noise sampler to customize noise sampling
fake_imgs
=
self
.
generator
(
None
,
num_batches
=
batch_size
,
chosen_scale
=
chosen_scale
)
disc_pred_fake_g
=
self
.
discriminator
(
fake_imgs
)
data_dict_
=
dict
(
gen
=
self
.
generator
,
disc
=
self
.
discriminator
,
fake_imgs
=
fake_imgs
,
disc_pred_fake_g
=
disc_pred_fake_g
,
iteration
=
curr_iter
,
batch_size
=
batch_size
,
gen_partial
=
partial
(
self
.
generator
,
chosen_scale
=
chosen_scale
))
loss_gen
,
log_vars_g
=
self
.
_get_gen_loss
(
data_dict_
)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if
ddp_reducer
is
not
None
:
ddp_reducer
.
prepare_for_backward
(
_find_tensors
(
loss_gen
))
loss_gen
.
backward
()
optimizer
[
'generator'
].
step
()
log_vars
=
{}
log_vars
.
update
(
log_vars_g
)
log_vars
.
update
(
log_vars_disc
)
log_vars
[
'curr_size'
]
=
curr_size
results
=
dict
(
fake_imgs
=
fake_imgs
.
cpu
(),
real_imgs
=
real_imgs
.
cpu
())
outputs
=
dict
(
log_vars
=
log_vars
,
num_samples
=
batch_size
,
results
=
results
)
if
hasattr
(
self
,
'iteration'
):
self
.
iteration
+=
1
return
outputs
build/lib/mmgen/models/gans/progressive_growing_unconditional_gan.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
from
functools
import
partial
import
mmcv
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.nn.parallel.distributed
import
_find_tensors
from
mmgen.core.optimizer
import
build_optimizers
from
mmgen.models.builder
import
MODELS
,
build_module
from
..common
import
set_requires_grad
from
.base_gan
import
BaseGAN
@
MODELS
.
register_module
(
'StyleGANV1'
)
@
MODELS
.
register_module
(
'PGGAN'
)
@
MODELS
.
register_module
()
class
ProgressiveGrowingGAN
(
BaseGAN
):
"""Progressive Growing Unconditional GAN.
In this GAN model, we implement progressive growing training schedule,
which is proposed in Progressive Growing of GANs for improved Quality,
Stability and Variation, ICLR 2018.
We highly recommend to use ``GrowScaleImgDataset`` for saving computational
load in data pre-processing.
Notes for **using PGGAN**:
#. In official implementation, Tero uses gradient penalty with
``norm_mode="HWC"``
#. We do not implement ``minibatch_repeats`` where has been used in
official Tensorflow implementation.
Notes for resuming progressive growing GANs:
Users should specify the ``prev_stage`` in ``train_cfg``. Otherwise, the
model is possible to reset the optimizer status, which will bring
inferior performance. For example, if your model is resumed from the
`256` stage, you should set ``train_cfg=dict(prev_stage=256)``.
Args:
generator (dict): Config for generator.
discriminator (dict): Config for discriminator.
gan_loss (dict): Config for generative adversarial loss.
disc_auxiliary_loss (dict): Config for auxiliary loss to
discriminator.
gen_auxiliary_loss (dict | None, optional): Config for auxiliary loss
to generator. Defaults to None.
train_cfg (dict | None, optional): Config for training schedule.
Defaults to None.
test_cfg (dict | None, optional): Config for testing schedule. Defaults
to None.
"""
def
__init__
(
self
,
generator
,
discriminator
,
gan_loss
,
disc_auxiliary_loss
,
gen_auxiliary_loss
=
None
,
train_cfg
=
None
,
test_cfg
=
None
):
super
().
__init__
()
self
.
_gen_cfg
=
deepcopy
(
generator
)
self
.
generator
=
build_module
(
generator
)
# support no discriminator in testing
if
discriminator
is
not
None
:
self
.
discriminator
=
build_module
(
discriminator
)
else
:
self
.
discriminator
=
None
# support no gan_loss in testing
if
gan_loss
is
not
None
:
self
.
gan_loss
=
build_module
(
gan_loss
)
else
:
self
.
gan_loss
=
None
if
disc_auxiliary_loss
:
self
.
disc_auxiliary_losses
=
build_module
(
disc_auxiliary_loss
)
if
not
isinstance
(
self
.
disc_auxiliary_losses
,
nn
.
ModuleList
):
self
.
disc_auxiliary_losses
=
nn
.
ModuleList
(
[
self
.
disc_auxiliary_losses
])
else
:
self
.
disc_auxiliary_losses
=
None
if
gen_auxiliary_loss
:
self
.
gen_auxiliary_losses
=
build_module
(
gen_auxiliary_loss
)
if
not
isinstance
(
self
.
gen_auxiliary_losses
,
nn
.
ModuleList
):
self
.
gen_auxiliary_losses
=
nn
.
ModuleList
(
[
self
.
gen_auxiliary_losses
])
else
:
self
.
gen_auxiliary_losses
=
None
# register necessary training status
self
.
register_buffer
(
'shown_nkimg'
,
torch
.
tensor
(
0.
))
self
.
register_buffer
(
'_curr_transition_weight'
,
torch
.
tensor
(
1.
))
self
.
train_cfg
=
deepcopy
(
train_cfg
)
if
train_cfg
else
None
self
.
test_cfg
=
deepcopy
(
test_cfg
)
if
test_cfg
else
None
self
.
_parse_train_cfg
()
# this buffer is used to resume model easily
self
.
register_buffer
(
'_next_scale_int'
,
torch
.
tensor
(
self
.
scales
[
0
][
0
],
dtype
=
torch
.
int32
))
# TODO: init it with the same value as `_next_scale_int`
# a dirty workaround for testing
self
.
register_buffer
(
'_curr_scale_int'
,
torch
.
tensor
(
self
.
scales
[
-
1
][
0
],
dtype
=
torch
.
int32
))
if
test_cfg
is
not
None
:
self
.
_parse_test_cfg
()
def
_parse_train_cfg
(
self
):
"""Parsing train config and set some attributes for training."""
if
self
.
train_cfg
is
None
:
self
.
train_cfg
=
dict
()
# control the work flow in train step
self
.
disc_steps
=
self
.
train_cfg
.
get
(
'disc_steps'
,
1
)
# whether to use exponential moving average for training
self
.
use_ema
=
self
.
train_cfg
.
get
(
'use_ema'
,
False
)
if
self
.
use_ema
:
# use deepcopy to guarantee the consistency
self
.
generator_ema
=
deepcopy
(
self
.
generator
)
# setup interpolation operation at the beginning of training iter
interp_real_cfg
=
deepcopy
(
self
.
train_cfg
.
get
(
'interp_real'
,
None
))
if
interp_real_cfg
is
None
:
interp_real_cfg
=
dict
(
mode
=
'bilinear'
,
align_corners
=
True
)
self
.
interp_real_to
=
partial
(
F
.
interpolate
,
**
interp_real_cfg
)
# parsing the training schedule: scales : kimg
assert
isinstance
(
self
.
train_cfg
[
'nkimgs_per_scale'
],
dict
),
(
'Please provide "nkimgs_per_'
'scale" to schedule the training procedure.'
)
nkimgs_per_scale
=
deepcopy
(
self
.
train_cfg
[
'nkimgs_per_scale'
])
self
.
scales
=
[]
self
.
nkimgs
=
[]
for
k
,
v
in
nkimgs_per_scale
.
items
():
# support for different data types
if
isinstance
(
k
,
str
):
k
=
(
int
(
k
),
int
(
k
))
elif
isinstance
(
k
,
int
):
k
=
(
k
,
k
)
else
:
assert
mmcv
.
is_tuple_of
(
k
,
int
)
# sanity check for the order of scales
assert
len
(
self
.
scales
)
==
0
or
k
[
0
]
>
self
.
scales
[
-
1
][
0
]
self
.
scales
.
append
(
k
)
self
.
nkimgs
.
append
(
v
)
self
.
cum_nkimgs
=
np
.
cumsum
(
self
.
nkimgs
)
self
.
curr_stage
=
0
self
.
prev_stage
=
0
# actually nkimgs shown at the end of per training stage
self
.
_actual_nkimgs
=
[]
# In each scale, transit from previous torgb layer to newer torgb layer
# with `transition_kimgs` imgs
self
.
transition_kimgs
=
self
.
train_cfg
.
get
(
'transition_kimgs'
,
600
)
# setup optimizer
self
.
optimizer
=
build_optimizers
(
self
,
deepcopy
(
self
.
train_cfg
[
'optimizer_cfg'
]))
# get lr schedule
self
.
g_lr_base
=
self
.
train_cfg
[
'g_lr_base'
]
self
.
d_lr_base
=
self
.
train_cfg
[
'd_lr_base'
]
# example for lr schedule: {'32': 0.001, '64': 0.0001}
self
.
g_lr_schedule
=
self
.
train_cfg
.
get
(
'g_lr_schedule'
,
dict
())
self
.
d_lr_schedule
=
self
.
train_cfg
.
get
(
'd_lr_schedule'
,
dict
())
# reset the states for optimizers, e.g. momentum in Adam
self
.
reset_optim_for_new_scale
=
self
.
train_cfg
.
get
(
'reset_optim_for_new_scale'
,
True
)
# dirty walkround for avoiding optimizer bug in resuming
self
.
prev_stage
=
self
.
train_cfg
.
get
(
'prev_stage'
,
self
.
prev_stage
)
def
_parse_test_cfg
(
self
):
"""Parsing train config and set some attributes for testing."""
if
self
.
test_cfg
is
None
:
self
.
test_cfg
=
dict
()
# basic testing information
self
.
batch_size
=
self
.
test_cfg
.
get
(
'batch_size'
,
1
)
# whether to use exponential moving average for testing
self
.
use_ema
=
self
.
test_cfg
.
get
(
'use_ema'
,
False
)
# TODO: finish ema part
def
sample_from_noise
(
self
,
noise
,
num_batches
=
0
,
curr_scale
=
None
,
transition_weight
=
None
,
sample_model
=
'ema/orig'
,
**
kwargs
):
"""Sample images from noises by using the generator.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
Returns:
torch.Tensor | dict: The output may be the direct synthesized
\
images in ``torch.Tensor``. Otherwise, a dict with queried
\
data, including generated images, will be returned.
"""
# use `self.curr_scale` if curr_scale is None
if
curr_scale
is
None
:
# in training, 'curr_scale' will be set as attribute
if
hasattr
(
self
,
'curr_scale'
):
curr_scale
=
self
.
curr_scale
[
0
]
# in testing, adopt '_curr_scale_int' from buffer as testing scale
else
:
curr_scale
=
self
.
_curr_scale_int
.
item
()
# use `self._curr_transition_weight` if `transition_weight` is None
if
transition_weight
is
None
:
transition_weight
=
self
.
_curr_transition_weight
.
item
()
if
sample_model
==
'ema'
:
assert
self
.
use_ema
_model
=
self
.
generator_ema
elif
sample_model
==
'ema/orig'
and
self
.
use_ema
:
_model
=
self
.
generator_ema
else
:
_model
=
self
.
generator
outputs
=
_model
(
noise
,
num_batches
=
num_batches
,
curr_scale
=
curr_scale
,
transition_weight
=
transition_weight
,
**
kwargs
)
if
isinstance
(
outputs
,
dict
)
and
'noise_batch'
in
outputs
:
noise
=
outputs
[
'noise_batch'
]
if
sample_model
==
'ema/orig'
and
self
.
use_ema
:
_model
=
self
.
generator
outputs_
=
_model
(
noise
,
num_batches
=
num_batches
,
curr_scale
=
curr_scale
,
transition_weight
=
transition_weight
,
**
kwargs
)
if
isinstance
(
outputs_
,
dict
):
outputs
[
'fake_img'
]
=
torch
.
cat
(
[
outputs
[
'fake_img'
],
outputs_
[
'fake_img'
]],
dim
=
0
)
else
:
outputs
=
torch
.
cat
([
outputs
,
outputs_
],
dim
=
0
)
return
outputs
def
train_step
(
self
,
data_batch
,
optimizer
,
ddp_reducer
=
None
,
running_status
=
None
):
"""Train step function.
This function implements the standard training iteration for
asynchronous adversarial training. Namely, in each iteration, we first
update discriminator and then compute loss for generator with the newly
updated discriminator.
As for distributed training, we use the ``reducer`` from ddp to
synchronize the necessary params in current computational graph.
Args:
data_batch (dict): Input data from dataloader.
optimizer (dict): Dict contains optimizer for generator and
discriminator.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Contains 'log_vars', 'num_samples', and 'results'.
"""
# get data from data_batch
real_imgs
=
data_batch
[
'real_img'
]
# If you adopt ddp, this batch size is local batch size for each GPU.
batch_size
=
real_imgs
.
shape
[
0
]
# get running status
if
running_status
is
not
None
:
curr_iter
=
running_status
[
'iteration'
]
else
:
# dirty walkround for not providing running status
if
not
hasattr
(
self
,
'iteration'
):
self
.
iteration
=
0
curr_iter
=
self
.
iteration
# check if optimizer from model
if
hasattr
(
self
,
'optimizer'
):
optimizer
=
self
.
optimizer
# update current stage
self
.
curr_stage
=
int
(
min
(
sum
(
self
.
cum_nkimgs
<=
self
.
shown_nkimg
.
item
()),
len
(
self
.
scales
)
-
1
))
self
.
curr_scale
=
self
.
scales
[
self
.
curr_stage
]
self
.
_curr_scale_int
=
self
.
_next_scale_int
.
clone
()
# add new scale and update training status
if
self
.
curr_stage
!=
self
.
prev_stage
:
self
.
prev_stage
=
self
.
curr_stage
self
.
_actual_nkimgs
.
append
(
self
.
shown_nkimg
.
item
())
# reset optimizer
if
self
.
reset_optim_for_new_scale
:
optim_cfg
=
deepcopy
(
self
.
train_cfg
[
'optimizer_cfg'
])
optim_cfg
[
'generator'
][
'lr'
]
=
self
.
g_lr_schedule
.
get
(
str
(
self
.
curr_scale
[
0
]),
self
.
g_lr_base
)
optim_cfg
[
'discriminator'
][
'lr'
]
=
self
.
d_lr_schedule
.
get
(
str
(
self
.
curr_scale
[
0
]),
self
.
d_lr_base
)
self
.
optimizer
=
build_optimizers
(
self
,
optim_cfg
)
optimizer
=
self
.
optimizer
mmcv
.
print_log
(
'Reset optimizer for new scale'
,
logger
=
'mmgen'
)
# update training configs, like transition weight for torgb layers.
# get current transition weight for interpolating two torgb layers
if
self
.
curr_stage
==
0
:
transition_weight
=
1.
else
:
transition_weight
=
(
self
.
shown_nkimg
.
item
()
-
self
.
_actual_nkimgs
[
-
1
])
/
self
.
transition_kimgs
# clip to [0, 1]
transition_weight
=
min
(
max
(
transition_weight
,
0.
),
1.
)
self
.
_curr_transition_weight
=
torch
.
tensor
(
transition_weight
).
to
(
self
.
_curr_transition_weight
)
# resize real image to target scale
if
real_imgs
.
shape
[
2
:]
==
self
.
curr_scale
:
pass
elif
real_imgs
.
shape
[
2
]
>=
self
.
curr_scale
[
0
]
and
real_imgs
.
shape
[
3
]
>=
self
.
curr_scale
[
1
]:
real_imgs
=
self
.
interp_real_to
(
real_imgs
,
size
=
self
.
curr_scale
)
else
:
raise
RuntimeError
(
f
'The scale of real image
{
real_imgs
.
shape
[
2
:]
}
is smaller '
f
'than current scale
{
self
.
curr_scale
}
.'
)
# disc training
set_requires_grad
(
self
.
discriminator
,
True
)
optimizer
[
'discriminator'
].
zero_grad
()
# TODO: add noise sampler to customize noise sampling
with
torch
.
no_grad
():
fake_imgs
=
self
.
generator
(
None
,
num_batches
=
batch_size
,
curr_scale
=
self
.
curr_scale
[
0
],
transition_weight
=
transition_weight
)
# disc pred for fake imgs and real_imgs
disc_pred_fake
=
self
.
discriminator
(
fake_imgs
,
curr_scale
=
self
.
curr_scale
[
0
],
transition_weight
=
transition_weight
)
disc_pred_real
=
self
.
discriminator
(
real_imgs
,
curr_scale
=
self
.
curr_scale
[
0
],
transition_weight
=
transition_weight
)
# get data dict to compute losses for disc
data_dict_
=
dict
(
iteration
=
curr_iter
,
gen
=
self
.
generator
,
disc
=
self
.
discriminator
,
disc_pred_fake
=
disc_pred_fake
,
disc_pred_real
=
disc_pred_real
,
fake_imgs
=
fake_imgs
,
real_imgs
=
real_imgs
,
curr_scale
=
self
.
curr_scale
[
0
],
transition_weight
=
transition_weight
,
gen_partial
=
partial
(
self
.
generator
,
curr_scale
=
self
.
curr_scale
[
0
],
transition_weight
=
transition_weight
),
disc_partial
=
partial
(
self
.
discriminator
,
curr_scale
=
self
.
curr_scale
[
0
],
transition_weight
=
transition_weight
))
loss_disc
,
log_vars_disc
=
self
.
_get_disc_loss
(
data_dict_
)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if
ddp_reducer
is
not
None
:
ddp_reducer
.
prepare_for_backward
(
_find_tensors
(
loss_disc
))
loss_disc
.
backward
()
optimizer
[
'discriminator'
].
step
()
# update training log status
if
dist
.
is_initialized
():
_batch_size
=
batch_size
*
dist
.
get_world_size
()
else
:
if
'batch_size'
not
in
running_status
:
raise
RuntimeError
(
'You should offer "batch_size" in running status for PGGAN'
)
_batch_size
=
running_status
[
'batch_size'
]
self
.
shown_nkimg
+=
(
_batch_size
/
1000.
)
log_vars_disc
.
update
(
dict
(
shown_nkimg
=
self
.
shown_nkimg
.
item
(),
curr_scale
=
self
.
curr_scale
[
0
],
transition_weight
=
transition_weight
))
# skip generator training if only train discriminator for current
# iteration
if
(
curr_iter
+
1
)
%
self
.
disc_steps
!=
0
:
results
=
dict
(
fake_imgs
=
fake_imgs
.
cpu
(),
real_imgs
=
real_imgs
.
cpu
())
outputs
=
dict
(
log_vars
=
log_vars_disc
,
num_samples
=
batch_size
,
results
=
results
)
if
hasattr
(
self
,
'iteration'
):
self
.
iteration
+=
1
return
outputs
# generator training
set_requires_grad
(
self
.
discriminator
,
False
)
optimizer
[
'generator'
].
zero_grad
()
# TODO: add noise sampler to customize noise sampling
fake_imgs
=
self
.
generator
(
None
,
num_batches
=
batch_size
,
curr_scale
=
self
.
curr_scale
[
0
],
transition_weight
=
transition_weight
)
disc_pred_fake_g
=
self
.
discriminator
(
fake_imgs
,
curr_scale
=
self
.
curr_scale
[
0
],
transition_weight
=
transition_weight
)
data_dict_
=
dict
(
iteration
=
curr_iter
,
gen
=
self
.
generator
,
disc
=
self
.
discriminator
,
fake_imgs
=
fake_imgs
,
disc_pred_fake_g
=
disc_pred_fake_g
)
loss_gen
,
log_vars_g
=
self
.
_get_gen_loss
(
data_dict_
)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if
ddp_reducer
is
not
None
:
ddp_reducer
.
prepare_for_backward
(
_find_tensors
(
loss_gen
))
loss_gen
.
backward
()
optimizer
[
'generator'
].
step
()
log_vars
=
{}
log_vars
.
update
(
log_vars_g
)
log_vars
.
update
(
log_vars_disc
)
log_vars
.
update
({
'batch_size'
:
batch_size
})
results
=
dict
(
fake_imgs
=
fake_imgs
.
cpu
(),
real_imgs
=
real_imgs
.
cpu
())
outputs
=
dict
(
log_vars
=
log_vars
,
num_samples
=
batch_size
,
results
=
results
)
if
hasattr
(
self
,
'iteration'
):
self
.
iteration
+=
1
# check if a new scale will be added in the next iteration
_curr_stage
=
int
(
min
(
sum
(
self
.
cum_nkimgs
<=
self
.
shown_nkimg
.
item
()),
len
(
self
.
scales
)
-
1
))
# in the next iteration, we will switch to a new scale
if
_curr_stage
!=
self
.
curr_stage
:
# `self._next_scale_int` is updated at the end of `train_step`
self
.
_next_scale_int
=
self
.
_next_scale_int
*
2
return
outputs
build/lib/mmgen/models/gans/singan.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
pickle
from
copy
import
deepcopy
from
functools
import
partial
import
mmcv
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.nn.parallel
import
DataParallel
,
DistributedDataParallel
from
torch.nn.parallel.distributed
import
_find_tensors
from
mmgen.models.architectures.common
import
get_module_device
from
mmgen.models.builder
import
MODELS
,
build_module
from
mmgen.models.gans.base_gan
import
BaseGAN
from
..common
import
set_requires_grad
@
MODELS
.
register_module
()
class
SinGAN
(
BaseGAN
):
"""SinGAN.
This model implement the single image generative adversarial model proposed
in: Singan: Learning a Generative Model from a Single Natural Image,
ICCV'19.
Notes for training:
- This model should be trained with our dataset ``SinGANDataset``.
- In training, the ``total_iters`` arguments is related to the number of
scales in the image pyramid and ``iters_per_scale`` in the ``train_cfg``.
You should set it carefully in the training config file.
Notes for model architectures:
- The generator and discriminator need ``num_scales`` in initialization.
However, this arguments is generated by ``create_real_pyramid`` function
from the ``singan_dataset.py``. The last element in the returned list
(``stop_scale``) is the value for ``num_scales``. Pay attention that this
scale is counted from zero. Please see our tutorial for SinGAN to obtain
more details or our standard config for reference.
Args:
generator (dict): Config for generator.
discriminator (dict): Config for discriminator.
gan_loss (dict): Config for generative adversarial loss.
disc_auxiliary_loss (dict): Config for auxiliary loss to
discriminator.
gen_auxiliary_loss (dict | None, optional): Config for auxiliary loss
to generator. Defaults to None.
train_cfg (dict | None, optional): Config for training schedule.
Defaults to None.
test_cfg (dict | None, optional): Config for testing schedule. Defaults
to None.
"""
def
__init__
(
self
,
generator
,
discriminator
,
gan_loss
,
disc_auxiliary_loss
,
gen_auxiliary_loss
=
None
,
train_cfg
=
None
,
test_cfg
=
None
):
super
().
__init__
()
self
.
_gen_cfg
=
deepcopy
(
generator
)
self
.
generator
=
build_module
(
generator
)
# support no discriminator in testing
if
discriminator
is
not
None
:
self
.
discriminator
=
build_module
(
discriminator
)
else
:
self
.
discriminator
=
None
# support no gan_loss in testing
if
gan_loss
is
not
None
:
self
.
gan_loss
=
build_module
(
gan_loss
)
else
:
self
.
gan_loss
=
None
if
disc_auxiliary_loss
:
self
.
disc_auxiliary_losses
=
build_module
(
disc_auxiliary_loss
)
if
not
isinstance
(
self
.
disc_auxiliary_losses
,
nn
.
ModuleList
):
self
.
disc_auxiliary_losses
=
nn
.
ModuleList
(
[
self
.
disc_auxiliary_losses
])
else
:
self
.
disc_auxiliary_losses
=
None
if
gen_auxiliary_loss
:
self
.
gen_auxiliary_losses
=
build_module
(
gen_auxiliary_loss
)
if
not
isinstance
(
self
.
gen_auxiliary_losses
,
nn
.
ModuleList
):
self
.
gen_auxiliary_losses
=
nn
.
ModuleList
(
[
self
.
gen_auxiliary_losses
])
else
:
self
.
gen_auxiliary_losses
=
None
# register necessary training status
self
.
curr_stage
=
-
1
self
.
noise_weights
=
[
1
]
self
.
fixed_noises
=
[]
self
.
reals
=
[]
self
.
train_cfg
=
deepcopy
(
train_cfg
)
if
train_cfg
else
None
self
.
test_cfg
=
deepcopy
(
test_cfg
)
if
test_cfg
else
None
self
.
_parse_train_cfg
()
if
test_cfg
is
not
None
:
self
.
_parse_test_cfg
()
def
_parse_train_cfg
(
self
):
"""Parsing train config and set some attributes for training."""
if
self
.
train_cfg
is
None
:
self
.
train_cfg
=
dict
()
# whether to use exponential moving average for training
self
.
use_ema
=
self
.
train_cfg
.
get
(
'use_ema'
,
False
)
if
self
.
use_ema
:
# use deepcopy to guarantee the consistency
self
.
generator_ema
=
deepcopy
(
self
.
generator
)
def
_parse_test_cfg
(
self
):
if
self
.
test_cfg
.
get
(
'pkl_data'
,
None
)
is
not
None
:
with
open
(
self
.
test_cfg
.
pkl_data
,
'rb'
)
as
f
:
data
=
pickle
.
load
(
f
)
self
.
fixed_noises
=
self
.
_from_numpy
(
data
[
'fixed_noises'
])
self
.
noise_weights
=
self
.
_from_numpy
(
data
[
'noise_weights'
])
self
.
curr_stage
=
data
[
'curr_stage'
]
mmcv
.
print_log
(
f
'Load pkl data from
{
self
.
test_cfg
.
pkl_data
}
'
,
'mmgen'
)
def
_from_numpy
(
self
,
data
):
if
isinstance
(
data
,
list
):
return
[
self
.
_from_numpy
(
x
)
for
x
in
data
]
if
isinstance
(
data
,
np
.
ndarray
):
data
=
torch
.
from_numpy
(
data
)
device
=
get_module_device
(
self
.
generator
)
data
=
data
.
to
(
device
)
return
data
return
data
def
get_module
(
self
,
model
,
module_name
):
"""Get an inner module from model.
Since we will wrapper DDP for some model, we have to judge whether the
module can be indexed directly.
Args:
model (nn.Module): This model may wrapped with DDP or not.
module_name (str): The name of specific module.
Return:
nn.Module: Returned sub module.
"""
if
isinstance
(
model
,
(
DataParallel
,
DistributedDataParallel
)):
return
getattr
(
model
.
module
,
module_name
)
return
getattr
(
model
,
module_name
)
def
sample_from_noise
(
self
,
noise
,
num_batches
=
0
,
curr_scale
=
None
,
sample_model
=
'ema/orig'
,
**
kwargs
):
"""Sample images from noises by using the generator.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
Returns:
torch.Tensor | dict: The output may be the direct synthesized
\
images in ``torch.Tensor``. Otherwise, a dict with queried
\
data, including generated images, will be returned.
"""
# use `self.curr_scale` if curr_scale is None
if
curr_scale
is
None
:
curr_scale
=
self
.
curr_stage
if
sample_model
==
'ema'
:
assert
self
.
use_ema
_model
=
self
.
generator_ema
elif
sample_model
==
'ema/orig'
and
self
.
use_ema
:
_model
=
self
.
generator_ema
else
:
_model
=
self
.
generator
if
not
self
.
fixed_noises
[
0
].
is_cuda
and
torch
.
cuda
.
is_available
():
self
.
fixed_noises
=
[
x
.
to
(
get_module_device
(
self
))
for
x
in
self
.
fixed_noises
]
outputs
=
_model
(
None
,
fixed_noises
=
self
.
fixed_noises
,
noise_weights
=
self
.
noise_weights
,
rand_mode
=
'rand'
,
num_batches
=
num_batches
,
curr_scale
=
curr_scale
,
**
kwargs
)
return
outputs
def
construct_fixed_noises
(
self
):
"""Construct the fixed noises list used in SinGAN."""
for
i
,
real
in
enumerate
(
self
.
reals
):
h
,
w
=
real
.
shape
[
-
2
:]
if
i
==
0
:
noise
=
torch
.
randn
(
1
,
1
,
h
,
w
).
to
(
real
)
self
.
fixed_noises
.
append
(
noise
)
else
:
noise
=
torch
.
zeros_like
(
real
)
self
.
fixed_noises
.
append
(
noise
)
def
train_step
(
self
,
data_batch
,
optimizer
,
ddp_reducer
=
None
,
running_status
=
None
):
"""Train step function.
This function implements the standard training iteration for
asynchronous adversarial training. Namely, in each iteration, we first
update discriminator and then compute loss for generator with the newly
updated discriminator.
As for distributed training, we use the ``reducer`` from ddp to
synchronize the necessary params in current computational graph.
Args:
data_batch (dict): Input data from dataloader.
optimizer (dict): Dict contains optimizer for generator and
discriminator.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Contains 'log_vars', 'num_samples', and 'results'.
"""
# get running status
if
running_status
is
not
None
:
curr_iter
=
running_status
[
'iteration'
]
else
:
# dirty walkround for not providing running status
if
not
hasattr
(
self
,
'iteration'
):
self
.
iteration
=
0
curr_iter
=
self
.
iteration
# init each scale
if
curr_iter
%
self
.
train_cfg
[
'iters_per_scale'
]
==
0
:
self
.
curr_stage
+=
1
# load weights from prev scale
self
.
get_module
(
self
.
generator
,
'check_and_load_prev_weight'
)(
self
.
curr_stage
)
self
.
get_module
(
self
.
discriminator
,
'check_and_load_prev_weight'
)(
self
.
curr_stage
)
# build optimizer for each scale
g_module
=
self
.
get_module
(
self
.
generator
,
'blocks'
)
param_list
=
g_module
[
self
.
curr_stage
].
parameters
()
self
.
g_optim
=
torch
.
optim
.
Adam
(
param_list
,
lr
=
self
.
train_cfg
[
'lr_g'
],
betas
=
(
0.5
,
0.999
))
d_module
=
self
.
get_module
(
self
.
discriminator
,
'blocks'
)
self
.
d_optim
=
torch
.
optim
.
Adam
(
d_module
[
self
.
curr_stage
].
parameters
(),
lr
=
self
.
train_cfg
[
'lr_d'
],
betas
=
(
0.5
,
0.999
))
self
.
optimizer
=
dict
(
generator
=
self
.
g_optim
,
discriminator
=
self
.
d_optim
)
self
.
g_scheduler
=
torch
.
optim
.
lr_scheduler
.
MultiStepLR
(
optimizer
=
self
.
g_optim
,
**
self
.
train_cfg
[
'lr_scheduler_args'
])
self
.
d_scheduler
=
torch
.
optim
.
lr_scheduler
.
MultiStepLR
(
optimizer
=
self
.
d_optim
,
**
self
.
train_cfg
[
'lr_scheduler_args'
])
optimizer
=
self
.
optimizer
# setup fixed noises and reals pyramid
if
curr_iter
==
0
or
len
(
self
.
reals
)
==
0
:
keys
=
[
k
for
k
in
data_batch
.
keys
()
if
'real_scale'
in
k
]
scales
=
len
(
keys
)
self
.
reals
=
[
data_batch
[
f
'real_scale
{
s
}
'
]
for
s
in
range
(
scales
)]
# here we do not padding fixed noises
self
.
construct_fixed_noises
()
# disc training
set_requires_grad
(
self
.
discriminator
,
True
)
for
_
in
range
(
self
.
train_cfg
[
'disc_steps'
]):
optimizer
[
'discriminator'
].
zero_grad
()
# TODO: add noise sampler to customize noise sampling
with
torch
.
no_grad
():
fake_imgs
=
self
.
generator
(
data_batch
[
'input_sample'
],
self
.
fixed_noises
,
self
.
noise_weights
,
rand_mode
=
'rand'
,
curr_scale
=
self
.
curr_stage
)
# disc pred for fake imgs and real_imgs
disc_pred_fake
=
self
.
discriminator
(
fake_imgs
.
detach
(),
self
.
curr_stage
)
disc_pred_real
=
self
.
discriminator
(
self
.
reals
[
self
.
curr_stage
],
self
.
curr_stage
)
# get data dict to compute losses for disc
data_dict_
=
dict
(
iteration
=
curr_iter
,
gen
=
self
.
generator
,
disc
=
self
.
discriminator
,
disc_pred_fake
=
disc_pred_fake
,
disc_pred_real
=
disc_pred_real
,
fake_imgs
=
fake_imgs
,
real_imgs
=
self
.
reals
[
self
.
curr_stage
],
disc_partial
=
partial
(
self
.
discriminator
,
curr_scale
=
self
.
curr_stage
))
loss_disc
,
log_vars_disc
=
self
.
_get_disc_loss
(
data_dict_
)
# prepare for backward in ddp. If you do not call this function
# before back propagation, the ddp will not dynamically find the
# used params in current computation.
if
ddp_reducer
is
not
None
:
ddp_reducer
.
prepare_for_backward
(
_find_tensors
(
loss_disc
))
loss_disc
.
backward
()
optimizer
[
'discriminator'
].
step
()
log_vars_disc
.
update
(
dict
(
curr_stage
=
self
.
curr_stage
))
# generator training
set_requires_grad
(
self
.
discriminator
,
False
)
for
_
in
range
(
self
.
train_cfg
[
'generator_steps'
]):
optimizer
[
'generator'
].
zero_grad
()
# TODO: add noise sampler to customize noise sampling
fake_imgs
=
self
.
generator
(
data_batch
[
'input_sample'
],
self
.
fixed_noises
,
self
.
noise_weights
,
rand_mode
=
'rand'
,
curr_scale
=
self
.
curr_stage
)
disc_pred_fake_g
=
self
.
discriminator
(
fake_imgs
,
curr_scale
=
self
.
curr_stage
)
recon_imgs
=
self
.
generator
(
data_batch
[
'input_sample'
],
self
.
fixed_noises
,
self
.
noise_weights
,
rand_mode
=
'recon'
,
curr_scale
=
self
.
curr_stage
)
data_dict_
=
dict
(
iteration
=
curr_iter
,
gen
=
self
.
generator
,
disc
=
self
.
discriminator
,
fake_imgs
=
fake_imgs
,
recon_imgs
=
recon_imgs
,
real_imgs
=
self
.
reals
[
self
.
curr_stage
],
disc_pred_fake_g
=
disc_pred_fake_g
)
loss_gen
,
log_vars_g
=
self
.
_get_gen_loss
(
data_dict_
)
# prepare for backward in ddp. If you do not call this function
# before back propagation, the ddp will not dynamically find the
# used params in current computation.
if
ddp_reducer
is
not
None
:
ddp_reducer
.
prepare_for_backward
(
_find_tensors
(
loss_gen
))
loss_gen
.
backward
()
optimizer
[
'generator'
].
step
()
# end of each scale
# calculate noise weight for next scale
if
(
curr_iter
%
self
.
train_cfg
[
'iters_per_scale'
]
==
0
)
and
(
self
.
curr_stage
<
len
(
self
.
reals
)
-
1
):
with
torch
.
no_grad
():
g_recon
=
self
.
generator
(
data_batch
[
'input_sample'
],
self
.
fixed_noises
,
self
.
noise_weights
,
rand_mode
=
'recon'
,
curr_scale
=
self
.
curr_stage
)
if
isinstance
(
g_recon
,
dict
):
g_recon
=
g_recon
[
'fake_img'
]
g_recon
=
F
.
interpolate
(
g_recon
,
self
.
reals
[
self
.
curr_stage
+
1
].
shape
[
-
2
:])
mse
=
F
.
mse_loss
(
g_recon
.
detach
(),
self
.
reals
[
self
.
curr_stage
+
1
])
rmse
=
torch
.
sqrt
(
mse
)
self
.
noise_weights
.
append
(
self
.
train_cfg
.
get
(
'noise_weight_init'
,
0.1
)
*
rmse
.
item
())
# try to release GPU memory.
torch
.
cuda
.
empty_cache
()
log_vars
=
{}
log_vars
.
update
(
log_vars_g
)
log_vars
.
update
(
log_vars_disc
)
results
=
dict
(
fake_imgs
=
fake_imgs
.
cpu
(),
real_imgs
=
self
.
reals
[
self
.
curr_stage
].
cpu
(),
recon_imgs
=
recon_imgs
.
cpu
(),
curr_stage
=
self
.
curr_stage
,
fixed_noises
=
self
.
fixed_noises
,
noise_weights
=
self
.
noise_weights
)
outputs
=
dict
(
log_vars
=
log_vars
,
num_samples
=
1
,
results
=
results
)
# update lr scheduler
self
.
d_scheduler
.
step
()
self
.
g_scheduler
.
step
()
if
hasattr
(
self
,
'iteration'
):
self
.
iteration
+=
1
return
outputs
@
MODELS
.
register_module
()
class
PESinGAN
(
SinGAN
):
"""Positional Encoding in SinGAN.
This modified SinGAN is used to reimplement the experiments in: Positional
Encoding as Spatial Inductive Bias in GANs, CVPR2021.
"""
def
_parse_train_cfg
(
self
):
super
(
PESinGAN
,
self
).
_parse_train_cfg
()
self
.
fixed_noise_with_pad
=
self
.
train_cfg
.
get
(
'fixed_noise_with_pad'
,
False
)
self
.
first_fixed_noises_ch
=
self
.
train_cfg
.
get
(
'first_fixed_noises_ch'
,
1
)
def
construct_fixed_noises
(
self
):
"""Construct the fixed noises list used in SinGAN."""
for
i
,
real
in
enumerate
(
self
.
reals
):
h
,
w
=
real
.
shape
[
-
2
:]
if
self
.
fixed_noise_with_pad
:
pad_
=
self
.
get_module
(
self
,
'generator'
).
pad_head
h
+=
2
*
pad_
w
+=
2
*
pad_
if
i
==
0
:
noise
=
torch
.
randn
(
1
,
self
.
first_fixed_noises_ch
,
h
,
w
).
to
(
real
)
self
.
fixed_noises
.
append
(
noise
)
else
:
noise
=
torch
.
zeros
((
1
,
1
,
h
,
w
)).
to
(
real
)
self
.
fixed_noises
.
append
(
noise
)
build/lib/mmgen/models/gans/static_unconditional_gan.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
import
torch
import
torch.nn
as
nn
from
torch.nn.parallel.distributed
import
_find_tensors
from
..builder
import
MODELS
,
build_module
from
..common
import
set_requires_grad
from
.base_gan
import
BaseGAN
# _SUPPORT_METHODS_ = ['DCGAN', 'STYLEGANv2']
# @MODELS.register_module(_SUPPORT_METHODS_)
@
MODELS
.
register_module
()
class
StaticUnconditionalGAN
(
BaseGAN
):
"""Unconditional GANs with static architecture in training.
This is the standard GAN model containing standard adversarial training
schedule. To fulfill the requirements of various GAN algorithms,
``disc_auxiliary_loss`` and ``gen_auxiliary_loss`` are provided to
customize auxiliary losses, e.g., gradient penalty loss, and discriminator
shift loss. In addition, ``train_cfg`` and ``test_cfg`` aims at setuping
training schedule.
Args:
generator (dict): Config for generator.
discriminator (dict): Config for discriminator.
gan_loss (dict): Config for generative adversarial loss.
disc_auxiliary_loss (dict): Config for auxiliary loss to
discriminator.
gen_auxiliary_loss (dict | None, optional): Config for auxiliary loss
to generator. Defaults to None.
train_cfg (dict | None, optional): Config for training schedule.
Defaults to None.
test_cfg (dict | None, optional): Config for testing schedule. Defaults
to None.
"""
def
__init__
(
self
,
generator
,
discriminator
,
gan_loss
,
disc_auxiliary_loss
=
None
,
gen_auxiliary_loss
=
None
,
train_cfg
=
None
,
test_cfg
=
None
):
super
().
__init__
()
self
.
_gen_cfg
=
deepcopy
(
generator
)
self
.
generator
=
build_module
(
generator
)
# support no discriminator in testing
if
discriminator
is
not
None
:
self
.
discriminator
=
build_module
(
discriminator
)
else
:
self
.
discriminator
=
None
# support no gan_loss in testing
if
gan_loss
is
not
None
:
self
.
gan_loss
=
build_module
(
gan_loss
)
else
:
self
.
gan_loss
=
None
if
disc_auxiliary_loss
:
self
.
disc_auxiliary_losses
=
build_module
(
disc_auxiliary_loss
)
if
not
isinstance
(
self
.
disc_auxiliary_losses
,
nn
.
ModuleList
):
self
.
disc_auxiliary_losses
=
nn
.
ModuleList
(
[
self
.
disc_auxiliary_losses
])
else
:
self
.
disc_auxiliary_loss
=
None
if
gen_auxiliary_loss
:
self
.
gen_auxiliary_losses
=
build_module
(
gen_auxiliary_loss
)
if
not
isinstance
(
self
.
gen_auxiliary_losses
,
nn
.
ModuleList
):
self
.
gen_auxiliary_losses
=
nn
.
ModuleList
(
[
self
.
gen_auxiliary_losses
])
else
:
self
.
gen_auxiliary_losses
=
None
self
.
train_cfg
=
deepcopy
(
train_cfg
)
if
train_cfg
else
None
self
.
test_cfg
=
deepcopy
(
test_cfg
)
if
test_cfg
else
None
self
.
_parse_train_cfg
()
if
test_cfg
is
not
None
:
self
.
_parse_test_cfg
()
def
_parse_train_cfg
(
self
):
"""Parsing train config and set some attributes for training."""
if
self
.
train_cfg
is
None
:
self
.
train_cfg
=
dict
()
# control the work flow in train step
self
.
disc_steps
=
self
.
train_cfg
.
get
(
'disc_steps'
,
1
)
# whether to use exponential moving average for training
self
.
use_ema
=
self
.
train_cfg
.
get
(
'use_ema'
,
False
)
if
self
.
use_ema
:
# use deepcopy to guarantee the consistency
self
.
generator_ema
=
deepcopy
(
self
.
generator
)
self
.
real_img_key
=
self
.
train_cfg
.
get
(
'real_img_key'
,
'real_img'
)
def
_parse_test_cfg
(
self
):
"""Parsing test config and set some attributes for testing."""
if
self
.
test_cfg
is
None
:
self
.
test_cfg
=
dict
()
# basic testing information
self
.
batch_size
=
self
.
test_cfg
.
get
(
'batch_size'
,
1
)
# whether to use exponential moving average for testing
self
.
use_ema
=
self
.
test_cfg
.
get
(
'use_ema'
,
False
)
# TODO: finish ema part
def
train_step
(
self
,
data_batch
,
optimizer
,
ddp_reducer
=
None
,
loss_scaler
=
None
,
use_apex_amp
=
False
,
running_status
=
None
):
"""Train step function.
This function implements the standard training iteration for
asynchronous adversarial training. Namely, in each iteration, we first
update discriminator and then compute loss for generator with the newly
updated discriminator.
As for distributed training, we use the ``reducer`` from ddp to
synchronize the necessary params in current computational graph.
Args:
data_batch (dict): Input data from dataloader.
optimizer (dict): Dict contains optimizer for generator and
discriminator.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
loss_scaler (:obj:`torch.cuda.amp.GradScaler` | None, optional):
The loss/gradient scaler used for auto mixed-precision
training. Defaults to ``None``.
use_apex_amp (bool, optional). Whether to use apex.amp. Defaults to
``False``.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Contains 'log_vars', 'num_samples', and 'results'.
"""
# get data from data_batch
real_imgs
=
data_batch
[
self
.
real_img_key
]
# If you adopt ddp, this batch size is local batch size for each GPU.
# If you adopt dp, this batch size is the global batch size as usual.
batch_size
=
real_imgs
.
shape
[
0
]
# get running status
if
running_status
is
not
None
:
curr_iter
=
running_status
[
'iteration'
]
else
:
# dirty walkround for not providing running status
if
not
hasattr
(
self
,
'iteration'
):
self
.
iteration
=
0
curr_iter
=
self
.
iteration
# disc training
set_requires_grad
(
self
.
discriminator
,
True
)
optimizer
[
'discriminator'
].
zero_grad
()
# TODO: add noise sampler to customize noise sampling
# pass model specific training kwargs
g_training_kwargs
=
{}
if
hasattr
(
self
.
generator
,
'get_training_kwargs'
):
g_training_kwargs
.
update
(
self
.
generator
.
get_training_kwargs
(
phase
=
'disc'
))
with
torch
.
no_grad
():
fake_imgs
=
self
.
generator
(
None
,
num_batches
=
batch_size
,
**
g_training_kwargs
)
# disc pred for fake imgs and real_imgs
disc_pred_fake
=
self
.
discriminator
(
fake_imgs
)
disc_pred_real
=
self
.
discriminator
(
real_imgs
)
# get data dict to compute losses for disc
data_dict_
=
dict
(
gen
=
self
.
generator
,
disc
=
self
.
discriminator
,
disc_pred_fake
=
disc_pred_fake
,
disc_pred_real
=
disc_pred_real
,
fake_imgs
=
fake_imgs
,
real_imgs
=
real_imgs
,
iteration
=
curr_iter
,
batch_size
=
batch_size
,
loss_scaler
=
loss_scaler
)
loss_disc
,
log_vars_disc
=
self
.
_get_disc_loss
(
data_dict_
)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if
ddp_reducer
is
not
None
:
ddp_reducer
.
prepare_for_backward
(
_find_tensors
(
loss_disc
))
if
loss_scaler
:
# add support for fp16
loss_scaler
.
scale
(
loss_disc
).
backward
()
elif
use_apex_amp
:
from
apex
import
amp
with
amp
.
scale_loss
(
loss_disc
,
optimizer
[
'discriminator'
],
loss_id
=
0
)
as
scaled_loss_disc
:
scaled_loss_disc
.
backward
()
else
:
loss_disc
.
backward
()
if
loss_scaler
:
loss_scaler
.
unscale_
(
optimizer
[
'discriminator'
])
# note that we do not contain clip_grad procedure
loss_scaler
.
step
(
optimizer
[
'discriminator'
])
# loss_scaler.update will be called in runner.train()
else
:
optimizer
[
'discriminator'
].
step
()
# skip generator training if only train discriminator for current
# iteration
if
(
curr_iter
+
1
)
%
self
.
disc_steps
!=
0
:
results
=
dict
(
fake_imgs
=
fake_imgs
.
cpu
(),
real_imgs
=
real_imgs
.
cpu
())
outputs
=
dict
(
log_vars
=
log_vars_disc
,
num_samples
=
batch_size
,
results
=
results
)
if
hasattr
(
self
,
'iteration'
):
self
.
iteration
+=
1
return
outputs
# generator training
set_requires_grad
(
self
.
discriminator
,
False
)
optimizer
[
'generator'
].
zero_grad
()
# TODO: add noise sampler to customize noise sampling
# pass model specific training kwargs
g_training_kwargs
=
{}
if
hasattr
(
self
.
generator
,
'get_training_kwargs'
):
g_training_kwargs
.
update
(
self
.
generator
.
get_training_kwargs
(
phase
=
'gen'
))
fake_imgs
=
self
.
generator
(
None
,
num_batches
=
batch_size
,
**
g_training_kwargs
)
disc_pred_fake_g
=
self
.
discriminator
(
fake_imgs
)
data_dict_
=
dict
(
gen
=
self
.
generator
,
disc
=
self
.
discriminator
,
fake_imgs
=
fake_imgs
,
disc_pred_fake_g
=
disc_pred_fake_g
,
iteration
=
curr_iter
,
batch_size
=
batch_size
,
loss_scaler
=
loss_scaler
)
loss_gen
,
log_vars_g
=
self
.
_get_gen_loss
(
data_dict_
)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if
ddp_reducer
is
not
None
:
ddp_reducer
.
prepare_for_backward
(
_find_tensors
(
loss_gen
))
if
loss_scaler
:
loss_scaler
.
scale
(
loss_gen
).
backward
()
elif
use_apex_amp
:
from
apex
import
amp
with
amp
.
scale_loss
(
loss_gen
,
optimizer
[
'generator'
],
loss_id
=
1
)
as
scaled_loss_disc
:
scaled_loss_disc
.
backward
()
else
:
loss_gen
.
backward
()
if
loss_scaler
:
loss_scaler
.
unscale_
(
optimizer
[
'generator'
])
# note that we do not contain clip_grad procedure
loss_scaler
.
step
(
optimizer
[
'generator'
])
# loss_scaler.update will be called in runner.train()
else
:
optimizer
[
'generator'
].
step
()
# update ada p
if
hasattr
(
self
.
discriminator
,
'with_ada'
)
and
self
.
discriminator
.
with_ada
:
self
.
discriminator
.
ada_aug
.
log_buffer
[
0
]
+=
batch_size
self
.
discriminator
.
ada_aug
.
log_buffer
[
1
]
+=
disc_pred_real
.
sign
(
).
sum
()
self
.
discriminator
.
ada_aug
.
update
(
iteration
=
curr_iter
,
num_batches
=
batch_size
)
log_vars_disc
[
'augment'
]
=
(
self
.
discriminator
.
ada_aug
.
aug_pipeline
.
p
.
data
.
cpu
())
log_vars
=
{}
log_vars
.
update
(
log_vars_g
)
log_vars
.
update
(
log_vars_disc
)
results
=
dict
(
fake_imgs
=
fake_imgs
.
cpu
(),
real_imgs
=
real_imgs
.
cpu
())
outputs
=
dict
(
log_vars
=
log_vars
,
num_samples
=
batch_size
,
results
=
results
)
if
hasattr
(
self
,
'iteration'
):
self
.
iteration
+=
1
return
outputs
build/lib/mmgen/models/losses/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.ddpm_loss
import
DDPMVLBLoss
from
.disc_auxiliary_loss
import
(
DiscShiftLoss
,
GradientPenaltyLoss
,
R1GradientPenalty
,
disc_shift_loss
,
gradient_penalty_loss
,
r1_gradient_penalty_loss
)
from
.gan_loss
import
GANLoss
from
.gen_auxiliary_loss
import
(
CLIPLoss
,
FaceIdLoss
,
GeneratorPathRegularizer
,
PerceptualLoss
,
gen_path_regularizer
)
from
.pixelwise_loss
import
(
DiscretizedGaussianLogLikelihoodLoss
,
GaussianKLDLoss
,
L1Loss
,
MSELoss
,
discretized_gaussian_log_likelihood
,
gaussian_kld
)
__all__
=
[
'GANLoss'
,
'DiscShiftLoss'
,
'disc_shift_loss'
,
'gradient_penalty_loss'
,
'GradientPenaltyLoss'
,
'R1GradientPenalty'
,
'r1_gradient_penalty_loss'
,
'GeneratorPathRegularizer'
,
'gen_path_regularizer'
,
'MSELoss'
,
'L1Loss'
,
'gaussian_kld'
,
'GaussianKLDLoss'
,
'DiscretizedGaussianLogLikelihoodLoss'
,
'DDPMVLBLoss'
,
'discretized_gaussian_log_likelihood'
,
'FaceIdLoss'
,
'CLIPLoss'
,
'PerceptualLoss'
]
build/lib/mmgen/models/losses/ddpm_loss.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
abc
import
abstractmethod
from
copy
import
deepcopy
from
functools
import
partial
import
mmcv
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
mmcv.utils
import
digit_version
from
mmgen.models.builder
import
MODULES
from
.pixelwise_loss
import
(
DiscretizedGaussianLogLikelihoodLoss
,
GaussianKLDLoss
,
_reduction_modes
,
mse_loss
)
from
.utils
import
reduce_loss
class
DDPMLoss
(
nn
.
Module
):
"""Base module for DDPM losses. We support loss weight rescale and log
collection for DDPM models in this module.
We support two kinds of loss rescale methods, which can be
controlled by ``rescale_mode`` and ``rescale_cfg``:
1. ``rescale_mode == 'constant'``: ``constant_rescale`` would be called,
and ``rescale_cfg`` should be passed as ``dict(scale=SCALE)``,
e.g., ``dict(scale=1.2)``. Then, all loss terms would be rescaled by
multiply with ``SCALE``
2. ``rescale_mode == timestep_weight``: ``timestep_weight_rescale`` would
be called, and ``weight`` or ``sampler`` who contains attribute of
weight must be passed. Then, loss at timestep `t` would be multiplied
with `weight[t]`. We also support users further apply a constant
rescale factor to all loss terms, e.g.
``rescale_cfg=dict(scale=SCALE)``. The overall rescale function for
loss at timestep ``t`` can be formulated as
`loss[t] := weight[t] * loss[t] * SCALE`. To be noted that, ``weight``
or ``sampler.weight`` would be inplace modified in the outer code.
e.g.,
.. code-blocks:: python
:linenos:
# 1. define weight
weight = torch.randn(10, )
# 2. define loss function
loss_fn = DDPMLoss(rescale_mode='timestep_weight', weight=weight)
# 3 update weight
# wrong usage: `weight` in `loss_fn` is not accessible from now
# because you assign a new tensor to variable `weight`
# weight = torch.randn(10, )
# correct usage: update `weight` inplace
weight[2] = 2
If ``rescale_mode`` is not passed, ``rescale_cfg`` would be ignored, and
all loss terms would not be rescaled.
For loss log collection, we support users to pass a list of (or single)
config by ``log_cfgs`` argument to define how they want to collect loss
terms and show them in the log. Each log collection returns a dict which
key and value are the name and value of collected loss terms. And the dict
will be merged into ``log_vars`` after the loss used for parameter
optimization is calculated. The log updating process for the class which
uses ddpm_loss can be referred to the following pseudo-code:
.. code-block:: python
:linenos:
# 1. loss dict for parameter optimization
losses_dict = {}
# 2. calculate losses
for loss_fn in self.ddpm_loss:
losses_dict[loss_fn.loss_name()] = loss_fn(outputs_dict)
# 3. init log_vars
log_vars = OrderedDict()
# 4. update log_vars with loss terms used for parameter optimization
for loss_name, loss_value in losses_dict.items():
log_vars[loss_name] = loss_value.mean()
# 5. sum all loss terms used for backward
loss = sum(_value for _key, _value in log_vars.items()
if 'loss' in _key)
# 6. update log_var with log collection functions
for loss_fn in self.ddpm_loss:
if hasattr(loss_fn, 'log_vars'):
log_vars.update(loss_fn.log_vars)
Each log configs must contain ``type`` keyword, and may contain ``prefix``
and ``reduction`` keywords.
``type``: Use to get the corresponding collection function. Functions would
be named as ``f'{type}_log_collect'``. In `DDPMLoss`, we only support
``type=='quartile'``, but users may define their log collection
functions and use them in this way.
``prefix``: This keyword is set for avoiding the name of displayed loss
terms being too long. The name of each loss term will set as
``'{prefix}_{log_coll_fn_spec_name}'``, where
``{log_coll_fn_spec_name}`` is name specific to the log collection
function. If passed, it must start with ``'loss_'``. If not passed,
``'loss_'`` would be used.
``reduction``: Control the reduction method of the collected loss terms.
We implement ``quartile_log_collection`` in this module. In detail, we
divide total timesteps into four parts and collect the loss in the
corresponding timestep intervals.
To use those collection methods, users may pass ``log_cfgs`` as the
following example:
.. code-block:: python
:linenos:
log_cfgs = [
dict(type='quartile', reduction=REUCTION, prefix_name=PREFIX),
...
]
Args:
rescale_mode (str, optional): Mode of the loss rescale method.
Defaults to None.
rescale_cfg (dict, optional): Config of the loss rescale method.
log_cfgs (list[dict] | dict | optional): Configs to collect logs.
Defaults to None.
sampler (object): Weight sampler. Defaults to None.
weight (torch.Tensor, optional): Weight used for rescale losses.
Defaults to None.
reduction (str, optional): Same as built-in losses of PyTorch.
Defaults to 'mean'.
loss_name (str, optional): Name of the loss item. Defaults to None.
"""
def
__init__
(
self
,
rescale_mode
=
None
,
rescale_cfg
=
None
,
log_cfgs
=
None
,
weight
=
None
,
sampler
=
None
,
reduction
=
'mean'
,
loss_name
=
None
):
super
().
__init__
()
if
reduction
not
in
_reduction_modes
:
raise
ValueError
(
f
'Unsupported reduction mode:
{
reduction
}
. '
f
'Supported ones are:
{
_reduction_modes
}
'
)
self
.
reduction
=
reduction
self
.
_loss_name
=
loss_name
self
.
log_fn_list
=
[]
log_cfgs_
=
deepcopy
(
log_cfgs
)
if
log_cfgs_
is
not
None
:
if
not
isinstance
(
log_cfgs_
,
list
):
log_cfgs_
=
[
log_cfgs_
]
assert
mmcv
.
is_list_of
(
log_cfgs_
,
dict
)
for
log_cfg_
in
log_cfgs_
:
log_type
=
log_cfg_
.
pop
(
'type'
)
log_collect_fn
=
f
'
{
log_type
}
_log_collect'
assert
hasattr
(
self
,
log_collect_fn
)
log_collect_fn
=
getattr
(
self
,
log_collect_fn
)
log_cfg_
.
setdefault
(
'prefix_name'
,
'loss'
)
assert
log_cfg_
[
'prefix_name'
].
startswith
(
'loss'
)
log_cfg_
.
setdefault
(
'reduction'
,
reduction
)
self
.
log_fn_list
.
append
(
partial
(
log_collect_fn
,
**
log_cfg_
))
self
.
log_vars
=
dict
()
# handle rescale mode
if
not
rescale_mode
:
self
.
rescale_fn
=
lambda
loss
,
t
:
loss
else
:
rescale_fn_name
=
f
'
{
rescale_mode
}
_rescale'
assert
hasattr
(
self
,
rescale_fn_name
)
if
rescale_mode
==
'timestep_weight'
:
if
sampler
is
not
None
and
hasattr
(
sampler
,
'weight'
):
weight
=
sampler
.
weight
else
:
assert
weight
is
not
None
and
isinstance
(
weight
,
torch
.
Tensor
),
(
'
\'
weight
\'
or a
\'
sampler
\'
contains weight '
'attribute is must be
\'
torch.Tensor
\'
for '
'
\'
timestep_weight
\'
rescale_mode.'
)
mmcv
.
print_log
(
'Apply
\'
timestep_weight
\'
rescale_mode for '
f
'
{
self
.
_loss_name
}
. Please make sure the passed weight '
'can be updated by external functions.'
,
'mmgen'
)
rescale_cfg
=
dict
(
weight
=
weight
)
self
.
rescale_fn
=
partial
(
getattr
(
self
,
rescale_fn_name
),
**
rescale_cfg
)
@
staticmethod
def
constant_rescale
(
loss
,
timesteps
,
scale
):
"""Rescale losses at all timesteps with a constant factor.
Args:
loss (torch.Tensor): Losses to rescale.
timesteps (torch.Tensor): Timesteps of each loss items.
scale (int): Rescale factor.
Returns:
torch.Tensor: Rescaled losses.
"""
return
loss
*
scale
@
staticmethod
def
timestep_weight_rescale
(
loss
,
timesteps
,
weight
,
scale
=
1
):
"""Rescale losses corresponding to timestep.
Args:
loss (torch.Tensor): Losses to rescale.
timesteps (torch.Tensor): Timesteps of each loss items.
weight (torch.Tensor): Weight corresponding to each timestep.
scale (int): Rescale factor.
Returns:
torch.Tensor: Rescaled losses.
"""
return
loss
*
weight
[
timesteps
]
*
scale
@
torch
.
no_grad
()
def
collect_log
(
self
,
loss
,
timesteps
):
"""Collect logs.
Args:
loss (torch.Tensor): Losses to collect.
timesteps (torch.Tensor): Timesteps of each loss items.
"""
if
not
self
.
log_fn_list
:
return
if
dist
.
is_initialized
():
ws
=
dist
.
get_world_size
()
placeholder_l
=
[
torch
.
zeros_like
(
loss
)
for
_
in
range
(
ws
)]
placeholder_t
=
[
torch
.
zeros_like
(
timesteps
)
for
_
in
range
(
ws
)]
dist
.
all_gather
(
placeholder_l
,
loss
)
dist
.
all_gather
(
placeholder_t
,
timesteps
)
loss
=
torch
.
cat
(
placeholder_l
,
dim
=
0
)
timesteps
=
torch
.
cat
(
placeholder_t
,
dim
=
0
)
log_vars
=
dict
()
if
(
dist
.
is_initialized
()
and
dist
.
get_rank
()
==
0
)
or
not
dist
.
is_initialized
():
for
log_fn
in
self
.
log_fn_list
:
log_vars
.
update
(
log_fn
(
loss
,
timesteps
))
self
.
log_vars
=
log_vars
@
torch
.
no_grad
()
def
quartile_log_collect
(
self
,
loss
,
timesteps
,
total_timesteps
,
prefix_name
,
reduction
=
'mean'
):
"""Collect loss logs by quartile timesteps.
Args:
loss (torch.Tensor): Loss value of each input. Each loss tensor
should be shape as [bz, ]
timesteps (torch.Tensor): Timesteps corresponding to each loss.
Each loss tensor should be shape as [bz, ].
total_timesteps (int): Total timesteps of diffusion process.
prefix_name (str): Prefix want to show in logs.
reduction (str, optional): Specifies the reduction to apply to the
output losses. Defaults to `mean`.
Returns:
dict: Collected log variables.
"""
if
digit_version
(
torch
.
__version__
)
<=
digit_version
(
'1.6.0'
):
# use true_divide in older torch version
quartile
=
torch
.
true_divide
(
timesteps
,
total_timesteps
)
*
4
else
:
quartile
=
(
timesteps
/
total_timesteps
*
4
)
quartile
=
quartile
.
type
(
torch
.
LongTensor
)
log_vars
=
dict
()
for
idx
in
range
(
4
):
if
not
(
quartile
==
idx
).
any
():
loss_quartile
=
torch
.
zeros
((
1
,
))
else
:
loss_quartile
=
reduce_loss
(
loss
[
quartile
==
idx
],
reduction
)
log_vars
[
f
'
{
prefix_name
}
_quartile_
{
idx
}
'
]
=
loss_quartile
.
item
()
return
log_vars
def
forward
(
self
,
*
args
,
**
kwargs
):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function, ``mse_loss``.
"""
if
len
(
args
)
==
1
:
assert
isinstance
(
args
[
0
],
dict
),
(
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.'
)
output_dict
=
args
[
0
]
elif
'output_dict'
in
kwargs
:
assert
len
(
args
)
==
0
,
(
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.'
)
output_dict
=
kwargs
.
pop
(
'outputs_dict'
)
else
:
raise
NotImplementedError
(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module'
)
# check keys in output_dict
assert
'timesteps'
in
output_dict
,
(
'
\'
timesteps
\'
is must for DDPM-based losses, but found'
f
'
{
output_dict
.
keys
()
}
in
\'
output_dict
\'
'
)
timesteps
=
output_dict
[
'timesteps'
]
loss
=
self
.
_forward_loss
(
output_dict
)
# update log_vars of this class
self
.
collect_log
(
loss
,
timesteps
=
timesteps
)
loss_rescaled
=
self
.
rescale_fn
(
loss
,
timesteps
)
return
reduce_loss
(
loss_rescaled
,
self
.
reduction
)
@
abstractmethod
def
_forward_loss
(
self
,
output_dict
):
"""Forward function for loss calculation. This method should be
implemented by each subclasses.
Args:
outputs_dict (dict): Outputs of the model used to calculate losses.
Returns:
torch.Tensor: Calculated loss.
"""
raise
NotImplementedError
(
'
\'
self._forward_loss
\'
must be implemented.'
)
def
loss_name
(
self
):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return
self
.
_loss_name
@
MODULES
.
register_module
()
class
DDPMVLBLoss
(
DDPMLoss
):
"""Variational lower-bound loss for DDPM-based models.
In this loss, we calculate VLB of different timesteps with different
method. In detail, ``DiscretizedGaussianLogLikelihoodLoss`` is used at
timesteps = 0 and ``GaussianKLDLoss`` at other timesteps.
To control the data flow for loss calculation, users should define
``data_info`` and ``data_info_t_0`` for ``GaussianKLDLoss`` and
``DiscretizedGaussianLogLikelihoodLoss`` respectively. If not passed
``_default_data_info`` and ``_default_data_info_t_0`` would be used.
To be noted that, we only penalize 'variance' in this loss term, and
tensors in output dict corresponding to 'mean' would be detached.
Additionally, we support another log collection function called
``name_log_collection``. In this collection method, we would directly
collect loss terms calculated by different methods.
To use this collection methods, users may passed ``log_cfgs`` as the
following example:
.. code-block:: python
:linenos:
log_cfgs = [
dict(type='name', reduction=REUCTION, prefix_name=PREFIX),
...
]
Args:
rescale_mode (str, optional): Mode of the loss rescale method.
Defaults to None.
rescale_cfg (dict, optional): Config of the loss rescale method.
sampler (object): Weight sampler. Defaults to None.
weight (torch.Tensor, optional): Weight used for rescale losses.
Defaults to None.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary for ``timesteps != 0``.
Defaults to None.
data_info_t_0 (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary for ``timesteps == 0``.
Defaults to None.
log_cfgs (list[dict] | dict | optional): Configs to collect logs.
Defaults to None.
reduction (str, optional): Same as built-in losses of PyTorch.
Defaults to 'mean'.
loss_name (str, optional): Name of the loss item. Defaults to
'loss_ddpm_vlb'.
"""
_default_data_info
=
dict
(
mean_pred
=
'mean_pred'
,
mean_target
=
'mean_target'
,
logvar_pred
=
'logvar_pred'
,
logvar_target
=
'logvar_target'
)
_default_data_info_t_0
=
dict
(
x
=
'real_imgs'
,
mean
=
'mean_pred'
,
logvar
=
'logvar_pred'
)
def
__init__
(
self
,
rescale_mode
=
None
,
rescale_cfg
=
None
,
sampler
=
None
,
weight
=
None
,
data_info
=
None
,
data_info_t_0
=
None
,
log_cfgs
=
None
,
reduction
=
'mean'
,
loss_name
=
'loss_ddpm_vlb'
):
super
().
__init__
(
rescale_mode
,
rescale_cfg
,
log_cfgs
,
weight
,
sampler
,
reduction
,
loss_name
)
self
.
data_info
=
self
.
_default_data_info
\
if
data_info
is
None
else
data_info
self
.
data_info_t_0
=
self
.
_default_data_info_t_0
\
if
data_info_t_0
is
None
else
data_info_t_0
self
.
loss_list
=
[
DiscretizedGaussianLogLikelihoodLoss
(
reduction
=
'flatmean'
,
data_info
=
self
.
data_info_t_0
,
base
=
'2'
,
loss_weight
=-
1
,
only_update_var
=
True
),
GaussianKLDLoss
(
reduction
=
'flatmean'
,
data_info
=
self
.
data_info
,
base
=
'2'
,
only_update_var
=
True
)
]
self
.
loss_select_fn_list
=
[
lambda
t
:
t
==
0
,
lambda
t
:
t
!=
0
]
@
torch
.
no_grad
()
def
name_log_collect
(
self
,
loss
,
timesteps
,
prefix_name
,
reduction
=
'mean'
):
"""Collect loss logs by name (GaissianKLD and
DiscGaussianLogLikelihood).
Args:
loss (torch.Tensor): Loss value of each input. Each loss tensor
should be in the shape of [bz, ].
timesteps (torch.Tensor): Timesteps corresponding to each losses.
Each loss tensor should be in the shape of [bz, ].
prefix_name (str): Prefix want to show in logs.
reduction (str, optional): Specifies the reduction to apply to the
output losses. Defaults to `mean`.
Returns:
dict: Collected log variables.
"""
log_vars
=
dict
()
for
select_fn
,
loss_fn
in
zip
(
self
.
loss_select_fn_list
,
self
.
loss_list
):
mask
=
select_fn
(
timesteps
)
if
not
mask
.
any
():
loss_reduced
=
torch
.
zeros
((
1
,
))
else
:
loss_reduced
=
reduce_loss
(
loss
[
mask
],
reduction
)
# remove original prefix in loss names
loss_term_name
=
loss_fn
.
loss_name
().
replace
(
'loss_'
,
''
)
log_vars
[
f
'
{
prefix_name
}
_
{
loss_term_name
}
'
]
=
loss_reduced
.
item
()
return
log_vars
def
_forward_loss
(
self
,
outputs_dict
):
"""Forward function for loss calculation.
Args:
outputs_dict (dict): Outputs of the model used to calculate losses.
Returns:
torch.Tensor: Calculated loss.
"""
# use `zeros` instead of `zeros_like` to avoid get int tensor
timesteps
=
outputs_dict
[
'timesteps'
]
loss
=
torch
.
zeros_like
(
timesteps
).
float
()
# loss = torch.zeros(*timesteps.shape).to(timesteps.device)
for
select_fn
,
loss_fn
in
zip
(
self
.
loss_select_fn_list
,
self
.
loss_list
):
mask
=
select_fn
(
timesteps
)
outputs_dict_
=
{}
for
k
,
v
in
outputs_dict
.
items
():
if
v
is
None
or
not
isinstance
(
v
,
(
torch
.
Tensor
,
list
)):
outputs_dict_
[
k
]
=
v
elif
isinstance
(
v
,
list
):
outputs_dict_
[
k
]
=
[
v
[
idx
]
for
idx
,
m
in
enumerate
(
mask
)
if
m
]
else
:
outputs_dict_
[
k
]
=
v
[
mask
]
loss
[
mask
]
=
loss_fn
(
outputs_dict_
)
return
loss
@
MODULES
.
register_module
()
class
DDPMMSELoss
(
DDPMLoss
):
"""Mean square loss for DDPM-based models.
Args:
rescale_mode (str, optional): Mode of the loss rescale method.
Defaults to None.
rescale_cfg (dict, optional): Config of the loss rescale method.
sampler (object): Weight sampler. Defaults to None.
weight (torch.Tensor, optional): Weight used for rescale losses.
Defaults to None.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary for ``timesteps != 0``.
Defaults to None.
log_cfgs (list[dict] | dict | optional): Configs to collect logs.
Defaults to None.
reduction (str, optional): Same as built-in losses of PyTorch.
Defaults to 'mean'.
loss_name (str, optional): Name of the loss item. Defaults to
'loss_ddpm_vlb'.
"""
_default_data_info
=
dict
(
pred
=
'eps_t_pred'
,
target
=
'noise'
)
def
__init__
(
self
,
rescale_mode
=
None
,
rescale_cfg
=
None
,
sampler
=
None
,
weight
=
None
,
log_cfgs
=
None
,
reduction
=
'mean'
,
data_info
=
None
,
loss_name
=
'loss_ddpm_mse'
):
super
().
__init__
(
rescale_mode
,
rescale_cfg
,
log_cfgs
,
weight
,
sampler
,
reduction
,
loss_name
)
self
.
data_info
=
self
.
_default_data_info
\
if
data_info
is
None
else
data_info
self
.
loss_fn
=
partial
(
mse_loss
,
reduction
=
'flatmean'
)
def
_forward_loss
(
self
,
outputs_dict
):
"""Forward function for loss calculation.
Args:
outputs_dict (dict): Outputs of the model used to calculate losses.
Returns:
torch.Tensor: Calculated loss.
"""
loss_input_dict
=
{
k
:
outputs_dict
[
v
]
for
k
,
v
in
self
.
data_info
.
items
()
}
loss
=
self
.
loss_fn
(
**
loss_input_dict
)
return
loss
build/lib/mmgen/models/losses/disc_auxiliary_loss.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.autograd
as
autograd
import
torch.nn
as
nn
from
mmgen.models.builder
import
MODULES
from
.utils
import
weighted_loss
@
weighted_loss
def
disc_shift_loss
(
pred
):
"""Disc Shift loss.
This loss is proposed in PGGAN as an auxiliary loss for discriminator.
Args:
pred (Tensor): Input tensor.
Returns:
torch.Tensor: loss tensor.
"""
return
pred
**
2
@
MODULES
.
register_module
()
class
DiscShiftLoss
(
nn
.
Module
):
"""Disc Shift Loss.
This loss is proposed in PGGAN as an auxiliary loss for discriminator.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we will need to provide ``pred`` as input. Thus, an
example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
pred='disc_pred_fake')
Then, the module will automatically construct this mapping from the input
data dictionary.
In addition, in general, ``disc_shift_loss`` will be applied over real and
fake data. In this case, users just need to add this loss module twice, but
with different ``data_info``. Our model will automatically add these two
items.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_disc_shift'.
"""
def
__init__
(
self
,
loss_weight
=
1.0
,
data_info
=
None
,
loss_name
=
'loss_disc_shift'
):
super
().
__init__
()
self
.
loss_weight
=
loss_weight
self
.
data_info
=
data_info
self
.
_loss_name
=
loss_name
def
forward
(
self
,
*
args
,
**
kwargs
):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function, ``disc_shift_loss``.
"""
# use data_info to build computational path
if
self
.
data_info
is
not
None
:
# parse the args and kwargs
if
len
(
args
)
==
1
:
assert
isinstance
(
args
[
0
],
dict
),
(
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.'
)
outputs_dict
=
args
[
0
]
elif
'outputs_dict'
in
kwargs
:
assert
len
(
args
)
==
0
,
(
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.'
)
outputs_dict
=
kwargs
.
pop
(
'outputs_dict'
)
else
:
raise
NotImplementedError
(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module'
)
# link the outputs with loss input args according to self.data_info
loss_input_dict
=
{
k
:
outputs_dict
[
v
]
for
k
,
v
in
self
.
data_info
.
items
()
}
kwargs
.
update
(
loss_input_dict
)
kwargs
.
update
(
dict
(
weight
=
self
.
loss_weight
))
return
disc_shift_loss
(
**
kwargs
)
else
:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return
disc_shift_loss
(
*
args
,
weight
=
self
.
loss_weight
,
**
kwargs
)
def
loss_name
(
self
):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return
self
.
_loss_name
@
weighted_loss
def
gradient_penalty_loss
(
discriminator
,
real_data
,
fake_data
,
mask
=
None
,
norm_mode
=
'pixel'
):
"""Calculate gradient penalty for wgan-gp.
In the detailed implementation, there are two streams where one uses the
pixel-wise gradient norm, but the other adopts normalization along instance
(HWC) dimensions. Thus, ``norm_mode`` are offered to define which mode you
want.
Args:
discriminator (nn.Module): Network for the discriminator.
real_data (Tensor): Real input data.
fake_data (Tensor): Fake input data.
mask (Tensor): Masks for inpainting. Default: None.
norm_mode (str): This argument decides along which dimension the norm
of the gradients will be calculated. Currently, we support ["pixel"
, "HWC"]. Defaults to "pixel".
Returns:
Tensor: A tensor for gradient penalty.
"""
batch_size
=
real_data
.
size
(
0
)
alpha
=
torch
.
rand
(
batch_size
,
1
,
1
,
1
).
to
(
real_data
)
# interpolate between real_data and fake_data
interpolates
=
alpha
*
real_data
+
(
1.
-
alpha
)
*
fake_data
interpolates
=
autograd
.
Variable
(
interpolates
,
requires_grad
=
True
)
disc_interpolates
=
discriminator
(
interpolates
)
gradients
=
autograd
.
grad
(
outputs
=
disc_interpolates
,
inputs
=
interpolates
,
grad_outputs
=
torch
.
ones_like
(
disc_interpolates
),
create_graph
=
True
,
retain_graph
=
True
,
only_inputs
=
True
)[
0
]
if
mask
is
not
None
:
gradients
=
gradients
*
mask
if
norm_mode
==
'pixel'
:
gradients_penalty
=
((
gradients
.
norm
(
2
,
dim
=
1
)
-
1
)
**
2
).
mean
()
elif
norm_mode
==
'HWC'
:
gradients_penalty
=
((
gradients
.
reshape
(
batch_size
,
-
1
).
norm
(
2
,
dim
=
1
)
-
1
)
**
2
).
mean
()
else
:
raise
NotImplementedError
(
'Currently, we only support ["pixel", "HWC"] '
f
'norm mode but got
{
norm_mode
}
.'
)
if
mask
is
not
None
:
gradients_penalty
/=
torch
.
mean
(
mask
)
return
gradients_penalty
@
MODULES
.
register_module
()
class
GradientPenaltyLoss
(
nn
.
Module
):
"""Gradient Penalty for WGAN-GP.
In the detailed implementation, there are two streams where one uses the
pixel-wise gradient norm, but the other adopts normalization along instance
(HWC) dimensions. Thus, ``norm_mode`` are offered to define which mode you
want.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we will need to provide ``discriminator``, ``real_data``,
and ``fake_data`` as input. Thus, an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
discriminator='disc',
real_data='real_imgs',
fake_data='fake_imgs')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
norm_mode (str): This argument decides along which dimension the norm
of the gradients will be calculated. Currently, we support ["pixel"
, "HWC"]. Defaults to "pixel".
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_gp'.
"""
def
__init__
(
self
,
loss_weight
=
1.0
,
norm_mode
=
'pixel'
,
data_info
=
None
,
loss_name
=
'loss_gp'
):
super
().
__init__
()
self
.
loss_weight
=
loss_weight
self
.
norm_mode
=
norm_mode
self
.
data_info
=
data_info
self
.
_loss_name
=
loss_name
def
forward
(
self
,
*
args
,
**
kwargs
):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``gradient_penalty_loss``.
"""
# use data_info to build computational path
if
self
.
data_info
is
not
None
:
# parse the args and kwargs
if
len
(
args
)
==
1
:
assert
isinstance
(
args
[
0
],
dict
),
(
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.'
)
outputs_dict
=
args
[
0
]
elif
'outputs_dict'
in
kwargs
:
assert
len
(
args
)
==
0
,
(
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.'
)
outputs_dict
=
kwargs
.
pop
(
'outputs_dict'
)
else
:
raise
NotImplementedError
(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module'
)
# link the outputs with loss input args according to self.data_info
loss_input_dict
=
{
k
:
outputs_dict
[
v
]
for
k
,
v
in
self
.
data_info
.
items
()
}
kwargs
.
update
(
loss_input_dict
)
kwargs
.
update
(
dict
(
weight
=
self
.
loss_weight
,
norm_mode
=
self
.
norm_mode
))
return
gradient_penalty_loss
(
**
kwargs
)
else
:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return
gradient_penalty_loss
(
*
args
,
weight
=
self
.
loss_weight
,
**
kwargs
)
def
loss_name
(
self
):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return
self
.
_loss_name
@
weighted_loss
def
r1_gradient_penalty_loss
(
discriminator
,
real_data
,
mask
=
None
,
norm_mode
=
'pixel'
,
loss_scaler
=
None
,
use_apex_amp
=
False
):
"""Calculate R1 gradient penalty for WGAN-GP.
R1 regularizer comes from:
"Which Training Methods for GANs do actually Converge?" ICML'2018
Different from original gradient penalty, this regularizer only penalized
gradient w.r.t. real data.
Args:
discriminator (nn.Module): Network for the discriminator.
real_data (Tensor): Real input data.
mask (Tensor): Masks for inpainting. Default: None.
norm_mode (str): This argument decides along which dimension the norm
of the gradients will be calculated. Currently, we support ["pixel"
, "HWC"]. Defaults to "pixel".
Returns:
Tensor: A tensor for gradient penalty.
"""
batch_size
=
real_data
.
shape
[
0
]
real_data
=
real_data
.
clone
().
requires_grad_
()
disc_pred
=
discriminator
(
real_data
)
if
loss_scaler
:
disc_pred
=
loss_scaler
.
scale
(
disc_pred
)
elif
use_apex_amp
:
from
apex.amp._amp_state
import
_amp_state
_loss_scaler
=
_amp_state
.
loss_scalers
[
0
]
disc_pred
=
_loss_scaler
.
loss_scale
()
*
disc_pred
.
float
()
gradients
=
autograd
.
grad
(
outputs
=
disc_pred
,
inputs
=
real_data
,
grad_outputs
=
torch
.
ones_like
(
disc_pred
),
create_graph
=
True
,
retain_graph
=
True
,
only_inputs
=
True
)[
0
]
if
loss_scaler
:
# unscale the gradient
inv_scale
=
1.
/
loss_scaler
.
get_scale
()
gradients
=
gradients
*
inv_scale
elif
use_apex_amp
:
inv_scale
=
1.
/
_loss_scaler
.
loss_scale
()
gradients
=
gradients
*
inv_scale
if
mask
is
not
None
:
gradients
=
gradients
*
mask
if
norm_mode
==
'pixel'
:
gradients_penalty
=
((
gradients
.
norm
(
2
,
dim
=
1
))
**
2
).
mean
()
elif
norm_mode
==
'HWC'
:
gradients_penalty
=
gradients
.
pow
(
2
).
reshape
(
batch_size
,
-
1
).
sum
(
1
).
mean
()
else
:
raise
NotImplementedError
(
'Currently, we only support ["pixel", "HWC"] '
f
'norm mode but got
{
norm_mode
}
.'
)
if
mask
is
not
None
:
gradients_penalty
/=
torch
.
mean
(
mask
)
return
gradients_penalty
@
MODULES
.
register_module
()
class
R1GradientPenalty
(
nn
.
Module
):
"""R1 gradient penalty for WGAN-GP.
R1 regularizer comes from:
"Which Training Methods for GANs do actually Converge?" ICML'2018
Different from original gradient penalty, this regularizer only penalized
gradient w.r.t. real data.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we will need to provide ``discriminator`` and
``real_data`` as input. Thus, an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
discriminator='disc',
real_data='real_imgs')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
norm_mode (str): This argument decides along which dimension the norm
of the gradients will be calculated. Currently, we support ["pixel"
, "HWC"]. Defaults to "pixel".
interval (int, optional): The interval of calculating this loss.
Defaults to 1.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_r1_gp'.
"""
def
__init__
(
self
,
loss_weight
=
1.0
,
norm_mode
=
'pixel'
,
interval
=
1
,
data_info
=
None
,
use_apex_amp
=
False
,
loss_name
=
'loss_r1_gp'
):
super
().
__init__
()
self
.
loss_weight
=
loss_weight
self
.
norm_mode
=
norm_mode
self
.
interval
=
interval
self
.
data_info
=
data_info
self
.
use_apex_amp
=
use_apex_amp
self
.
_loss_name
=
loss_name
def
forward
(
self
,
*
args
,
**
kwargs
):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``r1_gradient_penalty_loss``.
"""
if
self
.
interval
>
1
:
assert
self
.
data_info
is
not
None
# use data_info to build computational path
if
self
.
data_info
is
not
None
:
# parse the args and kwargs
if
len
(
args
)
==
1
:
assert
isinstance
(
args
[
0
],
dict
),
(
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.'
)
outputs_dict
=
args
[
0
]
elif
'outputs_dict'
in
kwargs
:
assert
len
(
args
)
==
0
,
(
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.'
)
outputs_dict
=
kwargs
.
pop
(
'outputs_dict'
)
else
:
raise
NotImplementedError
(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module'
)
if
self
.
interval
>
1
and
outputs_dict
[
'iteration'
]
%
self
.
interval
!=
0
:
return
None
# link the outputs with loss input args according to self.data_info
loss_input_dict
=
{
k
:
outputs_dict
[
v
]
for
k
,
v
in
self
.
data_info
.
items
()
}
kwargs
.
update
(
loss_input_dict
)
kwargs
.
update
(
dict
(
weight
=
self
.
loss_weight
,
norm_mode
=
self
.
norm_mode
,
use_apex_amp
=
self
.
use_apex_amp
))
return
r1_gradient_penalty_loss
(
**
kwargs
)
else
:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return
r1_gradient_penalty_loss
(
*
args
,
weight
=
self
.
loss_weight
,
norm_mode
=
self
.
norm_mode
,
**
kwargs
)
def
loss_name
(
self
):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return
self
.
_loss_name
build/lib/mmgen/models/losses/gan_loss.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
..builder
import
MODULES
@
MODULES
.
register_module
()
class
GANLoss
(
nn
.
Module
):
"""Define GAN loss.
Args:
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge',
'wgan-logistic-ns'.
real_label_val (float): The value for real label. Default: 1.0.
fake_label_val (float): The value for fake label. Default: 0.0.
loss_weight (float): Loss weight. Default: 1.0.
Note that loss_weight is only for generators; and it is always 1.0
for discriminators.
"""
def
__init__
(
self
,
gan_type
,
real_label_val
=
1.0
,
fake_label_val
=
0.0
,
loss_weight
=
1.0
):
super
().
__init__
()
self
.
gan_type
=
gan_type
self
.
loss_weight
=
loss_weight
self
.
real_label_val
=
real_label_val
self
.
fake_label_val
=
fake_label_val
if
self
.
gan_type
==
'vanilla'
:
self
.
loss
=
nn
.
BCEWithLogitsLoss
()
elif
self
.
gan_type
==
'lsgan'
:
self
.
loss
=
nn
.
MSELoss
()
elif
self
.
gan_type
==
'wgan'
:
self
.
loss
=
self
.
_wgan_loss
elif
self
.
gan_type
==
'wgan-logistic-ns'
:
self
.
loss
=
self
.
_wgan_logistic_ns_loss
elif
self
.
gan_type
==
'hinge'
:
self
.
loss
=
nn
.
ReLU
()
else
:
raise
NotImplementedError
(
f
'GAN type
{
self
.
gan_type
}
is not implemented.'
)
def
_wgan_loss
(
self
,
input
,
target
):
"""wgan loss.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return
-
input
.
mean
()
if
target
else
input
.
mean
()
def
_wgan_logistic_ns_loss
(
self
,
input
,
target
):
"""WGAN loss in logistically non-saturating mode.
This loss is widely used in StyleGANv2.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return
F
.
softplus
(
-
input
).
mean
()
if
target
else
F
.
softplus
(
input
).
mean
()
def
get_target_label
(
self
,
input
,
target_is_real
):
"""Get target label.
Args:
input (Tensor): Input tensor.
target_is_real (bool): Whether the target is real or fake.
Returns:
(bool | Tensor): Target tensor. Return bool for wgan, otherwise,
\
return Tensor.
"""
if
self
.
gan_type
in
[
'wgan'
,
'wgan-logistic-ns'
]:
return
target_is_real
target_val
=
(
self
.
real_label_val
if
target_is_real
else
self
.
fake_label_val
)
return
input
.
new_ones
(
input
.
size
())
*
target_val
def
forward
(
self
,
input
,
target_is_real
,
is_disc
=
False
):
"""
Args:
input (Tensor): The input for the loss module, i.e., the network
prediction.
target_is_real (bool): Whether the targe is real or fake.
is_disc (bool): Whether the loss for discriminators or not.
Default: False.
Returns:
Tensor: GAN loss value.
"""
target_label
=
self
.
get_target_label
(
input
,
target_is_real
)
if
self
.
gan_type
==
'hinge'
:
if
is_disc
:
# for discriminators in hinge-gan
input
=
-
input
if
target_is_real
else
input
loss
=
self
.
loss
(
1
+
input
).
mean
()
else
:
# for generators in hinge-gan
loss
=
-
input
.
mean
()
else
:
# other gan types
loss
=
self
.
loss
(
input
,
target_label
)
# loss_weight is always 1.0 for discriminators
return
loss
if
is_disc
else
loss
*
self
.
loss_weight
build/lib/mmgen/models/losses/gen_auxiliary_loss.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
torch
import
torch.autograd
as
autograd
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torchvision.models.vgg
as
vgg
from
mmcv.runner
import
load_checkpoint
from
mmgen.models.builder
import
MODULES
,
build_module
from
mmgen.utils
import
get_root_logger
from
.pixelwise_loss
import
l1_loss
,
mse_loss
def
gen_path_regularizer
(
generator
,
num_batches
,
mean_path_length
,
pl_batch_shrink
=
1
,
decay
=
0.01
,
weight
=
1.
,
pl_batch_size
=
None
,
sync_mean_buffer
=
False
,
loss_scaler
=
None
,
use_apex_amp
=
False
):
"""Generator Path Regularization.
Path regularization is proposed in StyelGAN2, which can help the improve
the continuity of the latent space. More details can be found in:
Analyzing and Improving the Image Quality of StyleGAN, CVPR2020.
Args:
generator (nn.Module): The generator module. Note that this loss
requires that the generator contains ``return_latents`` interface,
with which we can get the latent code of the current sample.
num_batches (int): The number of samples used in calculating this loss.
mean_path_length (Tensor): The mean path length, calculated by moving
average.
pl_batch_shrink (int, optional): The factor of shrinking the batch size
for saving GPU memory. Defaults to 1.
decay (float, optional): Decay for moving average of mean path length.
Defaults to 0.01.
weight (float, optional): Weight of this loss item. Defaults to ``1.``.
pl_batch_size (int | None, optional): The batch size in calculating
generator path. Once this argument is set, the ``num_batches`` will
be overridden with this argument and won't be affectted by
``pl_batch_shrink``. Defaults to None.
sync_mean_buffer (bool, optional): Whether to sync mean path length
across all of GPUs. Defaults to False.
Returns:
tuple[Tensor]: The penalty loss, detached mean path tensor, and
\
current path length.
"""
# reduce batch size for conserving GPU memory
if
pl_batch_shrink
>
1
:
num_batches
=
max
(
1
,
num_batches
//
pl_batch_shrink
)
# reset the batch size if pl_batch_size is not None
if
pl_batch_size
is
not
None
:
num_batches
=
pl_batch_size
# get output from different generators
output_dict
=
generator
(
None
,
num_batches
=
num_batches
,
return_latents
=
True
)
fake_img
,
latents
=
output_dict
[
'fake_img'
],
output_dict
[
'latent'
]
noise
=
torch
.
randn_like
(
fake_img
)
/
np
.
sqrt
(
fake_img
.
shape
[
2
]
*
fake_img
.
shape
[
3
])
if
loss_scaler
:
loss
=
loss_scaler
.
scale
((
fake_img
*
noise
).
sum
())[
0
]
grad
=
autograd
.
grad
(
outputs
=
loss
,
inputs
=
latents
,
grad_outputs
=
torch
.
ones
(()).
to
(
loss
),
create_graph
=
True
,
retain_graph
=
True
,
only_inputs
=
True
)[
0
]
# unsacle the grad
inv_scale
=
1.
/
loss_scaler
.
get_scale
()
grad
=
grad
*
inv_scale
elif
use_apex_amp
:
from
apex.amp._amp_state
import
_amp_state
# by default, we use loss_scalers[0] for discriminator and
# loss_scalers[1] for generator
_loss_scaler
=
_amp_state
.
loss_scalers
[
1
]
loss
=
_loss_scaler
.
loss_scale
()
*
((
fake_img
*
noise
).
sum
()).
float
()
grad
=
autograd
.
grad
(
outputs
=
loss
,
inputs
=
latents
,
grad_outputs
=
torch
.
ones
(()).
to
(
loss
),
create_graph
=
True
,
retain_graph
=
True
,
only_inputs
=
True
)[
0
]
# unsacle the grad
inv_scale
=
1.
/
_loss_scaler
.
loss_scale
()
grad
=
grad
*
inv_scale
else
:
grad
=
autograd
.
grad
(
outputs
=
(
fake_img
*
noise
).
sum
(),
inputs
=
latents
,
grad_outputs
=
torch
.
ones
(()).
to
(
fake_img
),
create_graph
=
True
,
retain_graph
=
True
,
only_inputs
=
True
)[
0
]
path_lengths
=
torch
.
sqrt
(
grad
.
pow
(
2
).
sum
(
2
).
mean
(
1
))
# update mean path
path_mean
=
mean_path_length
+
decay
*
(
path_lengths
.
mean
()
-
mean_path_length
)
if
sync_mean_buffer
and
dist
.
is_initialized
():
dist
.
all_reduce
(
path_mean
)
path_mean
=
path_mean
/
float
(
dist
.
get_world_size
())
path_penalty
=
(
path_lengths
-
path_mean
).
pow
(
2
).
mean
()
*
weight
return
path_penalty
,
path_mean
.
detach
(),
path_lengths
@
MODULES
.
register_module
()
class
GeneratorPathRegularizer
(
nn
.
Module
):
"""Generator Path Regularizer.
Path regularization is proposed in StyelGAN2, which can help the improve
the continuity of the latent space. More details can be found in:
Analyzing and Improving the Image Quality of StyleGAN, CVPR2020.
Users can achieve lazy regularization by setting ``interval`` arguments
here.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
fake_imgs=fake_imgs,
disc_pred_fake_g=disc_pred_fake_g,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we will need to provide ``generator`` and ``num_batches``
as input. Thus an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
generator='gen',
num_batches='batch_size')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
pl_batch_shrink (int, optional): The factor of shrinking the batch size
for saving GPU memory. Defaults to 1.
decay (float, optional): Decay for moving average of mean path length.
Defaults to 0.01.
pl_batch_size (int | None, optional): The batch size in calculating
generator path. Once this argument is set, the ``num_batches`` will
be overridden with this argument and won't be affectted by
``pl_batch_shrink``. Defaults to None.
sync_mean_buffer (bool, optional): Whether to sync mean path length
across all of GPUs. Defaults to False.
interval (int, optional): The interval of calculating this loss. This
argument is used to support lazy regularization. Defaults to 1.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_path_regular'.
"""
def
__init__
(
self
,
loss_weight
=
1.
,
pl_batch_shrink
=
1
,
decay
=
0.01
,
pl_batch_size
=
None
,
sync_mean_buffer
=
False
,
interval
=
1
,
data_info
=
None
,
use_apex_amp
=
False
,
loss_name
=
'loss_path_regular'
):
super
().
__init__
()
self
.
loss_weight
=
loss_weight
self
.
pl_batch_shrink
=
pl_batch_shrink
self
.
decay
=
decay
self
.
pl_batch_size
=
pl_batch_size
self
.
sync_mean_buffer
=
sync_mean_buffer
self
.
interval
=
interval
self
.
data_info
=
data_info
self
.
use_apex_amp
=
use_apex_amp
self
.
_loss_name
=
loss_name
self
.
register_buffer
(
'mean_path_length'
,
torch
.
tensor
(
0.
))
def
forward
(
self
,
*
args
,
**
kwargs
):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``gen_path_regularizer``.
"""
if
self
.
interval
>
1
:
assert
self
.
data_info
is
not
None
# use data_info to build computational path
if
self
.
data_info
is
not
None
:
# parse the args and kwargs
if
len
(
args
)
==
1
:
assert
isinstance
(
args
[
0
],
dict
),
(
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.'
)
outputs_dict
=
args
[
0
]
elif
'outputs_dict'
in
kwargs
:
assert
len
(
args
)
==
0
,
(
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.'
)
outputs_dict
=
kwargs
.
pop
(
'outputs_dict'
)
else
:
raise
NotImplementedError
(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module'
)
if
self
.
interval
>
1
and
outputs_dict
[
'iteration'
]
%
self
.
interval
!=
0
:
return
None
# link the outputs with loss input args according to self.data_info
loss_input_dict
=
{
k
:
outputs_dict
[
v
]
for
k
,
v
in
self
.
data_info
.
items
()
}
kwargs
.
update
(
loss_input_dict
)
kwargs
.
update
(
dict
(
weight
=
self
.
loss_weight
,
mean_path_length
=
self
.
mean_path_length
,
pl_batch_shrink
=
self
.
pl_batch_shrink
,
decay
=
self
.
decay
,
use_apex_amp
=
self
.
use_apex_amp
,
pl_batch_size
=
self
.
pl_batch_size
,
sync_mean_buffer
=
self
.
sync_mean_buffer
))
path_penalty
,
self
.
mean_path_length
,
_
=
gen_path_regularizer
(
**
kwargs
)
return
path_penalty
else
:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return
gen_path_regularizer
(
*
args
,
weight
=
self
.
loss_weight
,
**
kwargs
)
def
loss_name
(
self
):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return
self
.
_loss_name
def
third_party_net_loss
(
net
,
weight
=
1.0
,
**
kwargs
):
return
net
(
**
kwargs
)
*
weight
@
MODULES
.
register_module
()
class
FaceIdLoss
(
nn
.
Module
):
"""Face similarity loss. Generally this loss is used to keep the id
consistency of the input face image and output face image.
In this loss, we may need to provide ``gt``, ``pred`` and ``x``. Thus,
an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
gt='real_imgs',
pred='fake_imgs')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
facenet (dict, optional): Config dict for facenet. Defaults to
dict(type='ArcFace', ir_se50_weights=None, device='cuda').
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_id'.
"""
def
__init__
(
self
,
loss_weight
=
1.0
,
data_info
=
None
,
facenet
=
dict
(
type
=
'ArcFace'
,
ir_se50_weights
=
None
,
device
=
'cuda'
),
loss_name
=
'loss_id'
):
super
(
FaceIdLoss
,
self
).
__init__
()
self
.
loss_weight
=
loss_weight
self
.
data_info
=
data_info
self
.
net
=
build_module
(
facenet
)
self
.
_loss_name
=
loss_name
def
forward
(
self
,
*
args
,
**
kwargs
):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``third_party_net_loss``.
"""
# use data_info to build computational path
if
self
.
data_info
is
not
None
:
# parse the args and kwargs
if
len
(
args
)
==
1
:
assert
isinstance
(
args
[
0
],
dict
),
(
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.'
)
outputs_dict
=
args
[
0
]
elif
'outputs_dict'
in
kwargs
:
assert
len
(
args
)
==
0
,
(
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.'
)
outputs_dict
=
kwargs
.
pop
(
'outputs_dict'
)
else
:
raise
NotImplementedError
(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module'
)
# link the outputs with loss input args according to self.data_info
loss_input_dict
=
{
k
:
outputs_dict
[
v
]
for
k
,
v
in
self
.
data_info
.
items
()
}
kwargs
.
update
(
loss_input_dict
)
kwargs
.
update
(
dict
(
weight
=
self
.
loss_weight
))
return
third_party_net_loss
(
self
.
net
,
*
args
,
**
kwargs
)
else
:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return
third_party_net_loss
(
self
.
net
,
*
args
,
weight
=
self
.
loss_weight
,
**
kwargs
)
def
loss_name
(
self
):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return
self
.
_loss_name
class
CLIPLossModel
(
torch
.
nn
.
Module
):
"""Wrapped clip model to calculate clip loss.
Ref: https://github.com/orpatashnik/StyleCLIP/blob/main/criteria/clip_loss.py # noqa
Args:
in_size (int, optional): Input image size. Defaults to 1024.
scale_factor (int, optional): Unsampling factor. Defaults to 7.
pool_size (int, optional): Pooling output size. Defaults to 224.
clip_type (str, optional): A model name listed by
`clip.available_models()`, or the path to a model checkpoint
containing the state_dict. For more details, you can refer to
https://github.com/openai/CLIP/blob/573315e83f07b53a61ff5098757e8fc885f1703e/clip/clip.py#L91 # noqa
Defaults to 'ViT-B/32'.
device (str, optional): Model device. Defaults to 'cuda'.
"""
def
__init__
(
self
,
in_size
=
1024
,
scale_factor
=
7
,
pool_size
=
224
,
clip_type
=
'ViT-B/32'
,
device
=
'cuda'
):
super
(
CLIPLossModel
,
self
).
__init__
()
try
:
import
clip
except
ImportError
:
raise
'To use clip loss, openai clip need to be installed first'
self
.
model
,
self
.
preprocess
=
clip
.
load
(
clip_type
,
device
=
device
)
self
.
upsample
=
torch
.
nn
.
Upsample
(
scale_factor
=
scale_factor
)
self
.
avg_pool
=
torch
.
nn
.
AvgPool2d
(
kernel_size
=
(
scale_factor
*
in_size
//
pool_size
))
def
forward
(
self
,
image
=
None
,
text
=
None
):
"""Forward function."""
assert
image
is
not
None
assert
text
is
not
None
image
=
self
.
avg_pool
(
self
.
upsample
(
image
))
loss
=
1
-
self
.
model
(
image
,
text
)[
0
]
/
100
return
loss
@
MODULES
.
register_module
()
class
CLIPLoss
(
nn
.
Module
):
"""Clip loss. In styleclip, this loss is used to optimize the latent code
to generate image that match the text.
In this loss, we may need to provide ``image``, ``text``. Thus,
an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
image='fake_imgs',
text='descriptions')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
clip_model (dict, optional): Kwargs for clip loss model. Defaults to
dict().
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_clip'.
"""
def
__init__
(
self
,
loss_weight
=
1.0
,
data_info
=
None
,
clip_model
=
dict
(),
loss_name
=
'loss_clip'
):
super
(
CLIPLoss
,
self
).
__init__
()
self
.
loss_weight
=
loss_weight
self
.
data_info
=
data_info
self
.
net
=
CLIPLossModel
(
**
clip_model
)
self
.
_loss_name
=
loss_name
def
forward
(
self
,
*
args
,
**
kwargs
):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``third_party_net_loss``.
"""
# use data_info to build computational path
if
self
.
data_info
is
not
None
:
# parse the args and kwargs
if
len
(
args
)
==
1
:
assert
isinstance
(
args
[
0
],
dict
),
(
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.'
)
outputs_dict
=
args
[
0
]
elif
'outputs_dict'
in
kwargs
:
assert
len
(
args
)
==
0
,
(
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.'
)
outputs_dict
=
kwargs
.
pop
(
'outputs_dict'
)
else
:
raise
NotImplementedError
(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module'
)
# link the outputs with loss input args according to self.data_info
loss_input_dict
=
{
k
:
outputs_dict
[
v
]
for
k
,
v
in
self
.
data_info
.
items
()
}
kwargs
.
update
(
loss_input_dict
)
kwargs
.
update
(
dict
(
weight
=
self
.
loss_weight
))
return
third_party_net_loss
(
self
.
net
,
*
args
,
**
kwargs
)
else
:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return
third_party_net_loss
(
self
.
net
,
*
args
,
weight
=
self
.
loss_weight
,
**
kwargs
)
@
staticmethod
def
loss_name
():
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return
'clip_loss'
class
PerceptualVGG
(
nn
.
Module
):
"""VGG network used in calculating perceptual loss.
In this implementation, we allow users to choose whether use normalization
in the input feature and the type of vgg network. Note that the pretrained
path must fit the vgg type.
Args:
layer_name_list (list[str]): According to the name in this list,
forward function will return the corresponding features. This
list contains the name each layer in `vgg.feature`. An example
of this list is ['4', '10'].
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image.
Importantly, the input feature must in the range [0, 1].
Default: True.
pretrained (str): Path for pretrained weights. Default:
'torchvision://vgg19'
"""
def
__init__
(
self
,
layer_name_list
,
vgg_type
=
'vgg19'
,
use_input_norm
=
True
,
pretrained
=
'torchvision://vgg19'
):
super
().
__init__
()
if
pretrained
.
startswith
(
'torchvision://'
):
assert
vgg_type
in
pretrained
self
.
layer_name_list
=
layer_name_list
self
.
use_input_norm
=
use_input_norm
# get vgg model and load pretrained vgg weight
# remove _vgg from attributes to avoid `find_unused_parameters` bug
_vgg
=
getattr
(
vgg
,
vgg_type
)()
self
.
init_weights
(
_vgg
,
pretrained
)
num_layers
=
max
(
map
(
int
,
layer_name_list
))
+
1
assert
len
(
_vgg
.
features
)
>=
num_layers
# only borrow layers that will be used from _vgg to avoid unused params
self
.
vgg_layers
=
_vgg
.
features
[:
num_layers
]
if
self
.
use_input_norm
:
# the mean is for image with range [0, 1]
self
.
register_buffer
(
'mean'
,
torch
.
Tensor
([
0.485
,
0.456
,
0.406
]).
view
(
1
,
3
,
1
,
1
))
# the std is for image with range [-1, 1]
self
.
register_buffer
(
'std'
,
torch
.
Tensor
([
0.229
,
0.224
,
0.225
]).
view
(
1
,
3
,
1
,
1
))
for
v
in
self
.
vgg_layers
.
parameters
():
v
.
requires_grad
=
False
def
forward
(
self
,
x
):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if
self
.
use_input_norm
:
x
=
(
x
-
self
.
mean
)
/
self
.
std
output
=
{}
for
name
,
module
in
self
.
vgg_layers
.
named_children
():
x
=
module
(
x
)
if
name
in
self
.
layer_name_list
:
output
[
name
]
=
x
.
clone
()
return
output
def
init_weights
(
self
,
model
,
pretrained
):
"""Init weights.
Args:
model (nn.Module): Models to be inited.
pretrained (str): Path for pretrained weights.
"""
logger
=
get_root_logger
()
load_checkpoint
(
model
,
pretrained
,
logger
=
logger
)
@
MODULES
.
register_module
()
class
PerceptualLoss
(
nn
.
Module
):
"""Perceptual loss with commonly used style loss.
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we may need to provide ``pred`` and ``target`` as input.
Thus, an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
pred='fake_imgs',
target='real_imgs',
layer_weights={
'4': 1.,
'9': 1.,
'18': 1.},
)
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_mse'.
layers_weights (dict): The weight for each layer of vgg feature for
perceptual loss. Here is an example: {'4': 1., '9': 1., '18': 1.},
which means the 5th, 10th and 18th feature layer will be
extracted with weight 1.0 in calculating losses. Defaults to
'{'4': 1., '9': 1., '18': 1.}'.
layers_weights_style (dict): The weight for each layer of vgg feature
for style loss. If set to 'None', the weights are set equal to
the weights for perceptual loss. Default: None.
vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg.
Default: True.
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
loss will be calculated and the loss will multiplied by the
weight. Default: 1.0.
style_weight (float): If `style_weight > 0`, the style loss will be
calculated and the loss will multiplied by the weight.
Default: 1.0.
norm_img (bool): If True, the image will be normed to [0, 1]. Note that
this is different from the `use_input_norm` which norm the input in
in forward function of vgg according to the statistics of dataset.
Importantly, the input image must be in range [-1, 1].
pretrained (str): Path for pretrained weights. Default:
'torchvision://vgg19'.
criterion (str): Criterion type. Options are 'l1' and 'mse'.
Default: 'l1'.
split_style_loss (bool): Whether return a separate style loss item.
Options are True and False. Default: False
"""
def
__init__
(
self
,
data_info
=
None
,
loss_name
=
'loss_perceptual'
,
layer_weights
=
{
'4'
:
1.
,
'9'
:
1.
,
'18'
:
1.
},
layer_weights_style
=
None
,
vgg_type
=
'vgg19'
,
use_input_norm
=
True
,
perceptual_weight
=
1.0
,
style_weight
=
1.0
,
norm_img
=
True
,
pretrained
=
'torchvision://vgg19'
,
criterion
=
'l1'
,
split_style_loss
=
False
):
super
().
__init__
()
self
.
data_info
=
data_info
self
.
_loss_name
=
loss_name
self
.
norm_img
=
norm_img
self
.
perceptual_weight
=
perceptual_weight
self
.
style_weight
=
style_weight
self
.
layer_weights
=
layer_weights
self
.
layer_weights_style
=
layer_weights_style
self
.
split_style_loss
=
split_style_loss
self
.
vgg
=
PerceptualVGG
(
layer_name_list
=
list
(
self
.
layer_weights
.
keys
()),
vgg_type
=
vgg_type
,
use_input_norm
=
use_input_norm
,
pretrained
=
pretrained
)
if
self
.
layer_weights_style
is
not
None
and
\
self
.
layer_weights_style
!=
self
.
layer_weights
:
self
.
vgg_style
=
PerceptualVGG
(
layer_name_list
=
list
(
self
.
layer_weights_style
.
keys
()),
vgg_type
=
vgg_type
,
use_input_norm
=
use_input_norm
,
pretrained
=
pretrained
)
else
:
self
.
layer_weights_style
=
self
.
layer_weights
self
.
vgg_style
=
None
criterion
=
criterion
.
lower
()
if
criterion
==
'l1'
:
self
.
criterion
=
l1_loss
elif
criterion
==
'mse'
:
self
.
criterion
=
mse_loss
else
:
raise
NotImplementedError
(
f
'
{
criterion
}
criterion has not been supported in'
' this version.'
)
def
forward
(
self
,
*
args
,
**
kwargs
):
"""Forward function. If ``self.data_info`` is not ``None``, a
dictionary containing all of the data and necessary modules should be
passed into this function. If this dictionary is given as a non-keyword
argument, it should be offered as the first argument. If you are using
keyword argument, please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function, ``mse_loss``.
Args:
pred (Tensor): Input tensor with shape (n, c, h, w).
target (Tensor): Ground-truth tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
# use data_info to build computational path
if
self
.
data_info
is
not
None
:
# parse the args and kwargs
if
len
(
args
)
==
1
:
assert
isinstance
(
args
[
0
],
dict
),
(
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.'
)
outputs_dict
=
args
[
0
]
elif
'outputs_dict'
in
kwargs
:
assert
len
(
args
)
==
0
,
(
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.'
)
outputs_dict
=
kwargs
.
pop
(
'outputs_dict'
)
else
:
raise
NotImplementedError
(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module'
)
# link the outputs with loss input args according to self.data_info
loss_input_dict
=
{
k
:
outputs_dict
[
v
]
for
k
,
v
in
self
.
data_info
.
items
()
}
kwargs
.
update
(
loss_input_dict
)
return
self
.
perceptual_loss
(
**
kwargs
)
else
:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return
self
.
perceptual_loss
(
*
args
,
**
kwargs
)
def
perceptual_loss
(
self
,
pred
,
target
):
if
self
.
norm_img
:
pred
=
(
pred
+
1.
)
*
0.5
target
=
(
target
+
1.
)
*
0.5
# extract vgg features
pred_features
=
self
.
vgg
(
pred
)
target_features
=
self
.
vgg
(
target
.
detach
())
# calculate perceptual loss
if
self
.
perceptual_weight
>
0
:
percep_loss
=
0
for
k
in
pred_features
.
keys
():
percep_loss
+=
self
.
criterion
(
pred_features
[
k
],
target_features
[
k
],
weight
=
self
.
layer_weights
[
k
])
percep_loss
*=
self
.
perceptual_weight
else
:
percep_loss
=
0.
# calculate style loss
if
self
.
style_weight
>
0
:
if
self
.
vgg_style
is
not
None
:
pred_features
=
self
.
vgg_style
(
pred
)
target_features
=
self
.
vgg_style
(
target
.
detach
())
style_loss
=
0
for
k
in
pred_features
.
keys
():
style_loss
+=
self
.
criterion
(
self
.
_gram_mat
(
pred_features
[
k
]),
self
.
_gram_mat
(
target_features
[
k
]))
*
self
.
layer_weights_style
[
k
]
style_loss
*=
self
.
style_weight
else
:
style_loss
=
0.
if
self
.
split_style_loss
:
return
percep_loss
,
style_loss
else
:
return
percep_loss
+
style_loss
def
_gram_mat
(
self
,
x
):
"""Calculate Gram matrix.
Args:
x (torch.Tensor): Tensor with shape of (n, c, h, w).
Returns:
torch.Tensor: Gram matrix.
"""
(
n
,
c
,
h
,
w
)
=
x
.
size
()
features
=
x
.
view
(
n
,
c
,
w
*
h
)
features_t
=
features
.
transpose
(
1
,
2
)
gram
=
features
.
bmm
(
features_t
)
/
(
c
*
h
*
w
)
return
gram
def
loss_name
(
self
):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return
self
.
_loss_name
build/lib/mmgen/models/losses/pixelwise_loss.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmgen.models.builder
import
MODULES
from
.utils
import
weighted_loss
_reduction_modes
=
[
'none'
,
'mean'
,
'sum'
,
'batchmean'
,
'flatmean'
]
@
weighted_loss
def
l1_loss
(
pred
,
target
):
"""L1 loss.
Args:
pred (Tensor): Prediction Tensor with shape (n, c, h, w).
target (Tensor): Target Tensor with shape (n, c, h, w).
Returns:
Tensor: Calculated L1 loss.
"""
return
F
.
l1_loss
(
pred
,
target
,
reduction
=
'none'
)
@
weighted_loss
def
mse_loss
(
pred
,
target
):
"""MSE loss.
Args:
pred (Tensor): Prediction Tensor with shape (n, c, h, w).
target (Tensor): Target Tensor with shape (n, c, h, w).
Returns:
Tensor: Calculated MSE loss.
"""
return
F
.
mse_loss
(
pred
,
target
,
reduction
=
'none'
)
@
weighted_loss
def
gaussian_kld
(
mean_target
,
mean_pred
,
logvar_target
,
logvar_pred
,
base
=
'e'
):
r
"""Calculate KLD (Kullback-Leibler divergence) of two gaussian
distribution.
To be noted that in this function, KLD is calcuated in base `e`.
.. math::
:nowrap:
\begin{align}
KLD(p||q) &= -\int{p(x)\log{q(x)} dx} + \int{p(x)\log{p(x)} dx} \\
&= \frac{1}{2}\log{(2\pi \sigma_2^2)} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2} -
\frac{1}{2}(1 + \log{2\pi \sigma_1^2}) \\
&= \log{\frac{\sigma_2}{\sigma_1}} +
\frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2} - \frac{1}{2}
\end{align}
Args:
mean_target (torch.Tensor): Mean of the target (or the first)
distribution.
mean_pred (torch.Tensor): Mean of the predicted (or the second)
distribution.
logvar_target (torch.Tensor): Log variance of the target (or the first)
distribution
logvar_pred (torch.Tensor): Log variance of the predicted (or the
second) distribution.
base (str, optional): The log base of calculated KLD. We support
``'e'`` (for ln) and ``'2'`` (for log_2). Defaults to ``'e'``.
Returns:
torch.Tensor: KLD between two given distribution.
"""
if
base
not
in
[
'e'
,
'2'
]:
raise
ValueError
(
'Only support 2 and e for log base, but receive '
f
'
{
base
}
'
)
kld
=
0.5
*
(
-
1.0
+
logvar_pred
-
logvar_target
+
torch
.
exp
(
logvar_target
-
logvar_pred
)
+
((
mean_target
-
mean_pred
)
**
2
)
*
torch
.
exp
(
-
logvar_pred
))
if
base
==
'2'
:
return
kld
/
np
.
log
(
2.0
)
return
kld
def
approx_gaussian_cdf
(
x
):
r
"""Approximate the cumulative distribution function of the gaussian distribution.
Refers to:
Approximations to the Cumulative Normal Function and its Inverse for Use on a Pocket Calculator # noqa
https://www.jstor.org/stable/2346872?origin=crossref
.. math::
:nowrap:
\begin{eqnarray}
\Phi(x) &\approx \frac{1}{2} \left ( 1 + \tanh(y) \right ) \\
y &= \sqrt{\frac{2}{\pi}}(x+0.044715 x^3)
\end{eqnarray}
Args:
x (torch.Tensor): Input data.
Returns:
torch.Tensor: Calculated cumulative distribution.
"""
factor
=
np
.
sqrt
(
2.0
/
np
.
pi
)
y
=
factor
*
(
x
+
0.044715
*
torch
.
pow
(
x
,
3
))
phi
=
0.5
*
(
1
+
torch
.
tanh
(
y
))
return
phi
@
weighted_loss
def
discretized_gaussian_log_likelihood
(
x
,
mean
,
logvar
,
base
=
'e'
):
r
"""Calculate gaussian log-likelihood for a discretized input. We assume
that the input `x` are ranged in [-1, 1], the likelihood term can be
calculated as the following equation:
.. math::
:nowrap:
\begin{equarray}
p_{\theta}(\mathbf{x}_0 | \mathbf{x}_1) =
\prod_{i=1}^{D} \int_{\delta_{-}(x_0^i)}^{\delta_{+}(x_0^i)}
{\mathcal{N}(x; \mu_{\theta}^i(\mathbf{x}_1, 1),
\sigma_{1}^2)}dx\\
\delta_{+}(x)= \begin{cases}
\infty & \text{if } x = 1 \\
x + \frac{1}{255} & \text{if } x < 1
\end{cases}
\quad
\delta_{-}(x)= \begin{cases}
-\infty & \text{if } x = -1 \\
x - \frac{1}{255} & \text{if } x > -1
\end{cases}
\end{equarray}
When calculating this loss term, we first normalize `x` to normal
distribution and calculate the above integral by the cumulative
distribution function of normal distribution. Then rescale results to the
target ones.
Args:
x (torch.Tensor): Target `x_0` to be modeled. Range in [-1, 1].
mean (torch.Tensor): Predicted mean of `x_0`.
logvar (torch.Tensor): Predicted log variance of `x_0`.
base (str, optional): The log base of calculated KLD. Support ``'e'``
and ``'2'``. Defaults to ``'e'``.
Returns:
torch.Tensor: Calculated log likelihood.
"""
if
base
not
in
[
'e'
,
'2'
]:
raise
ValueError
(
'Only support 2 and e for log base, but receive '
f
'
{
base
}
'
)
inv_std
=
torch
.
exp
(
-
logvar
*
0.5
)
x_centered
=
x
-
mean
lower_bound
=
(
x_centered
-
1.0
/
255.0
)
*
inv_std
upper_bound
=
(
x_centered
+
1.0
/
255.0
)
*
inv_std
cdf_to_lower
=
approx_gaussian_cdf
(
lower_bound
)
cdf_to_upper
=
approx_gaussian_cdf
(
upper_bound
)
log_cdf_upper
=
torch
.
log
(
cdf_to_upper
.
clamp
(
min
=
1e-12
))
log_one_minus_cdf_lower
=
torch
.
log
((
1.0
-
cdf_to_lower
).
clamp
(
min
=
1e-12
))
log_cdf_delta
=
torch
.
log
((
cdf_to_upper
-
cdf_to_lower
).
clamp
(
min
=
1e-12
))
log_probs
=
torch
.
where
(
x
<
-
0.999
,
log_cdf_upper
,
torch
.
where
(
x
>
0.999
,
log_one_minus_cdf_lower
,
log_cdf_delta
))
if
base
==
'2'
:
return
log_probs
/
np
.
log
(
2.0
)
return
log_probs
@
MODULES
.
register_module
()
class
MSELoss
(
nn
.
Module
):
"""MSE loss.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we may need to provide ``pred`` and ``target`` as input.
Thus, an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
pred='fake_imgs',
target='real_imgs')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_mse'.
"""
def
__init__
(
self
,
loss_weight
=
1.0
,
data_info
=
None
,
loss_name
=
'loss_mse'
):
super
().
__init__
()
self
.
loss_weight
=
loss_weight
self
.
data_info
=
data_info
self
.
_loss_name
=
loss_name
def
forward
(
self
,
*
args
,
**
kwargs
):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function, ``mse_loss``.
"""
# use data_info to build computational path
if
self
.
data_info
is
not
None
:
# parse the args and kwargs
if
len
(
args
)
==
1
:
assert
isinstance
(
args
[
0
],
dict
),
(
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.'
)
outputs_dict
=
args
[
0
]
elif
'outputs_dict'
in
kwargs
:
assert
len
(
args
)
==
0
,
(
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.'
)
outputs_dict
=
kwargs
.
pop
(
'outputs_dict'
)
else
:
raise
NotImplementedError
(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module'
)
# link the outputs with loss input args according to self.data_info
loss_input_dict
=
{
k
:
outputs_dict
[
v
]
for
k
,
v
in
self
.
data_info
.
items
()
}
kwargs
.
update
(
loss_input_dict
)
kwargs
.
update
(
dict
(
weight
=
self
.
loss_weight
))
return
mse_loss
(
**
kwargs
)
else
:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return
mse_loss
(
*
args
,
weight
=
self
.
loss_weight
,
**
kwargs
)
def
loss_name
(
self
):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return
self
.
_loss_name
@
MODULES
.
register_module
()
class
L1Loss
(
nn
.
Module
):
"""L1 loss.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from StaticUnconditionalGAN, train_step
:linenos:
data_dict_ = dict(
gen=self.generator,
disc=self.discriminator,
disc_pred_fake=disc_pred_fake,
disc_pred_real=disc_pred_real,
fake_imgs=fake_imgs,
real_imgs=real_imgs,
iteration=curr_iter,
batch_size=batch_size)
But in this loss, we may need to provide ``pred`` and ``target`` as input.
Thus, an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
pred='fake_imgs',
target='real_imgs')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
reduction (str, optional): Same as built-in losses of PyTorch.
Defaults to 'mean'.
avg_factor (float | None, optional): Average factor when computing the
mean of losses. Defaults to ``None``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If ``None``, this module will
directly pass the input data to the loss function.
Defaults to None.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_l1'.
"""
def
__init__
(
self
,
loss_weight
=
1.0
,
reduction
=
'mean'
,
avg_factor
=
None
,
data_info
=
None
,
loss_name
=
'loss_l1'
):
super
().
__init__
()
if
reduction
not
in
_reduction_modes
:
raise
ValueError
(
f
'Unsupported reduction mode:
{
reduction
}
. '
f
'Supported ones are:
{
_reduction_modes
}
'
)
self
.
loss_weight
=
loss_weight
self
.
reduction
=
reduction
self
.
avg_factor
=
avg_factor
self
.
data_info
=
data_info
self
.
_loss_name
=
loss_name
def
forward
(
self
,
*
args
,
**
kwargs
):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function, ``l1_loss``.
"""
# use data_info to build computational path
if
self
.
data_info
is
not
None
:
# parse the args and kwargs
if
len
(
args
)
==
1
:
assert
isinstance
(
args
[
0
],
dict
),
(
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.'
)
outputs_dict
=
args
[
0
]
elif
'outputs_dict'
in
kwargs
:
assert
len
(
args
)
==
0
,
(
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.'
)
outputs_dict
=
kwargs
.
pop
(
'outputs_dict'
)
else
:
raise
NotImplementedError
(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module'
)
# link the outputs with loss input args according to self.data_info
loss_input_dict
=
{
k
:
outputs_dict
[
v
]
for
k
,
v
in
self
.
data_info
.
items
()
}
kwargs
.
update
(
loss_input_dict
)
kwargs
.
update
(
dict
(
weight
=
self
.
loss_weight
,
reduction
=
self
.
reduction
))
return
l1_loss
(
**
kwargs
)
else
:
# if you have not define how to build computational graph, this
# module will just directly return the loss as usual.
return
l1_loss
(
*
args
,
weight
=
self
.
loss_weight
,
reduction
=
self
.
reduction
,
avg_factor
=
self
.
avg_factor
,
**
kwargs
)
def
loss_name
(
self
):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return
self
.
_loss_name
@
MODULES
.
register_module
()
class
GaussianKLDLoss
(
nn
.
Module
):
"""GaussianKLD loss.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from BaseDiffusion, train_step
:linenos:
data_dict_ = dict(
denoising=denoising,
real_imgs=torch.Tensor([N, C, H, W]),
mean_pred=torch.Tensor([N, C, H, W]),
mean_target=torch.Tensor([N, C, H, W]),
logvar_pred=torch.Tensor([N, C, H, W]),
logvar_target=torch.Tensor([N, C, H, W]),
timesteps=torch.Tensor([N,]),
iteration=curr_iter,
batch_size=batch_size)
In this loss, we may need to provide ``mean_pred``, ``mean_target``,
``logvar_pred`` and ``logvar_target`` as input. Thus, an example of the
``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
mean_pred='mean_pred',
mean_target='mean_target',
logvar_pred='logvar_pred',
logvar_target='logvar_target')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
reduction (str, optional): Same as built-in losses of PyTorch. Noted
that 'batchmean' mode given the correct KL divergence where losses
are averaged over batch dimension only. Defaults to 'mean'.
avg_factor (float | None, optional): Average factor when computing the
mean of losses. Defaults to ``None``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If not passed,
``_default_data_info`` would be used. Defaults to None.
base (str, optional): The log base of calculated KLD. Support
``'e'`` and ``'2'``. Defaults to ``'e'``.
only_update_var (bool, optional): If true, only `logvar_pred` will be
updated and variable in output_dict corresponding to `mean_pred`
will be detached. Defaults to False.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_l1'.
"""
_default_data_info
=
dict
(
mean_pred
=
'mean_pred'
,
mean_target
=
'mean_target'
,
logvar_pred
=
'logvar_pred'
,
logvar_target
=
'logvar_target'
)
def
__init__
(
self
,
loss_weight
=
1.0
,
reduction
=
'mean'
,
avg_factor
=
None
,
data_info
=
None
,
base
=
'e'
,
only_update_var
=
False
,
loss_name
=
'loss_GaussianKLD'
):
super
().
__init__
()
if
reduction
not
in
_reduction_modes
:
raise
ValueError
(
f
'Unsupported reduction mode:
{
reduction
}
. '
f
'Supported ones are:
{
_reduction_modes
}
'
)
self
.
loss_weight
=
loss_weight
self
.
reduction
=
reduction
self
.
avg_factor
=
avg_factor
self
.
data_info
=
self
.
_default_data_info
if
data_info
is
None
\
else
data_info
self
.
base
=
base
self
.
only_update_var
=
only_update_var
self
.
_loss_name
=
loss_name
def
forward
(
self
,
*
args
,
**
kwargs
):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``gaussian_kld_loss``.
"""
# parse the args and kwargs
if
len
(
args
)
==
1
:
assert
isinstance
(
args
[
0
],
dict
),
(
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.'
)
outputs_dict
=
args
[
0
]
elif
'outputs_dict'
in
kwargs
:
assert
len
(
args
)
==
0
,
(
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.'
)
outputs_dict
=
kwargs
.
pop
(
'outputs_dict'
)
else
:
raise
NotImplementedError
(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module'
)
# link the outputs with loss input args according to self.data_info
loss_input_dict
=
dict
()
for
k
,
v
in
self
.
data_info
.
items
():
if
'mean_pred'
==
k
and
self
.
only_update_var
:
loss_input_dict
[
k
]
=
outputs_dict
[
v
].
detach
()
else
:
loss_input_dict
[
k
]
=
outputs_dict
[
v
]
kwargs
.
update
(
loss_input_dict
)
kwargs
.
update
(
dict
(
weight
=
self
.
loss_weight
,
reduction
=
self
.
reduction
,
base
=
self
.
base
))
return
gaussian_kld
(
**
kwargs
)
def
loss_name
(
self
):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return
self
.
_loss_name
# TODO: this name is toooooo long.
@
MODULES
.
register_module
()
class
DiscretizedGaussianLogLikelihoodLoss
(
nn
.
Module
):
r
"""Discretized-Gaussian-Log-Likelihood Loss.
**Note for the design of ``data_info``:**
In ``MMGeneration``, almost all of loss modules contain the argument
``data_info``, which can be used for constructing the link between the
input items (needed in loss calculation) and the data from the generative
model. For example, in the training of GAN model, we will collect all of
important data/modules into a dictionary:
.. code-block:: python
:caption: Code from BaseDiffusion, train_step
:linenos:
data_dict_ = dict(
denoising=denoising,
real_imgs=torch.Tensor([N, C, H, W]),
mean_pred=torch.Tensor([N, C, H, W]),
mean_target=torch.Tensor([N, C, H, W]),
logvar_pred=torch.Tensor([N, C, H, W]),
logvar_target=torch.Tensor([N, C, H, W]),
timesteps=torch.Tensor([N,]),
iteration=curr_iter,
batch_size=batch_size)
In this loss, we may need to provide ``mean``, ``logvar`` and ``x``. Thus,
an example of the ``data_info`` is:
.. code-block:: python
:linenos:
data_info = dict(
x='real_imgs',
mean='mean_pred',
logvar='logvar_pred')
Then, the module will automatically construct this mapping from the input
data dictionary.
Args:
loss_weight (float, optional): Weight of this loss item.
Defaults to ``1.``.
reduction (str, optional): Same as built-in losses of PyTorch.
Defaults to 'mean'.
avg_factor (float | None, optional): Average factor when computing the
mean of losses. Defaults to ``None``.
data_info (dict, optional): Dictionary contains the mapping between
loss input args and data dictionary. If not passed,
``_default_data_info`` would be used. Defaults to None.
base (str, optional): The log base of calculated KLD. Support
``'e'`` and ``'2'``. Defaults to ``'e'``.
only_update_var (bool, optional): If true, only `logvar_pred` will be
updated and variable in output_dict corresponding to `mean_pred`
will be detached. Defaults to False.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_l1'.
"""
_default_data_info
=
dict
(
x
=
'real_imgs'
,
mean
=
'mean_pred'
,
logvar
=
'logvar_pred'
)
def
__init__
(
self
,
loss_weight
=
1.0
,
reduction
=
'mean'
,
avg_factor
=
None
,
data_info
=
None
,
base
=
'e'
,
only_update_var
=
False
,
loss_name
=
'loss_DiscGaussianLogLikelihood'
):
super
().
__init__
()
if
reduction
not
in
_reduction_modes
:
raise
ValueError
(
f
'Unsupported reduction mode:
{
reduction
}
. '
f
'Supported ones are:
{
_reduction_modes
}
'
)
self
.
loss_weight
=
loss_weight
self
.
reduction
=
reduction
self
.
avg_factor
=
avg_factor
self
.
data_info
=
self
.
_default_data_info
if
data_info
is
None
\
else
data_info
self
.
base
=
base
self
.
only_update_var
=
only_update_var
self
.
_loss_name
=
loss_name
def
forward
(
self
,
*
args
,
**
kwargs
):
"""Forward function.
If ``self.data_info`` is not ``None``, a dictionary containing all of
the data and necessary modules should be passed into this function.
If this dictionary is given as a non-keyword argument, it should be
offered as the first argument. If you are using keyword argument,
please name it as `outputs_dict`.
If ``self.data_info`` is ``None``, the input argument or key-word
argument will be directly passed to loss function,
``gaussian_kld_loss``.
"""
# parse the args and kwargs
if
len
(
args
)
==
1
:
assert
isinstance
(
args
[
0
],
dict
),
(
'You should offer a dictionary containing network outputs '
'for building up computational graph of this loss module.'
)
outputs_dict
=
args
[
0
]
elif
'outputs_dict'
in
kwargs
:
assert
len
(
args
)
==
0
,
(
'If the outputs dict is given in keyworded arguments, no'
' further non-keyworded arguments should be offered.'
)
outputs_dict
=
kwargs
.
pop
(
'outputs_dict'
)
else
:
raise
NotImplementedError
(
'Cannot parsing your arguments passed to this loss module.'
' Please check the usage of this module'
)
# link the outputs with loss input args according to self.data_info
loss_input_dict
=
dict
()
for
k
,
v
in
self
.
data_info
.
items
():
if
k
==
'mean'
and
self
.
only_update_var
:
loss_input_dict
[
k
]
=
outputs_dict
[
v
].
detach
()
else
:
loss_input_dict
[
k
]
=
outputs_dict
[
v
]
kwargs
.
update
(
loss_input_dict
)
kwargs
.
update
(
dict
(
weight
=
self
.
loss_weight
,
reduction
=
self
.
reduction
,
base
=
self
.
base
))
return
discretized_gaussian_log_likelihood
(
**
kwargs
)
def
loss_name
(
self
):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return
self
.
_loss_name
build/lib/mmgen/models/losses/utils.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
functools
import
torch.nn.functional
as
F
def
reduce_loss
(
loss
,
reduction
):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are "none", "mean", "sum", "flatmean" and
"batchmean". 'none': no reduction will be applied. 'mean': the
output will be divided by the number of elements in the output.
'sum': the output will be summed. 'batchmean': the sum of the
output will be divided by batchsize. 'flatmean': each sample
will be divided by the number of element respectively and
output will shape as [bz, ].
Return:
Tensor: Reduced loss tensor.
"""
if
reduction
==
'batchmean'
:
return
loss
.
sum
()
/
loss
.
shape
[
0
]
if
reduction
==
'flatmean'
:
return
loss
.
mean
(
dim
=
list
(
range
(
1
,
loss
.
ndim
)))
reduction_enum
=
F
.
_Reduction
.
get_enum
(
reduction
)
# none: 0, elementwise_mean:1, sum: 2
if
reduction_enum
==
0
:
return
loss
if
reduction_enum
==
1
:
return
loss
.
mean
()
if
reduction_enum
==
2
:
return
loss
.
sum
()
raise
ValueError
(
f
'reduction type
{
reduction
}
not supported'
)
def
weight_reduce_loss
(
loss
,
weight
=
None
,
reduction
=
'mean'
,
avg_factor
=
None
):
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): Element-wise loss.
weight (Tensor): Element-wise weights.
reduction (str): Same as built-in losses of PyTorch.
avg_factor (float): Average factor when computing the mean of losses.
Returns:
Tensor: Processed loss values.
"""
# if weight is specified, apply element-wise weight
if
weight
is
not
None
:
loss
=
loss
*
weight
# if avg_factor is not specified, just reduce the loss
if
avg_factor
is
None
:
loss
=
reduce_loss
(
loss
,
reduction
)
else
:
# if reduction is mean, then average the loss by avg_factor
if
reduction
==
'mean'
:
loss
=
loss
.
sum
()
/
avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif
reduction
!=
'none'
:
raise
ValueError
(
'avg_factor can not be used with reduction="sum"'
)
return
loss
def
weighted_loss
(
loss_func
):
"""Create a weighted version of a given loss function.
To use this decorator, the loss function must have the signature like
`loss_func(pred, target, **kwargs)`. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like `loss_func(pred, target, weight=None, reduction='mean',
avg_factor=None, **kwargs)`.
:Example:
>>> import torch
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, avg_factor=2)
tensor(1.5000)
"""
@
functools
.
wraps
(
loss_func
)
def
wrapper
(
*
args
,
weight
=
None
,
reduction
=
'mean'
,
avg_factor
=
None
,
**
kwargs
):
# get element-wise loss
loss
=
loss_func
(
*
args
,
**
kwargs
)
loss
=
weight_reduce_loss
(
loss
,
weight
,
reduction
,
avg_factor
)
return
loss
return
wrapper
build/lib/mmgen/models/misc.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
torch
from
torchvision.utils
import
make_grid
def
tensor2img
(
tensor
,
out_type
=
np
.
uint8
,
min_max
=
(
0
,
1
)):
"""Convert torch Tensors into image numpy arrays.
After clamping to (min, max), image values will be normalized to [0, 1].
For different tensor shapes, this function will have different behaviors:
1. 4D mini-batch Tensor of shape (N x 3/1 x H x W):
Use `make_grid` to stitch images in the batch dimension, and then
convert it to numpy array.
2. 3D Tensor of shape (3/1 x H x W) and 2D Tensor of shape (H x W):
Directly change to numpy array.
Note that the image channel in input tensors should be RGB order. This
function will convert it to cv2 convention, i.e., (H x W x C) with BGR
order.
Args:
tensor (Tensor | list[Tensor]): Input tensors.
out_type (numpy type): Output types. If ``np.uint8``, transform outputs
to uint8 type with range [0, 255]; otherwise, float type with
range [0, 1]. Default: ``np.uint8``.
min_max (tuple): min and max values for clamp.
Returns:
(Tensor | list[Tensor]): 3D ndarray of shape (H x W x C) or 2D ndarray
of shape (H x W).
"""
if
not
(
torch
.
is_tensor
(
tensor
)
or
(
isinstance
(
tensor
,
list
)
and
all
(
torch
.
is_tensor
(
t
)
for
t
in
tensor
))):
raise
TypeError
(
f
'tensor or list of tensors expected, got
{
type
(
tensor
)
}
'
)
if
torch
.
is_tensor
(
tensor
):
tensor
=
[
tensor
]
result
=
[]
for
_tensor
in
tensor
:
# Squeeze two times so that:
# 1. (1, 1, h, w) -> (h, w) or
# 3. (1, 3, h, w) -> (3, h, w) or
# 2. (n>1, 3/1, h, w) -> (n>1, 3/1, h, w)
_tensor
=
_tensor
.
squeeze
(
0
).
squeeze
(
0
)
_tensor
=
_tensor
.
float
().
detach
().
cpu
().
clamp_
(
*
min_max
)
_tensor
=
(
_tensor
-
min_max
[
0
])
/
(
min_max
[
1
]
-
min_max
[
0
])
n_dim
=
_tensor
.
dim
()
if
n_dim
==
4
:
img_np
=
make_grid
(
_tensor
,
nrow
=
int
(
np
.
sqrt
(
_tensor
.
size
(
0
))),
normalize
=
False
).
numpy
()
img_np
=
np
.
transpose
(
img_np
[[
2
,
1
,
0
],
:,
:],
(
1
,
2
,
0
))
elif
n_dim
==
3
:
img_np
=
_tensor
.
numpy
()
img_np
=
np
.
transpose
(
img_np
[[
2
,
1
,
0
],
:,
:],
(
1
,
2
,
0
))
elif
n_dim
==
2
:
img_np
=
_tensor
.
numpy
()
else
:
raise
ValueError
(
'Only support 4D, 3D or 2D tensor. '
f
'But received with dimension:
{
n_dim
}
'
)
if
out_type
==
np
.
uint8
:
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
img_np
=
(
img_np
*
255.0
).
round
()
img_np
=
img_np
.
astype
(
out_type
)
result
.
append
(
img_np
)
result
=
result
[
0
]
if
len
(
result
)
==
1
else
result
return
result
build/lib/mmgen/models/translation_models/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.base_translation_model
import
BaseTranslationModel
from
.cyclegan
import
CycleGAN
from
.pix2pix
import
Pix2Pix
from
.static_translation_gan
import
StaticTranslationGAN
__all__
=
[
'Pix2Pix'
,
'CycleGAN'
,
'BaseTranslationModel'
,
'StaticTranslationGAN'
]
build/lib/mmgen/models/translation_models/base_translation_model.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
abc
import
ABCMeta
,
abstractmethod
from
copy
import
deepcopy
import
torch.nn
as
nn
from
..builder
import
MODELS
@
MODELS
.
register_module
()
class
BaseTranslationModel
(
nn
.
Module
,
metaclass
=
ABCMeta
):
"""Base Translation Model.
Translation models can transfer images from one domain to
another. Domain information like `default_domain`,
`reachable_domains` are needed to initialize the class.
And we also provide query functions like `is_domain_reachable`,
`get_other_domains`.
You can get a specific generator based on the domain,
and by specifying `target_domain` in the forward function,
you can decide the domain of generated images.
Considering the difference among different image translation models,
we only provide the external interfaces mentioned above.
When you implement image translation with a specific method,
you can inherit both `BaseTranslationModel`
and the method (e.g BaseGAN) and implement abstract methods.
Args:
default_domain (str): Default output domain.
reachable_domains (list[str]): Domains that can be generated by
the model.
related_domains (list[str]): Domains involved in training and
testing. `reachable_domains` must be contained in
`related_domains`. However, related_domains may contain
source domains that are used to retrieve source images from
data_batch but not in reachable_domains.
train_cfg (dict): Config for training. Default: None.
test_cfg (dict): Config for testing. Default: None.
"""
def
__init__
(
self
,
default_domain
,
reachable_domains
,
related_domains
,
train_cfg
=
None
,
test_cfg
=
None
):
self
.
_default_domain
=
default_domain
self
.
_reachable_domains
=
reachable_domains
self
.
_related_domains
=
related_domains
assert
self
.
_default_domain
in
self
.
_reachable_domains
assert
set
(
self
.
_reachable_domains
)
<=
set
(
self
.
_related_domains
)
self
.
train_cfg
=
deepcopy
(
train_cfg
)
if
train_cfg
else
None
self
.
test_cfg
=
deepcopy
(
test_cfg
)
if
test_cfg
else
None
self
.
_parse_train_cfg
()
if
test_cfg
is
not
None
:
self
.
_parse_test_cfg
()
@
abstractmethod
def
_parse_train_cfg
(
self
):
"""Parsing train config and set some attributes for training."""
@
abstractmethod
def
_parse_test_cfg
(
self
):
"""Parsing test config and set some attributes for testing."""
def
forward
(
self
,
img
,
test_mode
=
False
,
**
kwargs
):
"""Forward function.
Args:
img (tensor): Input image tensor.
test_mode (bool): Whether in test mode or not. Default: False.
kwargs (dict): Other arguments.
"""
if
not
test_mode
:
return
self
.
forward_train
(
img
,
**
kwargs
)
return
self
.
forward_test
(
img
,
**
kwargs
)
def
forward_train
(
self
,
img
,
target_domain
,
**
kwargs
):
"""Forward function for training.
Args:
img (tensor): Input image tensor.
target_domain (str): Target domain of output image.
kwargs (dict): Other arguments.
Returns:
dict: Forward results.
"""
target
=
self
.
translation
(
img
,
target_domain
=
target_domain
,
**
kwargs
)
results
=
dict
(
source
=
img
,
target
=
target
)
return
results
def
forward_test
(
self
,
img
,
target_domain
,
**
kwargs
):
"""Forward function for testing.
Args:
img (tensor): Input image tensor.
target_domain (str): Target domain of output image.
kwargs (dict): Other arguments.
Returns:
dict: Forward results.
"""
target
=
self
.
translation
(
img
,
target_domain
=
target_domain
,
**
kwargs
)
results
=
dict
(
source
=
img
.
cpu
(),
target
=
target
.
cpu
())
return
results
def
is_domain_reachable
(
self
,
domain
):
"""Whether image of this domain can be generated."""
return
domain
in
self
.
_reachable_domains
def
get_other_domains
(
self
,
domain
):
"""get other domains."""
return
list
(
set
(
self
.
_related_domains
)
-
set
([
domain
]))
@
abstractmethod
def
_get_target_generator
(
self
,
domain
):
"""get target generator."""
def
translation
(
self
,
image
,
target_domain
=
None
,
**
kwargs
):
"""Translation Image to target style.
Args:
image (tensor): Image tensor with a shape of (N, C, H, W).
target_domain (str, optional): Target domain of output image.
Default to None.
Returns:
dict: Image tensor of target style.
"""
if
target_domain
is
None
:
target_domain
=
self
.
_default_domain
_model
=
self
.
_get_target_generator
(
target_domain
)
outputs
=
_model
(
image
,
**
kwargs
)
return
outputs
build/lib/mmgen/models/translation_models/cyclegan.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
torch.nn.parallel.distributed
import
_find_tensors
from
mmgen.models.builder
import
MODELS
from
..common
import
GANImageBuffer
,
set_requires_grad
from
.static_translation_gan
import
StaticTranslationGAN
@
MODELS
.
register_module
()
class
CycleGAN
(
StaticTranslationGAN
):
"""CycleGAN model for unpaired image-to-image translation.
Ref:
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial
Networks
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
# GAN image buffers
self
.
image_buffers
=
dict
()
self
.
buffer_size
=
(
50
if
self
.
train_cfg
is
None
else
self
.
train_cfg
.
get
(
'buffer_size'
,
50
))
for
domain
in
self
.
_reachable_domains
:
self
.
image_buffers
[
domain
]
=
GANImageBuffer
(
self
.
buffer_size
)
self
.
use_ema
=
False
def
forward_test
(
self
,
img
,
target_domain
,
**
kwargs
):
"""Forward function for testing.
Args:
img (tensor): Input image tensor.
target_domain (str): Target domain of output image.
kwargs (dict): Other arguments.
Returns:
dict: Forward results.
"""
# This is a trick for CycleGAN
# ref: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/e1bdf46198662b0f4d0b318e24568205ec4d7aee/test.py#L54 # noqa
self
.
train
()
target
=
self
.
translation
(
img
,
target_domain
=
target_domain
,
**
kwargs
)
results
=
dict
(
source
=
img
.
cpu
(),
target
=
target
.
cpu
())
return
results
def
_get_disc_loss
(
self
,
outputs
):
"""Backward function for the discriminators.
Args:
outputs (dict): Dict of forward results.
Returns:
dict: Discriminators' loss and loss dict.
"""
discriminators
=
self
.
get_module
(
self
.
discriminators
)
log_vars_d
=
dict
()
loss_d
=
0
# GAN loss for discriminators['a']
for
domain
in
self
.
_reachable_domains
:
losses
=
dict
()
fake_img
=
self
.
image_buffers
[
domain
].
query
(
outputs
[
f
'fake_
{
domain
}
'
])
fake_pred
=
discriminators
[
domain
](
fake_img
.
detach
())
losses
[
f
'loss_gan_d_
{
domain
}
_fake'
]
=
self
.
gan_loss
(
fake_pred
,
target_is_real
=
False
,
is_disc
=
True
)
real_pred
=
discriminators
[
domain
](
outputs
[
f
'real_
{
domain
}
'
])
losses
[
f
'loss_gan_d_
{
domain
}
_real'
]
=
self
.
gan_loss
(
real_pred
,
target_is_real
=
True
,
is_disc
=
True
)
_loss_d
,
_log_vars_d
=
self
.
_parse_losses
(
losses
)
_loss_d
*=
0.5
loss_d
+=
_loss_d
log_vars_d
[
f
'loss_gan_d_
{
domain
}
'
]
=
_log_vars_d
[
'loss'
]
*
0.5
return
loss_d
,
log_vars_d
def
_get_gen_loss
(
self
,
outputs
):
"""Backward function for the generators.
Args:
outputs (dict): Dict of forward results.
Returns:
dict: Generators' loss and loss dict.
"""
generators
=
self
.
get_module
(
self
.
generators
)
discriminators
=
self
.
get_module
(
self
.
discriminators
)
losses
=
dict
()
for
domain
in
self
.
_reachable_domains
:
# Identity reconstruction for generators
outputs
[
f
'identity_
{
domain
}
'
]
=
generators
[
domain
](
outputs
[
f
'real_
{
domain
}
'
])
# GAN loss for generators
fake_pred
=
discriminators
[
domain
](
outputs
[
f
'fake_
{
domain
}
'
])
losses
[
f
'loss_gan_g_
{
domain
}
'
]
=
self
.
gan_loss
(
fake_pred
,
target_is_real
=
True
,
is_disc
=
False
)
# gen auxiliary loss
if
self
.
with_gen_auxiliary_loss
:
for
loss_module
in
self
.
gen_auxiliary_losses
:
loss_
=
loss_module
(
outputs
)
if
loss_
is
None
:
continue
# the `loss_name()` function return name as 'loss_xxx'
if
loss_module
.
loss_name
()
in
losses
:
losses
[
loss_module
.
loss_name
(
)]
=
losses
[
loss_module
.
loss_name
()]
+
loss_
else
:
losses
[
loss_module
.
loss_name
()]
=
loss_
loss_g
,
log_vars_g
=
self
.
_parse_losses
(
losses
)
return
loss_g
,
log_vars_g
def
_get_opposite_domain
(
self
,
domain
):
for
item
in
self
.
_reachable_domains
:
if
item
!=
domain
:
return
item
return
None
def
train_step
(
self
,
data_batch
,
optimizer
,
ddp_reducer
=
None
,
running_status
=
None
):
"""Training step function.
Args:
data_batch (dict): Dict of the input data batch.
optimizer (dict[torch.optim.Optimizer]): Dict of optimizers for
the generators and discriminators.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Dict of loss, information for logger, the number of samples
\
and results for visualization.
"""
# get running status
if
running_status
is
not
None
:
curr_iter
=
running_status
[
'iteration'
]
else
:
# dirty walkround for not providing running status
if
not
hasattr
(
self
,
'iteration'
):
self
.
iteration
=
0
curr_iter
=
self
.
iteration
# forward generators
outputs
=
dict
()
for
target_domain
in
self
.
_reachable_domains
:
# fetch data by domain
source_domain
=
self
.
get_other_domains
(
target_domain
)[
0
]
img
=
data_batch
[
f
'img_
{
source_domain
}
'
]
# translation process
results
=
self
(
img
,
test_mode
=
False
,
target_domain
=
target_domain
)
outputs
[
f
'real_
{
source_domain
}
'
]
=
results
[
'source'
]
outputs
[
f
'fake_
{
target_domain
}
'
]
=
results
[
'target'
]
# cycle process
results
=
self
(
results
[
'target'
],
test_mode
=
False
,
target_domain
=
source_domain
)
outputs
[
f
'cycle_
{
source_domain
}
'
]
=
results
[
'target'
]
log_vars
=
dict
()
# discriminators
set_requires_grad
(
self
.
discriminators
,
True
)
# optimize
optimizer
[
'discriminators'
].
zero_grad
()
loss_d
,
log_vars_d
=
self
.
_get_disc_loss
(
outputs
)
log_vars
.
update
(
log_vars_d
)
if
ddp_reducer
is
not
None
:
ddp_reducer
.
prepare_for_backward
(
_find_tensors
(
loss_d
))
loss_d
.
backward
()
optimizer
[
'discriminators'
].
step
()
# generators, no updates to discriminator parameters.
if
(
curr_iter
%
self
.
disc_steps
==
0
and
curr_iter
>=
self
.
disc_init_steps
):
set_requires_grad
(
self
.
discriminators
,
False
)
# optimize
optimizer
[
'generators'
].
zero_grad
()
loss_g
,
log_vars_g
=
self
.
_get_gen_loss
(
outputs
)
log_vars
.
update
(
log_vars_g
)
if
ddp_reducer
is
not
None
:
ddp_reducer
.
prepare_for_backward
(
_find_tensors
(
loss_g
))
loss_g
.
backward
()
optimizer
[
'generators'
].
step
()
if
hasattr
(
self
,
'iteration'
):
self
.
iteration
+=
1
image_results
=
dict
()
for
domain
in
self
.
_reachable_domains
:
image_results
[
f
'real_
{
domain
}
'
]
=
outputs
[
f
'real_
{
domain
}
'
].
cpu
()
image_results
[
f
'fake_
{
domain
}
'
]
=
outputs
[
f
'fake_
{
domain
}
'
].
cpu
()
results
=
dict
(
log_vars
=
log_vars
,
num_samples
=
len
(
outputs
[
f
'real_
{
domain
}
'
]),
results
=
image_results
)
return
results
build/lib/mmgen/models/translation_models/pix2pix.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
torch.nn.parallel.distributed
import
_find_tensors
from
mmgen.models.builder
import
MODELS
from
..common
import
set_requires_grad
from
.static_translation_gan
import
StaticTranslationGAN
@
MODELS
.
register_module
()
class
Pix2Pix
(
StaticTranslationGAN
):
"""Pix2Pix model for paired image-to-image translation.
Ref:
Image-to-Image Translation with Conditional Adversarial Networks
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
use_ema
=
False
def
forward_test
(
self
,
img
,
target_domain
,
**
kwargs
):
"""Forward function for testing.
Args:
img (tensor): Input image tensor.
target_domain (str): Target domain of output image.
kwargs (dict): Other arguments.
Returns:
dict: Forward results.
"""
# This is a trick for Pix2Pix
# ref: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/e1bdf46198662b0f4d0b318e24568205ec4d7aee/test.py#L54 # noqa
self
.
train
()
target
=
self
.
translation
(
img
,
target_domain
=
target_domain
,
**
kwargs
)
results
=
dict
(
source
=
img
.
cpu
(),
target
=
target
.
cpu
())
return
results
def
_get_disc_loss
(
self
,
outputs
):
# GAN loss for the discriminator
losses
=
dict
()
discriminators
=
self
.
get_module
(
self
.
discriminators
)
target_domain
=
self
.
_default_domain
source_domain
=
self
.
get_other_domains
(
target_domain
)[
0
]
fake_ab
=
torch
.
cat
((
outputs
[
f
'real_
{
source_domain
}
'
],
outputs
[
f
'fake_
{
target_domain
}
'
]),
1
)
fake_pred
=
discriminators
[
target_domain
](
fake_ab
.
detach
())
losses
[
'loss_gan_d_fake'
]
=
self
.
gan_loss
(
fake_pred
,
target_is_real
=
False
,
is_disc
=
True
)
real_ab
=
torch
.
cat
((
outputs
[
f
'real_
{
source_domain
}
'
],
outputs
[
f
'real_
{
target_domain
}
'
]),
1
)
real_pred
=
discriminators
[
target_domain
](
real_ab
)
losses
[
'loss_gan_d_real'
]
=
self
.
gan_loss
(
real_pred
,
target_is_real
=
True
,
is_disc
=
True
)
loss_d
,
log_vars_d
=
self
.
_parse_losses
(
losses
)
loss_d
*=
0.5
return
loss_d
,
log_vars_d
def
_get_gen_loss
(
self
,
outputs
):
target_domain
=
self
.
_default_domain
source_domain
=
self
.
get_other_domains
(
target_domain
)[
0
]
losses
=
dict
()
discriminators
=
self
.
get_module
(
self
.
discriminators
)
# GAN loss for the generator
fake_ab
=
torch
.
cat
((
outputs
[
f
'real_
{
source_domain
}
'
],
outputs
[
f
'fake_
{
target_domain
}
'
]),
1
)
fake_pred
=
discriminators
[
target_domain
](
fake_ab
)
losses
[
'loss_gan_g'
]
=
self
.
gan_loss
(
fake_pred
,
target_is_real
=
True
,
is_disc
=
False
)
# gen auxiliary loss
if
self
.
with_gen_auxiliary_loss
:
for
loss_module
in
self
.
gen_auxiliary_losses
:
loss_
=
loss_module
(
outputs
)
if
loss_
is
None
:
continue
# the `loss_name()` function return name as 'loss_xxx'
if
loss_module
.
loss_name
()
in
losses
:
losses
[
loss_module
.
loss_name
(
)]
=
losses
[
loss_module
.
loss_name
()]
+
loss_
else
:
losses
[
loss_module
.
loss_name
()]
=
loss_
loss_g
,
log_vars_g
=
self
.
_parse_losses
(
losses
)
return
loss_g
,
log_vars_g
def
train_step
(
self
,
data_batch
,
optimizer
,
ddp_reducer
=
None
,
running_status
=
None
):
"""Training step function.
Args:
data_batch (dict): Dict of the input data batch.
optimizer (dict[torch.optim.Optimizer]): Dict of optimizers for
the generator and discriminator.
ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
It is used to prepare for ``backward()`` in ddp. Defaults to
None.
running_status (dict | None, optional): Contains necessary basic
information for training, e.g., iteration number. Defaults to
None.
Returns:
dict: Dict of loss, information for logger, the number of samples
\
and results for visualization.
"""
# data
target_domain
=
self
.
_default_domain
source_domain
=
self
.
get_other_domains
(
self
.
_default_domain
)[
0
]
source_image
=
data_batch
[
f
'img_
{
source_domain
}
'
]
target_image
=
data_batch
[
f
'img_
{
target_domain
}
'
]
# get running status
if
running_status
is
not
None
:
curr_iter
=
running_status
[
'iteration'
]
else
:
# dirty walkround for not providing running status
if
not
hasattr
(
self
,
'iteration'
):
self
.
iteration
=
0
curr_iter
=
self
.
iteration
# forward generator
outputs
=
dict
()
results
=
self
(
source_image
,
target_domain
=
self
.
_default_domain
,
test_mode
=
False
)
outputs
[
f
'real_
{
source_domain
}
'
]
=
results
[
'source'
]
outputs
[
f
'fake_
{
target_domain
}
'
]
=
results
[
'target'
]
outputs
[
f
'real_
{
target_domain
}
'
]
=
target_image
log_vars
=
dict
()
# discriminator
set_requires_grad
(
self
.
discriminators
,
True
)
# optimize
optimizer
[
'discriminators'
].
zero_grad
()
loss_d
,
log_vars_d
=
self
.
_get_disc_loss
(
outputs
)
log_vars
.
update
(
log_vars_d
)
# prepare for backward in ddp. If you do not call this function before
# back propagation, the ddp will not dynamically find the used params
# in current computation.
if
ddp_reducer
is
not
None
:
ddp_reducer
.
prepare_for_backward
(
_find_tensors
(
loss_d
))
loss_d
.
backward
()
optimizer
[
'discriminators'
].
step
()
# generator, no updates to discriminator parameters.
if
(
curr_iter
%
self
.
disc_steps
==
0
and
curr_iter
>=
self
.
disc_init_steps
):
set_requires_grad
(
self
.
discriminators
,
False
)
# optimize
optimizer
[
'generators'
].
zero_grad
()
loss_g
,
log_vars_g
=
self
.
_get_gen_loss
(
outputs
)
log_vars
.
update
(
log_vars_g
)
# prepare for backward in ddp. If you do not call this function
# before back propagation, the ddp will not dynamically find the
# used params in current computation.
if
ddp_reducer
is
not
None
:
ddp_reducer
.
prepare_for_backward
(
_find_tensors
(
loss_g
))
loss_g
.
backward
()
optimizer
[
'generators'
].
step
()
if
hasattr
(
self
,
'iteration'
):
self
.
iteration
+=
1
image_results
=
dict
()
image_results
[
f
'real_
{
source_domain
}
'
]
=
outputs
[
f
'real_
{
source_domain
}
'
].
cpu
()
image_results
[
f
'fake_
{
target_domain
}
'
]
=
outputs
[
f
'fake_
{
target_domain
}
'
].
cpu
()
image_results
[
f
'real_
{
target_domain
}
'
]
=
outputs
[
f
'real_
{
target_domain
}
'
].
cpu
()
results
=
dict
(
log_vars
=
log_vars
,
num_samples
=
len
(
outputs
[
f
'real_
{
source_domain
}
'
]),
results
=
image_results
)
return
results
build/lib/mmgen/models/translation_models/static_translation_gan.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
import
torch.nn
as
nn
from
mmcv.parallel
import
MMDistributedDataParallel
from
..builder
import
MODELS
,
build_module
from
..gans
import
BaseGAN
from
.base_translation_model
import
BaseTranslationModel
@
MODELS
.
register_module
()
class
StaticTranslationGAN
(
BaseTranslationModel
,
BaseGAN
):
"""Basic translation model based on static unconditional GAN.
Args:
generator (dict): Config for the generator.
discriminator (dict): Config for the discriminator.
gan_loss (dict): Config for the gan loss.
pretrained (str | optional): Path for pretrained model.
Defaults to None.
disc_auxiliary_loss (dict | optional): Config for auxiliary loss to
discriminator. Defaults to None.
gen_auxiliary_loss (dict | optional): Config for auxiliary loss
to generator. Defaults to None.
"""
def
__init__
(
self
,
generator
,
discriminator
,
gan_loss
,
*
args
,
pretrained
=
None
,
disc_auxiliary_loss
=
None
,
gen_auxiliary_loss
=
None
,
**
kwargs
):
BaseGAN
.
__init__
(
self
)
BaseTranslationModel
.
__init__
(
self
,
*
args
,
**
kwargs
)
# Building generators and discriminators
self
.
_gen_cfg
=
deepcopy
(
generator
)
# build domain generators
self
.
generators
=
nn
.
ModuleDict
()
for
domain
in
self
.
_reachable_domains
:
self
.
generators
[
domain
]
=
build_module
(
generator
)
self
.
_disc_cfg
=
deepcopy
(
discriminator
)
# build domain discriminators
if
discriminator
is
not
None
:
self
.
discriminators
=
nn
.
ModuleDict
()
for
domain
in
self
.
_reachable_domains
:
self
.
discriminators
[
domain
]
=
build_module
(
discriminator
)
# support no discriminator in testing
else
:
self
.
discriminators
=
None
# support no gan_loss in testing
if
gan_loss
is
not
None
:
self
.
gan_loss
=
build_module
(
gan_loss
)
else
:
self
.
gan_loss
=
None
if
disc_auxiliary_loss
:
self
.
disc_auxiliary_losses
=
build_module
(
disc_auxiliary_loss
)
if
not
isinstance
(
self
.
disc_auxiliary_losses
,
nn
.
ModuleList
):
self
.
disc_auxiliary_losses
=
nn
.
ModuleList
(
[
self
.
disc_auxiliary_losses
])
else
:
self
.
disc_auxiliary_loss
=
None
if
gen_auxiliary_loss
:
self
.
gen_auxiliary_losses
=
build_module
(
gen_auxiliary_loss
)
if
not
isinstance
(
self
.
gen_auxiliary_losses
,
nn
.
ModuleList
):
self
.
gen_auxiliary_losses
=
nn
.
ModuleList
(
[
self
.
gen_auxiliary_losses
])
else
:
self
.
gen_auxiliary_losses
=
None
self
.
init_weights
(
pretrained
)
def
init_weights
(
self
,
pretrained
=
None
):
"""Initialize weights for the model.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Default: None.
"""
for
domain
in
self
.
_reachable_domains
:
self
.
generators
[
domain
].
init_weights
(
pretrained
=
pretrained
)
self
.
discriminators
[
domain
].
init_weights
(
pretrained
=
pretrained
)
def
_parse_train_cfg
(
self
):
"""Parsing train config and set some attributes for training."""
if
self
.
train_cfg
is
None
:
self
.
train_cfg
=
dict
()
# control the work flow in train step
self
.
disc_steps
=
self
.
train_cfg
.
get
(
'disc_steps'
,
1
)
self
.
disc_init_steps
=
(
0
if
self
.
train_cfg
is
None
else
self
.
train_cfg
.
get
(
'disc_init_steps'
,
0
))
self
.
real_img_key
=
self
.
train_cfg
.
get
(
'real_img_key'
,
'real_img'
)
def
_parse_test_cfg
(
self
):
"""Parsing test config and set some attributes for testing."""
if
self
.
test_cfg
is
None
:
self
.
test_cfg
=
dict
()
# basic testing information
self
.
batch_size
=
self
.
test_cfg
.
get
(
'batch_size'
,
1
)
def
get_module
(
self
,
module
):
"""Get `nn.ModuleDict` to fit the `MMDistributedDataParallel`
interface.
Args:
module (MMDistributedDataParallel | nn.ModuleDict): The input
module that needs processing.
Returns:
nn.ModuleDict: The ModuleDict of multiple networks.
"""
if
isinstance
(
module
,
MMDistributedDataParallel
):
return
module
.
module
return
module
def
_get_target_generator
(
self
,
domain
):
"""get target generator."""
assert
self
.
is_domain_reachable
(
domain
),
f
'
{
domain
}
domain is not reachable, available domain list is
\
{
self
.
_reachable_domains
}
'
return
self
.
get_module
(
self
.
generators
)[
domain
]
def
_get_target_discriminator
(
self
,
domain
):
"""get target discriminator."""
assert
self
.
is_domain_reachable
(
domain
),
f
'
{
domain
}
domain is not reachable, available domain list is
\
{
self
.
_reachable_domains
}
'
return
self
.
get_module
(
self
.
discriminators
)[
domain
]
Prev
1
…
12
13
14
15
16
17
18
19
20
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment