Commit 1401de15 authored by dongchy920's avatar dongchy920
Browse files

stylegan2_mmcv

parents
Pipeline #1274 canceled with stages
#!/usr/bin/env bash
CONFIG=$1
GPUS=$2
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
PORT=${PORT:-29500}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch \
--nnodes=$NNODES \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--nproc_per_node=$GPUS \
--master_port=$PORT \
$(dirname "$0")/train.py \
$CONFIG \
--seed 0 \
--launcher pytorch ${@:3}
#!/usr/bin/env bash
set -x
CONFIG=$1
CKPT=$2
PY_ARGS=${@:3}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -u tools/evaluation.py ${CONFIG} ${CKPT} --launcher="none" ${PY_ARGS}
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import warnings
import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from mmgen.apis import set_random_seed
from mmgen.core import build_metric, offline_evaluation, online_evaluation
from mmgen.datasets import build_dataloader, build_dataset
from mmgen.models import build_model
from mmgen.utils import get_root_logger
_distributed_metrics = ['FID', 'IS']
def parse_args():
parser = argparse.ArgumentParser(description='Evaluate a Generation model')
parser.add_argument('config', help='evaluation config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
type=int,
help='number of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument(
'--gpu-id',
type=int,
default=0,
help='id of gpu to use '
'(only applicable to non-distributed testing)')
parser.add_argument('--seed', type=int, default=2021, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--batch-size', type=int, default=10, help='batch size of dataloader')
parser.add_argument(
'--samples-path',
type=str,
default=None,
help='path to store images. If not given, remove it after evaluation\
finished')
parser.add_argument(
'--sample-model',
type=str,
default='ema',
choices=['ema', 'orig'],
help='use which mode (ema/orig) in sampling')
parser.add_argument(
'--eval',
nargs='*',
type=str,
default=None,
help='select the metrics you want to access')
parser.add_argument(
'--online',
action='store_true',
help='whether to use online mode for evaluation')
parser.add_argument(
'--num-samples',
type=int,
default=-1,
help='The number of images to be sampled for evaluation.')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file.')
parser.add_argument(
'--sample-cfg',
nargs='+',
action=DictAction,
help='Other customized kwargs for sampling function')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids[0:1]
warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
'Because we only support single GPU mode in '
'non-distributed testing. Use the first GPU '
'in `gpu_ids` now.')
else:
cfg.gpu_ids = [args.gpu_id]
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
rank = 0
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
rank, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
assert args.online or world_size == 1, (
'We only support online mode for distrbuted evaluation.')
dirname = os.path.dirname(args.checkpoint)
ckpt = os.path.basename(args.checkpoint)
if 'http' in args.checkpoint:
log_path = None
else:
log_name = ckpt.split('.')[0] + '_eval_log' + '.txt'
log_path = os.path.join(dirname, log_name)
logger = get_root_logger(
log_file=log_path, log_level=cfg.log_level, file_mode='a')
logger.info('evaluation')
# set random seeds
if args.seed is not None:
if rank == 0:
mmcv.print_log(f'set random seed to {args.seed}', 'mmgen')
set_random_seed(args.seed, deterministic=args.deterministic)
# build the model and load checkpoint
model = build_model(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
# sanity check for models without ema
if not model.use_ema:
args.sample_model = 'orig'
mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')
model.eval()
if args.eval:
if args.eval[0] == 'none':
# only sample images
metrics = []
assert args.num_samples is not None and args.num_samples > 0
else:
metrics = [
build_metric(cfg.metrics[metric]) for metric in args.eval
]
else:
metrics = [build_metric(cfg.metrics[metric]) for metric in cfg.metrics]
# check metrics for dist evaluation
if distributed and metrics:
for metric in metrics:
assert metric.name in _distributed_metrics, (
f'We only support {_distributed_metrics} for multi gpu '
f'evaluation, but receive {args.eval}.')
_ = load_checkpoint(model, args.checkpoint, map_location='cpu')
basic_table_info = dict(
train_cfg=os.path.basename(cfg._filename),
ckpt=ckpt,
sample_model=args.sample_model)
if len(metrics) == 0:
basic_table_info['num_samples'] = args.num_samples
data_loader = None
else:
basic_table_info['num_samples'] = -1
# build the dataloader
if cfg.data.get('test', None) and cfg.data.test.get('imgs_root', None):
dataset = build_dataset(cfg.data.test)
elif cfg.data.get('val', None) and cfg.data.val.get('imgs_root', None):
dataset = build_dataset(cfg.data.val)
elif cfg.data.get('train', None):
# we assume that the train part should work well
dataset = build_dataset(cfg.data.train)
else:
raise RuntimeError('There is no valid dataset config to run, '
'please check your dataset configs.')
# The default loader config
loader_cfg = dict(
samples_per_gpu=args.batch_size,
workers_per_gpu=cfg.data.get('val_workers_per_gpu',
cfg.data.workers_per_gpu),
num_gpus=len(cfg.gpu_ids),
dist=distributed,
shuffle=True)
# The overall dataloader settings
loader_cfg.update({
k: v
for k, v in cfg.data.items() if k not in [
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
'test_dataloader'
]
})
# specific config for test loader
test_loader_cfg = {**loader_cfg, **cfg.data.get('test_dataloader', {})}
data_loader = build_dataloader(dataset, **test_loader_cfg)
if args.sample_cfg is None:
args.sample_cfg = dict()
if not distributed:
model = MMDataParallel(model, device_ids=[0])
else:
find_unused_parameters = cfg.get('find_unused_parameters', False)
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
# online mode will not save samples
if args.online and len(metrics) > 0:
online_evaluation(model, data_loader, metrics, logger,
basic_table_info, args.batch_size, **args.sample_cfg)
else:
offline_evaluation(model, data_loader, metrics, logger,
basic_table_info, args.batch_size,
args.samples_path, **args.sample_cfg)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from mmcv import Config, DictAction
def parse_args():
parser = argparse.ArgumentParser(description='Print the whole config')
parser.add_argument('config', help='config file path')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# import modules from string list.
if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports'])
print(f'Config:\n{cfg.pretty_text}')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import subprocess
from datetime import datetime
import torch
def parse_args():
parser = argparse.ArgumentParser(
description='Process a checkpoint to be published')
parser.add_argument('in_file', help='input checkpoint filename')
parser.add_argument('out_file', help='output checkpoint filename')
args = parser.parse_args()
return args
def process_checkpoint(in_file, out_file):
checkpoint = torch.load(in_file, map_location='cpu')
# remove optimizer for smaller file size
if 'optimizer' in checkpoint:
del checkpoint['optimizer']
# if it is necessary to remove some sensitive data in checkpoint['meta'],
# add the code here.
torch.save(checkpoint, out_file)
now = datetime.now()
time = now.strftime('%Y%m%d_%H%M%S')
sha = subprocess.check_output(['sha256sum', out_file]).decode()
final_file = out_file.rstrip('.pth') + f'_{time}-{sha[:8]}.pth'
subprocess.Popen(['mv', out_file, final_file])
def main():
args = parse_args()
process_checkpoint(args.in_file, args.out_file)
if __name__ == '__main__':
main()
#!/usr/bin/env bash
set -x
PARTITION=$1
JOB_NAME=$2
CONFIG=$3
CKPT=$4
GPUS=${GPUS:-1}
GPUS_PER_NODE=${GPUS_PER_NODE:-1}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
PY_ARGS=${@:5}
SRUN_ARGS=${SRUN_ARGS:-""}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python -u tools/evaluation.py ${CONFIG} ${CKPT} --launcher="none" ${PY_ARGS}
#!/usr/bin/env bash
set -x
PARTITION=$1
JOB_NAME=$2
CONFIG=$3
CKPT=$4
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
PY_ARGS=${@:5}
SRUN_ARGS=${SRUN_ARGS:-""}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python -u tools/evaluation.py ${CONFIG} ${CKPT} --launcher="slurm" ${PY_ARGS}
#!/usr/bin/env bash
set -x
PARTITION=$1
JOB_NAME=$2
CONFIG=$3
WORK_DIR=$4
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
PY_ARGS=${@:5}
SRUN_ARGS=${SRUN_ARGS:-""}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS}
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy
import multiprocessing as mp
import os
import os.path as osp
import platform
import time
import warnings
import pdb
import cv2
import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import get_git_hash
from mmgen import __version__
from mmgen.apis import set_random_seed, train_model
from mmgen.datasets import build_dataset
from mmgen.models import build_model
from mmgen.utils import collect_env, get_root_logger
cv2.setNumThreads(0)
def parse_args():
parser = argparse.ArgumentParser(description='Train a GAN model')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
type=int,
help='(Deprecated, please use --gpu-id) number of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-id',
type=int,
default=0,
help='id of gpu to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=2021, help='random seed')
parser.add_argument(
'--diff_seed',
action='store_true',
help='Whether or not set different seeds for different ranks')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def setup_multi_processes(cfg):
# set multi-process start method as `fork` to speed up the training
if platform.system() != 'Windows':
mp_start_method = cfg.get('mp_start_method', 'fork')
mp.set_start_method(mp_start_method)
# disable opencv multithreading to avoid system being overloaded
opencv_num_threads = cfg.get('opencv_num_threads', 0)
cv2.setNumThreads(opencv_num_threads)
# setup OMP threads
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
if ('OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1):
omp_num_threads = 1
warnings.warn(
f'Setting OMP_NUM_THREADS environment variable for each process '
f'to be {omp_num_threads} in default, to avoid your system being '
f'overloaded, please further tune the variable for optimal '
f'performance in your application as needed.')
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
# setup MKL threads
if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
mkl_num_threads = 1
warnings.warn(
f'Setting MKL_NUM_THREADS environment variable for each process '
f'to be {mkl_num_threads} in default, to avoid your system being '
f'overloaded, please further tune the variable for optimal '
f'performance in your application as needed.')
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
setup_multi_processes(cfg)
# import modules from string list.
if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports'])
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
if args.gpus is not None:
cfg.gpu_ids = range(1)
warnings.warn('`--gpus` is deprecated because we only support '
'single GPU mode in non-distributed training. '
'Use `gpus=1` now.')
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids[0:1]
warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
'Because we only support single GPU mode in '
'non-distributed training. Use the first GPU '
'in `gpu_ids` now.')
if args.gpus is None and args.gpu_ids is None:
cfg.gpu_ids = [args.gpu_id]
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# re-set gpu_ids with distributed training mode
_, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# dump config
# pdb.set_trace()
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info
meta['config'] = cfg.pretty_text
# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, '
f'deterministic: {args.deterministic}, '
f'use_rank_shift: {args.diff_seed}')
set_random_seed(
args.seed,
deterministic=args.deterministic,
use_rank_shift=args.diff_seed)
cfg.seed = args.seed
meta['seed'] = args.seed
meta['exp_name'] = osp.basename(args.config)
model = build_model(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
val_dataset.pipeline = cfg.data.val.pipeline
datasets.append(build_dataset(val_dataset))
if cfg.checkpoint_config is not None:
# save mmgen version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(mmgen_version=__version__ +
get_git_hash()[:7])
train_model(
model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import pickle
import sys
import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmcv import Config, print_log
# yapf: disable
sys.path.append(osp.abspath(osp.join(__file__, '../../..'))) # isort:skip # noqa
from mmgen.core.evaluation.metric_utils import extract_inception_features # isort:skip # noqa
from mmgen.datasets import (UnconditionalImageDataset, build_dataloader, # isort:skip # noqa
build_dataset) # isort:skip # noqa
from mmgen.models.architectures import InceptionV3 # isort:skip # noqa
# yapf: enable
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Pre-calculate inception data and save it in pkl file')
parser.add_argument(
'--imgsdir', type=str, default=None, help='the dir containing images.')
parser.add_argument(
'--data-cfg',
type=str,
default=None,
help='the config file for test data pipeline')
parser.add_argument(
'--pklname', type=str, help='the name of inception pkl')
parser.add_argument(
'--pkl-dir',
type=str,
default='work_dirs/inception_pkl',
help='path to save pkl file')
parser.add_argument(
'--pipeline-cfg',
type=str,
default=None,
help=('config file containing dataset pipeline. If None, the default'
' pipeline will be adopted'))
parser.add_argument(
'--flip', action='store_true', help='whether to flip real images')
parser.add_argument(
'--size',
type=int,
nargs='+',
default=(299, 299),
help='image size in the data pipeline')
parser.add_argument(
'--batch-size',
type=int,
default=25,
help='batch size used in extracted features')
parser.add_argument(
'--num-samples',
type=int,
default=50000,
help=('the number of total samples, if input -1, '
'automaticly use all samples in the subset'))
parser.add_argument(
'--no-shuffle',
action='store_true',
help='not use shuffle in data loader')
parser.add_argument(
'--subset',
default='test',
help='which subset and corresponding pipeline to use')
parser.add_argument(
'--inception-style',
choices=['stylegan', 'pytorch'],
default='pytorch',
help='which inception network to use')
parser.add_argument(
'--inception-pth',
type=str,
default='work_dirs/cache/inception-2015-12-05.pt')
args = parser.parse_args()
# dataset pipeline (only be used when args.imgsdir is not None)
if args.pipeline_cfg is not None:
pipeline = Config.fromfile(args.pipeline_cfg)['inception_pipeline']
elif args.imgsdir is not None:
if isinstance(args.size, list) and len(args.size) == 2:
size = args.size
elif isinstance(args.size, list) and len(args.size) == 1:
size = (args.size[0], args.size[0])
elif isinstance(args.size, int):
size = (args.size, args.size)
else:
raise TypeError(
f'args.size mush be int or tuple but got {args.size}')
pipeline = [
dict(type='LoadImageFromFile', key='real_img'),
dict(
type='Resize', keys=['real_img'], scale=size,
keep_ratio=False),
dict(
type='Normalize',
keys=['real_img'],
mean=[127.5] * 3,
std=[127.5] * 3,
to_rgb=True), # default to RGB images
dict(type='Collect', keys=['real_img'], meta_keys=[]),
dict(type='ImageToTensor', keys=['real_img'])
]
# insert flip aug
if args.flip:
pipeline.insert(
1,
dict(type='Flip', keys=['real_img'], direction='horizontal'))
# build dataloader
if args.imgsdir is not None:
dataset = UnconditionalImageDataset(args.imgsdir, pipeline)
elif args.data_cfg is not None:
# Please make sure the dataset will sample images in `RGB` order.
data_config = Config.fromfile(args.data_cfg)
subset_config = data_config.data.get(args.subset, None)
print_log(subset_config, 'mmgen')
dataset = build_dataset(subset_config)
else:
raise RuntimeError('Please provide imgsdir or data_cfg')
data_loader = build_dataloader(
dataset, args.batch_size, 4, dist=False, shuffle=(not args.no_shuffle))
mmcv.mkdir_or_exist(args.pkl_dir)
# build inception network
if args.inception_style == 'stylegan':
inception = torch.jit.load(args.inception_pth).eval().cuda()
inception = nn.DataParallel(inception)
print_log('Adopt Inception network in StyleGAN', 'mmgen')
else:
inception = nn.DataParallel(
InceptionV3([3], resize_input=True, normalize_input=False).cuda())
inception.eval()
if args.num_samples == -1:
print_log('Use all samples in subset', 'mmgen')
num_samples = len(dataset)
else:
num_samples = args.num_samples
features = extract_inception_features(data_loader, inception, num_samples,
args.inception_style).numpy()
# sanity check for the number of features
assert features.shape[
0] == num_samples, 'the number of features != num_samples'
print_log(f'Extract {num_samples} features', 'mmgen')
mean = np.mean(features, 0)
cov = np.cov(features, rowvar=False)
with open(osp.join(args.pkl_dir, args.pklname), 'wb') as f:
pickle.dump(
{
'mean': mean,
'cov': cov,
'size': num_samples,
'name': args.pklname
}, f)
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import sys
import mmcv
import torch
from mmcv import Config
from mmcv.parallel import MMDataParallel
from mmcv.runner import load_checkpoint
# yapf: disable
sys.path.append(os.path.abspath(os.path.join(__file__, '../../..'))) # isort:skip # noqa
from mmgen.apis import set_random_seed # isort:skip # noqa
from mmgen.models import build_model # isort:skip # noqa
# yapf: enable
def parse_args():
parser = argparse.ArgumentParser(description='Evaluate a GAN model')
parser.add_argument('config', help='evaluation config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--seed', type=int, default=2021, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--samples-path',
type=str,
default=None,
help='path to store images. If not given, remove it after evaluation\
finished')
parser.add_argument(
'--save-prev-res',
action='store_true',
help='whether to store the results from previous stages')
parser.add_argument(
'--num-samples',
type=int,
default=10,
help='the number of synthesized samples')
args = parser.parse_args()
return args
def _tensor2img(img):
img = img[0].permute(1, 2, 0)
img = ((img + 1) / 2 * 255).to(torch.uint8)
return img.cpu().numpy()
@torch.no_grad()
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# set random seeds
if args.seed is not None:
set_random_seed(args.seed, deterministic=args.deterministic)
# build the model and load checkpoint
model = build_model(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
model.eval()
# load ckpt
mmcv.print_log(f'Loading ckpt from {args.checkpoint}', 'mmgen')
_ = load_checkpoint(model, args.checkpoint, map_location='cpu')
# add dp wrapper
model = MMDataParallel(model, device_ids=[0])
pbar = mmcv.ProgressBar(args.num_samples)
for sample_iter in range(args.num_samples):
outputs = model(None, num_batches=1, get_prev_res=args.save_prev_res)
# store results from previous stages
if args.save_prev_res:
fake_img = outputs['fake_img']
prev_res_list = outputs['prev_res_list']
prev_res_list.append(fake_img)
for i, img in enumerate(prev_res_list):
img = _tensor2img(img)
mmcv.imwrite(
img,
os.path.join(args.samples_path, f'stage{i}',
f'rand_sample_{sample_iter}.png'))
# just store the final result
else:
img = _tensor2img(outputs)
mmcv.imwrite(
img,
os.path.join(args.samples_path,
f'rand_sample_{sample_iter}.png'))
pbar.update()
# change the line after pbar
sys.stdout.write('\n')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import shutil
import sys
import mmcv
import torch
from mmcv import Config
from mmcv.parallel import MMDataParallel
from mmcv.runner import load_checkpoint
from torchvision.utils import save_image
from mmgen.apis import set_random_seed
from mmgen.core import build_metric
from mmgen.core.evaluation import make_metrics_table, make_vanilla_dataloader
from mmgen.datasets import build_dataloader, build_dataset
from mmgen.models import build_model
from mmgen.models.translation_models import BaseTranslationModel
from mmgen.utils import get_root_logger
def parse_args():
parser = argparse.ArgumentParser(description='Evaluate a GAN model')
parser.add_argument('config', help='evaluation config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--target-domain', type=str, default=None, help='Desired image domain')
parser.add_argument('--seed', type=int, default=2021, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--batch-size', type=int, default=1, help='batch size of dataloader')
parser.add_argument(
'--samples-path',
type=str,
default=None,
help='path to store images. If not given, remove it after evaluation\
finished')
parser.add_argument(
'--sample-model',
type=str,
default='ema',
help='use which mode (ema/orig) in sampling')
parser.add_argument(
'--eval',
nargs='*',
type=str,
default=None,
help='select the metrics you want to access')
parser.add_argument(
'--online',
action='store_true',
help='whether to use online mode for evaluation')
args = parser.parse_args()
return args
@torch.no_grad()
def single_gpu_evaluation(model,
data_loader,
metrics,
logger,
basic_table_info,
batch_size,
samples_path=None,
**kwargs):
"""Evaluate model with a single gpu.
This method evaluate model with a single gpu and displays eval progress
bar.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): PyTorch data loader.
metrics (list): List of metric objects.
logger (Logger): logger used to record results of evaluation.
basic_table_info (dict): Dictionary containing the basic information \
of the metric table include training configuration and ckpt.
batch_size (int): Batch size of images fed into metrics.
samples_path (str): Used to save generated images. If it's none, we'll
give it a default directory and delete it after finishing the
evaluation. Default to None.
kwargs (dict): Other arguments.
"""
# decide samples path
delete_samples_path = False
if samples_path:
mmcv.mkdir_or_exist(samples_path)
else:
temp_path = './work_dirs/temp_samples'
# if temp_path exists, add suffix
suffix = 1
samples_path = temp_path
while os.path.exists(samples_path):
samples_path = temp_path + '_' + str(suffix)
suffix += 1
os.makedirs(samples_path)
delete_samples_path = True
# sample images
num_exist = len(
list(
mmcv.scandir(
samples_path, suffix=('.jpg', '.png', '.jpeg', '.JPEG'))))
if basic_table_info['num_samples'] > 0:
max_num_images = basic_table_info['num_samples']
else:
max_num_images = max(metric.num_images for metric in metrics)
num_needed = max(max_num_images - num_exist, 0)
if num_needed > 0:
mmcv.print_log(f'Sample {num_needed} fake images for evaluation',
'mmgen')
# define mmcv progress bar
pbar = mmcv.ProgressBar(num_needed)
# select key to fetch fake images
target_domain = basic_table_info['target_domain']
source_domain = basic_table_info['source_domain']
# if no images, `num_needed` should be zero
data_loader_iter = iter(data_loader)
for begin in range(0, num_needed, batch_size):
end = min(begin + batch_size, max_num_images)
# for translation model, we feed them images from dataloader
data_batch = next(data_loader_iter)
output_dict = model(
data_batch[f'img_{source_domain}'],
test_mode=True,
target_domain=target_domain)
fakes = output_dict['target']
pbar.update(end - begin)
for i in range(end - begin):
images = fakes[i:i + 1]
images = ((images + 1) / 2)
images = images[:, [2, 1, 0], ...]
images = images.clamp_(0, 1)
image_name = str(begin + i) + '.png'
save_image(images, os.path.join(samples_path, image_name))
if num_needed > 0:
sys.stdout.write('\n')
# return if only save sampled images
if len(metrics) == 0:
return
# empty cache to release GPU memory
torch.cuda.empty_cache()
fake_dataloader = make_vanilla_dataloader(samples_path, batch_size)
for metric in metrics:
mmcv.print_log(f'Evaluate with {metric.name} metric.', 'mmgen')
metric.prepare()
# feed in real images
for data in data_loader:
reals = data[f'img_{target_domain}']
num_left = metric.feed(reals, 'reals')
if num_left <= 0:
break
# feed in fake images
for data in fake_dataloader:
fakes = data['real_img']
num_left = metric.feed(fakes, 'fakes')
if num_left <= 0:
break
metric.summary()
table_str = make_metrics_table(basic_table_info['train_cfg'],
basic_table_info['ckpt'],
basic_table_info['sample_model'], metrics)
logger.info('\n' + table_str)
if delete_samples_path:
shutil.rmtree(samples_path)
@torch.no_grad()
def single_gpu_online_evaluation(model, data_loader, metrics, logger,
basic_table_info, batch_size, **kwargs):
"""Evaluate model with a single gpu in online mode.
This method evaluate model with a single gpu and displays eval progress
bar. Different form `single_gpu_evaluation`, this function will not save
the images or read images from disks. Namely, there do not exist any IO
operations in this function. Thus, in general, `online` mode will achieve a
faster evaluation. However, this mode will take much more memory cost.
Therefore this evaluation function is recommended to evaluate your model
with a single metric.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): PyTorch data loader.
metrics (list): List of metric objects.
logger (Logger): logger used to record results of evaluation.
basic_table_info (dict): Dictionary containing the basic information \
of the metric table include training configuration and ckpt.
batch_size (int): Batch size of images fed into metrics.
kwargs (dict): Other arguments.
"""
# sample images
max_num_images = 0 if len(metrics) == 0 else max(metric.num_fake_need
for metric in metrics)
pbar = mmcv.ProgressBar(max_num_images)
# select key to fetch images
target_domain = basic_table_info['target_domain']
source_domain = basic_table_info['source_domain']
for metric in metrics:
mmcv.print_log(f'Evaluate with {metric.name} metric.', 'mmgen')
metric.prepare()
# feed reals and fakes
data_loader_iter = iter(data_loader)
for begin in range(0, max_num_images, batch_size):
end = min(begin + batch_size, max_num_images)
# for translation model, we feed them images from dataloader
data_batch = next(data_loader_iter)
output_dict = model(
data_batch[f'img_{source_domain}'],
test_mode=True,
target_domain=target_domain)
fakes = output_dict['target']
reals = data_batch[f'img_{target_domain}']
pbar.update(end - begin)
for metric in metrics:
metric.feed(reals, 'reals')
metric.feed(fakes, 'fakes')
for metric in metrics:
metric.summary()
table_str = make_metrics_table(basic_table_info['train_cfg'],
basic_table_info['ckpt'],
basic_table_info['sample_model'], metrics)
logger.info('\n' + table_str)
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
dirname = os.path.dirname(args.checkpoint)
ckpt = os.path.basename(args.checkpoint)
if 'http' in args.checkpoint:
log_path = None
else:
log_name = ckpt.split('.')[0] + '_eval_log' + '.txt'
log_path = os.path.join(dirname, log_name)
logger = get_root_logger(
log_file=log_path, log_level=cfg.log_level, file_mode='a')
logger.info('evaluation')
# set random seeds
if args.seed is not None:
set_random_seed(args.seed, deterministic=args.deterministic)
# build the model and load checkpoint
model = build_model(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
assert isinstance(model, BaseTranslationModel)
# sanity check for models without ema
if not model.use_ema:
args.sample_model = 'orig'
mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')
model.eval()
_ = load_checkpoint(model, args.checkpoint, map_location='cpu')
model = MMDataParallel(model, device_ids=[0])
# build metrics
if args.eval:
if args.eval[0] == 'none':
# only sample images
metrics = []
assert args.num_samples is not None and args.num_samples > 0
else:
metrics = [
build_metric(cfg.metrics[metric]) for metric in args.eval
]
else:
metrics = [build_metric(cfg.metrics[metric]) for metric in cfg.metrics]
# get source domain and target domain
target_domain = args.target_domain
if target_domain is None:
target_domain = model.module._default_domain
source_domain = model.module.get_other_domains(target_domain)[0]
basic_table_info = dict(
train_cfg=os.path.basename(cfg._filename),
ckpt=ckpt,
sample_model=args.sample_model,
source_domain=source_domain,
target_domain=target_domain)
# build the dataloader
if len(metrics) == 0:
basic_table_info['num_samples'] = args.num_samples
data_loader = None
else:
basic_table_info['num_samples'] = -1
if cfg.data.get('test', None):
dataset = build_dataset(cfg.data.test)
else:
dataset = build_dataset(cfg.data.train)
data_loader = build_dataloader(
dataset,
samples_per_gpu=args.batch_size,
workers_per_gpu=cfg.data.get('val_workers_per_gpu',
cfg.data.workers_per_gpu),
dist=False,
shuffle=True)
if args.online:
single_gpu_online_evaluation(model, data_loader, metrics, logger,
basic_table_info, args.batch_size)
else:
single_gpu_evaluation(model, data_loader, metrics, logger,
basic_table_info, args.batch_size,
args.samples_path)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
from .version import __version__, parse_version_info, version_info
def digit_version(version_str):
digit_version = []
for x in version_str.split('.'):
if x.isdigit():
digit_version.append(int(x))
elif x.find('rc') != -1:
patch_version = x.split('rc')
digit_version.append(int(patch_version[0]) - 1)
digit_version.append(int(patch_version[1]))
return digit_version
mmcv_minimum_version = '1.3.0'
mmcv_maximum_version = '1.8.0'
mmcv_version = digit_version(mmcv.__version__)
assert (mmcv_version >= digit_version(mmcv_minimum_version)
and mmcv_version <= digit_version(mmcv_maximum_version)), \
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
__all__ = ['__version__', 'version_info', 'parse_version_info']
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import (init_model, sample_conditional_model,
sample_ddpm_model, sample_img2img_model,
sample_unconditional_model)
from .train import set_random_seed, train_model
__all__ = [
'set_random_seed', 'train_model', 'init_model', 'sample_img2img_model',
'sample_unconditional_model', 'sample_conditional_model',
'sample_ddpm_model'
]
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from mmcv.utils import is_list_of
from mmgen.datasets.pipelines import Compose
from mmgen.models import BaseTranslationModel, build_model
def init_model(config, checkpoint=None, device='cpu', cfg_options=None):
"""Initialize a detector from config file.
Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
cfg_options (dict): Options to override some settings in the used
config.
Returns:
nn.Module: The constructed unconditional model.
"""
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
if cfg_options is not None:
config.merge_from_dict(cfg_options)
model = build_model(
config.model, train_cfg=config.train_cfg, test_cfg=config.test_cfg)
if checkpoint is not None:
load_checkpoint(model, checkpoint, map_location='cpu')
model._cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
@torch.no_grad()
def sample_unconditional_model(model,
num_samples=16,
num_batches=4,
sample_model='ema',
**kwargs):
"""Sampling from unconditional models.
Args:
model (nn.Module): Unconditional models in MMGeneration.
num_samples (int, optional): The total number of samples.
Defaults to 16.
num_batches (int, optional): The number of batch size for inference.
Defaults to 4.
sample_model (str, optional): Which model you want to use. ['ema',
'orig']. Defaults to 'ema'.
Returns:
Tensor: Generated image tensor.
"""
# set eval mode
model.eval()
# construct sampling list for batches
n_repeat = num_samples // num_batches
batches_list = [num_batches] * n_repeat
if num_samples % num_batches > 0:
batches_list.append(num_samples % num_batches)
res_list = []
# inference
for batches in batches_list:
res = model.sample_from_noise(
None, num_batches=batches, sample_model=sample_model, **kwargs)
res_list.append(res.cpu())
results = torch.cat(res_list, dim=0)
return results
@torch.no_grad()
def sample_conditional_model(model,
num_samples=16,
num_batches=4,
sample_model='ema',
label=None,
**kwargs):
"""Sampling from conditional models.
Args:
model (nn.Module): Conditional models in MMGeneration.
num_samples (int, optional): The total number of samples.
Defaults to 16.
num_batches (int, optional): The number of batch size for inference.
Defaults to 4.
sample_model (str, optional): Which model you want to use. ['ema',
'orig']. Defaults to 'ema'.
label (int | torch.Tensor | list[int], optional): Labels used to
generate images. Default to None.,
Returns:
Tensor: Generated image tensor.
"""
# set eval mode
model.eval()
# construct sampling list for batches
n_repeat = num_samples // num_batches
batches_list = [num_batches] * n_repeat
# check and convert the input labels
if isinstance(label, int):
label = torch.LongTensor([label] * num_samples)
elif isinstance(label, torch.Tensor):
label = label.type(torch.int64)
if label.numel() == 1:
# repeat single tensor
# call view(-1) to avoid nested tensor like [[[1]]]
label = label.view(-1).repeat(num_samples)
else:
# flatten multi tensors
label = label.view(-1)
elif isinstance(label, list):
if is_list_of(label, int):
label = torch.LongTensor(label)
# `nargs='+'` parse single integer as list
if label.numel() == 1:
# repeat single tensor
label = label.repeat(num_samples)
else:
raise TypeError('Only support `int` for label list elements, '
f'but receive {type(label[0])}')
elif label is None:
pass
else:
raise TypeError('Only support `int`, `torch.Tensor`, `list[int]` or '
f'None as label, but receive {type(label)}.')
# check the length of the (converted) label
if label is not None and label.size(0) != num_samples:
raise ValueError('Number of elements in the label list should be ONE '
'or the length of `num_samples`. Requires '
f'{num_samples}, but receive {label.size(0)}.')
# make label list
label_list = []
for n in range(n_repeat):
if label is None:
label_list.append(None)
else:
label_list.append(label[n * num_batches:(n + 1) * num_batches])
if num_samples % num_batches > 0:
batches_list.append(num_samples % num_batches)
if label is None:
label_list.append(None)
else:
label_list.append(label[(n + 1) * num_batches:])
res_list = []
# inference
for batches, labels in zip(batches_list, label_list):
res = model.sample_from_noise(
None,
num_batches=batches,
label=labels,
sample_model=sample_model,
**kwargs)
res_list.append(res.cpu())
results = torch.cat(res_list, dim=0)
return results
def sample_img2img_model(model, image_path, target_domain=None, **kwargs):
"""Sampling from translation models.
Args:
model (nn.Module): The loaded model.
image_path (str): File path of input image.
style (str): Target style of output image.
Returns:
Tensor: Translated image tensor.
"""
assert isinstance(model, BaseTranslationModel)
# get source domain and target domain
if target_domain is None:
target_domain = model._default_domain
source_domain = model.get_other_domains(target_domain)[0]
cfg = model._cfg
device = next(model.parameters()).device # model device
# build the data pipeline
test_pipeline = Compose(cfg.test_pipeline)
# prepare data
data = dict()
# dirty code to deal with test data pipeline
data['pair_path'] = image_path
data[f'img_{source_domain}_path'] = image_path
data[f'img_{target_domain}_path'] = image_path
data = test_pipeline(data)
if device.type == 'cpu':
data = collate([data], samples_per_gpu=1)
data['meta'] = []
else:
data = scatter(collate([data], samples_per_gpu=1), [device])[0]
source_image = data[f'img_{source_domain}']
# forward the model
with torch.no_grad():
results = model(
source_image,
test_mode=True,
target_domain=target_domain,
**kwargs)
output = results['target']
return output
@torch.no_grad()
def sample_ddpm_model(model,
num_samples=16,
num_batches=4,
sample_model='ema',
same_noise=False,
**kwargs):
"""Sampling from ddpm models.
Args:
model (nn.Module): DDPM models in MMGeneration.
num_samples (int, optional): The total number of samples.
Defaults to 16.
num_batches (int, optional): The number of batch size for inference.
Defaults to 4.
sample_model (str, optional): Which model you want to use. ['ema',
'orig']. Defaults to 'ema'.
noise_batch (torch.Tensor): Noise batch used as denoising starting up.
Defaults to None.
Returns:
list[Tensor | dict]: Generated image tensor.
"""
model.eval()
n_repeat = num_samples // num_batches
batches_list = [num_batches] * n_repeat
if num_samples % num_batches > 0:
batches_list.append(num_samples % num_batches)
noise_batch = torch.randn(model.image_shape) if same_noise else None
res_list = []
# inference
for idx, batches in enumerate(batches_list):
mmcv.print_log(
f'Start to sample batch [{idx+1} / '
f'{len(batches_list)}]', 'mmgen')
noise_batch_ = noise_batch[None, ...].expand(batches, -1, -1, -1) \
if same_noise else None
res = model.sample_from_noise(
noise_batch_,
num_batches=batches,
sample_model=sample_model,
show_pbar=True,
**kwargs)
if isinstance(res, dict):
res = {k: v.cpu() for k, v in res.items()}
elif isinstance(res, torch.Tensor):
res = res.cpu()
else:
raise ValueError('Sample results should be \'dict\' or '
f'\'torch.Tensor\', but receive \'{type(res)}\'')
res_list.append(res)
# gather the res_list
if isinstance(res_list[0], dict):
res_dict = dict()
for t in res_list[0].keys():
# num_samples x 3 x H x W
res_dict[t] = torch.cat([res[t] for res in res_list], dim=0)
return res_dict
else:
return torch.cat(res_list, dim=0)
# Copyright (c) OpenMMLab. All rights reserved.
import os
from copy import deepcopy
import mmcv
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import HOOKS, IterBasedRunner, OptimizerHook, build_runner
from mmcv.runner import set_random_seed as set_random_seed_mmcv
from mmcv.utils import build_from_cfg
from mmgen.core.ddp_wrapper import DistributedDataParallelWrapper
from mmgen.core.optimizer import build_optimizers
from mmgen.core.runners.apex_amp_utils import apex_amp_initialize
from mmgen.datasets import build_dataloader, build_dataset
from mmgen.utils import get_root_logger
def set_random_seed(seed, deterministic=False, use_rank_shift=True):
"""Set random seed.
In this function, we just modify the default behavior of the similar
function defined in MMCV.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
rank_shift (bool): Whether to add rank number to the random seed to
have different random seed in different threads. Default: True.
"""
set_random_seed_mmcv(
seed, deterministic=deterministic, use_rank_shift=use_rank_shift)
def train_model(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
logger = get_root_logger(cfg.log_level)
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
# default loader config
loader_cfg = dict(
samples_per_gpu=cfg.data.samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
persistent_workers=cfg.data.get('persistent_workers', False),
seed=cfg.seed)
# The overall dataloader settings
loader_cfg.update({
k: v
for k, v in cfg.data.items() if k not in [
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
'test_dataloader'
]
})
# The specific datalaoder settings
train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})}
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
# dirty code for use apex amp
# apex.amp request that models should be in cuda device before
# initialization.
if cfg.get('apex_amp', None):
assert distributed, (
'Currently, apex.amp is only supported with DDP training.')
model = model.cuda()
# build optimizer
if cfg.optimizer:
optimizer = build_optimizers(model, cfg.optimizer)
# In GANs, we allow building optimizer in GAN model.
else:
optimizer = None
_use_apex_amp = False
if cfg.get('apex_amp', None):
model, optimizer = apex_amp_initialize(model, optimizer,
**cfg.apex_amp)
_use_apex_amp = True
# put model on gpus
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
use_ddp_wrapper = cfg.get('use_ddp_wrapper', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
if use_ddp_wrapper:
mmcv.print_log('Use DDP Wrapper.', 'mmgen')
model = DistributedDataParallelWrapper(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
# allow users to define the runner
if cfg.get('runner', None):
runner = build_runner(
cfg.runner,
dict(
model=model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
use_apex_amp=_use_apex_amp,
meta=meta))
else:
runner = IterBasedRunner(
model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta)
# set if use dynamic ddp in training
# is_dynamic_ddp=cfg.get('is_dynamic_ddp', False))
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
# In GANs, we can directly optimize parameter in `train_step` function.
if cfg.get('optimizer_cfg', None) is None:
optimizer_config = None
elif fp16_cfg is not None:
raise NotImplementedError('Fp16 has not been supported.')
# optimizer_config = Fp16OptimizerHook(
# **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
# default to use OptimizerHook
elif distributed and 'type' not in cfg.optimizer_config:
optimizer_config = OptimizerHook(**cfg.optimizer_config)
else:
optimizer_config = cfg.optimizer_config
# update `out_dir` in ckpt hook
if cfg.checkpoint_config is not None:
cfg.checkpoint_config['out_dir'] = os.path.join(
cfg.work_dir, cfg.checkpoint_config.get('out_dir', 'ckpt'))
# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
# # DistSamplerSeedHook should be used with EpochBasedRunner
# if distributed:
# runner.register_hook(DistSamplerSeedHook())
# In general, we do NOT adopt standard evaluation hook in GAN training.
# Thus, if you want a eval hook, you need further define the key of
# 'evaluation' in the config.
# register eval hooks
if validate and cfg.get('evaluation', None) is not None:
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
# Support batch_size > 1 in validation
val_loader_cfg = {
**loader_cfg, 'shuffle': False,
**cfg.data.get('val_data_loader', {})
}
val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)
eval_cfg = deepcopy(cfg.get('evaluation'))
priority = eval_cfg.pop('priority', 'LOW')
eval_cfg.update(dict(dist=distributed, dataloader=val_dataloader))
eval_hook = build_from_cfg(eval_cfg, HOOKS)
runner.register_hook(eval_hook, priority=priority)
# user-defined hooks
if cfg.get('custom_hooks', None):
custom_hooks = cfg.custom_hooks
assert isinstance(custom_hooks, list), \
f'custom_hooks expect list type, but got {type(custom_hooks)}'
for hook_cfg in cfg.custom_hooks:
assert isinstance(hook_cfg, dict), \
'Each item in custom_hooks expects dict type, but got ' \
f'{type(hook_cfg)}'
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = build_from_cfg(hook_cfg, HOOKS)
runner.register_hook(hook, priority=priority)
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_iters)
# Copyright (c) OpenMMLab. All rights reserved.
from .evaluation import * # noqa: F401, F403
from .hooks import * # noqa: F401, F403
from .optimizer import * # noqa: F401, F403
from .registry import * # noqa: F401, F403
from .runners import * # noqa: F401, F403
from .scheduler import * # noqa: F401, F403
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.parallel import MODULE_WRAPPERS, MMDistributedDataParallel
from mmcv.parallel.scatter_gather import scatter_kwargs
from torch.cuda._utils import _get_device_index
@MODULE_WRAPPERS.register_module('mmgen.DDPWrapper')
class DistributedDataParallelWrapper(nn.Module):
"""A DistributedDataParallel wrapper for models in MMGeneration.
In MMedting, there is a need to wrap different modules in the models
with separate DistributedDataParallel. Otherwise, it will cause
errors for GAN training.
More specific, the GAN model, usually has two sub-modules:
generator and discriminator. If we wrap both of them in one
standard DistributedDataParallel, it will cause errors during training,
because when we update the parameters of the generator (or discriminator),
the parameters of the discriminator (or generator) is not updated, which is
not allowed for DistributedDataParallel.
So we design this wrapper to separately wrap DistributedDataParallel
for generator and discriminator.
In this wrapper, we perform two operations:
1. Wrap the modules in the models with separate MMDistributedDataParallel.
Note that only modules with parameters will be wrapped.
2. Do scatter operation for 'forward', 'train_step' and 'val_step'.
Note that the arguments of this wrapper is the same as those in
`torch.nn.parallel.distributed.DistributedDataParallel`.
Args:
module (nn.Module): Module that needs to be wrapped.
device_ids (list[int | `torch.device`]): Same as that in
`torch.nn.parallel.distributed.DistributedDataParallel`.
dim (int, optional): Same as that in the official scatter function in
pytorch. Defaults to 0.
broadcast_buffers (bool): Same as that in
`torch.nn.parallel.distributed.DistributedDataParallel`.
Defaults to False.
find_unused_parameters (bool, optional): Same as that in
`torch.nn.parallel.distributed.DistributedDataParallel`.
Traverse the autograd graph of all tensors contained in returned
value of the wrapped module’s forward function. Defaults to False.
kwargs (dict): Other arguments used in
`torch.nn.parallel.distributed.DistributedDataParallel`.
"""
def __init__(self,
module,
device_ids,
dim=0,
broadcast_buffers=False,
find_unused_parameters=False,
**kwargs):
super().__init__()
assert len(device_ids) == 1, (
'Currently, DistributedDataParallelWrapper only supports one'
'single CUDA device for each process.'
f'The length of device_ids must be 1, but got {len(device_ids)}.')
self.module = module
self.dim = dim
self.to_ddp(
device_ids=device_ids,
dim=dim,
broadcast_buffers=broadcast_buffers,
find_unused_parameters=find_unused_parameters,
**kwargs)
self.output_device = _get_device_index(device_ids[0], True)
def to_ddp(self, device_ids, dim, broadcast_buffers,
find_unused_parameters, **kwargs):
"""Wrap models with separate MMDistributedDataParallel.
It only wraps the modules with parameters.
"""
for name, module in self.module._modules.items():
if next(module.parameters(), None) is None:
module = module.cuda()
elif all(not p.requires_grad for p in module.parameters()):
module = module.cuda()
else:
module = MMDistributedDataParallel(
module.cuda(),
device_ids=device_ids,
dim=dim,
broadcast_buffers=broadcast_buffers,
find_unused_parameters=find_unused_parameters,
**kwargs)
self.module._modules[name] = module
def scatter(self, inputs, kwargs, device_ids):
"""Scatter function.
Args:
inputs (Tensor): Input Tensor.
kwargs (dict): Args for
``mmcv.parallel.scatter_gather.scatter_kwargs``.
device_ids (int): Device id.
"""
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def forward(self, *inputs, **kwargs):
"""Forward function.
Args:
inputs (tuple): Input data.
kwargs (dict): Args for
``mmcv.parallel.scatter_gather.scatter_kwargs``.
"""
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
return self.module(*inputs[0], **kwargs[0])
def train_step(self, *inputs, **kwargs):
"""Train step function.
Args:
inputs (Tensor): Input Tensor.
kwargs (dict): Args for
``mmcv.parallel.scatter_gather.scatter_kwargs``.
"""
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
output = self.module.train_step(*inputs[0], **kwargs[0])
return output
def val_step(self, *inputs, **kwargs):
"""Validation step function.
Args:
inputs (tuple): Input data.
kwargs (dict): Args for ``scatter_kwargs``.
"""
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
output = self.module.val_step(*inputs[0], **kwargs[0])
return output
# Copyright (c) OpenMMLab. All rights reserved.
from .eval_hooks import GenerativeEvalHook, TranslationEvalHook
from .evaluation import (make_metrics_table, make_vanilla_dataloader,
offline_evaluation, online_evaluation)
from .metric_utils import slerp
from .metrics import (IS, MS_SSIM, PR, SWD, GaussianKLD, ms_ssim,
sliced_wasserstein)
__all__ = [
'MS_SSIM', 'SWD', 'ms_ssim', 'sliced_wasserstein', 'offline_evaluation',
'online_evaluation', 'PR', 'IS', 'slerp', 'GenerativeEvalHook',
'make_metrics_table', 'make_vanilla_dataloader', 'GaussianKLD',
'TranslationEvalHook'
]
# Copyright (c) OpenMMLab. All rights reserved.
import math
import os
import os.path as osp
import sys
import warnings
from bisect import bisect_right
import mmcv
import torch
from mmcv.runner import HOOKS, Hook, get_dist_info
from ..registry import build_metric
@HOOKS.register_module()
class GenerativeEvalHook(Hook):
"""Evaluation Hook for Generative Models.
This evaluation hook can be used to evaluate unconditional and conditional
models. Note that only ``FID`` and ``IS`` metric are supported for the
distributed training now. In the future, we will support more metrics for
the evaluation during the training procedure.
In our config system, you only need to add `evaluation` with the detailed
configureations. Below is several usage cases for different situations.
What you need to do is to add these lines at the end of your config file.
Then, you can use this evaluation hook in the training procedure.
To be noted that, this evaluation hook support evaluation with dynamic
intervals for FID or other metrics may fluctuate frequently at the end of
the training process.
# TODO: fix the online doc
#. Only use FID for evaluation
.. code-block:: python
:linenos:
evaluation = dict(
type='GenerativeEvalHook',
interval=10000,
metrics=dict(
type='FID',
num_images=50000,
inception_pkl='work_dirs/inception_pkl/ffhq-256-50k-rgb.pkl',
bgr2rgb=True),
sample_kwargs=dict(sample_model='ema'))
#. Use FID and IS simultaneously and save the best checkpoints respectively
.. code-block:: python
:linenos:
evaluation = dict(
type='GenerativeEvalHook',
interval=10000,
metrics=[dict(
type='FID',
num_images=50000,
inception_pkl='work_dirs/inception_pkl/ffhq-256-50k-rgb.pkl',
bgr2rgb=True),
dict(type='IS',
num_images=50000)],
best_metric=['fid', 'is'],
sample_kwargs=dict(sample_model='ema'))
#. Use dynamic evaluation intervals
.. code-block:: python
:linenos:
# interval = 10000 if iter < 50000,
# interval = 4000, if 50000 <= iter < 750000,
# interval = 2000, if iter >= 750000
evaluation = dict(
type='GenerativeEvalHook',
interval=dict(milestones=[500000, 750000],
interval=[10000, 4000, 2000])
metrics=[dict(
type='FID',
num_images=50000,
inception_pkl='work_dirs/inception_pkl/ffhq-256-50k-rgb.pkl',
bgr2rgb=True),
dict(type='IS',
num_images=50000)],
best_metric=['fid', 'is'],
sample_kwargs=dict(sample_model='ema'))
Args:
dataloader (DataLoader): A PyTorch dataloader.
interval (int | dict): Evaluation interval. If int is passed,
``eval_hook`` would run under given interval. If a dict is passed,
The key and value would be interpret as 'milestones' and 'interval'
of the evaluation. Default: 1.
dist (bool, optional): Whether to use distributed evaluation.
Defaults to True.
metrics (dict | list[dict], optional): Configs for metrics that will be
used in evaluation hook. Defaults to None.
sample_kwargs (dict | None, optional): Additional keyword arguments for
sampling images. Defaults to None.
save_best_ckpt (bool, optional): Whether to save the best checkpoint
according to ``best_metric``. Defaults to ``True``.
best_metric (str | list, optional): Which metric to be used in saving
the best checkpoint. Multiple metrics have been supported by
inputing a list of metric names, e.g., ``['fid', 'is']``.
Defaults to ``'fid'``.
"""
rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
init_value_map = {'greater': -math.inf, 'less': math.inf}
greater_keys = ['acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'is']
less_keys = ['loss', 'fid']
_supported_best_metrics = ['fid', 'is']
def __init__(self,
dataloader,
interval=1,
dist=True,
metrics=None,
sample_kwargs=None,
save_best_ckpt=True,
best_metric='fid'):
assert metrics is not None
self.dataloader = dataloader
self.dist = dist
self.sample_kwargs = sample_kwargs if sample_kwargs else dict()
self.save_best_ckpt = save_best_ckpt
self.best_metric = best_metric
if isinstance(interval, int):
self.interval = interval
elif isinstance(interval, dict):
if 'milestones' not in interval or 'interval' not in interval:
raise KeyError(
'`milestones` and `interval` must exist in interval dict '
'if you want to use the dynamic interval evaluation '
f'strategy. But receive [{[k for k in interval.keys()]}] '
'in the interval dict.')
self.milestones = interval['milestones']
self.interval = interval['interval']
# check if length of interval match with the milestones
if len(self.interval) != len(self.milestones) + 1:
raise ValueError(
f'Length of `interval`(={len(self.interval)}) cannot '
f'match length of `milestones`(={len(self.milestones)}).')
# check if milestones is in order
for idx in range(len(self.milestones) - 1):
former, latter = self.milestones[idx], self.milestones[idx + 1]
if former >= latter:
raise ValueError(
'Elements in `milestones` should in ascending order.')
else:
raise TypeError('`interval` only support `int` or `dict`,'
f'recieve {type(self.interval)} instead.')
if isinstance(best_metric, str):
self.best_metric = [self.best_metric]
if self.save_best_ckpt:
not_supported = set(self.best_metric) - set(
self._supported_best_metrics)
assert len(not_supported) == 0, (
f'{not_supported} is not supported for saving best ckpt')
self.metrics = build_metric(metrics)
if isinstance(metrics, dict):
self.metrics = [self.metrics]
for metric in self.metrics:
metric.prepare()
# add support for saving best ckpt
if self.save_best_ckpt:
self.rule = {}
self.compare_func = {}
self._curr_best_score = {}
self._curr_best_ckpt_path = {}
for name in self.best_metric:
if name in self.greater_keys:
self.rule[name] = 'greater'
else:
self.rule[name] = 'less'
self.compare_func[name] = self.rule_map[self.rule[name]]
self._curr_best_score[name] = self.init_value_map[
self.rule[name]]
self._curr_best_ckpt_path[name] = None
def get_current_interval(self, runner):
"""Get current evaluation interval.
Args:
runner (``mmcv.runner.BaseRunner``): The runner.
"""
if isinstance(self.interval, int):
return self.interval
else:
curr_iter = runner.iter + 1
index = bisect_right(self.milestones, curr_iter)
return self.interval[index]
def before_run(self, runner):
"""The behavior before running.
Args:
runner (``mmcv.runner.BaseRunner``): The runner.
"""
if self.save_best_ckpt is not None:
if runner.meta is None:
warnings.warn('runner.meta is None. Creating an empty one.')
runner.meta = dict()
runner.meta.setdefault('hook_msgs', dict())
def after_train_iter(self, runner):
"""The behavior after each train iteration.
Args:
runner (``mmcv.runner.BaseRunner``): The runner.
"""
interval = self.get_current_interval(runner)
if not self.every_n_iters(runner, interval):
return
runner.model.eval()
batch_size = self.dataloader.batch_size
rank, ws = get_dist_info()
total_batch_size = batch_size * ws
# sample real images
max_real_num_images = max(metric.num_images - metric.num_real_feeded
for metric in self.metrics)
# define mmcv progress bar
if rank == 0 and max_real_num_images > 0:
mmcv.print_log(
f'Sample {max_real_num_images} real images for evaluation',
'mmgen')
pbar = mmcv.ProgressBar(max_real_num_images)
if max_real_num_images > 0:
for data in self.dataloader:
if 'real_img' in data:
reals = data['real_img']
# key for conditional GAN
elif 'img' in data:
reals = data['img']
else:
raise KeyError('Cannot found key for images in data_dict. '
'Only support `real_img` for unconditional '
'datasets and `img` for conditional '
'datasets.')
if reals.shape[1] not in [1, 3]:
raise RuntimeError('real images should have one or three '
'channels in the first, '
'not % d' % reals.shape[1])
if reals.shape[1] == 1:
reals = reals.repeat(1, 3, 1, 1)
num_feed = 0
for metric in self.metrics:
num_feed_ = metric.feed(reals, 'reals')
num_feed = max(num_feed_, num_feed)
if num_feed <= 0:
break
if rank == 0:
pbar.update(num_feed)
max_num_images = max(metric.num_images for metric in self.metrics)
if rank == 0:
mmcv.print_log(
f'Sample {max_num_images} fake images for evaluation', 'mmgen')
# define mmcv progress bar
if rank == 0:
pbar = mmcv.ProgressBar(max_num_images)
# sampling fake images and directly send them to metrics
for _ in range(0, max_num_images, total_batch_size):
with torch.no_grad():
fakes = runner.model(
None,
num_batches=batch_size,
return_loss=False,
**self.sample_kwargs)
for metric in self.metrics:
# feed in fake images
metric.feed(fakes, 'fakes')
if rank == 0:
pbar.update(total_batch_size)
runner.log_buffer.clear()
# a dirty walkround to change the line at the end of pbar
if rank == 0:
sys.stdout.write('\n')
for metric in self.metrics:
with torch.no_grad():
metric.summary()
for name, val in metric._result_dict.items():
runner.log_buffer.output[name] = val
# record best metric and save the best ckpt
if self.save_best_ckpt and name in self.best_metric:
self._save_best_ckpt(runner, val, name)
runner.log_buffer.ready = True
runner.model.train()
# clear all current states for next evaluation
for metric in self.metrics:
metric.clear()
def _save_best_ckpt(self, runner, new_score, metric_name):
"""Save checkpoint with best metric score.
Args:
runner (``mmcv.runner.BaseRunner``): The runner.
new_score (float): New metric score.
metric_name (str): Name of metric.
"""
curr_iter = f'iter_{runner.iter + 1}'
if self.compare_func[metric_name](new_score,
self._curr_best_score[metric_name]):
best_ckpt_name = f'best_{metric_name}_{curr_iter}.pth'
runner.meta['hook_msgs'][f'best_score_{metric_name}'] = new_score
if self._curr_best_ckpt_path[metric_name] and osp.isfile(
self._curr_best_ckpt_path[metric_name]):
os.remove(self._curr_best_ckpt_path[metric_name])
self._curr_best_ckpt_path[metric_name] = osp.join(
runner.work_dir, best_ckpt_name)
runner.save_checkpoint(
runner.work_dir, best_ckpt_name, create_symlink=False)
runner.meta['hook_msgs'][
f'best_ckpt_{metric_name}'] = self._curr_best_ckpt_path[
metric_name]
self._curr_best_score[metric_name] = new_score
runner.logger.info(
f'Now best checkpoint is saved as {best_ckpt_name}.')
runner.logger.info(f'Best {metric_name} is {new_score:0.4f} '
f'at {curr_iter}.')
@HOOKS.register_module()
class TranslationEvalHook(GenerativeEvalHook):
"""Evaluation Hook for Translation Models.
This evaluation hook can be used to evaluate translation models. Note
that only ``FID`` and ``IS`` metric are supported for the distributed
training now. In the future, we will support more metrics for the
evaluation during the training procedure.
In our config system, you only need to add `evaluation` with the detailed
configureations. Below is several usage cases for different situations.
What you need to do is to add these lines at the end of your config file.
Then, you can use this evaluation hook in the training procedure.
To be noted that, this evaluation hook support evaluation with dynamic
intervals for FID or other metrics may fluctuate frequently at the end of
the training process.
# TODO: fix the online doc
#. Only use FID for evaluation
.. code-blcok:: python
:linenos
evaluation = dict(
type='TranslationEvalHook',
target_domain='photo',
interval=10000,
metrics=dict(type='FID', num_images=106, bgr2rgb=True))
#. Use FID and IS simultaneously and save the best checkpoints respectively
.. code-block:: python
:linenos
evaluation = dict(
type='TranslationEvalHook',
target_domain='photo',
interval=10000,
metrics=[
dict(type='FID', num_images=106, bgr2rgb=True),
dict(
type='IS',
num_images=106,
inception_args=dict(type='pytorch'))
],
best_metric=['fid', 'is'])
#. Use dynamic evaluation intervals
.. code-block:: python
:linenos
# interval = 10000 if iter < 100000,
# interval = 4000, if 100000 <= iter < 200000,
# interval = 2000, if iter >= 200000
evaluation = dict(
type='TranslationEvalHook',
interval=dict(milestones=[100000, 200000],
interval=[10000, 4000, 2000]),
target_domain='zebra',
metrics=[
dict(type='FID', num_images=140, bgr2rgb=True),
dict(type='IS', num_images=140)
],
best_metric=['fid', 'is'])
Args:
target_domain (str): Target domain of output image.
"""
def __init__(self, *args, target_domain, **kwargs):
super().__init__(*args, **kwargs)
self.target_domain = target_domain
def after_train_iter(self, runner):
"""The behavior after each train iteration.
Args:
runner (``mmcv.runner.BaseRunner``): The runner.
"""
interval = self.get_current_interval(runner)
if not self.every_n_iters(runner, interval):
return
runner.model.eval()
source_domain = runner.model.module.get_other_domains(
self.target_domain)[0]
# feed real images
max_num_images = max(metric.num_images for metric in self.metrics)
for metric in self.metrics:
if metric.num_real_feeded >= metric.num_real_need:
continue
mmcv.print_log(f'Feed reals to {metric.name} metric.', 'mmgen')
# feed in real images
for data in self.dataloader:
# key for translation model
if f'img_{self.target_domain}' in data:
reals = data[f'img_{self.target_domain}']
# key for conditional GAN
else:
raise KeyError(
'Cannot found key for images in data_dict. ')
num_feed = metric.feed(reals, 'reals')
if num_feed <= 0:
break
mmcv.print_log(f'Sample {max_num_images} fake images for evaluation',
'mmgen')
rank, ws = get_dist_info()
# define mmcv progress bar
if rank == 0:
pbar = mmcv.ProgressBar(max_num_images)
# feed in fake images
for data in self.dataloader:
# key for translation model
if f'img_{source_domain}' in data:
with torch.no_grad():
output_dict = runner.model(
data[f'img_{source_domain}'],
test_mode=True,
target_domain=self.target_domain,
**self.sample_kwargs)
fakes = output_dict['target']
# key Error
else:
raise KeyError('Cannot found key for images in data_dict. ')
# sampling fake images and directly send them to metrics
# pbar update number for one proc
num_update = 0
for metric in self.metrics:
if metric.num_fake_feeded >= metric.num_fake_need:
continue
num_feed = metric.feed(fakes, 'fakes')
num_update = max(num_update, num_feed)
if num_feed <= 0:
break
if rank == 0:
if num_update > 0:
pbar.update(num_update * ws)
runner.log_buffer.clear()
# a dirty walkround to change the line at the end of pbar
if rank == 0:
sys.stdout.write('\n')
for metric in self.metrics:
with torch.no_grad():
metric.summary()
for name, val in metric._result_dict.items():
runner.log_buffer.output[name] = val
# record best metric and save the best ckpt
if self.save_best_ckpt and name in self.best_metric:
self._save_best_ckpt(runner, val, name)
runner.log_buffer.ready = True
runner.model.train()
# clear all current states for next evaluation
for metric in self.metrics:
metric.clear()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment