Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
stylegan2_mmcv
Commits
1401de15
Commit
1401de15
authored
Jun 28, 2024
by
dongchy920
Browse files
stylegan2_mmcv
parents
Pipeline
#1274
canceled with stages
Changes
463
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2492 additions
and
0 deletions
+2492
-0
build/lib/mmgen/.mim/tools/dist_train.sh
build/lib/mmgen/.mim/tools/dist_train.sh
+20
-0
build/lib/mmgen/.mim/tools/eval.sh
build/lib/mmgen/.mim/tools/eval.sh
+10
-0
build/lib/mmgen/.mim/tools/evaluation.py
build/lib/mmgen/.mim/tools/evaluation.py
+249
-0
build/lib/mmgen/.mim/tools/misc/print_config.py
build/lib/mmgen/.mim/tools/misc/print_config.py
+39
-0
build/lib/mmgen/.mim/tools/publish_model.py
build/lib/mmgen/.mim/tools/publish_model.py
+39
-0
build/lib/mmgen/.mim/tools/slurm_eval.sh
build/lib/mmgen/.mim/tools/slurm_eval.sh
+24
-0
build/lib/mmgen/.mim/tools/slurm_eval_multi_gpu.sh
build/lib/mmgen/.mim/tools/slurm_eval_multi_gpu.sh
+24
-0
build/lib/mmgen/.mim/tools/slurm_train.sh
build/lib/mmgen/.mim/tools/slurm_train.sh
+24
-0
build/lib/mmgen/.mim/tools/train.py
build/lib/mmgen/.mim/tools/train.py
+230
-0
build/lib/mmgen/.mim/tools/utils/inception_stat.py
build/lib/mmgen/.mim/tools/utils/inception_stat.py
+168
-0
build/lib/mmgen/.mim/tools/utils/singan_inference.py
build/lib/mmgen/.mim/tools/utils/singan_inference.py
+111
-0
build/lib/mmgen/.mim/tools/utils/translation_eval.py
build/lib/mmgen/.mim/tools/utils/translation_eval.py
+329
-0
build/lib/mmgen/__init__.py
build/lib/mmgen/__init__.py
+29
-0
build/lib/mmgen/apis/__init__.py
build/lib/mmgen/apis/__init__.py
+11
-0
build/lib/mmgen/apis/inference.py
build/lib/mmgen/apis/inference.py
+297
-0
build/lib/mmgen/apis/train.py
build/lib/mmgen/apis/train.py
+207
-0
build/lib/mmgen/core/__init__.py
build/lib/mmgen/core/__init__.py
+7
-0
build/lib/mmgen/core/ddp_wrapper.py
build/lib/mmgen/core/ddp_wrapper.py
+136
-0
build/lib/mmgen/core/evaluation/__init__.py
build/lib/mmgen/core/evaluation/__init__.py
+14
-0
build/lib/mmgen/core/evaluation/eval_hooks.py
build/lib/mmgen/core/evaluation/eval_hooks.py
+524
-0
No files found.
Too many changes to show.
To preserve performance only
463 of 463+
files are displayed.
Plain diff
Email patch
build/lib/mmgen/.mim/tools/dist_train.sh
0 → 100644
View file @
1401de15
#!/usr/bin/env bash
CONFIG
=
$1
GPUS
=
$2
NNODES
=
${
NNODES
:-
1
}
NODE_RANK
=
${
NODE_RANK
:-
0
}
PORT
=
${
PORT
:-
29500
}
MASTER_ADDR
=
${
MASTER_ADDR
:-
"127.0.0.1"
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
python
-m
torch.distributed.launch
\
--nnodes
=
$NNODES
\
--node_rank
=
$NODE_RANK
\
--master_addr
=
$MASTER_ADDR
\
--nproc_per_node
=
$GPUS
\
--master_port
=
$PORT
\
$(
dirname
"
$0
"
)
/train.py
\
$CONFIG
\
--seed
0
\
--launcher
pytorch
${
@
:3
}
build/lib/mmgen/.mim/tools/eval.sh
0 → 100644
View file @
1401de15
#!/usr/bin/env bash
set
-x
CONFIG
=
$1
CKPT
=
$2
PY_ARGS
=
${
@
:3
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
python
-u
tools/evaluation.py
${
CONFIG
}
${
CKPT
}
--launcher
=
"none"
${
PY_ARGS
}
build/lib/mmgen/.mim/tools/evaluation.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os
import
warnings
import
mmcv
import
torch
from
mmcv
import
Config
,
DictAction
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.runner
import
get_dist_info
,
init_dist
,
load_checkpoint
from
mmgen.apis
import
set_random_seed
from
mmgen.core
import
build_metric
,
offline_evaluation
,
online_evaluation
from
mmgen.datasets
import
build_dataloader
,
build_dataset
from
mmgen.models
import
build_model
from
mmgen.utils
import
get_root_logger
_distributed_metrics
=
[
'FID'
,
'IS'
]
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Evaluate a Generation model'
)
parser
.
add_argument
(
'config'
,
help
=
'evaluation config file path'
)
parser
.
add_argument
(
'checkpoint'
,
help
=
'checkpoint file'
)
parser
.
add_argument
(
'--launcher'
,
choices
=
[
'none'
,
'pytorch'
,
'slurm'
,
'mpi'
],
default
=
'none'
,
help
=
'job launcher'
)
group_gpus
=
parser
.
add_mutually_exclusive_group
()
group_gpus
.
add_argument
(
'--gpus'
,
type
=
int
,
help
=
'number of gpus to use '
'(only applicable to non-distributed training)'
)
parser
.
add_argument
(
'--gpu-ids'
,
type
=
int
,
nargs
=
'+'
,
help
=
'(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed training)'
)
parser
.
add_argument
(
'--gpu-id'
,
type
=
int
,
default
=
0
,
help
=
'id of gpu to use '
'(only applicable to non-distributed testing)'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
2021
,
help
=
'random seed'
)
parser
.
add_argument
(
'--deterministic'
,
action
=
'store_true'
,
help
=
'whether to set deterministic options for CUDNN backend.'
)
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
10
,
help
=
'batch size of dataloader'
)
parser
.
add_argument
(
'--samples-path'
,
type
=
str
,
default
=
None
,
help
=
'path to store images. If not given, remove it after evaluation
\
finished'
)
parser
.
add_argument
(
'--sample-model'
,
type
=
str
,
default
=
'ema'
,
choices
=
[
'ema'
,
'orig'
],
help
=
'use which mode (ema/orig) in sampling'
)
parser
.
add_argument
(
'--eval'
,
nargs
=
'*'
,
type
=
str
,
default
=
None
,
help
=
'select the metrics you want to access'
)
parser
.
add_argument
(
'--online'
,
action
=
'store_true'
,
help
=
'whether to use online mode for evaluation'
)
parser
.
add_argument
(
'--num-samples'
,
type
=
int
,
default
=-
1
,
help
=
'The number of images to be sampled for evaluation.'
)
parser
.
add_argument
(
'--cfg-options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file.'
)
parser
.
add_argument
(
'--sample-cfg'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'Other customized kwargs for sampling function'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
0
)
args
=
parser
.
parse_args
()
if
'LOCAL_RANK'
not
in
os
.
environ
:
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
args
.
local_rank
)
return
args
def
main
():
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
if
args
.
cfg_options
is
not
None
:
cfg
.
merge_from_dict
(
args
.
cfg_options
)
# set cudnn_benchmark
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
torch
.
backends
.
cudnn
.
benchmark
=
True
if
args
.
gpu_ids
is
not
None
:
cfg
.
gpu_ids
=
args
.
gpu_ids
[
0
:
1
]
warnings
.
warn
(
'`--gpu-ids` is deprecated, please use `--gpu-id`. '
'Because we only support single GPU mode in '
'non-distributed testing. Use the first GPU '
'in `gpu_ids` now.'
)
else
:
cfg
.
gpu_ids
=
[
args
.
gpu_id
]
# init distributed env first, since logger depends on the dist info.
if
args
.
launcher
==
'none'
:
distributed
=
False
rank
=
0
else
:
distributed
=
True
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
rank
,
world_size
=
get_dist_info
()
cfg
.
gpu_ids
=
range
(
world_size
)
assert
args
.
online
or
world_size
==
1
,
(
'We only support online mode for distrbuted evaluation.'
)
dirname
=
os
.
path
.
dirname
(
args
.
checkpoint
)
ckpt
=
os
.
path
.
basename
(
args
.
checkpoint
)
if
'http'
in
args
.
checkpoint
:
log_path
=
None
else
:
log_name
=
ckpt
.
split
(
'.'
)[
0
]
+
'_eval_log'
+
'.txt'
log_path
=
os
.
path
.
join
(
dirname
,
log_name
)
logger
=
get_root_logger
(
log_file
=
log_path
,
log_level
=
cfg
.
log_level
,
file_mode
=
'a'
)
logger
.
info
(
'evaluation'
)
# set random seeds
if
args
.
seed
is
not
None
:
if
rank
==
0
:
mmcv
.
print_log
(
f
'set random seed to
{
args
.
seed
}
'
,
'mmgen'
)
set_random_seed
(
args
.
seed
,
deterministic
=
args
.
deterministic
)
# build the model and load checkpoint
model
=
build_model
(
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
# sanity check for models without ema
if
not
model
.
use_ema
:
args
.
sample_model
=
'orig'
mmcv
.
print_log
(
f
'Sampling model:
{
args
.
sample_model
}
'
,
'mmgen'
)
model
.
eval
()
if
args
.
eval
:
if
args
.
eval
[
0
]
==
'none'
:
# only sample images
metrics
=
[]
assert
args
.
num_samples
is
not
None
and
args
.
num_samples
>
0
else
:
metrics
=
[
build_metric
(
cfg
.
metrics
[
metric
])
for
metric
in
args
.
eval
]
else
:
metrics
=
[
build_metric
(
cfg
.
metrics
[
metric
])
for
metric
in
cfg
.
metrics
]
# check metrics for dist evaluation
if
distributed
and
metrics
:
for
metric
in
metrics
:
assert
metric
.
name
in
_distributed_metrics
,
(
f
'We only support
{
_distributed_metrics
}
for multi gpu '
f
'evaluation, but receive
{
args
.
eval
}
.'
)
_
=
load_checkpoint
(
model
,
args
.
checkpoint
,
map_location
=
'cpu'
)
basic_table_info
=
dict
(
train_cfg
=
os
.
path
.
basename
(
cfg
.
_filename
),
ckpt
=
ckpt
,
sample_model
=
args
.
sample_model
)
if
len
(
metrics
)
==
0
:
basic_table_info
[
'num_samples'
]
=
args
.
num_samples
data_loader
=
None
else
:
basic_table_info
[
'num_samples'
]
=
-
1
# build the dataloader
if
cfg
.
data
.
get
(
'test'
,
None
)
and
cfg
.
data
.
test
.
get
(
'imgs_root'
,
None
):
dataset
=
build_dataset
(
cfg
.
data
.
test
)
elif
cfg
.
data
.
get
(
'val'
,
None
)
and
cfg
.
data
.
val
.
get
(
'imgs_root'
,
None
):
dataset
=
build_dataset
(
cfg
.
data
.
val
)
elif
cfg
.
data
.
get
(
'train'
,
None
):
# we assume that the train part should work well
dataset
=
build_dataset
(
cfg
.
data
.
train
)
else
:
raise
RuntimeError
(
'There is no valid dataset config to run, '
'please check your dataset configs.'
)
# The default loader config
loader_cfg
=
dict
(
samples_per_gpu
=
args
.
batch_size
,
workers_per_gpu
=
cfg
.
data
.
get
(
'val_workers_per_gpu'
,
cfg
.
data
.
workers_per_gpu
),
num_gpus
=
len
(
cfg
.
gpu_ids
),
dist
=
distributed
,
shuffle
=
True
)
# The overall dataloader settings
loader_cfg
.
update
({
k
:
v
for
k
,
v
in
cfg
.
data
.
items
()
if
k
not
in
[
'train'
,
'val'
,
'test'
,
'train_dataloader'
,
'val_dataloader'
,
'test_dataloader'
]
})
# specific config for test loader
test_loader_cfg
=
{
**
loader_cfg
,
**
cfg
.
data
.
get
(
'test_dataloader'
,
{})}
data_loader
=
build_dataloader
(
dataset
,
**
test_loader_cfg
)
if
args
.
sample_cfg
is
None
:
args
.
sample_cfg
=
dict
()
if
not
distributed
:
model
=
MMDataParallel
(
model
,
device_ids
=
[
0
])
else
:
find_unused_parameters
=
cfg
.
get
(
'find_unused_parameters'
,
False
)
model
=
MMDistributedDataParallel
(
model
.
cuda
(),
device_ids
=
[
torch
.
cuda
.
current_device
()],
broadcast_buffers
=
False
,
find_unused_parameters
=
find_unused_parameters
)
# online mode will not save samples
if
args
.
online
and
len
(
metrics
)
>
0
:
online_evaluation
(
model
,
data_loader
,
metrics
,
logger
,
basic_table_info
,
args
.
batch_size
,
**
args
.
sample_cfg
)
else
:
offline_evaluation
(
model
,
data_loader
,
metrics
,
logger
,
basic_table_info
,
args
.
batch_size
,
args
.
samples_path
,
**
args
.
sample_cfg
)
if
__name__
==
'__main__'
:
main
()
build/lib/mmgen/.mim/tools/misc/print_config.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
from
mmcv
import
Config
,
DictAction
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Print the whole config'
)
parser
.
add_argument
(
'config'
,
help
=
'config file path'
)
parser
.
add_argument
(
'--cfg-options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
if
args
.
cfg_options
is
not
None
:
cfg
.
merge_from_dict
(
args
.
cfg_options
)
# import modules from string list.
if
cfg
.
get
(
'custom_imports'
,
None
):
from
mmcv.utils
import
import_modules_from_strings
import_modules_from_strings
(
**
cfg
[
'custom_imports'
])
print
(
f
'Config:
\n
{
cfg
.
pretty_text
}
'
)
if
__name__
==
'__main__'
:
main
()
build/lib/mmgen/.mim/tools/publish_model.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
subprocess
from
datetime
import
datetime
import
torch
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Process a checkpoint to be published'
)
parser
.
add_argument
(
'in_file'
,
help
=
'input checkpoint filename'
)
parser
.
add_argument
(
'out_file'
,
help
=
'output checkpoint filename'
)
args
=
parser
.
parse_args
()
return
args
def
process_checkpoint
(
in_file
,
out_file
):
checkpoint
=
torch
.
load
(
in_file
,
map_location
=
'cpu'
)
# remove optimizer for smaller file size
if
'optimizer'
in
checkpoint
:
del
checkpoint
[
'optimizer'
]
# if it is necessary to remove some sensitive data in checkpoint['meta'],
# add the code here.
torch
.
save
(
checkpoint
,
out_file
)
now
=
datetime
.
now
()
time
=
now
.
strftime
(
'%Y%m%d_%H%M%S'
)
sha
=
subprocess
.
check_output
([
'sha256sum'
,
out_file
]).
decode
()
final_file
=
out_file
.
rstrip
(
'.pth'
)
+
f
'_
{
time
}
-
{
sha
[:
8
]
}
.pth'
subprocess
.
Popen
([
'mv'
,
out_file
,
final_file
])
def
main
():
args
=
parse_args
()
process_checkpoint
(
args
.
in_file
,
args
.
out_file
)
if
__name__
==
'__main__'
:
main
()
build/lib/mmgen/.mim/tools/slurm_eval.sh
0 → 100644
View file @
1401de15
#!/usr/bin/env bash
set
-x
PARTITION
=
$1
JOB_NAME
=
$2
CONFIG
=
$3
CKPT
=
$4
GPUS
=
${
GPUS
:-
1
}
GPUS_PER_NODE
=
${
GPUS_PER_NODE
:-
1
}
CPUS_PER_TASK
=
${
CPUS_PER_TASK
:-
5
}
PY_ARGS
=
${
@
:5
}
SRUN_ARGS
=
${
SRUN_ARGS
:-
""
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
srun
-p
${
PARTITION
}
\
--job-name
=
${
JOB_NAME
}
\
--gres
=
gpu:
${
GPUS_PER_NODE
}
\
--ntasks
=
${
GPUS
}
\
--ntasks-per-node
=
${
GPUS_PER_NODE
}
\
--cpus-per-task
=
${
CPUS_PER_TASK
}
\
--kill-on-bad-exit
=
1
\
${
SRUN_ARGS
}
\
python
-u
tools/evaluation.py
${
CONFIG
}
${
CKPT
}
--launcher
=
"none"
${
PY_ARGS
}
build/lib/mmgen/.mim/tools/slurm_eval_multi_gpu.sh
0 → 100644
View file @
1401de15
#!/usr/bin/env bash
set
-x
PARTITION
=
$1
JOB_NAME
=
$2
CONFIG
=
$3
CKPT
=
$4
GPUS
=
${
GPUS
:-
8
}
GPUS_PER_NODE
=
${
GPUS_PER_NODE
:-
8
}
CPUS_PER_TASK
=
${
CPUS_PER_TASK
:-
5
}
PY_ARGS
=
${
@
:5
}
SRUN_ARGS
=
${
SRUN_ARGS
:-
""
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
srun
-p
${
PARTITION
}
\
--job-name
=
${
JOB_NAME
}
\
--gres
=
gpu:
${
GPUS_PER_NODE
}
\
--ntasks
=
${
GPUS
}
\
--ntasks-per-node
=
${
GPUS_PER_NODE
}
\
--cpus-per-task
=
${
CPUS_PER_TASK
}
\
--kill-on-bad-exit
=
1
\
${
SRUN_ARGS
}
\
python
-u
tools/evaluation.py
${
CONFIG
}
${
CKPT
}
--launcher
=
"slurm"
${
PY_ARGS
}
build/lib/mmgen/.mim/tools/slurm_train.sh
0 → 100644
View file @
1401de15
#!/usr/bin/env bash
set
-x
PARTITION
=
$1
JOB_NAME
=
$2
CONFIG
=
$3
WORK_DIR
=
$4
GPUS
=
${
GPUS
:-
8
}
GPUS_PER_NODE
=
${
GPUS_PER_NODE
:-
8
}
CPUS_PER_TASK
=
${
CPUS_PER_TASK
:-
5
}
PY_ARGS
=
${
@
:5
}
SRUN_ARGS
=
${
SRUN_ARGS
:-
""
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
srun
-p
${
PARTITION
}
\
--job-name
=
${
JOB_NAME
}
\
--gres
=
gpu:
${
GPUS_PER_NODE
}
\
--ntasks
=
${
GPUS
}
\
--ntasks-per-node
=
${
GPUS_PER_NODE
}
\
--cpus-per-task
=
${
CPUS_PER_TASK
}
\
--kill-on-bad-exit
=
1
\
${
SRUN_ARGS
}
\
python
-u
tools/train.py
${
CONFIG
}
--work-dir
=
${
WORK_DIR
}
--launcher
=
"slurm"
${
PY_ARGS
}
build/lib/mmgen/.mim/tools/train.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
copy
import
multiprocessing
as
mp
import
os
import
os.path
as
osp
import
platform
import
time
import
warnings
import
pdb
import
cv2
import
mmcv
import
torch
from
mmcv
import
Config
,
DictAction
from
mmcv.runner
import
get_dist_info
,
init_dist
from
mmcv.utils
import
get_git_hash
from
mmgen
import
__version__
from
mmgen.apis
import
set_random_seed
,
train_model
from
mmgen.datasets
import
build_dataset
from
mmgen.models
import
build_model
from
mmgen.utils
import
collect_env
,
get_root_logger
cv2
.
setNumThreads
(
0
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a GAN model'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
parser
.
add_argument
(
'--work-dir'
,
help
=
'the dir to save logs and models'
)
parser
.
add_argument
(
'--resume-from'
,
help
=
'the checkpoint file to resume from'
)
parser
.
add_argument
(
'--no-validate'
,
action
=
'store_true'
,
help
=
'whether not to evaluate the checkpoint during training'
)
group_gpus
=
parser
.
add_mutually_exclusive_group
()
group_gpus
.
add_argument
(
'--gpus'
,
type
=
int
,
help
=
'(Deprecated, please use --gpu-id) number of gpus to use '
'(only applicable to non-distributed training)'
)
group_gpus
.
add_argument
(
'--gpu-ids'
,
type
=
int
,
nargs
=
'+'
,
help
=
'(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed training)'
)
group_gpus
.
add_argument
(
'--gpu-id'
,
type
=
int
,
default
=
0
,
help
=
'id of gpu to use '
'(only applicable to non-distributed training)'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
2021
,
help
=
'random seed'
)
parser
.
add_argument
(
'--diff_seed'
,
action
=
'store_true'
,
help
=
'Whether or not set different seeds for different ranks'
)
parser
.
add_argument
(
'--deterministic'
,
action
=
'store_true'
,
help
=
'whether to set deterministic options for CUDNN backend.'
)
parser
.
add_argument
(
'--cfg-options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file.'
)
parser
.
add_argument
(
'--launcher'
,
choices
=
[
'none'
,
'pytorch'
,
'slurm'
,
'mpi'
],
default
=
'none'
,
help
=
'job launcher'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
0
)
args
=
parser
.
parse_args
()
if
'LOCAL_RANK'
not
in
os
.
environ
:
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
args
.
local_rank
)
return
args
def
setup_multi_processes
(
cfg
):
# set multi-process start method as `fork` to speed up the training
if
platform
.
system
()
!=
'Windows'
:
mp_start_method
=
cfg
.
get
(
'mp_start_method'
,
'fork'
)
mp
.
set_start_method
(
mp_start_method
)
# disable opencv multithreading to avoid system being overloaded
opencv_num_threads
=
cfg
.
get
(
'opencv_num_threads'
,
0
)
cv2
.
setNumThreads
(
opencv_num_threads
)
# setup OMP threads
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
if
(
'OMP_NUM_THREADS'
not
in
os
.
environ
and
cfg
.
data
.
workers_per_gpu
>
1
):
omp_num_threads
=
1
warnings
.
warn
(
f
'Setting OMP_NUM_THREADS environment variable for each process '
f
'to be
{
omp_num_threads
}
in default, to avoid your system being '
f
'overloaded, please further tune the variable for optimal '
f
'performance in your application as needed.'
)
os
.
environ
[
'OMP_NUM_THREADS'
]
=
str
(
omp_num_threads
)
# setup MKL threads
if
'MKL_NUM_THREADS'
not
in
os
.
environ
and
cfg
.
data
.
workers_per_gpu
>
1
:
mkl_num_threads
=
1
warnings
.
warn
(
f
'Setting MKL_NUM_THREADS environment variable for each process '
f
'to be
{
mkl_num_threads
}
in default, to avoid your system being '
f
'overloaded, please further tune the variable for optimal '
f
'performance in your application as needed.'
)
os
.
environ
[
'MKL_NUM_THREADS'
]
=
str
(
mkl_num_threads
)
def
main
():
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
if
args
.
cfg_options
is
not
None
:
cfg
.
merge_from_dict
(
args
.
cfg_options
)
setup_multi_processes
(
cfg
)
# import modules from string list.
if
cfg
.
get
(
'custom_imports'
,
None
):
from
mmcv.utils
import
import_modules_from_strings
import_modules_from_strings
(
**
cfg
[
'custom_imports'
])
# set cudnn_benchmark
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
torch
.
backends
.
cudnn
.
benchmark
=
True
# work_dir is determined in this priority: CLI > segment in file > filename
if
args
.
work_dir
is
not
None
:
# update configs according to CLI args if args.work_dir is not None
cfg
.
work_dir
=
args
.
work_dir
elif
cfg
.
get
(
'work_dir'
,
None
)
is
None
:
# use config filename as default work_dir if cfg.work_dir is None
cfg
.
work_dir
=
osp
.
join
(
'./work_dirs'
,
osp
.
splitext
(
osp
.
basename
(
args
.
config
))[
0
])
if
args
.
resume_from
is
not
None
:
cfg
.
resume_from
=
args
.
resume_from
if
args
.
gpus
is
not
None
:
cfg
.
gpu_ids
=
range
(
1
)
warnings
.
warn
(
'`--gpus` is deprecated because we only support '
'single GPU mode in non-distributed training. '
'Use `gpus=1` now.'
)
if
args
.
gpu_ids
is
not
None
:
cfg
.
gpu_ids
=
args
.
gpu_ids
[
0
:
1
]
warnings
.
warn
(
'`--gpu-ids` is deprecated, please use `--gpu-id`. '
'Because we only support single GPU mode in '
'non-distributed training. Use the first GPU '
'in `gpu_ids` now.'
)
if
args
.
gpus
is
None
and
args
.
gpu_ids
is
None
:
cfg
.
gpu_ids
=
[
args
.
gpu_id
]
# init distributed env first, since logger depends on the dist info.
if
args
.
launcher
==
'none'
:
distributed
=
False
else
:
distributed
=
True
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
# re-set gpu_ids with distributed training mode
_
,
world_size
=
get_dist_info
()
cfg
.
gpu_ids
=
range
(
world_size
)
# create work_dir
mmcv
.
mkdir_or_exist
(
osp
.
abspath
(
cfg
.
work_dir
))
# dump config
# pdb.set_trace()
cfg
.
dump
(
osp
.
join
(
cfg
.
work_dir
,
osp
.
basename
(
args
.
config
)))
# init the logger before other steps
timestamp
=
time
.
strftime
(
'%Y%m%d_%H%M%S'
,
time
.
localtime
())
log_file
=
osp
.
join
(
cfg
.
work_dir
,
f
'
{
timestamp
}
.log'
)
logger
=
get_root_logger
(
log_file
=
log_file
,
log_level
=
cfg
.
log_level
)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta
=
dict
()
# log env info
env_info_dict
=
collect_env
()
env_info
=
'
\n
'
.
join
([(
f
'
{
k
}
:
{
v
}
'
)
for
k
,
v
in
env_info_dict
.
items
()])
dash_line
=
'-'
*
60
+
'
\n
'
logger
.
info
(
'Environment info:
\n
'
+
dash_line
+
env_info
+
'
\n
'
+
dash_line
)
meta
[
'env_info'
]
=
env_info
meta
[
'config'
]
=
cfg
.
pretty_text
# log some basic info
logger
.
info
(
f
'Distributed training:
{
distributed
}
'
)
logger
.
info
(
f
'Config:
\n
{
cfg
.
pretty_text
}
'
)
# set random seeds
if
args
.
seed
is
not
None
:
logger
.
info
(
f
'Set random seed to
{
args
.
seed
}
, '
f
'deterministic:
{
args
.
deterministic
}
, '
f
'use_rank_shift:
{
args
.
diff_seed
}
'
)
set_random_seed
(
args
.
seed
,
deterministic
=
args
.
deterministic
,
use_rank_shift
=
args
.
diff_seed
)
cfg
.
seed
=
args
.
seed
meta
[
'seed'
]
=
args
.
seed
meta
[
'exp_name'
]
=
osp
.
basename
(
args
.
config
)
model
=
build_model
(
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
datasets
=
[
build_dataset
(
cfg
.
data
.
train
)]
if
len
(
cfg
.
workflow
)
==
2
:
val_dataset
=
copy
.
deepcopy
(
cfg
.
data
.
val
)
val_dataset
.
pipeline
=
cfg
.
data
.
val
.
pipeline
datasets
.
append
(
build_dataset
(
val_dataset
))
if
cfg
.
checkpoint_config
is
not
None
:
# save mmgen version, config file content and class names in
# checkpoints as meta data
cfg
.
checkpoint_config
.
meta
=
dict
(
mmgen_version
=
__version__
+
get_git_hash
()[:
7
])
train_model
(
model
,
datasets
,
cfg
,
distributed
=
distributed
,
validate
=
(
not
args
.
no_validate
),
timestamp
=
timestamp
,
meta
=
meta
)
if
__name__
==
'__main__'
:
main
()
build/lib/mmgen/.mim/tools/utils/inception_stat.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os.path
as
osp
import
pickle
import
sys
import
mmcv
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv
import
Config
,
print_log
# yapf: disable
sys
.
path
.
append
(
osp
.
abspath
(
osp
.
join
(
__file__
,
'../../..'
)))
# isort:skip # noqa
from
mmgen.core.evaluation.metric_utils
import
extract_inception_features
# isort:skip # noqa
from
mmgen.datasets
import
(
UnconditionalImageDataset
,
build_dataloader
,
# isort:skip # noqa
build_dataset
)
# isort:skip # noqa
from
mmgen.models.architectures
import
InceptionV3
# isort:skip # noqa
# yapf: enable
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Pre-calculate inception data and save it in pkl file'
)
parser
.
add_argument
(
'--imgsdir'
,
type
=
str
,
default
=
None
,
help
=
'the dir containing images.'
)
parser
.
add_argument
(
'--data-cfg'
,
type
=
str
,
default
=
None
,
help
=
'the config file for test data pipeline'
)
parser
.
add_argument
(
'--pklname'
,
type
=
str
,
help
=
'the name of inception pkl'
)
parser
.
add_argument
(
'--pkl-dir'
,
type
=
str
,
default
=
'work_dirs/inception_pkl'
,
help
=
'path to save pkl file'
)
parser
.
add_argument
(
'--pipeline-cfg'
,
type
=
str
,
default
=
None
,
help
=
(
'config file containing dataset pipeline. If None, the default'
' pipeline will be adopted'
))
parser
.
add_argument
(
'--flip'
,
action
=
'store_true'
,
help
=
'whether to flip real images'
)
parser
.
add_argument
(
'--size'
,
type
=
int
,
nargs
=
'+'
,
default
=
(
299
,
299
),
help
=
'image size in the data pipeline'
)
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
25
,
help
=
'batch size used in extracted features'
)
parser
.
add_argument
(
'--num-samples'
,
type
=
int
,
default
=
50000
,
help
=
(
'the number of total samples, if input -1, '
'automaticly use all samples in the subset'
))
parser
.
add_argument
(
'--no-shuffle'
,
action
=
'store_true'
,
help
=
'not use shuffle in data loader'
)
parser
.
add_argument
(
'--subset'
,
default
=
'test'
,
help
=
'which subset and corresponding pipeline to use'
)
parser
.
add_argument
(
'--inception-style'
,
choices
=
[
'stylegan'
,
'pytorch'
],
default
=
'pytorch'
,
help
=
'which inception network to use'
)
parser
.
add_argument
(
'--inception-pth'
,
type
=
str
,
default
=
'work_dirs/cache/inception-2015-12-05.pt'
)
args
=
parser
.
parse_args
()
# dataset pipeline (only be used when args.imgsdir is not None)
if
args
.
pipeline_cfg
is
not
None
:
pipeline
=
Config
.
fromfile
(
args
.
pipeline_cfg
)[
'inception_pipeline'
]
elif
args
.
imgsdir
is
not
None
:
if
isinstance
(
args
.
size
,
list
)
and
len
(
args
.
size
)
==
2
:
size
=
args
.
size
elif
isinstance
(
args
.
size
,
list
)
and
len
(
args
.
size
)
==
1
:
size
=
(
args
.
size
[
0
],
args
.
size
[
0
])
elif
isinstance
(
args
.
size
,
int
):
size
=
(
args
.
size
,
args
.
size
)
else
:
raise
TypeError
(
f
'args.size mush be int or tuple but got
{
args
.
size
}
'
)
pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
,
key
=
'real_img'
),
dict
(
type
=
'Resize'
,
keys
=
[
'real_img'
],
scale
=
size
,
keep_ratio
=
False
),
dict
(
type
=
'Normalize'
,
keys
=
[
'real_img'
],
mean
=
[
127.5
]
*
3
,
std
=
[
127.5
]
*
3
,
to_rgb
=
True
),
# default to RGB images
dict
(
type
=
'Collect'
,
keys
=
[
'real_img'
],
meta_keys
=
[]),
dict
(
type
=
'ImageToTensor'
,
keys
=
[
'real_img'
])
]
# insert flip aug
if
args
.
flip
:
pipeline
.
insert
(
1
,
dict
(
type
=
'Flip'
,
keys
=
[
'real_img'
],
direction
=
'horizontal'
))
# build dataloader
if
args
.
imgsdir
is
not
None
:
dataset
=
UnconditionalImageDataset
(
args
.
imgsdir
,
pipeline
)
elif
args
.
data_cfg
is
not
None
:
# Please make sure the dataset will sample images in `RGB` order.
data_config
=
Config
.
fromfile
(
args
.
data_cfg
)
subset_config
=
data_config
.
data
.
get
(
args
.
subset
,
None
)
print_log
(
subset_config
,
'mmgen'
)
dataset
=
build_dataset
(
subset_config
)
else
:
raise
RuntimeError
(
'Please provide imgsdir or data_cfg'
)
data_loader
=
build_dataloader
(
dataset
,
args
.
batch_size
,
4
,
dist
=
False
,
shuffle
=
(
not
args
.
no_shuffle
))
mmcv
.
mkdir_or_exist
(
args
.
pkl_dir
)
# build inception network
if
args
.
inception_style
==
'stylegan'
:
inception
=
torch
.
jit
.
load
(
args
.
inception_pth
).
eval
().
cuda
()
inception
=
nn
.
DataParallel
(
inception
)
print_log
(
'Adopt Inception network in StyleGAN'
,
'mmgen'
)
else
:
inception
=
nn
.
DataParallel
(
InceptionV3
([
3
],
resize_input
=
True
,
normalize_input
=
False
).
cuda
())
inception
.
eval
()
if
args
.
num_samples
==
-
1
:
print_log
(
'Use all samples in subset'
,
'mmgen'
)
num_samples
=
len
(
dataset
)
else
:
num_samples
=
args
.
num_samples
features
=
extract_inception_features
(
data_loader
,
inception
,
num_samples
,
args
.
inception_style
).
numpy
()
# sanity check for the number of features
assert
features
.
shape
[
0
]
==
num_samples
,
'the number of features != num_samples'
print_log
(
f
'Extract
{
num_samples
}
features'
,
'mmgen'
)
mean
=
np
.
mean
(
features
,
0
)
cov
=
np
.
cov
(
features
,
rowvar
=
False
)
with
open
(
osp
.
join
(
args
.
pkl_dir
,
args
.
pklname
),
'wb'
)
as
f
:
pickle
.
dump
(
{
'mean'
:
mean
,
'cov'
:
cov
,
'size'
:
num_samples
,
'name'
:
args
.
pklname
},
f
)
build/lib/mmgen/.mim/tools/utils/singan_inference.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os
import
sys
import
mmcv
import
torch
from
mmcv
import
Config
from
mmcv.parallel
import
MMDataParallel
from
mmcv.runner
import
load_checkpoint
# yapf: disable
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__file__
,
'../../..'
)))
# isort:skip # noqa
from
mmgen.apis
import
set_random_seed
# isort:skip # noqa
from
mmgen.models
import
build_model
# isort:skip # noqa
# yapf: enable
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Evaluate a GAN model'
)
parser
.
add_argument
(
'config'
,
help
=
'evaluation config file path'
)
parser
.
add_argument
(
'checkpoint'
,
help
=
'checkpoint file'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
2021
,
help
=
'random seed'
)
parser
.
add_argument
(
'--deterministic'
,
action
=
'store_true'
,
help
=
'whether to set deterministic options for CUDNN backend.'
)
parser
.
add_argument
(
'--samples-path'
,
type
=
str
,
default
=
None
,
help
=
'path to store images. If not given, remove it after evaluation
\
finished'
)
parser
.
add_argument
(
'--save-prev-res'
,
action
=
'store_true'
,
help
=
'whether to store the results from previous stages'
)
parser
.
add_argument
(
'--num-samples'
,
type
=
int
,
default
=
10
,
help
=
'the number of synthesized samples'
)
args
=
parser
.
parse_args
()
return
args
def
_tensor2img
(
img
):
img
=
img
[
0
].
permute
(
1
,
2
,
0
)
img
=
((
img
+
1
)
/
2
*
255
).
to
(
torch
.
uint8
)
return
img
.
cpu
().
numpy
()
@
torch
.
no_grad
()
def
main
():
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
# set cudnn_benchmark
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
torch
.
backends
.
cudnn
.
benchmark
=
True
# set random seeds
if
args
.
seed
is
not
None
:
set_random_seed
(
args
.
seed
,
deterministic
=
args
.
deterministic
)
# build the model and load checkpoint
model
=
build_model
(
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
model
.
eval
()
# load ckpt
mmcv
.
print_log
(
f
'Loading ckpt from
{
args
.
checkpoint
}
'
,
'mmgen'
)
_
=
load_checkpoint
(
model
,
args
.
checkpoint
,
map_location
=
'cpu'
)
# add dp wrapper
model
=
MMDataParallel
(
model
,
device_ids
=
[
0
])
pbar
=
mmcv
.
ProgressBar
(
args
.
num_samples
)
for
sample_iter
in
range
(
args
.
num_samples
):
outputs
=
model
(
None
,
num_batches
=
1
,
get_prev_res
=
args
.
save_prev_res
)
# store results from previous stages
if
args
.
save_prev_res
:
fake_img
=
outputs
[
'fake_img'
]
prev_res_list
=
outputs
[
'prev_res_list'
]
prev_res_list
.
append
(
fake_img
)
for
i
,
img
in
enumerate
(
prev_res_list
):
img
=
_tensor2img
(
img
)
mmcv
.
imwrite
(
img
,
os
.
path
.
join
(
args
.
samples_path
,
f
'stage
{
i
}
'
,
f
'rand_sample_
{
sample_iter
}
.png'
))
# just store the final result
else
:
img
=
_tensor2img
(
outputs
)
mmcv
.
imwrite
(
img
,
os
.
path
.
join
(
args
.
samples_path
,
f
'rand_sample_
{
sample_iter
}
.png'
))
pbar
.
update
()
# change the line after pbar
sys
.
stdout
.
write
(
'
\n
'
)
if
__name__
==
'__main__'
:
main
()
build/lib/mmgen/.mim/tools/utils/translation_eval.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os
import
shutil
import
sys
import
mmcv
import
torch
from
mmcv
import
Config
from
mmcv.parallel
import
MMDataParallel
from
mmcv.runner
import
load_checkpoint
from
torchvision.utils
import
save_image
from
mmgen.apis
import
set_random_seed
from
mmgen.core
import
build_metric
from
mmgen.core.evaluation
import
make_metrics_table
,
make_vanilla_dataloader
from
mmgen.datasets
import
build_dataloader
,
build_dataset
from
mmgen.models
import
build_model
from
mmgen.models.translation_models
import
BaseTranslationModel
from
mmgen.utils
import
get_root_logger
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Evaluate a GAN model'
)
parser
.
add_argument
(
'config'
,
help
=
'evaluation config file path'
)
parser
.
add_argument
(
'checkpoint'
,
help
=
'checkpoint file'
)
parser
.
add_argument
(
'--target-domain'
,
type
=
str
,
default
=
None
,
help
=
'Desired image domain'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
2021
,
help
=
'random seed'
)
parser
.
add_argument
(
'--deterministic'
,
action
=
'store_true'
,
help
=
'whether to set deterministic options for CUDNN backend.'
)
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
1
,
help
=
'batch size of dataloader'
)
parser
.
add_argument
(
'--samples-path'
,
type
=
str
,
default
=
None
,
help
=
'path to store images. If not given, remove it after evaluation
\
finished'
)
parser
.
add_argument
(
'--sample-model'
,
type
=
str
,
default
=
'ema'
,
help
=
'use which mode (ema/orig) in sampling'
)
parser
.
add_argument
(
'--eval'
,
nargs
=
'*'
,
type
=
str
,
default
=
None
,
help
=
'select the metrics you want to access'
)
parser
.
add_argument
(
'--online'
,
action
=
'store_true'
,
help
=
'whether to use online mode for evaluation'
)
args
=
parser
.
parse_args
()
return
args
@
torch
.
no_grad
()
def
single_gpu_evaluation
(
model
,
data_loader
,
metrics
,
logger
,
basic_table_info
,
batch_size
,
samples_path
=
None
,
**
kwargs
):
"""Evaluate model with a single gpu.
This method evaluate model with a single gpu and displays eval progress
bar.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): PyTorch data loader.
metrics (list): List of metric objects.
logger (Logger): logger used to record results of evaluation.
basic_table_info (dict): Dictionary containing the basic information
\
of the metric table include training configuration and ckpt.
batch_size (int): Batch size of images fed into metrics.
samples_path (str): Used to save generated images. If it's none, we'll
give it a default directory and delete it after finishing the
evaluation. Default to None.
kwargs (dict): Other arguments.
"""
# decide samples path
delete_samples_path
=
False
if
samples_path
:
mmcv
.
mkdir_or_exist
(
samples_path
)
else
:
temp_path
=
'./work_dirs/temp_samples'
# if temp_path exists, add suffix
suffix
=
1
samples_path
=
temp_path
while
os
.
path
.
exists
(
samples_path
):
samples_path
=
temp_path
+
'_'
+
str
(
suffix
)
suffix
+=
1
os
.
makedirs
(
samples_path
)
delete_samples_path
=
True
# sample images
num_exist
=
len
(
list
(
mmcv
.
scandir
(
samples_path
,
suffix
=
(
'.jpg'
,
'.png'
,
'.jpeg'
,
'.JPEG'
))))
if
basic_table_info
[
'num_samples'
]
>
0
:
max_num_images
=
basic_table_info
[
'num_samples'
]
else
:
max_num_images
=
max
(
metric
.
num_images
for
metric
in
metrics
)
num_needed
=
max
(
max_num_images
-
num_exist
,
0
)
if
num_needed
>
0
:
mmcv
.
print_log
(
f
'Sample
{
num_needed
}
fake images for evaluation'
,
'mmgen'
)
# define mmcv progress bar
pbar
=
mmcv
.
ProgressBar
(
num_needed
)
# select key to fetch fake images
target_domain
=
basic_table_info
[
'target_domain'
]
source_domain
=
basic_table_info
[
'source_domain'
]
# if no images, `num_needed` should be zero
data_loader_iter
=
iter
(
data_loader
)
for
begin
in
range
(
0
,
num_needed
,
batch_size
):
end
=
min
(
begin
+
batch_size
,
max_num_images
)
# for translation model, we feed them images from dataloader
data_batch
=
next
(
data_loader_iter
)
output_dict
=
model
(
data_batch
[
f
'img_
{
source_domain
}
'
],
test_mode
=
True
,
target_domain
=
target_domain
)
fakes
=
output_dict
[
'target'
]
pbar
.
update
(
end
-
begin
)
for
i
in
range
(
end
-
begin
):
images
=
fakes
[
i
:
i
+
1
]
images
=
((
images
+
1
)
/
2
)
images
=
images
[:,
[
2
,
1
,
0
],
...]
images
=
images
.
clamp_
(
0
,
1
)
image_name
=
str
(
begin
+
i
)
+
'.png'
save_image
(
images
,
os
.
path
.
join
(
samples_path
,
image_name
))
if
num_needed
>
0
:
sys
.
stdout
.
write
(
'
\n
'
)
# return if only save sampled images
if
len
(
metrics
)
==
0
:
return
# empty cache to release GPU memory
torch
.
cuda
.
empty_cache
()
fake_dataloader
=
make_vanilla_dataloader
(
samples_path
,
batch_size
)
for
metric
in
metrics
:
mmcv
.
print_log
(
f
'Evaluate with
{
metric
.
name
}
metric.'
,
'mmgen'
)
metric
.
prepare
()
# feed in real images
for
data
in
data_loader
:
reals
=
data
[
f
'img_
{
target_domain
}
'
]
num_left
=
metric
.
feed
(
reals
,
'reals'
)
if
num_left
<=
0
:
break
# feed in fake images
for
data
in
fake_dataloader
:
fakes
=
data
[
'real_img'
]
num_left
=
metric
.
feed
(
fakes
,
'fakes'
)
if
num_left
<=
0
:
break
metric
.
summary
()
table_str
=
make_metrics_table
(
basic_table_info
[
'train_cfg'
],
basic_table_info
[
'ckpt'
],
basic_table_info
[
'sample_model'
],
metrics
)
logger
.
info
(
'
\n
'
+
table_str
)
if
delete_samples_path
:
shutil
.
rmtree
(
samples_path
)
@
torch
.
no_grad
()
def
single_gpu_online_evaluation
(
model
,
data_loader
,
metrics
,
logger
,
basic_table_info
,
batch_size
,
**
kwargs
):
"""Evaluate model with a single gpu in online mode.
This method evaluate model with a single gpu and displays eval progress
bar. Different form `single_gpu_evaluation`, this function will not save
the images or read images from disks. Namely, there do not exist any IO
operations in this function. Thus, in general, `online` mode will achieve a
faster evaluation. However, this mode will take much more memory cost.
Therefore this evaluation function is recommended to evaluate your model
with a single metric.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): PyTorch data loader.
metrics (list): List of metric objects.
logger (Logger): logger used to record results of evaluation.
basic_table_info (dict): Dictionary containing the basic information
\
of the metric table include training configuration and ckpt.
batch_size (int): Batch size of images fed into metrics.
kwargs (dict): Other arguments.
"""
# sample images
max_num_images
=
0
if
len
(
metrics
)
==
0
else
max
(
metric
.
num_fake_need
for
metric
in
metrics
)
pbar
=
mmcv
.
ProgressBar
(
max_num_images
)
# select key to fetch images
target_domain
=
basic_table_info
[
'target_domain'
]
source_domain
=
basic_table_info
[
'source_domain'
]
for
metric
in
metrics
:
mmcv
.
print_log
(
f
'Evaluate with
{
metric
.
name
}
metric.'
,
'mmgen'
)
metric
.
prepare
()
# feed reals and fakes
data_loader_iter
=
iter
(
data_loader
)
for
begin
in
range
(
0
,
max_num_images
,
batch_size
):
end
=
min
(
begin
+
batch_size
,
max_num_images
)
# for translation model, we feed them images from dataloader
data_batch
=
next
(
data_loader_iter
)
output_dict
=
model
(
data_batch
[
f
'img_
{
source_domain
}
'
],
test_mode
=
True
,
target_domain
=
target_domain
)
fakes
=
output_dict
[
'target'
]
reals
=
data_batch
[
f
'img_
{
target_domain
}
'
]
pbar
.
update
(
end
-
begin
)
for
metric
in
metrics
:
metric
.
feed
(
reals
,
'reals'
)
metric
.
feed
(
fakes
,
'fakes'
)
for
metric
in
metrics
:
metric
.
summary
()
table_str
=
make_metrics_table
(
basic_table_info
[
'train_cfg'
],
basic_table_info
[
'ckpt'
],
basic_table_info
[
'sample_model'
],
metrics
)
logger
.
info
(
'
\n
'
+
table_str
)
def
main
():
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
dirname
=
os
.
path
.
dirname
(
args
.
checkpoint
)
ckpt
=
os
.
path
.
basename
(
args
.
checkpoint
)
if
'http'
in
args
.
checkpoint
:
log_path
=
None
else
:
log_name
=
ckpt
.
split
(
'.'
)[
0
]
+
'_eval_log'
+
'.txt'
log_path
=
os
.
path
.
join
(
dirname
,
log_name
)
logger
=
get_root_logger
(
log_file
=
log_path
,
log_level
=
cfg
.
log_level
,
file_mode
=
'a'
)
logger
.
info
(
'evaluation'
)
# set random seeds
if
args
.
seed
is
not
None
:
set_random_seed
(
args
.
seed
,
deterministic
=
args
.
deterministic
)
# build the model and load checkpoint
model
=
build_model
(
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
assert
isinstance
(
model
,
BaseTranslationModel
)
# sanity check for models without ema
if
not
model
.
use_ema
:
args
.
sample_model
=
'orig'
mmcv
.
print_log
(
f
'Sampling model:
{
args
.
sample_model
}
'
,
'mmgen'
)
model
.
eval
()
_
=
load_checkpoint
(
model
,
args
.
checkpoint
,
map_location
=
'cpu'
)
model
=
MMDataParallel
(
model
,
device_ids
=
[
0
])
# build metrics
if
args
.
eval
:
if
args
.
eval
[
0
]
==
'none'
:
# only sample images
metrics
=
[]
assert
args
.
num_samples
is
not
None
and
args
.
num_samples
>
0
else
:
metrics
=
[
build_metric
(
cfg
.
metrics
[
metric
])
for
metric
in
args
.
eval
]
else
:
metrics
=
[
build_metric
(
cfg
.
metrics
[
metric
])
for
metric
in
cfg
.
metrics
]
# get source domain and target domain
target_domain
=
args
.
target_domain
if
target_domain
is
None
:
target_domain
=
model
.
module
.
_default_domain
source_domain
=
model
.
module
.
get_other_domains
(
target_domain
)[
0
]
basic_table_info
=
dict
(
train_cfg
=
os
.
path
.
basename
(
cfg
.
_filename
),
ckpt
=
ckpt
,
sample_model
=
args
.
sample_model
,
source_domain
=
source_domain
,
target_domain
=
target_domain
)
# build the dataloader
if
len
(
metrics
)
==
0
:
basic_table_info
[
'num_samples'
]
=
args
.
num_samples
data_loader
=
None
else
:
basic_table_info
[
'num_samples'
]
=
-
1
if
cfg
.
data
.
get
(
'test'
,
None
):
dataset
=
build_dataset
(
cfg
.
data
.
test
)
else
:
dataset
=
build_dataset
(
cfg
.
data
.
train
)
data_loader
=
build_dataloader
(
dataset
,
samples_per_gpu
=
args
.
batch_size
,
workers_per_gpu
=
cfg
.
data
.
get
(
'val_workers_per_gpu'
,
cfg
.
data
.
workers_per_gpu
),
dist
=
False
,
shuffle
=
True
)
if
args
.
online
:
single_gpu_online_evaluation
(
model
,
data_loader
,
metrics
,
logger
,
basic_table_info
,
args
.
batch_size
)
else
:
single_gpu_evaluation
(
model
,
data_loader
,
metrics
,
logger
,
basic_table_info
,
args
.
batch_size
,
args
.
samples_path
)
if
__name__
==
'__main__'
:
main
()
build/lib/mmgen/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
mmcv
from
.version
import
__version__
,
parse_version_info
,
version_info
def
digit_version
(
version_str
):
digit_version
=
[]
for
x
in
version_str
.
split
(
'.'
):
if
x
.
isdigit
():
digit_version
.
append
(
int
(
x
))
elif
x
.
find
(
'rc'
)
!=
-
1
:
patch_version
=
x
.
split
(
'rc'
)
digit_version
.
append
(
int
(
patch_version
[
0
])
-
1
)
digit_version
.
append
(
int
(
patch_version
[
1
]))
return
digit_version
mmcv_minimum_version
=
'1.3.0'
mmcv_maximum_version
=
'1.8.0'
mmcv_version
=
digit_version
(
mmcv
.
__version__
)
assert
(
mmcv_version
>=
digit_version
(
mmcv_minimum_version
)
and
mmcv_version
<=
digit_version
(
mmcv_maximum_version
)),
\
f
'MMCV==
{
mmcv
.
__version__
}
is used but incompatible. '
\
f
'Please install mmcv>=
{
mmcv_minimum_version
}
, <=
{
mmcv_maximum_version
}
.'
__all__
=
[
'__version__'
,
'version_info'
,
'parse_version_info'
]
build/lib/mmgen/apis/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.inference
import
(
init_model
,
sample_conditional_model
,
sample_ddpm_model
,
sample_img2img_model
,
sample_unconditional_model
)
from
.train
import
set_random_seed
,
train_model
__all__
=
[
'set_random_seed'
,
'train_model'
,
'init_model'
,
'sample_img2img_model'
,
'sample_unconditional_model'
,
'sample_conditional_model'
,
'sample_ddpm_model'
]
build/lib/mmgen/apis/inference.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
mmcv
import
torch
from
mmcv.parallel
import
collate
,
scatter
from
mmcv.runner
import
load_checkpoint
from
mmcv.utils
import
is_list_of
from
mmgen.datasets.pipelines
import
Compose
from
mmgen.models
import
BaseTranslationModel
,
build_model
def
init_model
(
config
,
checkpoint
=
None
,
device
=
'cpu'
,
cfg_options
=
None
):
"""Initialize a detector from config file.
Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
cfg_options (dict): Options to override some settings in the used
config.
Returns:
nn.Module: The constructed unconditional model.
"""
if
isinstance
(
config
,
str
):
config
=
mmcv
.
Config
.
fromfile
(
config
)
elif
not
isinstance
(
config
,
mmcv
.
Config
):
raise
TypeError
(
'config must be a filename or Config object, '
f
'but got
{
type
(
config
)
}
'
)
if
cfg_options
is
not
None
:
config
.
merge_from_dict
(
cfg_options
)
model
=
build_model
(
config
.
model
,
train_cfg
=
config
.
train_cfg
,
test_cfg
=
config
.
test_cfg
)
if
checkpoint
is
not
None
:
load_checkpoint
(
model
,
checkpoint
,
map_location
=
'cpu'
)
model
.
_cfg
=
config
# save the config in the model for convenience
model
.
to
(
device
)
model
.
eval
()
return
model
@
torch
.
no_grad
()
def
sample_unconditional_model
(
model
,
num_samples
=
16
,
num_batches
=
4
,
sample_model
=
'ema'
,
**
kwargs
):
"""Sampling from unconditional models.
Args:
model (nn.Module): Unconditional models in MMGeneration.
num_samples (int, optional): The total number of samples.
Defaults to 16.
num_batches (int, optional): The number of batch size for inference.
Defaults to 4.
sample_model (str, optional): Which model you want to use. ['ema',
'orig']. Defaults to 'ema'.
Returns:
Tensor: Generated image tensor.
"""
# set eval mode
model
.
eval
()
# construct sampling list for batches
n_repeat
=
num_samples
//
num_batches
batches_list
=
[
num_batches
]
*
n_repeat
if
num_samples
%
num_batches
>
0
:
batches_list
.
append
(
num_samples
%
num_batches
)
res_list
=
[]
# inference
for
batches
in
batches_list
:
res
=
model
.
sample_from_noise
(
None
,
num_batches
=
batches
,
sample_model
=
sample_model
,
**
kwargs
)
res_list
.
append
(
res
.
cpu
())
results
=
torch
.
cat
(
res_list
,
dim
=
0
)
return
results
@
torch
.
no_grad
()
def
sample_conditional_model
(
model
,
num_samples
=
16
,
num_batches
=
4
,
sample_model
=
'ema'
,
label
=
None
,
**
kwargs
):
"""Sampling from conditional models.
Args:
model (nn.Module): Conditional models in MMGeneration.
num_samples (int, optional): The total number of samples.
Defaults to 16.
num_batches (int, optional): The number of batch size for inference.
Defaults to 4.
sample_model (str, optional): Which model you want to use. ['ema',
'orig']. Defaults to 'ema'.
label (int | torch.Tensor | list[int], optional): Labels used to
generate images. Default to None.,
Returns:
Tensor: Generated image tensor.
"""
# set eval mode
model
.
eval
()
# construct sampling list for batches
n_repeat
=
num_samples
//
num_batches
batches_list
=
[
num_batches
]
*
n_repeat
# check and convert the input labels
if
isinstance
(
label
,
int
):
label
=
torch
.
LongTensor
([
label
]
*
num_samples
)
elif
isinstance
(
label
,
torch
.
Tensor
):
label
=
label
.
type
(
torch
.
int64
)
if
label
.
numel
()
==
1
:
# repeat single tensor
# call view(-1) to avoid nested tensor like [[[1]]]
label
=
label
.
view
(
-
1
).
repeat
(
num_samples
)
else
:
# flatten multi tensors
label
=
label
.
view
(
-
1
)
elif
isinstance
(
label
,
list
):
if
is_list_of
(
label
,
int
):
label
=
torch
.
LongTensor
(
label
)
# `nargs='+'` parse single integer as list
if
label
.
numel
()
==
1
:
# repeat single tensor
label
=
label
.
repeat
(
num_samples
)
else
:
raise
TypeError
(
'Only support `int` for label list elements, '
f
'but receive
{
type
(
label
[
0
])
}
'
)
elif
label
is
None
:
pass
else
:
raise
TypeError
(
'Only support `int`, `torch.Tensor`, `list[int]` or '
f
'None as label, but receive
{
type
(
label
)
}
.'
)
# check the length of the (converted) label
if
label
is
not
None
and
label
.
size
(
0
)
!=
num_samples
:
raise
ValueError
(
'Number of elements in the label list should be ONE '
'or the length of `num_samples`. Requires '
f
'
{
num_samples
}
, but receive
{
label
.
size
(
0
)
}
.'
)
# make label list
label_list
=
[]
for
n
in
range
(
n_repeat
):
if
label
is
None
:
label_list
.
append
(
None
)
else
:
label_list
.
append
(
label
[
n
*
num_batches
:(
n
+
1
)
*
num_batches
])
if
num_samples
%
num_batches
>
0
:
batches_list
.
append
(
num_samples
%
num_batches
)
if
label
is
None
:
label_list
.
append
(
None
)
else
:
label_list
.
append
(
label
[(
n
+
1
)
*
num_batches
:])
res_list
=
[]
# inference
for
batches
,
labels
in
zip
(
batches_list
,
label_list
):
res
=
model
.
sample_from_noise
(
None
,
num_batches
=
batches
,
label
=
labels
,
sample_model
=
sample_model
,
**
kwargs
)
res_list
.
append
(
res
.
cpu
())
results
=
torch
.
cat
(
res_list
,
dim
=
0
)
return
results
def
sample_img2img_model
(
model
,
image_path
,
target_domain
=
None
,
**
kwargs
):
"""Sampling from translation models.
Args:
model (nn.Module): The loaded model.
image_path (str): File path of input image.
style (str): Target style of output image.
Returns:
Tensor: Translated image tensor.
"""
assert
isinstance
(
model
,
BaseTranslationModel
)
# get source domain and target domain
if
target_domain
is
None
:
target_domain
=
model
.
_default_domain
source_domain
=
model
.
get_other_domains
(
target_domain
)[
0
]
cfg
=
model
.
_cfg
device
=
next
(
model
.
parameters
()).
device
# model device
# build the data pipeline
test_pipeline
=
Compose
(
cfg
.
test_pipeline
)
# prepare data
data
=
dict
()
# dirty code to deal with test data pipeline
data
[
'pair_path'
]
=
image_path
data
[
f
'img_
{
source_domain
}
_path'
]
=
image_path
data
[
f
'img_
{
target_domain
}
_path'
]
=
image_path
data
=
test_pipeline
(
data
)
if
device
.
type
==
'cpu'
:
data
=
collate
([
data
],
samples_per_gpu
=
1
)
data
[
'meta'
]
=
[]
else
:
data
=
scatter
(
collate
([
data
],
samples_per_gpu
=
1
),
[
device
])[
0
]
source_image
=
data
[
f
'img_
{
source_domain
}
'
]
# forward the model
with
torch
.
no_grad
():
results
=
model
(
source_image
,
test_mode
=
True
,
target_domain
=
target_domain
,
**
kwargs
)
output
=
results
[
'target'
]
return
output
@
torch
.
no_grad
()
def
sample_ddpm_model
(
model
,
num_samples
=
16
,
num_batches
=
4
,
sample_model
=
'ema'
,
same_noise
=
False
,
**
kwargs
):
"""Sampling from ddpm models.
Args:
model (nn.Module): DDPM models in MMGeneration.
num_samples (int, optional): The total number of samples.
Defaults to 16.
num_batches (int, optional): The number of batch size for inference.
Defaults to 4.
sample_model (str, optional): Which model you want to use. ['ema',
'orig']. Defaults to 'ema'.
noise_batch (torch.Tensor): Noise batch used as denoising starting up.
Defaults to None.
Returns:
list[Tensor | dict]: Generated image tensor.
"""
model
.
eval
()
n_repeat
=
num_samples
//
num_batches
batches_list
=
[
num_batches
]
*
n_repeat
if
num_samples
%
num_batches
>
0
:
batches_list
.
append
(
num_samples
%
num_batches
)
noise_batch
=
torch
.
randn
(
model
.
image_shape
)
if
same_noise
else
None
res_list
=
[]
# inference
for
idx
,
batches
in
enumerate
(
batches_list
):
mmcv
.
print_log
(
f
'Start to sample batch [
{
idx
+
1
}
/ '
f
'
{
len
(
batches_list
)
}
]'
,
'mmgen'
)
noise_batch_
=
noise_batch
[
None
,
...].
expand
(
batches
,
-
1
,
-
1
,
-
1
)
\
if
same_noise
else
None
res
=
model
.
sample_from_noise
(
noise_batch_
,
num_batches
=
batches
,
sample_model
=
sample_model
,
show_pbar
=
True
,
**
kwargs
)
if
isinstance
(
res
,
dict
):
res
=
{
k
:
v
.
cpu
()
for
k
,
v
in
res
.
items
()}
elif
isinstance
(
res
,
torch
.
Tensor
):
res
=
res
.
cpu
()
else
:
raise
ValueError
(
'Sample results should be
\'
dict
\'
or '
f
'
\'
torch.Tensor
\'
, but receive
\'
{
type
(
res
)
}
\'
'
)
res_list
.
append
(
res
)
# gather the res_list
if
isinstance
(
res_list
[
0
],
dict
):
res_dict
=
dict
()
for
t
in
res_list
[
0
].
keys
():
# num_samples x 3 x H x W
res_dict
[
t
]
=
torch
.
cat
([
res
[
t
]
for
res
in
res_list
],
dim
=
0
)
return
res_dict
else
:
return
torch
.
cat
(
res_list
,
dim
=
0
)
build/lib/mmgen/apis/train.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
os
from
copy
import
deepcopy
import
mmcv
import
torch
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.runner
import
HOOKS
,
IterBasedRunner
,
OptimizerHook
,
build_runner
from
mmcv.runner
import
set_random_seed
as
set_random_seed_mmcv
from
mmcv.utils
import
build_from_cfg
from
mmgen.core.ddp_wrapper
import
DistributedDataParallelWrapper
from
mmgen.core.optimizer
import
build_optimizers
from
mmgen.core.runners.apex_amp_utils
import
apex_amp_initialize
from
mmgen.datasets
import
build_dataloader
,
build_dataset
from
mmgen.utils
import
get_root_logger
def
set_random_seed
(
seed
,
deterministic
=
False
,
use_rank_shift
=
True
):
"""Set random seed.
In this function, we just modify the default behavior of the similar
function defined in MMCV.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
rank_shift (bool): Whether to add rank number to the random seed to
have different random seed in different threads. Default: True.
"""
set_random_seed_mmcv
(
seed
,
deterministic
=
deterministic
,
use_rank_shift
=
use_rank_shift
)
def
train_model
(
model
,
dataset
,
cfg
,
distributed
=
False
,
validate
=
False
,
timestamp
=
None
,
meta
=
None
):
logger
=
get_root_logger
(
cfg
.
log_level
)
# prepare data loaders
dataset
=
dataset
if
isinstance
(
dataset
,
(
list
,
tuple
))
else
[
dataset
]
# default loader config
loader_cfg
=
dict
(
samples_per_gpu
=
cfg
.
data
.
samples_per_gpu
,
workers_per_gpu
=
cfg
.
data
.
workers_per_gpu
,
# cfg.gpus will be ignored if distributed
num_gpus
=
len
(
cfg
.
gpu_ids
),
dist
=
distributed
,
persistent_workers
=
cfg
.
data
.
get
(
'persistent_workers'
,
False
),
seed
=
cfg
.
seed
)
# The overall dataloader settings
loader_cfg
.
update
({
k
:
v
for
k
,
v
in
cfg
.
data
.
items
()
if
k
not
in
[
'train'
,
'val'
,
'test'
,
'train_dataloader'
,
'val_dataloader'
,
'test_dataloader'
]
})
# The specific datalaoder settings
train_loader_cfg
=
{
**
loader_cfg
,
**
cfg
.
data
.
get
(
'train_dataloader'
,
{})}
data_loaders
=
[
build_dataloader
(
ds
,
**
train_loader_cfg
)
for
ds
in
dataset
]
# dirty code for use apex amp
# apex.amp request that models should be in cuda device before
# initialization.
if
cfg
.
get
(
'apex_amp'
,
None
):
assert
distributed
,
(
'Currently, apex.amp is only supported with DDP training.'
)
model
=
model
.
cuda
()
# build optimizer
if
cfg
.
optimizer
:
optimizer
=
build_optimizers
(
model
,
cfg
.
optimizer
)
# In GANs, we allow building optimizer in GAN model.
else
:
optimizer
=
None
_use_apex_amp
=
False
if
cfg
.
get
(
'apex_amp'
,
None
):
model
,
optimizer
=
apex_amp_initialize
(
model
,
optimizer
,
**
cfg
.
apex_amp
)
_use_apex_amp
=
True
# put model on gpus
if
distributed
:
find_unused_parameters
=
cfg
.
get
(
'find_unused_parameters'
,
False
)
use_ddp_wrapper
=
cfg
.
get
(
'use_ddp_wrapper'
,
False
)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
if
use_ddp_wrapper
:
mmcv
.
print_log
(
'Use DDP Wrapper.'
,
'mmgen'
)
model
=
DistributedDataParallelWrapper
(
model
.
cuda
(),
device_ids
=
[
torch
.
cuda
.
current_device
()],
broadcast_buffers
=
False
,
find_unused_parameters
=
find_unused_parameters
)
else
:
model
=
MMDistributedDataParallel
(
model
.
cuda
(),
device_ids
=
[
torch
.
cuda
.
current_device
()],
broadcast_buffers
=
False
,
find_unused_parameters
=
find_unused_parameters
)
else
:
model
=
MMDataParallel
(
model
,
device_ids
=
cfg
.
gpu_ids
)
# allow users to define the runner
if
cfg
.
get
(
'runner'
,
None
):
runner
=
build_runner
(
cfg
.
runner
,
dict
(
model
=
model
,
optimizer
=
optimizer
,
work_dir
=
cfg
.
work_dir
,
logger
=
logger
,
use_apex_amp
=
_use_apex_amp
,
meta
=
meta
))
else
:
runner
=
IterBasedRunner
(
model
,
optimizer
=
optimizer
,
work_dir
=
cfg
.
work_dir
,
logger
=
logger
,
meta
=
meta
)
# set if use dynamic ddp in training
# is_dynamic_ddp=cfg.get('is_dynamic_ddp', False))
# an ugly walkaround to make the .log and .log.json filenames the same
runner
.
timestamp
=
timestamp
# fp16 setting
fp16_cfg
=
cfg
.
get
(
'fp16'
,
None
)
# In GANs, we can directly optimize parameter in `train_step` function.
if
cfg
.
get
(
'optimizer_cfg'
,
None
)
is
None
:
optimizer_config
=
None
elif
fp16_cfg
is
not
None
:
raise
NotImplementedError
(
'Fp16 has not been supported.'
)
# optimizer_config = Fp16OptimizerHook(
# **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
# default to use OptimizerHook
elif
distributed
and
'type'
not
in
cfg
.
optimizer_config
:
optimizer_config
=
OptimizerHook
(
**
cfg
.
optimizer_config
)
else
:
optimizer_config
=
cfg
.
optimizer_config
# update `out_dir` in ckpt hook
if
cfg
.
checkpoint_config
is
not
None
:
cfg
.
checkpoint_config
[
'out_dir'
]
=
os
.
path
.
join
(
cfg
.
work_dir
,
cfg
.
checkpoint_config
.
get
(
'out_dir'
,
'ckpt'
))
# register hooks
runner
.
register_training_hooks
(
cfg
.
lr_config
,
optimizer_config
,
cfg
.
checkpoint_config
,
cfg
.
log_config
,
cfg
.
get
(
'momentum_config'
,
None
))
# # DistSamplerSeedHook should be used with EpochBasedRunner
# if distributed:
# runner.register_hook(DistSamplerSeedHook())
# In general, we do NOT adopt standard evaluation hook in GAN training.
# Thus, if you want a eval hook, you need further define the key of
# 'evaluation' in the config.
# register eval hooks
if
validate
and
cfg
.
get
(
'evaluation'
,
None
)
is
not
None
:
val_dataset
=
build_dataset
(
cfg
.
data
.
val
,
dict
(
test_mode
=
True
))
# Support batch_size > 1 in validation
val_loader_cfg
=
{
**
loader_cfg
,
'shuffle'
:
False
,
**
cfg
.
data
.
get
(
'val_data_loader'
,
{})
}
val_dataloader
=
build_dataloader
(
val_dataset
,
**
val_loader_cfg
)
eval_cfg
=
deepcopy
(
cfg
.
get
(
'evaluation'
))
priority
=
eval_cfg
.
pop
(
'priority'
,
'LOW'
)
eval_cfg
.
update
(
dict
(
dist
=
distributed
,
dataloader
=
val_dataloader
))
eval_hook
=
build_from_cfg
(
eval_cfg
,
HOOKS
)
runner
.
register_hook
(
eval_hook
,
priority
=
priority
)
# user-defined hooks
if
cfg
.
get
(
'custom_hooks'
,
None
):
custom_hooks
=
cfg
.
custom_hooks
assert
isinstance
(
custom_hooks
,
list
),
\
f
'custom_hooks expect list type, but got
{
type
(
custom_hooks
)
}
'
for
hook_cfg
in
cfg
.
custom_hooks
:
assert
isinstance
(
hook_cfg
,
dict
),
\
'Each item in custom_hooks expects dict type, but got '
\
f
'
{
type
(
hook_cfg
)
}
'
hook_cfg
=
hook_cfg
.
copy
()
priority
=
hook_cfg
.
pop
(
'priority'
,
'NORMAL'
)
hook
=
build_from_cfg
(
hook_cfg
,
HOOKS
)
runner
.
register_hook
(
hook
,
priority
=
priority
)
if
cfg
.
resume_from
:
runner
.
resume
(
cfg
.
resume_from
)
elif
cfg
.
load_from
:
runner
.
load_checkpoint
(
cfg
.
load_from
)
runner
.
run
(
data_loaders
,
cfg
.
workflow
,
cfg
.
total_iters
)
build/lib/mmgen/core/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.evaluation
import
*
# noqa: F401, F403
from
.hooks
import
*
# noqa: F401, F403
from
.optimizer
import
*
# noqa: F401, F403
from
.registry
import
*
# noqa: F401, F403
from
.runners
import
*
# noqa: F401, F403
from
.scheduler
import
*
# noqa: F401, F403
build/lib/mmgen/core/ddp_wrapper.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.nn
as
nn
from
mmcv.parallel
import
MODULE_WRAPPERS
,
MMDistributedDataParallel
from
mmcv.parallel.scatter_gather
import
scatter_kwargs
from
torch.cuda._utils
import
_get_device_index
@
MODULE_WRAPPERS
.
register_module
(
'mmgen.DDPWrapper'
)
class
DistributedDataParallelWrapper
(
nn
.
Module
):
"""A DistributedDataParallel wrapper for models in MMGeneration.
In MMedting, there is a need to wrap different modules in the models
with separate DistributedDataParallel. Otherwise, it will cause
errors for GAN training.
More specific, the GAN model, usually has two sub-modules:
generator and discriminator. If we wrap both of them in one
standard DistributedDataParallel, it will cause errors during training,
because when we update the parameters of the generator (or discriminator),
the parameters of the discriminator (or generator) is not updated, which is
not allowed for DistributedDataParallel.
So we design this wrapper to separately wrap DistributedDataParallel
for generator and discriminator.
In this wrapper, we perform two operations:
1. Wrap the modules in the models with separate MMDistributedDataParallel.
Note that only modules with parameters will be wrapped.
2. Do scatter operation for 'forward', 'train_step' and 'val_step'.
Note that the arguments of this wrapper is the same as those in
`torch.nn.parallel.distributed.DistributedDataParallel`.
Args:
module (nn.Module): Module that needs to be wrapped.
device_ids (list[int | `torch.device`]): Same as that in
`torch.nn.parallel.distributed.DistributedDataParallel`.
dim (int, optional): Same as that in the official scatter function in
pytorch. Defaults to 0.
broadcast_buffers (bool): Same as that in
`torch.nn.parallel.distributed.DistributedDataParallel`.
Defaults to False.
find_unused_parameters (bool, optional): Same as that in
`torch.nn.parallel.distributed.DistributedDataParallel`.
Traverse the autograd graph of all tensors contained in returned
value of the wrapped module’s forward function. Defaults to False.
kwargs (dict): Other arguments used in
`torch.nn.parallel.distributed.DistributedDataParallel`.
"""
def
__init__
(
self
,
module
,
device_ids
,
dim
=
0
,
broadcast_buffers
=
False
,
find_unused_parameters
=
False
,
**
kwargs
):
super
().
__init__
()
assert
len
(
device_ids
)
==
1
,
(
'Currently, DistributedDataParallelWrapper only supports one'
'single CUDA device for each process.'
f
'The length of device_ids must be 1, but got
{
len
(
device_ids
)
}
.'
)
self
.
module
=
module
self
.
dim
=
dim
self
.
to_ddp
(
device_ids
=
device_ids
,
dim
=
dim
,
broadcast_buffers
=
broadcast_buffers
,
find_unused_parameters
=
find_unused_parameters
,
**
kwargs
)
self
.
output_device
=
_get_device_index
(
device_ids
[
0
],
True
)
def
to_ddp
(
self
,
device_ids
,
dim
,
broadcast_buffers
,
find_unused_parameters
,
**
kwargs
):
"""Wrap models with separate MMDistributedDataParallel.
It only wraps the modules with parameters.
"""
for
name
,
module
in
self
.
module
.
_modules
.
items
():
if
next
(
module
.
parameters
(),
None
)
is
None
:
module
=
module
.
cuda
()
elif
all
(
not
p
.
requires_grad
for
p
in
module
.
parameters
()):
module
=
module
.
cuda
()
else
:
module
=
MMDistributedDataParallel
(
module
.
cuda
(),
device_ids
=
device_ids
,
dim
=
dim
,
broadcast_buffers
=
broadcast_buffers
,
find_unused_parameters
=
find_unused_parameters
,
**
kwargs
)
self
.
module
.
_modules
[
name
]
=
module
def
scatter
(
self
,
inputs
,
kwargs
,
device_ids
):
"""Scatter function.
Args:
inputs (Tensor): Input Tensor.
kwargs (dict): Args for
``mmcv.parallel.scatter_gather.scatter_kwargs``.
device_ids (int): Device id.
"""
return
scatter_kwargs
(
inputs
,
kwargs
,
device_ids
,
dim
=
self
.
dim
)
def
forward
(
self
,
*
inputs
,
**
kwargs
):
"""Forward function.
Args:
inputs (tuple): Input data.
kwargs (dict): Args for
``mmcv.parallel.scatter_gather.scatter_kwargs``.
"""
inputs
,
kwargs
=
self
.
scatter
(
inputs
,
kwargs
,
[
torch
.
cuda
.
current_device
()])
return
self
.
module
(
*
inputs
[
0
],
**
kwargs
[
0
])
def
train_step
(
self
,
*
inputs
,
**
kwargs
):
"""Train step function.
Args:
inputs (Tensor): Input Tensor.
kwargs (dict): Args for
``mmcv.parallel.scatter_gather.scatter_kwargs``.
"""
inputs
,
kwargs
=
self
.
scatter
(
inputs
,
kwargs
,
[
torch
.
cuda
.
current_device
()])
output
=
self
.
module
.
train_step
(
*
inputs
[
0
],
**
kwargs
[
0
])
return
output
def
val_step
(
self
,
*
inputs
,
**
kwargs
):
"""Validation step function.
Args:
inputs (tuple): Input data.
kwargs (dict): Args for ``scatter_kwargs``.
"""
inputs
,
kwargs
=
self
.
scatter
(
inputs
,
kwargs
,
[
torch
.
cuda
.
current_device
()])
output
=
self
.
module
.
val_step
(
*
inputs
[
0
],
**
kwargs
[
0
])
return
output
build/lib/mmgen/core/evaluation/__init__.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
from
.eval_hooks
import
GenerativeEvalHook
,
TranslationEvalHook
from
.evaluation
import
(
make_metrics_table
,
make_vanilla_dataloader
,
offline_evaluation
,
online_evaluation
)
from
.metric_utils
import
slerp
from
.metrics
import
(
IS
,
MS_SSIM
,
PR
,
SWD
,
GaussianKLD
,
ms_ssim
,
sliced_wasserstein
)
__all__
=
[
'MS_SSIM'
,
'SWD'
,
'ms_ssim'
,
'sliced_wasserstein'
,
'offline_evaluation'
,
'online_evaluation'
,
'PR'
,
'IS'
,
'slerp'
,
'GenerativeEvalHook'
,
'make_metrics_table'
,
'make_vanilla_dataloader'
,
'GaussianKLD'
,
'TranslationEvalHook'
]
build/lib/mmgen/core/evaluation/eval_hooks.py
0 → 100644
View file @
1401de15
# Copyright (c) OpenMMLab. All rights reserved.
import
math
import
os
import
os.path
as
osp
import
sys
import
warnings
from
bisect
import
bisect_right
import
mmcv
import
torch
from
mmcv.runner
import
HOOKS
,
Hook
,
get_dist_info
from
..registry
import
build_metric
@
HOOKS
.
register_module
()
class
GenerativeEvalHook
(
Hook
):
"""Evaluation Hook for Generative Models.
This evaluation hook can be used to evaluate unconditional and conditional
models. Note that only ``FID`` and ``IS`` metric are supported for the
distributed training now. In the future, we will support more metrics for
the evaluation during the training procedure.
In our config system, you only need to add `evaluation` with the detailed
configureations. Below is several usage cases for different situations.
What you need to do is to add these lines at the end of your config file.
Then, you can use this evaluation hook in the training procedure.
To be noted that, this evaluation hook support evaluation with dynamic
intervals for FID or other metrics may fluctuate frequently at the end of
the training process.
# TODO: fix the online doc
#. Only use FID for evaluation
.. code-block:: python
:linenos:
evaluation = dict(
type='GenerativeEvalHook',
interval=10000,
metrics=dict(
type='FID',
num_images=50000,
inception_pkl='work_dirs/inception_pkl/ffhq-256-50k-rgb.pkl',
bgr2rgb=True),
sample_kwargs=dict(sample_model='ema'))
#. Use FID and IS simultaneously and save the best checkpoints respectively
.. code-block:: python
:linenos:
evaluation = dict(
type='GenerativeEvalHook',
interval=10000,
metrics=[dict(
type='FID',
num_images=50000,
inception_pkl='work_dirs/inception_pkl/ffhq-256-50k-rgb.pkl',
bgr2rgb=True),
dict(type='IS',
num_images=50000)],
best_metric=['fid', 'is'],
sample_kwargs=dict(sample_model='ema'))
#. Use dynamic evaluation intervals
.. code-block:: python
:linenos:
# interval = 10000 if iter < 50000,
# interval = 4000, if 50000 <= iter < 750000,
# interval = 2000, if iter >= 750000
evaluation = dict(
type='GenerativeEvalHook',
interval=dict(milestones=[500000, 750000],
interval=[10000, 4000, 2000])
metrics=[dict(
type='FID',
num_images=50000,
inception_pkl='work_dirs/inception_pkl/ffhq-256-50k-rgb.pkl',
bgr2rgb=True),
dict(type='IS',
num_images=50000)],
best_metric=['fid', 'is'],
sample_kwargs=dict(sample_model='ema'))
Args:
dataloader (DataLoader): A PyTorch dataloader.
interval (int | dict): Evaluation interval. If int is passed,
``eval_hook`` would run under given interval. If a dict is passed,
The key and value would be interpret as 'milestones' and 'interval'
of the evaluation. Default: 1.
dist (bool, optional): Whether to use distributed evaluation.
Defaults to True.
metrics (dict | list[dict], optional): Configs for metrics that will be
used in evaluation hook. Defaults to None.
sample_kwargs (dict | None, optional): Additional keyword arguments for
sampling images. Defaults to None.
save_best_ckpt (bool, optional): Whether to save the best checkpoint
according to ``best_metric``. Defaults to ``True``.
best_metric (str | list, optional): Which metric to be used in saving
the best checkpoint. Multiple metrics have been supported by
inputing a list of metric names, e.g., ``['fid', 'is']``.
Defaults to ``'fid'``.
"""
rule_map
=
{
'greater'
:
lambda
x
,
y
:
x
>
y
,
'less'
:
lambda
x
,
y
:
x
<
y
}
init_value_map
=
{
'greater'
:
-
math
.
inf
,
'less'
:
math
.
inf
}
greater_keys
=
[
'acc'
,
'top'
,
'AR@'
,
'auc'
,
'precision'
,
'mAP'
,
'is'
]
less_keys
=
[
'loss'
,
'fid'
]
_supported_best_metrics
=
[
'fid'
,
'is'
]
def
__init__
(
self
,
dataloader
,
interval
=
1
,
dist
=
True
,
metrics
=
None
,
sample_kwargs
=
None
,
save_best_ckpt
=
True
,
best_metric
=
'fid'
):
assert
metrics
is
not
None
self
.
dataloader
=
dataloader
self
.
dist
=
dist
self
.
sample_kwargs
=
sample_kwargs
if
sample_kwargs
else
dict
()
self
.
save_best_ckpt
=
save_best_ckpt
self
.
best_metric
=
best_metric
if
isinstance
(
interval
,
int
):
self
.
interval
=
interval
elif
isinstance
(
interval
,
dict
):
if
'milestones'
not
in
interval
or
'interval'
not
in
interval
:
raise
KeyError
(
'`milestones` and `interval` must exist in interval dict '
'if you want to use the dynamic interval evaluation '
f
'strategy. But receive [
{
[
k
for
k
in
interval
.
keys
()]
}
] '
'in the interval dict.'
)
self
.
milestones
=
interval
[
'milestones'
]
self
.
interval
=
interval
[
'interval'
]
# check if length of interval match with the milestones
if
len
(
self
.
interval
)
!=
len
(
self
.
milestones
)
+
1
:
raise
ValueError
(
f
'Length of `interval`(=
{
len
(
self
.
interval
)
}
) cannot '
f
'match length of `milestones`(=
{
len
(
self
.
milestones
)
}
).'
)
# check if milestones is in order
for
idx
in
range
(
len
(
self
.
milestones
)
-
1
):
former
,
latter
=
self
.
milestones
[
idx
],
self
.
milestones
[
idx
+
1
]
if
former
>=
latter
:
raise
ValueError
(
'Elements in `milestones` should in ascending order.'
)
else
:
raise
TypeError
(
'`interval` only support `int` or `dict`,'
f
'recieve
{
type
(
self
.
interval
)
}
instead.'
)
if
isinstance
(
best_metric
,
str
):
self
.
best_metric
=
[
self
.
best_metric
]
if
self
.
save_best_ckpt
:
not_supported
=
set
(
self
.
best_metric
)
-
set
(
self
.
_supported_best_metrics
)
assert
len
(
not_supported
)
==
0
,
(
f
'
{
not_supported
}
is not supported for saving best ckpt'
)
self
.
metrics
=
build_metric
(
metrics
)
if
isinstance
(
metrics
,
dict
):
self
.
metrics
=
[
self
.
metrics
]
for
metric
in
self
.
metrics
:
metric
.
prepare
()
# add support for saving best ckpt
if
self
.
save_best_ckpt
:
self
.
rule
=
{}
self
.
compare_func
=
{}
self
.
_curr_best_score
=
{}
self
.
_curr_best_ckpt_path
=
{}
for
name
in
self
.
best_metric
:
if
name
in
self
.
greater_keys
:
self
.
rule
[
name
]
=
'greater'
else
:
self
.
rule
[
name
]
=
'less'
self
.
compare_func
[
name
]
=
self
.
rule_map
[
self
.
rule
[
name
]]
self
.
_curr_best_score
[
name
]
=
self
.
init_value_map
[
self
.
rule
[
name
]]
self
.
_curr_best_ckpt_path
[
name
]
=
None
def
get_current_interval
(
self
,
runner
):
"""Get current evaluation interval.
Args:
runner (``mmcv.runner.BaseRunner``): The runner.
"""
if
isinstance
(
self
.
interval
,
int
):
return
self
.
interval
else
:
curr_iter
=
runner
.
iter
+
1
index
=
bisect_right
(
self
.
milestones
,
curr_iter
)
return
self
.
interval
[
index
]
def
before_run
(
self
,
runner
):
"""The behavior before running.
Args:
runner (``mmcv.runner.BaseRunner``): The runner.
"""
if
self
.
save_best_ckpt
is
not
None
:
if
runner
.
meta
is
None
:
warnings
.
warn
(
'runner.meta is None. Creating an empty one.'
)
runner
.
meta
=
dict
()
runner
.
meta
.
setdefault
(
'hook_msgs'
,
dict
())
def
after_train_iter
(
self
,
runner
):
"""The behavior after each train iteration.
Args:
runner (``mmcv.runner.BaseRunner``): The runner.
"""
interval
=
self
.
get_current_interval
(
runner
)
if
not
self
.
every_n_iters
(
runner
,
interval
):
return
runner
.
model
.
eval
()
batch_size
=
self
.
dataloader
.
batch_size
rank
,
ws
=
get_dist_info
()
total_batch_size
=
batch_size
*
ws
# sample real images
max_real_num_images
=
max
(
metric
.
num_images
-
metric
.
num_real_feeded
for
metric
in
self
.
metrics
)
# define mmcv progress bar
if
rank
==
0
and
max_real_num_images
>
0
:
mmcv
.
print_log
(
f
'Sample
{
max_real_num_images
}
real images for evaluation'
,
'mmgen'
)
pbar
=
mmcv
.
ProgressBar
(
max_real_num_images
)
if
max_real_num_images
>
0
:
for
data
in
self
.
dataloader
:
if
'real_img'
in
data
:
reals
=
data
[
'real_img'
]
# key for conditional GAN
elif
'img'
in
data
:
reals
=
data
[
'img'
]
else
:
raise
KeyError
(
'Cannot found key for images in data_dict. '
'Only support `real_img` for unconditional '
'datasets and `img` for conditional '
'datasets.'
)
if
reals
.
shape
[
1
]
not
in
[
1
,
3
]:
raise
RuntimeError
(
'real images should have one or three '
'channels in the first, '
'not % d'
%
reals
.
shape
[
1
])
if
reals
.
shape
[
1
]
==
1
:
reals
=
reals
.
repeat
(
1
,
3
,
1
,
1
)
num_feed
=
0
for
metric
in
self
.
metrics
:
num_feed_
=
metric
.
feed
(
reals
,
'reals'
)
num_feed
=
max
(
num_feed_
,
num_feed
)
if
num_feed
<=
0
:
break
if
rank
==
0
:
pbar
.
update
(
num_feed
)
max_num_images
=
max
(
metric
.
num_images
for
metric
in
self
.
metrics
)
if
rank
==
0
:
mmcv
.
print_log
(
f
'Sample
{
max_num_images
}
fake images for evaluation'
,
'mmgen'
)
# define mmcv progress bar
if
rank
==
0
:
pbar
=
mmcv
.
ProgressBar
(
max_num_images
)
# sampling fake images and directly send them to metrics
for
_
in
range
(
0
,
max_num_images
,
total_batch_size
):
with
torch
.
no_grad
():
fakes
=
runner
.
model
(
None
,
num_batches
=
batch_size
,
return_loss
=
False
,
**
self
.
sample_kwargs
)
for
metric
in
self
.
metrics
:
# feed in fake images
metric
.
feed
(
fakes
,
'fakes'
)
if
rank
==
0
:
pbar
.
update
(
total_batch_size
)
runner
.
log_buffer
.
clear
()
# a dirty walkround to change the line at the end of pbar
if
rank
==
0
:
sys
.
stdout
.
write
(
'
\n
'
)
for
metric
in
self
.
metrics
:
with
torch
.
no_grad
():
metric
.
summary
()
for
name
,
val
in
metric
.
_result_dict
.
items
():
runner
.
log_buffer
.
output
[
name
]
=
val
# record best metric and save the best ckpt
if
self
.
save_best_ckpt
and
name
in
self
.
best_metric
:
self
.
_save_best_ckpt
(
runner
,
val
,
name
)
runner
.
log_buffer
.
ready
=
True
runner
.
model
.
train
()
# clear all current states for next evaluation
for
metric
in
self
.
metrics
:
metric
.
clear
()
def
_save_best_ckpt
(
self
,
runner
,
new_score
,
metric_name
):
"""Save checkpoint with best metric score.
Args:
runner (``mmcv.runner.BaseRunner``): The runner.
new_score (float): New metric score.
metric_name (str): Name of metric.
"""
curr_iter
=
f
'iter_
{
runner
.
iter
+
1
}
'
if
self
.
compare_func
[
metric_name
](
new_score
,
self
.
_curr_best_score
[
metric_name
]):
best_ckpt_name
=
f
'best_
{
metric_name
}
_
{
curr_iter
}
.pth'
runner
.
meta
[
'hook_msgs'
][
f
'best_score_
{
metric_name
}
'
]
=
new_score
if
self
.
_curr_best_ckpt_path
[
metric_name
]
and
osp
.
isfile
(
self
.
_curr_best_ckpt_path
[
metric_name
]):
os
.
remove
(
self
.
_curr_best_ckpt_path
[
metric_name
])
self
.
_curr_best_ckpt_path
[
metric_name
]
=
osp
.
join
(
runner
.
work_dir
,
best_ckpt_name
)
runner
.
save_checkpoint
(
runner
.
work_dir
,
best_ckpt_name
,
create_symlink
=
False
)
runner
.
meta
[
'hook_msgs'
][
f
'best_ckpt_
{
metric_name
}
'
]
=
self
.
_curr_best_ckpt_path
[
metric_name
]
self
.
_curr_best_score
[
metric_name
]
=
new_score
runner
.
logger
.
info
(
f
'Now best checkpoint is saved as
{
best_ckpt_name
}
.'
)
runner
.
logger
.
info
(
f
'Best
{
metric_name
}
is
{
new_score
:
0.4
f
}
'
f
'at
{
curr_iter
}
.'
)
@
HOOKS
.
register_module
()
class
TranslationEvalHook
(
GenerativeEvalHook
):
"""Evaluation Hook for Translation Models.
This evaluation hook can be used to evaluate translation models. Note
that only ``FID`` and ``IS`` metric are supported for the distributed
training now. In the future, we will support more metrics for the
evaluation during the training procedure.
In our config system, you only need to add `evaluation` with the detailed
configureations. Below is several usage cases for different situations.
What you need to do is to add these lines at the end of your config file.
Then, you can use this evaluation hook in the training procedure.
To be noted that, this evaluation hook support evaluation with dynamic
intervals for FID or other metrics may fluctuate frequently at the end of
the training process.
# TODO: fix the online doc
#. Only use FID for evaluation
.. code-blcok:: python
:linenos
evaluation = dict(
type='TranslationEvalHook',
target_domain='photo',
interval=10000,
metrics=dict(type='FID', num_images=106, bgr2rgb=True))
#. Use FID and IS simultaneously and save the best checkpoints respectively
.. code-block:: python
:linenos
evaluation = dict(
type='TranslationEvalHook',
target_domain='photo',
interval=10000,
metrics=[
dict(type='FID', num_images=106, bgr2rgb=True),
dict(
type='IS',
num_images=106,
inception_args=dict(type='pytorch'))
],
best_metric=['fid', 'is'])
#. Use dynamic evaluation intervals
.. code-block:: python
:linenos
# interval = 10000 if iter < 100000,
# interval = 4000, if 100000 <= iter < 200000,
# interval = 2000, if iter >= 200000
evaluation = dict(
type='TranslationEvalHook',
interval=dict(milestones=[100000, 200000],
interval=[10000, 4000, 2000]),
target_domain='zebra',
metrics=[
dict(type='FID', num_images=140, bgr2rgb=True),
dict(type='IS', num_images=140)
],
best_metric=['fid', 'is'])
Args:
target_domain (str): Target domain of output image.
"""
def
__init__
(
self
,
*
args
,
target_domain
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
target_domain
=
target_domain
def
after_train_iter
(
self
,
runner
):
"""The behavior after each train iteration.
Args:
runner (``mmcv.runner.BaseRunner``): The runner.
"""
interval
=
self
.
get_current_interval
(
runner
)
if
not
self
.
every_n_iters
(
runner
,
interval
):
return
runner
.
model
.
eval
()
source_domain
=
runner
.
model
.
module
.
get_other_domains
(
self
.
target_domain
)[
0
]
# feed real images
max_num_images
=
max
(
metric
.
num_images
for
metric
in
self
.
metrics
)
for
metric
in
self
.
metrics
:
if
metric
.
num_real_feeded
>=
metric
.
num_real_need
:
continue
mmcv
.
print_log
(
f
'Feed reals to
{
metric
.
name
}
metric.'
,
'mmgen'
)
# feed in real images
for
data
in
self
.
dataloader
:
# key for translation model
if
f
'img_
{
self
.
target_domain
}
'
in
data
:
reals
=
data
[
f
'img_
{
self
.
target_domain
}
'
]
# key for conditional GAN
else
:
raise
KeyError
(
'Cannot found key for images in data_dict. '
)
num_feed
=
metric
.
feed
(
reals
,
'reals'
)
if
num_feed
<=
0
:
break
mmcv
.
print_log
(
f
'Sample
{
max_num_images
}
fake images for evaluation'
,
'mmgen'
)
rank
,
ws
=
get_dist_info
()
# define mmcv progress bar
if
rank
==
0
:
pbar
=
mmcv
.
ProgressBar
(
max_num_images
)
# feed in fake images
for
data
in
self
.
dataloader
:
# key for translation model
if
f
'img_
{
source_domain
}
'
in
data
:
with
torch
.
no_grad
():
output_dict
=
runner
.
model
(
data
[
f
'img_
{
source_domain
}
'
],
test_mode
=
True
,
target_domain
=
self
.
target_domain
,
**
self
.
sample_kwargs
)
fakes
=
output_dict
[
'target'
]
# key Error
else
:
raise
KeyError
(
'Cannot found key for images in data_dict. '
)
# sampling fake images and directly send them to metrics
# pbar update number for one proc
num_update
=
0
for
metric
in
self
.
metrics
:
if
metric
.
num_fake_feeded
>=
metric
.
num_fake_need
:
continue
num_feed
=
metric
.
feed
(
fakes
,
'fakes'
)
num_update
=
max
(
num_update
,
num_feed
)
if
num_feed
<=
0
:
break
if
rank
==
0
:
if
num_update
>
0
:
pbar
.
update
(
num_update
*
ws
)
runner
.
log_buffer
.
clear
()
# a dirty walkround to change the line at the end of pbar
if
rank
==
0
:
sys
.
stdout
.
write
(
'
\n
'
)
for
metric
in
self
.
metrics
:
with
torch
.
no_grad
():
metric
.
summary
()
for
name
,
val
in
metric
.
_result_dict
.
items
():
runner
.
log_buffer
.
output
[
name
]
=
val
# record best metric and save the best ckpt
if
self
.
save_best_ckpt
and
name
in
self
.
best_metric
:
self
.
_save_best_ckpt
(
runner
,
val
,
name
)
runner
.
log_buffer
.
ready
=
True
runner
.
model
.
train
()
# clear all current states for next evaluation
for
metric
in
self
.
metrics
:
metric
.
clear
()
Prev
1
…
6
7
8
9
10
11
12
13
14
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment