Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmgeneration
Commits
b7536f78
Commit
b7536f78
authored
Jun 16, 2025
by
limm
Browse files
add a to another part of mmgeneration code
parent
57e0e891
Pipeline
#2777
canceled with stages
Changes
185
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
860 additions
and
0 deletions
+860
-0
tools/slurm_train.sh
tools/slurm_train.sh
+24
-0
tools/train.py
tools/train.py
+228
-0
tools/utils/inception_stat.py
tools/utils/inception_stat.py
+168
-0
tools/utils/singan_inference.py
tools/utils/singan_inference.py
+111
-0
tools/utils/translation_eval.py
tools/utils/translation_eval.py
+329
-0
No files found.
tools/slurm_train.sh
0 → 100644
View file @
b7536f78
#!/usr/bin/env bash
set
-x
PARTITION
=
$1
JOB_NAME
=
$2
CONFIG
=
$3
WORK_DIR
=
$4
GPUS
=
${
GPUS
:-
8
}
GPUS_PER_NODE
=
${
GPUS_PER_NODE
:-
8
}
CPUS_PER_TASK
=
${
CPUS_PER_TASK
:-
5
}
PY_ARGS
=
${
@
:5
}
SRUN_ARGS
=
${
SRUN_ARGS
:-
""
}
PYTHONPATH
=
"
$(
dirname
$0
)
/.."
:
$PYTHONPATH
\
srun
-p
${
PARTITION
}
\
--job-name
=
${
JOB_NAME
}
\
--gres
=
gpu:
${
GPUS_PER_NODE
}
\
--ntasks
=
${
GPUS
}
\
--ntasks-per-node
=
${
GPUS_PER_NODE
}
\
--cpus-per-task
=
${
CPUS_PER_TASK
}
\
--kill-on-bad-exit
=
1
\
${
SRUN_ARGS
}
\
python
-u
tools/train.py
${
CONFIG
}
--work-dir
=
${
WORK_DIR
}
--launcher
=
"slurm"
${
PY_ARGS
}
tools/train.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
copy
import
multiprocessing
as
mp
import
os
import
os.path
as
osp
import
platform
import
time
import
warnings
import
cv2
import
mmcv
import
torch
from
mmcv
import
Config
,
DictAction
from
mmcv.runner
import
get_dist_info
,
init_dist
from
mmcv.utils
import
get_git_hash
from
mmgen
import
__version__
from
mmgen.apis
import
set_random_seed
,
train_model
from
mmgen.datasets
import
build_dataset
from
mmgen.models
import
build_model
from
mmgen.utils
import
collect_env
,
get_root_logger
cv2
.
setNumThreads
(
0
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a GAN model'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
parser
.
add_argument
(
'--work-dir'
,
help
=
'the dir to save logs and models'
)
parser
.
add_argument
(
'--resume-from'
,
help
=
'the checkpoint file to resume from'
)
parser
.
add_argument
(
'--no-validate'
,
action
=
'store_true'
,
help
=
'whether not to evaluate the checkpoint during training'
)
group_gpus
=
parser
.
add_mutually_exclusive_group
()
group_gpus
.
add_argument
(
'--gpus'
,
type
=
int
,
help
=
'(Deprecated, please use --gpu-id) number of gpus to use '
'(only applicable to non-distributed training)'
)
group_gpus
.
add_argument
(
'--gpu-ids'
,
type
=
int
,
nargs
=
'+'
,
help
=
'(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed training)'
)
group_gpus
.
add_argument
(
'--gpu-id'
,
type
=
int
,
default
=
0
,
help
=
'id of gpu to use '
'(only applicable to non-distributed training)'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
2021
,
help
=
'random seed'
)
parser
.
add_argument
(
'--diff_seed'
,
action
=
'store_true'
,
help
=
'Whether or not set different seeds for different ranks'
)
parser
.
add_argument
(
'--deterministic'
,
action
=
'store_true'
,
help
=
'whether to set deterministic options for CUDNN backend.'
)
parser
.
add_argument
(
'--cfg-options'
,
nargs
=
'+'
,
action
=
DictAction
,
help
=
'override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file.'
)
parser
.
add_argument
(
'--launcher'
,
choices
=
[
'none'
,
'pytorch'
,
'slurm'
,
'mpi'
],
default
=
'none'
,
help
=
'job launcher'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
0
)
args
=
parser
.
parse_args
()
if
'LOCAL_RANK'
not
in
os
.
environ
:
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
args
.
local_rank
)
return
args
def
setup_multi_processes
(
cfg
):
# set multi-process start method as `fork` to speed up the training
if
platform
.
system
()
!=
'Windows'
:
mp_start_method
=
cfg
.
get
(
'mp_start_method'
,
'fork'
)
mp
.
set_start_method
(
mp_start_method
)
# disable opencv multithreading to avoid system being overloaded
opencv_num_threads
=
cfg
.
get
(
'opencv_num_threads'
,
0
)
cv2
.
setNumThreads
(
opencv_num_threads
)
# setup OMP threads
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
if
(
'OMP_NUM_THREADS'
not
in
os
.
environ
and
cfg
.
data
.
workers_per_gpu
>
1
):
omp_num_threads
=
1
warnings
.
warn
(
f
'Setting OMP_NUM_THREADS environment variable for each process '
f
'to be
{
omp_num_threads
}
in default, to avoid your system being '
f
'overloaded, please further tune the variable for optimal '
f
'performance in your application as needed.'
)
os
.
environ
[
'OMP_NUM_THREADS'
]
=
str
(
omp_num_threads
)
# setup MKL threads
if
'MKL_NUM_THREADS'
not
in
os
.
environ
and
cfg
.
data
.
workers_per_gpu
>
1
:
mkl_num_threads
=
1
warnings
.
warn
(
f
'Setting MKL_NUM_THREADS environment variable for each process '
f
'to be
{
mkl_num_threads
}
in default, to avoid your system being '
f
'overloaded, please further tune the variable for optimal '
f
'performance in your application as needed.'
)
os
.
environ
[
'MKL_NUM_THREADS'
]
=
str
(
mkl_num_threads
)
def
main
():
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
if
args
.
cfg_options
is
not
None
:
cfg
.
merge_from_dict
(
args
.
cfg_options
)
setup_multi_processes
(
cfg
)
# import modules from string list.
if
cfg
.
get
(
'custom_imports'
,
None
):
from
mmcv.utils
import
import_modules_from_strings
import_modules_from_strings
(
**
cfg
[
'custom_imports'
])
# set cudnn_benchmark
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
torch
.
backends
.
cudnn
.
benchmark
=
True
# work_dir is determined in this priority: CLI > segment in file > filename
if
args
.
work_dir
is
not
None
:
# update configs according to CLI args if args.work_dir is not None
cfg
.
work_dir
=
args
.
work_dir
elif
cfg
.
get
(
'work_dir'
,
None
)
is
None
:
# use config filename as default work_dir if cfg.work_dir is None
cfg
.
work_dir
=
osp
.
join
(
'./work_dirs'
,
osp
.
splitext
(
osp
.
basename
(
args
.
config
))[
0
])
if
args
.
resume_from
is
not
None
:
cfg
.
resume_from
=
args
.
resume_from
if
args
.
gpus
is
not
None
:
cfg
.
gpu_ids
=
range
(
1
)
warnings
.
warn
(
'`--gpus` is deprecated because we only support '
'single GPU mode in non-distributed training. '
'Use `gpus=1` now.'
)
if
args
.
gpu_ids
is
not
None
:
cfg
.
gpu_ids
=
args
.
gpu_ids
[
0
:
1
]
warnings
.
warn
(
'`--gpu-ids` is deprecated, please use `--gpu-id`. '
'Because we only support single GPU mode in '
'non-distributed training. Use the first GPU '
'in `gpu_ids` now.'
)
if
args
.
gpus
is
None
and
args
.
gpu_ids
is
None
:
cfg
.
gpu_ids
=
[
args
.
gpu_id
]
# init distributed env first, since logger depends on the dist info.
if
args
.
launcher
==
'none'
:
distributed
=
False
else
:
distributed
=
True
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
# re-set gpu_ids with distributed training mode
_
,
world_size
=
get_dist_info
()
cfg
.
gpu_ids
=
range
(
world_size
)
# create work_dir
mmcv
.
mkdir_or_exist
(
osp
.
abspath
(
cfg
.
work_dir
))
# dump config
cfg
.
dump
(
osp
.
join
(
cfg
.
work_dir
,
osp
.
basename
(
args
.
config
)))
# init the logger before other steps
timestamp
=
time
.
strftime
(
'%Y%m%d_%H%M%S'
,
time
.
localtime
())
log_file
=
osp
.
join
(
cfg
.
work_dir
,
f
'
{
timestamp
}
.log'
)
logger
=
get_root_logger
(
log_file
=
log_file
,
log_level
=
cfg
.
log_level
)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta
=
dict
()
# log env info
env_info_dict
=
collect_env
()
env_info
=
'
\n
'
.
join
([(
f
'
{
k
}
:
{
v
}
'
)
for
k
,
v
in
env_info_dict
.
items
()])
dash_line
=
'-'
*
60
+
'
\n
'
logger
.
info
(
'Environment info:
\n
'
+
dash_line
+
env_info
+
'
\n
'
+
dash_line
)
meta
[
'env_info'
]
=
env_info
meta
[
'config'
]
=
cfg
.
pretty_text
# log some basic info
logger
.
info
(
f
'Distributed training:
{
distributed
}
'
)
logger
.
info
(
f
'Config:
\n
{
cfg
.
pretty_text
}
'
)
# set random seeds
if
args
.
seed
is
not
None
:
logger
.
info
(
f
'Set random seed to
{
args
.
seed
}
, '
f
'deterministic:
{
args
.
deterministic
}
, '
f
'use_rank_shift:
{
args
.
diff_seed
}
'
)
set_random_seed
(
args
.
seed
,
deterministic
=
args
.
deterministic
,
use_rank_shift
=
args
.
diff_seed
)
cfg
.
seed
=
args
.
seed
meta
[
'seed'
]
=
args
.
seed
meta
[
'exp_name'
]
=
osp
.
basename
(
args
.
config
)
model
=
build_model
(
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
datasets
=
[
build_dataset
(
cfg
.
data
.
train
)]
if
len
(
cfg
.
workflow
)
==
2
:
val_dataset
=
copy
.
deepcopy
(
cfg
.
data
.
val
)
val_dataset
.
pipeline
=
cfg
.
data
.
val
.
pipeline
datasets
.
append
(
build_dataset
(
val_dataset
))
if
cfg
.
checkpoint_config
is
not
None
:
# save mmgen version, config file content and class names in
# checkpoints as meta data
cfg
.
checkpoint_config
.
meta
=
dict
(
mmgen_version
=
__version__
+
get_git_hash
()[:
7
])
train_model
(
model
,
datasets
,
cfg
,
distributed
=
distributed
,
validate
=
(
not
args
.
no_validate
),
timestamp
=
timestamp
,
meta
=
meta
)
if
__name__
==
'__main__'
:
main
()
tools/utils/inception_stat.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os.path
as
osp
import
pickle
import
sys
import
mmcv
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv
import
Config
,
print_log
# yapf: disable
sys
.
path
.
append
(
osp
.
abspath
(
osp
.
join
(
__file__
,
'../../..'
)))
# isort:skip # noqa
from
mmgen.core.evaluation.metric_utils
import
extract_inception_features
# isort:skip # noqa
from
mmgen.datasets
import
(
UnconditionalImageDataset
,
build_dataloader
,
# isort:skip # noqa
build_dataset
)
# isort:skip # noqa
from
mmgen.models.architectures
import
InceptionV3
# isort:skip # noqa
# yapf: enable
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Pre-calculate inception data and save it in pkl file'
)
parser
.
add_argument
(
'--imgsdir'
,
type
=
str
,
default
=
None
,
help
=
'the dir containing images.'
)
parser
.
add_argument
(
'--data-cfg'
,
type
=
str
,
default
=
None
,
help
=
'the config file for test data pipeline'
)
parser
.
add_argument
(
'--pklname'
,
type
=
str
,
help
=
'the name of inception pkl'
)
parser
.
add_argument
(
'--pkl-dir'
,
type
=
str
,
default
=
'work_dirs/inception_pkl'
,
help
=
'path to save pkl file'
)
parser
.
add_argument
(
'--pipeline-cfg'
,
type
=
str
,
default
=
None
,
help
=
(
'config file containing dataset pipeline. If None, the default'
' pipeline will be adopted'
))
parser
.
add_argument
(
'--flip'
,
action
=
'store_true'
,
help
=
'whether to flip real images'
)
parser
.
add_argument
(
'--size'
,
type
=
int
,
nargs
=
'+'
,
default
=
(
299
,
299
),
help
=
'image size in the data pipeline'
)
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
25
,
help
=
'batch size used in extracted features'
)
parser
.
add_argument
(
'--num-samples'
,
type
=
int
,
default
=
50000
,
help
=
(
'the number of total samples, if input -1, '
'automaticly use all samples in the subset'
))
parser
.
add_argument
(
'--no-shuffle'
,
action
=
'store_true'
,
help
=
'not use shuffle in data loader'
)
parser
.
add_argument
(
'--subset'
,
default
=
'test'
,
help
=
'which subset and corresponding pipeline to use'
)
parser
.
add_argument
(
'--inception-style'
,
choices
=
[
'stylegan'
,
'pytorch'
],
default
=
'pytorch'
,
help
=
'which inception network to use'
)
parser
.
add_argument
(
'--inception-pth'
,
type
=
str
,
default
=
'work_dirs/cache/inception-2015-12-05.pt'
)
args
=
parser
.
parse_args
()
# dataset pipeline (only be used when args.imgsdir is not None)
if
args
.
pipeline_cfg
is
not
None
:
pipeline
=
Config
.
fromfile
(
args
.
pipeline_cfg
)[
'inception_pipeline'
]
elif
args
.
imgsdir
is
not
None
:
if
isinstance
(
args
.
size
,
list
)
and
len
(
args
.
size
)
==
2
:
size
=
args
.
size
elif
isinstance
(
args
.
size
,
list
)
and
len
(
args
.
size
)
==
1
:
size
=
(
args
.
size
[
0
],
args
.
size
[
0
])
elif
isinstance
(
args
.
size
,
int
):
size
=
(
args
.
size
,
args
.
size
)
else
:
raise
TypeError
(
f
'args.size mush be int or tuple but got
{
args
.
size
}
'
)
pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
,
key
=
'real_img'
),
dict
(
type
=
'Resize'
,
keys
=
[
'real_img'
],
scale
=
size
,
keep_ratio
=
False
),
dict
(
type
=
'Normalize'
,
keys
=
[
'real_img'
],
mean
=
[
127.5
]
*
3
,
std
=
[
127.5
]
*
3
,
to_rgb
=
True
),
# default to RGB images
dict
(
type
=
'Collect'
,
keys
=
[
'real_img'
],
meta_keys
=
[]),
dict
(
type
=
'ImageToTensor'
,
keys
=
[
'real_img'
])
]
# insert flip aug
if
args
.
flip
:
pipeline
.
insert
(
1
,
dict
(
type
=
'Flip'
,
keys
=
[
'real_img'
],
direction
=
'horizontal'
))
# build dataloader
if
args
.
imgsdir
is
not
None
:
dataset
=
UnconditionalImageDataset
(
args
.
imgsdir
,
pipeline
)
elif
args
.
data_cfg
is
not
None
:
# Please make sure the dataset will sample images in `RGB` order.
data_config
=
Config
.
fromfile
(
args
.
data_cfg
)
subset_config
=
data_config
.
data
.
get
(
args
.
subset
,
None
)
print_log
(
subset_config
,
'mmgen'
)
dataset
=
build_dataset
(
subset_config
)
else
:
raise
RuntimeError
(
'Please provide imgsdir or data_cfg'
)
data_loader
=
build_dataloader
(
dataset
,
args
.
batch_size
,
4
,
dist
=
False
,
shuffle
=
(
not
args
.
no_shuffle
))
mmcv
.
mkdir_or_exist
(
args
.
pkl_dir
)
# build inception network
if
args
.
inception_style
==
'stylegan'
:
inception
=
torch
.
jit
.
load
(
args
.
inception_pth
).
eval
().
cuda
()
inception
=
nn
.
DataParallel
(
inception
)
print_log
(
'Adopt Inception network in StyleGAN'
,
'mmgen'
)
else
:
inception
=
nn
.
DataParallel
(
InceptionV3
([
3
],
resize_input
=
True
,
normalize_input
=
False
).
cuda
())
inception
.
eval
()
if
args
.
num_samples
==
-
1
:
print_log
(
'Use all samples in subset'
,
'mmgen'
)
num_samples
=
len
(
dataset
)
else
:
num_samples
=
args
.
num_samples
features
=
extract_inception_features
(
data_loader
,
inception
,
num_samples
,
args
.
inception_style
).
numpy
()
# sanity check for the number of features
assert
features
.
shape
[
0
]
==
num_samples
,
'the number of features != num_samples'
print_log
(
f
'Extract
{
num_samples
}
features'
,
'mmgen'
)
mean
=
np
.
mean
(
features
,
0
)
cov
=
np
.
cov
(
features
,
rowvar
=
False
)
with
open
(
osp
.
join
(
args
.
pkl_dir
,
args
.
pklname
),
'wb'
)
as
f
:
pickle
.
dump
(
{
'mean'
:
mean
,
'cov'
:
cov
,
'size'
:
num_samples
,
'name'
:
args
.
pklname
},
f
)
tools/utils/singan_inference.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os
import
sys
import
mmcv
import
torch
from
mmcv
import
Config
from
mmcv.parallel
import
MMDataParallel
from
mmcv.runner
import
load_checkpoint
# yapf: disable
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__file__
,
'../../..'
)))
# isort:skip # noqa
from
mmgen.apis
import
set_random_seed
# isort:skip # noqa
from
mmgen.models
import
build_model
# isort:skip # noqa
# yapf: enable
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Evaluate a GAN model'
)
parser
.
add_argument
(
'config'
,
help
=
'evaluation config file path'
)
parser
.
add_argument
(
'checkpoint'
,
help
=
'checkpoint file'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
2021
,
help
=
'random seed'
)
parser
.
add_argument
(
'--deterministic'
,
action
=
'store_true'
,
help
=
'whether to set deterministic options for CUDNN backend.'
)
parser
.
add_argument
(
'--samples-path'
,
type
=
str
,
default
=
None
,
help
=
'path to store images. If not given, remove it after evaluation
\
finished'
)
parser
.
add_argument
(
'--save-prev-res'
,
action
=
'store_true'
,
help
=
'whether to store the results from previous stages'
)
parser
.
add_argument
(
'--num-samples'
,
type
=
int
,
default
=
10
,
help
=
'the number of synthesized samples'
)
args
=
parser
.
parse_args
()
return
args
def
_tensor2img
(
img
):
img
=
img
[
0
].
permute
(
1
,
2
,
0
)
img
=
((
img
+
1
)
/
2
*
255
).
to
(
torch
.
uint8
)
return
img
.
cpu
().
numpy
()
@
torch
.
no_grad
()
def
main
():
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
# set cudnn_benchmark
if
cfg
.
get
(
'cudnn_benchmark'
,
False
):
torch
.
backends
.
cudnn
.
benchmark
=
True
# set random seeds
if
args
.
seed
is
not
None
:
set_random_seed
(
args
.
seed
,
deterministic
=
args
.
deterministic
)
# build the model and load checkpoint
model
=
build_model
(
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
model
.
eval
()
# load ckpt
mmcv
.
print_log
(
f
'Loading ckpt from
{
args
.
checkpoint
}
'
,
'mmgen'
)
_
=
load_checkpoint
(
model
,
args
.
checkpoint
,
map_location
=
'cpu'
)
# add dp wrapper
model
=
MMDataParallel
(
model
,
device_ids
=
[
0
])
pbar
=
mmcv
.
ProgressBar
(
args
.
num_samples
)
for
sample_iter
in
range
(
args
.
num_samples
):
outputs
=
model
(
None
,
num_batches
=
1
,
get_prev_res
=
args
.
save_prev_res
)
# store results from previous stages
if
args
.
save_prev_res
:
fake_img
=
outputs
[
'fake_img'
]
prev_res_list
=
outputs
[
'prev_res_list'
]
prev_res_list
.
append
(
fake_img
)
for
i
,
img
in
enumerate
(
prev_res_list
):
img
=
_tensor2img
(
img
)
mmcv
.
imwrite
(
img
,
os
.
path
.
join
(
args
.
samples_path
,
f
'stage
{
i
}
'
,
f
'rand_sample_
{
sample_iter
}
.png'
))
# just store the final result
else
:
img
=
_tensor2img
(
outputs
)
mmcv
.
imwrite
(
img
,
os
.
path
.
join
(
args
.
samples_path
,
f
'rand_sample_
{
sample_iter
}
.png'
))
pbar
.
update
()
# change the line after pbar
sys
.
stdout
.
write
(
'
\n
'
)
if
__name__
==
'__main__'
:
main
()
tools/utils/translation_eval.py
0 → 100644
View file @
b7536f78
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
os
import
shutil
import
sys
import
mmcv
import
torch
from
mmcv
import
Config
from
mmcv.parallel
import
MMDataParallel
from
mmcv.runner
import
load_checkpoint
from
torchvision.utils
import
save_image
from
mmgen.apis
import
set_random_seed
from
mmgen.core
import
build_metric
from
mmgen.core.evaluation
import
make_metrics_table
,
make_vanilla_dataloader
from
mmgen.datasets
import
build_dataloader
,
build_dataset
from
mmgen.models
import
build_model
from
mmgen.models.translation_models
import
BaseTranslationModel
from
mmgen.utils
import
get_root_logger
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Evaluate a GAN model'
)
parser
.
add_argument
(
'config'
,
help
=
'evaluation config file path'
)
parser
.
add_argument
(
'checkpoint'
,
help
=
'checkpoint file'
)
parser
.
add_argument
(
'--target-domain'
,
type
=
str
,
default
=
None
,
help
=
'Desired image domain'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
2021
,
help
=
'random seed'
)
parser
.
add_argument
(
'--deterministic'
,
action
=
'store_true'
,
help
=
'whether to set deterministic options for CUDNN backend.'
)
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
1
,
help
=
'batch size of dataloader'
)
parser
.
add_argument
(
'--samples-path'
,
type
=
str
,
default
=
None
,
help
=
'path to store images. If not given, remove it after evaluation
\
finished'
)
parser
.
add_argument
(
'--sample-model'
,
type
=
str
,
default
=
'ema'
,
help
=
'use which mode (ema/orig) in sampling'
)
parser
.
add_argument
(
'--eval'
,
nargs
=
'*'
,
type
=
str
,
default
=
None
,
help
=
'select the metrics you want to access'
)
parser
.
add_argument
(
'--online'
,
action
=
'store_true'
,
help
=
'whether to use online mode for evaluation'
)
args
=
parser
.
parse_args
()
return
args
@
torch
.
no_grad
()
def
single_gpu_evaluation
(
model
,
data_loader
,
metrics
,
logger
,
basic_table_info
,
batch_size
,
samples_path
=
None
,
**
kwargs
):
"""Evaluate model with a single gpu.
This method evaluate model with a single gpu and displays eval progress
bar.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): PyTorch data loader.
metrics (list): List of metric objects.
logger (Logger): logger used to record results of evaluation.
basic_table_info (dict): Dictionary containing the basic information
\
of the metric table include training configuration and ckpt.
batch_size (int): Batch size of images fed into metrics.
samples_path (str): Used to save generated images. If it's none, we'll
give it a default directory and delete it after finishing the
evaluation. Default to None.
kwargs (dict): Other arguments.
"""
# decide samples path
delete_samples_path
=
False
if
samples_path
:
mmcv
.
mkdir_or_exist
(
samples_path
)
else
:
temp_path
=
'./work_dirs/temp_samples'
# if temp_path exists, add suffix
suffix
=
1
samples_path
=
temp_path
while
os
.
path
.
exists
(
samples_path
):
samples_path
=
temp_path
+
'_'
+
str
(
suffix
)
suffix
+=
1
os
.
makedirs
(
samples_path
)
delete_samples_path
=
True
# sample images
num_exist
=
len
(
list
(
mmcv
.
scandir
(
samples_path
,
suffix
=
(
'.jpg'
,
'.png'
,
'.jpeg'
,
'.JPEG'
))))
if
basic_table_info
[
'num_samples'
]
>
0
:
max_num_images
=
basic_table_info
[
'num_samples'
]
else
:
max_num_images
=
max
(
metric
.
num_images
for
metric
in
metrics
)
num_needed
=
max
(
max_num_images
-
num_exist
,
0
)
if
num_needed
>
0
:
mmcv
.
print_log
(
f
'Sample
{
num_needed
}
fake images for evaluation'
,
'mmgen'
)
# define mmcv progress bar
pbar
=
mmcv
.
ProgressBar
(
num_needed
)
# select key to fetch fake images
target_domain
=
basic_table_info
[
'target_domain'
]
source_domain
=
basic_table_info
[
'source_domain'
]
# if no images, `num_needed` should be zero
data_loader_iter
=
iter
(
data_loader
)
for
begin
in
range
(
0
,
num_needed
,
batch_size
):
end
=
min
(
begin
+
batch_size
,
max_num_images
)
# for translation model, we feed them images from dataloader
data_batch
=
next
(
data_loader_iter
)
output_dict
=
model
(
data_batch
[
f
'img_
{
source_domain
}
'
],
test_mode
=
True
,
target_domain
=
target_domain
)
fakes
=
output_dict
[
'target'
]
pbar
.
update
(
end
-
begin
)
for
i
in
range
(
end
-
begin
):
images
=
fakes
[
i
:
i
+
1
]
images
=
((
images
+
1
)
/
2
)
images
=
images
[:,
[
2
,
1
,
0
],
...]
images
=
images
.
clamp_
(
0
,
1
)
image_name
=
str
(
begin
+
i
)
+
'.png'
save_image
(
images
,
os
.
path
.
join
(
samples_path
,
image_name
))
if
num_needed
>
0
:
sys
.
stdout
.
write
(
'
\n
'
)
# return if only save sampled images
if
len
(
metrics
)
==
0
:
return
# empty cache to release GPU memory
torch
.
cuda
.
empty_cache
()
fake_dataloader
=
make_vanilla_dataloader
(
samples_path
,
batch_size
)
for
metric
in
metrics
:
mmcv
.
print_log
(
f
'Evaluate with
{
metric
.
name
}
metric.'
,
'mmgen'
)
metric
.
prepare
()
# feed in real images
for
data
in
data_loader
:
reals
=
data
[
f
'img_
{
target_domain
}
'
]
num_left
=
metric
.
feed
(
reals
,
'reals'
)
if
num_left
<=
0
:
break
# feed in fake images
for
data
in
fake_dataloader
:
fakes
=
data
[
'real_img'
]
num_left
=
metric
.
feed
(
fakes
,
'fakes'
)
if
num_left
<=
0
:
break
metric
.
summary
()
table_str
=
make_metrics_table
(
basic_table_info
[
'train_cfg'
],
basic_table_info
[
'ckpt'
],
basic_table_info
[
'sample_model'
],
metrics
)
logger
.
info
(
'
\n
'
+
table_str
)
if
delete_samples_path
:
shutil
.
rmtree
(
samples_path
)
@
torch
.
no_grad
()
def
single_gpu_online_evaluation
(
model
,
data_loader
,
metrics
,
logger
,
basic_table_info
,
batch_size
,
**
kwargs
):
"""Evaluate model with a single gpu in online mode.
This method evaluate model with a single gpu and displays eval progress
bar. Different form `single_gpu_evaluation`, this function will not save
the images or read images from disks. Namely, there do not exist any IO
operations in this function. Thus, in general, `online` mode will achieve a
faster evaluation. However, this mode will take much more memory cost.
Therefore this evaluation function is recommended to evaluate your model
with a single metric.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): PyTorch data loader.
metrics (list): List of metric objects.
logger (Logger): logger used to record results of evaluation.
basic_table_info (dict): Dictionary containing the basic information
\
of the metric table include training configuration and ckpt.
batch_size (int): Batch size of images fed into metrics.
kwargs (dict): Other arguments.
"""
# sample images
max_num_images
=
0
if
len
(
metrics
)
==
0
else
max
(
metric
.
num_fake_need
for
metric
in
metrics
)
pbar
=
mmcv
.
ProgressBar
(
max_num_images
)
# select key to fetch images
target_domain
=
basic_table_info
[
'target_domain'
]
source_domain
=
basic_table_info
[
'source_domain'
]
for
metric
in
metrics
:
mmcv
.
print_log
(
f
'Evaluate with
{
metric
.
name
}
metric.'
,
'mmgen'
)
metric
.
prepare
()
# feed reals and fakes
data_loader_iter
=
iter
(
data_loader
)
for
begin
in
range
(
0
,
max_num_images
,
batch_size
):
end
=
min
(
begin
+
batch_size
,
max_num_images
)
# for translation model, we feed them images from dataloader
data_batch
=
next
(
data_loader_iter
)
output_dict
=
model
(
data_batch
[
f
'img_
{
source_domain
}
'
],
test_mode
=
True
,
target_domain
=
target_domain
)
fakes
=
output_dict
[
'target'
]
reals
=
data_batch
[
f
'img_
{
target_domain
}
'
]
pbar
.
update
(
end
-
begin
)
for
metric
in
metrics
:
metric
.
feed
(
reals
,
'reals'
)
metric
.
feed
(
fakes
,
'fakes'
)
for
metric
in
metrics
:
metric
.
summary
()
table_str
=
make_metrics_table
(
basic_table_info
[
'train_cfg'
],
basic_table_info
[
'ckpt'
],
basic_table_info
[
'sample_model'
],
metrics
)
logger
.
info
(
'
\n
'
+
table_str
)
def
main
():
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
dirname
=
os
.
path
.
dirname
(
args
.
checkpoint
)
ckpt
=
os
.
path
.
basename
(
args
.
checkpoint
)
if
'http'
in
args
.
checkpoint
:
log_path
=
None
else
:
log_name
=
ckpt
.
split
(
'.'
)[
0
]
+
'_eval_log'
+
'.txt'
log_path
=
os
.
path
.
join
(
dirname
,
log_name
)
logger
=
get_root_logger
(
log_file
=
log_path
,
log_level
=
cfg
.
log_level
,
file_mode
=
'a'
)
logger
.
info
(
'evaluation'
)
# set random seeds
if
args
.
seed
is
not
None
:
set_random_seed
(
args
.
seed
,
deterministic
=
args
.
deterministic
)
# build the model and load checkpoint
model
=
build_model
(
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
assert
isinstance
(
model
,
BaseTranslationModel
)
# sanity check for models without ema
if
not
model
.
use_ema
:
args
.
sample_model
=
'orig'
mmcv
.
print_log
(
f
'Sampling model:
{
args
.
sample_model
}
'
,
'mmgen'
)
model
.
eval
()
_
=
load_checkpoint
(
model
,
args
.
checkpoint
,
map_location
=
'cpu'
)
model
=
MMDataParallel
(
model
,
device_ids
=
[
0
])
# build metrics
if
args
.
eval
:
if
args
.
eval
[
0
]
==
'none'
:
# only sample images
metrics
=
[]
assert
args
.
num_samples
is
not
None
and
args
.
num_samples
>
0
else
:
metrics
=
[
build_metric
(
cfg
.
metrics
[
metric
])
for
metric
in
args
.
eval
]
else
:
metrics
=
[
build_metric
(
cfg
.
metrics
[
metric
])
for
metric
in
cfg
.
metrics
]
# get source domain and target domain
target_domain
=
args
.
target_domain
if
target_domain
is
None
:
target_domain
=
model
.
module
.
_default_domain
source_domain
=
model
.
module
.
get_other_domains
(
target_domain
)[
0
]
basic_table_info
=
dict
(
train_cfg
=
os
.
path
.
basename
(
cfg
.
_filename
),
ckpt
=
ckpt
,
sample_model
=
args
.
sample_model
,
source_domain
=
source_domain
,
target_domain
=
target_domain
)
# build the dataloader
if
len
(
metrics
)
==
0
:
basic_table_info
[
'num_samples'
]
=
args
.
num_samples
data_loader
=
None
else
:
basic_table_info
[
'num_samples'
]
=
-
1
if
cfg
.
data
.
get
(
'test'
,
None
):
dataset
=
build_dataset
(
cfg
.
data
.
test
)
else
:
dataset
=
build_dataset
(
cfg
.
data
.
train
)
data_loader
=
build_dataloader
(
dataset
,
samples_per_gpu
=
args
.
batch_size
,
workers_per_gpu
=
cfg
.
data
.
get
(
'val_workers_per_gpu'
,
cfg
.
data
.
workers_per_gpu
),
dist
=
False
,
shuffle
=
True
)
if
args
.
online
:
single_gpu_online_evaluation
(
model
,
data_loader
,
metrics
,
logger
,
basic_table_info
,
args
.
batch_size
)
else
:
single_gpu_evaluation
(
model
,
data_loader
,
metrics
,
logger
,
basic_table_info
,
args
.
batch_size
,
args
.
samples_path
)
if
__name__
==
'__main__'
:
main
()
Prev
1
…
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment