"include/vscode:/vscode.git/clone" did not exist on "d955e83bee3919b871616223b777bab2f04942d9"
Commit b7536f78 authored by limm's avatar limm
Browse files

add a to another part of mmgeneration code

parent 57e0e891
Pipeline #2777 canceled with stages
#!/usr/bin/env bash
set -x
PARTITION=$1
JOB_NAME=$2
CONFIG=$3
WORK_DIR=$4
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
PY_ARGS=${@:5}
SRUN_ARGS=${SRUN_ARGS:-""}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS}
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy
import multiprocessing as mp
import os
import os.path as osp
import platform
import time
import warnings
import cv2
import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import get_git_hash
from mmgen import __version__
from mmgen.apis import set_random_seed, train_model
from mmgen.datasets import build_dataset
from mmgen.models import build_model
from mmgen.utils import collect_env, get_root_logger
cv2.setNumThreads(0)
def parse_args():
parser = argparse.ArgumentParser(description='Train a GAN model')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
type=int,
help='(Deprecated, please use --gpu-id) number of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-id',
type=int,
default=0,
help='id of gpu to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=2021, help='random seed')
parser.add_argument(
'--diff_seed',
action='store_true',
help='Whether or not set different seeds for different ranks')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def setup_multi_processes(cfg):
# set multi-process start method as `fork` to speed up the training
if platform.system() != 'Windows':
mp_start_method = cfg.get('mp_start_method', 'fork')
mp.set_start_method(mp_start_method)
# disable opencv multithreading to avoid system being overloaded
opencv_num_threads = cfg.get('opencv_num_threads', 0)
cv2.setNumThreads(opencv_num_threads)
# setup OMP threads
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
if ('OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1):
omp_num_threads = 1
warnings.warn(
f'Setting OMP_NUM_THREADS environment variable for each process '
f'to be {omp_num_threads} in default, to avoid your system being '
f'overloaded, please further tune the variable for optimal '
f'performance in your application as needed.')
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
# setup MKL threads
if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
mkl_num_threads = 1
warnings.warn(
f'Setting MKL_NUM_THREADS environment variable for each process '
f'to be {mkl_num_threads} in default, to avoid your system being '
f'overloaded, please further tune the variable for optimal '
f'performance in your application as needed.')
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
setup_multi_processes(cfg)
# import modules from string list.
if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports'])
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
if args.gpus is not None:
cfg.gpu_ids = range(1)
warnings.warn('`--gpus` is deprecated because we only support '
'single GPU mode in non-distributed training. '
'Use `gpus=1` now.')
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids[0:1]
warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
'Because we only support single GPU mode in '
'non-distributed training. Use the first GPU '
'in `gpu_ids` now.')
if args.gpus is None and args.gpu_ids is None:
cfg.gpu_ids = [args.gpu_id]
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# re-set gpu_ids with distributed training mode
_, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# dump config
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info
meta['config'] = cfg.pretty_text
# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, '
f'deterministic: {args.deterministic}, '
f'use_rank_shift: {args.diff_seed}')
set_random_seed(
args.seed,
deterministic=args.deterministic,
use_rank_shift=args.diff_seed)
cfg.seed = args.seed
meta['seed'] = args.seed
meta['exp_name'] = osp.basename(args.config)
model = build_model(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
val_dataset.pipeline = cfg.data.val.pipeline
datasets.append(build_dataset(val_dataset))
if cfg.checkpoint_config is not None:
# save mmgen version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(mmgen_version=__version__ +
get_git_hash()[:7])
train_model(
model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import pickle
import sys
import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmcv import Config, print_log
# yapf: disable
sys.path.append(osp.abspath(osp.join(__file__, '../../..'))) # isort:skip # noqa
from mmgen.core.evaluation.metric_utils import extract_inception_features # isort:skip # noqa
from mmgen.datasets import (UnconditionalImageDataset, build_dataloader, # isort:skip # noqa
build_dataset) # isort:skip # noqa
from mmgen.models.architectures import InceptionV3 # isort:skip # noqa
# yapf: enable
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Pre-calculate inception data and save it in pkl file')
parser.add_argument(
'--imgsdir', type=str, default=None, help='the dir containing images.')
parser.add_argument(
'--data-cfg',
type=str,
default=None,
help='the config file for test data pipeline')
parser.add_argument(
'--pklname', type=str, help='the name of inception pkl')
parser.add_argument(
'--pkl-dir',
type=str,
default='work_dirs/inception_pkl',
help='path to save pkl file')
parser.add_argument(
'--pipeline-cfg',
type=str,
default=None,
help=('config file containing dataset pipeline. If None, the default'
' pipeline will be adopted'))
parser.add_argument(
'--flip', action='store_true', help='whether to flip real images')
parser.add_argument(
'--size',
type=int,
nargs='+',
default=(299, 299),
help='image size in the data pipeline')
parser.add_argument(
'--batch-size',
type=int,
default=25,
help='batch size used in extracted features')
parser.add_argument(
'--num-samples',
type=int,
default=50000,
help=('the number of total samples, if input -1, '
'automaticly use all samples in the subset'))
parser.add_argument(
'--no-shuffle',
action='store_true',
help='not use shuffle in data loader')
parser.add_argument(
'--subset',
default='test',
help='which subset and corresponding pipeline to use')
parser.add_argument(
'--inception-style',
choices=['stylegan', 'pytorch'],
default='pytorch',
help='which inception network to use')
parser.add_argument(
'--inception-pth',
type=str,
default='work_dirs/cache/inception-2015-12-05.pt')
args = parser.parse_args()
# dataset pipeline (only be used when args.imgsdir is not None)
if args.pipeline_cfg is not None:
pipeline = Config.fromfile(args.pipeline_cfg)['inception_pipeline']
elif args.imgsdir is not None:
if isinstance(args.size, list) and len(args.size) == 2:
size = args.size
elif isinstance(args.size, list) and len(args.size) == 1:
size = (args.size[0], args.size[0])
elif isinstance(args.size, int):
size = (args.size, args.size)
else:
raise TypeError(
f'args.size mush be int or tuple but got {args.size}')
pipeline = [
dict(type='LoadImageFromFile', key='real_img'),
dict(
type='Resize', keys=['real_img'], scale=size,
keep_ratio=False),
dict(
type='Normalize',
keys=['real_img'],
mean=[127.5] * 3,
std=[127.5] * 3,
to_rgb=True), # default to RGB images
dict(type='Collect', keys=['real_img'], meta_keys=[]),
dict(type='ImageToTensor', keys=['real_img'])
]
# insert flip aug
if args.flip:
pipeline.insert(
1,
dict(type='Flip', keys=['real_img'], direction='horizontal'))
# build dataloader
if args.imgsdir is not None:
dataset = UnconditionalImageDataset(args.imgsdir, pipeline)
elif args.data_cfg is not None:
# Please make sure the dataset will sample images in `RGB` order.
data_config = Config.fromfile(args.data_cfg)
subset_config = data_config.data.get(args.subset, None)
print_log(subset_config, 'mmgen')
dataset = build_dataset(subset_config)
else:
raise RuntimeError('Please provide imgsdir or data_cfg')
data_loader = build_dataloader(
dataset, args.batch_size, 4, dist=False, shuffle=(not args.no_shuffle))
mmcv.mkdir_or_exist(args.pkl_dir)
# build inception network
if args.inception_style == 'stylegan':
inception = torch.jit.load(args.inception_pth).eval().cuda()
inception = nn.DataParallel(inception)
print_log('Adopt Inception network in StyleGAN', 'mmgen')
else:
inception = nn.DataParallel(
InceptionV3([3], resize_input=True, normalize_input=False).cuda())
inception.eval()
if args.num_samples == -1:
print_log('Use all samples in subset', 'mmgen')
num_samples = len(dataset)
else:
num_samples = args.num_samples
features = extract_inception_features(data_loader, inception, num_samples,
args.inception_style).numpy()
# sanity check for the number of features
assert features.shape[
0] == num_samples, 'the number of features != num_samples'
print_log(f'Extract {num_samples} features', 'mmgen')
mean = np.mean(features, 0)
cov = np.cov(features, rowvar=False)
with open(osp.join(args.pkl_dir, args.pklname), 'wb') as f:
pickle.dump(
{
'mean': mean,
'cov': cov,
'size': num_samples,
'name': args.pklname
}, f)
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import sys
import mmcv
import torch
from mmcv import Config
from mmcv.parallel import MMDataParallel
from mmcv.runner import load_checkpoint
# yapf: disable
sys.path.append(os.path.abspath(os.path.join(__file__, '../../..'))) # isort:skip # noqa
from mmgen.apis import set_random_seed # isort:skip # noqa
from mmgen.models import build_model # isort:skip # noqa
# yapf: enable
def parse_args():
parser = argparse.ArgumentParser(description='Evaluate a GAN model')
parser.add_argument('config', help='evaluation config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--seed', type=int, default=2021, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--samples-path',
type=str,
default=None,
help='path to store images. If not given, remove it after evaluation\
finished')
parser.add_argument(
'--save-prev-res',
action='store_true',
help='whether to store the results from previous stages')
parser.add_argument(
'--num-samples',
type=int,
default=10,
help='the number of synthesized samples')
args = parser.parse_args()
return args
def _tensor2img(img):
img = img[0].permute(1, 2, 0)
img = ((img + 1) / 2 * 255).to(torch.uint8)
return img.cpu().numpy()
@torch.no_grad()
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# set random seeds
if args.seed is not None:
set_random_seed(args.seed, deterministic=args.deterministic)
# build the model and load checkpoint
model = build_model(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
model.eval()
# load ckpt
mmcv.print_log(f'Loading ckpt from {args.checkpoint}', 'mmgen')
_ = load_checkpoint(model, args.checkpoint, map_location='cpu')
# add dp wrapper
model = MMDataParallel(model, device_ids=[0])
pbar = mmcv.ProgressBar(args.num_samples)
for sample_iter in range(args.num_samples):
outputs = model(None, num_batches=1, get_prev_res=args.save_prev_res)
# store results from previous stages
if args.save_prev_res:
fake_img = outputs['fake_img']
prev_res_list = outputs['prev_res_list']
prev_res_list.append(fake_img)
for i, img in enumerate(prev_res_list):
img = _tensor2img(img)
mmcv.imwrite(
img,
os.path.join(args.samples_path, f'stage{i}',
f'rand_sample_{sample_iter}.png'))
# just store the final result
else:
img = _tensor2img(outputs)
mmcv.imwrite(
img,
os.path.join(args.samples_path,
f'rand_sample_{sample_iter}.png'))
pbar.update()
# change the line after pbar
sys.stdout.write('\n')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import shutil
import sys
import mmcv
import torch
from mmcv import Config
from mmcv.parallel import MMDataParallel
from mmcv.runner import load_checkpoint
from torchvision.utils import save_image
from mmgen.apis import set_random_seed
from mmgen.core import build_metric
from mmgen.core.evaluation import make_metrics_table, make_vanilla_dataloader
from mmgen.datasets import build_dataloader, build_dataset
from mmgen.models import build_model
from mmgen.models.translation_models import BaseTranslationModel
from mmgen.utils import get_root_logger
def parse_args():
parser = argparse.ArgumentParser(description='Evaluate a GAN model')
parser.add_argument('config', help='evaluation config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--target-domain', type=str, default=None, help='Desired image domain')
parser.add_argument('--seed', type=int, default=2021, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--batch-size', type=int, default=1, help='batch size of dataloader')
parser.add_argument(
'--samples-path',
type=str,
default=None,
help='path to store images. If not given, remove it after evaluation\
finished')
parser.add_argument(
'--sample-model',
type=str,
default='ema',
help='use which mode (ema/orig) in sampling')
parser.add_argument(
'--eval',
nargs='*',
type=str,
default=None,
help='select the metrics you want to access')
parser.add_argument(
'--online',
action='store_true',
help='whether to use online mode for evaluation')
args = parser.parse_args()
return args
@torch.no_grad()
def single_gpu_evaluation(model,
data_loader,
metrics,
logger,
basic_table_info,
batch_size,
samples_path=None,
**kwargs):
"""Evaluate model with a single gpu.
This method evaluate model with a single gpu and displays eval progress
bar.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): PyTorch data loader.
metrics (list): List of metric objects.
logger (Logger): logger used to record results of evaluation.
basic_table_info (dict): Dictionary containing the basic information \
of the metric table include training configuration and ckpt.
batch_size (int): Batch size of images fed into metrics.
samples_path (str): Used to save generated images. If it's none, we'll
give it a default directory and delete it after finishing the
evaluation. Default to None.
kwargs (dict): Other arguments.
"""
# decide samples path
delete_samples_path = False
if samples_path:
mmcv.mkdir_or_exist(samples_path)
else:
temp_path = './work_dirs/temp_samples'
# if temp_path exists, add suffix
suffix = 1
samples_path = temp_path
while os.path.exists(samples_path):
samples_path = temp_path + '_' + str(suffix)
suffix += 1
os.makedirs(samples_path)
delete_samples_path = True
# sample images
num_exist = len(
list(
mmcv.scandir(
samples_path, suffix=('.jpg', '.png', '.jpeg', '.JPEG'))))
if basic_table_info['num_samples'] > 0:
max_num_images = basic_table_info['num_samples']
else:
max_num_images = max(metric.num_images for metric in metrics)
num_needed = max(max_num_images - num_exist, 0)
if num_needed > 0:
mmcv.print_log(f'Sample {num_needed} fake images for evaluation',
'mmgen')
# define mmcv progress bar
pbar = mmcv.ProgressBar(num_needed)
# select key to fetch fake images
target_domain = basic_table_info['target_domain']
source_domain = basic_table_info['source_domain']
# if no images, `num_needed` should be zero
data_loader_iter = iter(data_loader)
for begin in range(0, num_needed, batch_size):
end = min(begin + batch_size, max_num_images)
# for translation model, we feed them images from dataloader
data_batch = next(data_loader_iter)
output_dict = model(
data_batch[f'img_{source_domain}'],
test_mode=True,
target_domain=target_domain)
fakes = output_dict['target']
pbar.update(end - begin)
for i in range(end - begin):
images = fakes[i:i + 1]
images = ((images + 1) / 2)
images = images[:, [2, 1, 0], ...]
images = images.clamp_(0, 1)
image_name = str(begin + i) + '.png'
save_image(images, os.path.join(samples_path, image_name))
if num_needed > 0:
sys.stdout.write('\n')
# return if only save sampled images
if len(metrics) == 0:
return
# empty cache to release GPU memory
torch.cuda.empty_cache()
fake_dataloader = make_vanilla_dataloader(samples_path, batch_size)
for metric in metrics:
mmcv.print_log(f'Evaluate with {metric.name} metric.', 'mmgen')
metric.prepare()
# feed in real images
for data in data_loader:
reals = data[f'img_{target_domain}']
num_left = metric.feed(reals, 'reals')
if num_left <= 0:
break
# feed in fake images
for data in fake_dataloader:
fakes = data['real_img']
num_left = metric.feed(fakes, 'fakes')
if num_left <= 0:
break
metric.summary()
table_str = make_metrics_table(basic_table_info['train_cfg'],
basic_table_info['ckpt'],
basic_table_info['sample_model'], metrics)
logger.info('\n' + table_str)
if delete_samples_path:
shutil.rmtree(samples_path)
@torch.no_grad()
def single_gpu_online_evaluation(model, data_loader, metrics, logger,
basic_table_info, batch_size, **kwargs):
"""Evaluate model with a single gpu in online mode.
This method evaluate model with a single gpu and displays eval progress
bar. Different form `single_gpu_evaluation`, this function will not save
the images or read images from disks. Namely, there do not exist any IO
operations in this function. Thus, in general, `online` mode will achieve a
faster evaluation. However, this mode will take much more memory cost.
Therefore this evaluation function is recommended to evaluate your model
with a single metric.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): PyTorch data loader.
metrics (list): List of metric objects.
logger (Logger): logger used to record results of evaluation.
basic_table_info (dict): Dictionary containing the basic information \
of the metric table include training configuration and ckpt.
batch_size (int): Batch size of images fed into metrics.
kwargs (dict): Other arguments.
"""
# sample images
max_num_images = 0 if len(metrics) == 0 else max(metric.num_fake_need
for metric in metrics)
pbar = mmcv.ProgressBar(max_num_images)
# select key to fetch images
target_domain = basic_table_info['target_domain']
source_domain = basic_table_info['source_domain']
for metric in metrics:
mmcv.print_log(f'Evaluate with {metric.name} metric.', 'mmgen')
metric.prepare()
# feed reals and fakes
data_loader_iter = iter(data_loader)
for begin in range(0, max_num_images, batch_size):
end = min(begin + batch_size, max_num_images)
# for translation model, we feed them images from dataloader
data_batch = next(data_loader_iter)
output_dict = model(
data_batch[f'img_{source_domain}'],
test_mode=True,
target_domain=target_domain)
fakes = output_dict['target']
reals = data_batch[f'img_{target_domain}']
pbar.update(end - begin)
for metric in metrics:
metric.feed(reals, 'reals')
metric.feed(fakes, 'fakes')
for metric in metrics:
metric.summary()
table_str = make_metrics_table(basic_table_info['train_cfg'],
basic_table_info['ckpt'],
basic_table_info['sample_model'], metrics)
logger.info('\n' + table_str)
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
dirname = os.path.dirname(args.checkpoint)
ckpt = os.path.basename(args.checkpoint)
if 'http' in args.checkpoint:
log_path = None
else:
log_name = ckpt.split('.')[0] + '_eval_log' + '.txt'
log_path = os.path.join(dirname, log_name)
logger = get_root_logger(
log_file=log_path, log_level=cfg.log_level, file_mode='a')
logger.info('evaluation')
# set random seeds
if args.seed is not None:
set_random_seed(args.seed, deterministic=args.deterministic)
# build the model and load checkpoint
model = build_model(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
assert isinstance(model, BaseTranslationModel)
# sanity check for models without ema
if not model.use_ema:
args.sample_model = 'orig'
mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')
model.eval()
_ = load_checkpoint(model, args.checkpoint, map_location='cpu')
model = MMDataParallel(model, device_ids=[0])
# build metrics
if args.eval:
if args.eval[0] == 'none':
# only sample images
metrics = []
assert args.num_samples is not None and args.num_samples > 0
else:
metrics = [
build_metric(cfg.metrics[metric]) for metric in args.eval
]
else:
metrics = [build_metric(cfg.metrics[metric]) for metric in cfg.metrics]
# get source domain and target domain
target_domain = args.target_domain
if target_domain is None:
target_domain = model.module._default_domain
source_domain = model.module.get_other_domains(target_domain)[0]
basic_table_info = dict(
train_cfg=os.path.basename(cfg._filename),
ckpt=ckpt,
sample_model=args.sample_model,
source_domain=source_domain,
target_domain=target_domain)
# build the dataloader
if len(metrics) == 0:
basic_table_info['num_samples'] = args.num_samples
data_loader = None
else:
basic_table_info['num_samples'] = -1
if cfg.data.get('test', None):
dataset = build_dataset(cfg.data.test)
else:
dataset = build_dataset(cfg.data.train)
data_loader = build_dataloader(
dataset,
samples_per_gpu=args.batch_size,
workers_per_gpu=cfg.data.get('val_workers_per_gpu',
cfg.data.workers_per_gpu),
dist=False,
shuffle=True)
if args.online:
single_gpu_online_evaluation(model, data_loader, metrics, logger,
basic_table_info, args.batch_size)
else:
single_gpu_evaluation(model, data_loader, metrics, logger,
basic_table_info, args.batch_size,
args.samples_path)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment