Commit b7536f78 authored by limm's avatar limm
Browse files

add a to another part of mmgeneration code

parent 57e0e891
Pipeline #2777 canceled with stages
Import:
- configs/ada/metafile.yml
- configs/biggan/metafile.yml
- configs/cyclegan/metafile.yml
- configs/dcgan/metafile.yml
- configs/ggan/metafile.yml
- configs/improved_ddpm/metafile.yml
- configs/lsgan/metafile.yml
- configs/pggan/metafile.yml
- configs/pix2pix/metafile.yml
- configs/positional_encoding_in_gans/metafile.yml
- configs/sagan/metafile.yml
- configs/singan/metafile.yml
- configs/sngan_proj/metafile.yml
- configs/styleganv1/metafile.yml
- configs/styleganv2/metafile.yml
- configs/styleganv3/metafile.yml
- configs/wgan-gp/metafile.yml
-r requirements/runtime.txt
-r requirements/tests.txt
click
docutils==0.16.0
m2r
mmcls==0.18.0
myst-parser
opencv-python!=4.5.5.62,!=4.5.5.64
# Skip problematic opencv-python versions
# MMCV depends opencv-python instead of headless, thus we install opencv-python
# Due to a bug from upstream, we skip this two version
# https://github.com/opencv/opencv-python/issues/602
# https://github.com/opencv/opencv/issues/21366
# It seems to be fixed in https://github.com/opencv/opencv/pull/21382opencv-python
prettytable
-e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
scipy
sphinx==4.0.2
sphinx-copybutton
sphinx_markdown_tables
mmcls>=0.18.0
mmcv-full>=1.3.0,<=1.8.0
mmcv
torch
torchvision
mmcls
ninja
numpy
prettytable
requests
scikit-image
scipy
tqdm
yapf
coverage < 7.0.0
# codecov
flake8
interrogate
isort==4.3.21
pytest
pytest-runner
[bdist_wheel]
universal=1
[aliases]
test=pytest
[yapf]
based_on_style=pep8
blank_line_before_nested_class_or_def=true
split_before_expression_after_opening_paren=true
[isort]
line_length=79
multi_line_output=0
extra_standard_library=argparse,inspect,contextlib,hashlib,subprocess,unittest,tempfile,copy,pkg_resources,logging,pickle,platform,setuptools,abc,collections,functools,os,math,time,warnings,random,shutil,sys
known_first_party=mmgen
known_third_party=PIL,click,clip,cv2,imageio,mmcls,mmcv,numpy,prettytable,pytest,pytorch_sphinx_theme,recommonmark,requests,scipy,torch,torchvision,tqdm,ts
no_lines_before=STDLIB,LOCALFOLDER
default_section=THIRDPARTY
import os
import os.path as osp
import re
import shutil
import sys
import warnings
from setuptools import find_packages, setup
import torch
from torch.utils.cpp_extension import (BuildExtension, CppExtension,
CUDAExtension)
def readme():
with open('README.md', encoding='utf-8') as f:
content = f.read()
return content
def get_version():
version_file = 'mmgen/version.py'
with open(version_file, 'r') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']
def parse_requirements(filename='requirements.txt', with_version=True):
"""Parse the package dependencies listed in a requirements file but strips
specific versioning information.
Args:
filename (str): path to requirements file
with_version (bool, default=False): if True include version specs
Returns:
List[str]: list of requirements items
"""
def parse_line(line):
"""Parse information from a line in a requirements text file."""
if line.startswith('-r '):
# Allow specifying requirements in other files
target = line.split(' ')[1]
for info in parse_require_file(target):
yield info
else:
info = {'line': line}
if line.startswith('-e '):
info['package'] = line.split('#egg=')[1]
elif '@git+' in line:
info['package'] = line
else:
# Remove versioning from the package
pat = '(' + '|'.join(['>=', '==', '>']) + ')'
parts = re.split(pat, line, maxsplit=1)
parts = [p.strip() for p in parts]
info['package'] = parts[0]
if len(parts) > 1:
op, rest = parts[1:]
if ';' in rest:
# Handle platform specific dependencies
# http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
version, platform_deps = map(str.strip,
rest.split(';'))
info['platform_deps'] = platform_deps
else:
version = rest # NOQA
info['version'] = (op, version)
yield info
def parse_require_file(fpath):
with open(fpath, 'r') as f:
for line in f.readlines():
line = line.strip()
if line and not line.startswith('#'):
for info in parse_line(line):
yield info
def gen_packages_items():
if osp.exists(filename):
for info in parse_require_file(filename):
parts = [info['package']]
if with_version and 'version' in info:
parts.extend(info['version'])
if not sys.version.startswith('3.4'):
# apparently package_deps are broken in 3.4
platform_deps = info.get('platform_deps')
if platform_deps is not None:
parts.append(';' + platform_deps)
item = ''.join(parts)
yield item
packages = list(gen_packages_items())
return packages
def make_cuda_ext(name, module, sources, sources_cuda=[]):
define_macros = []
extra_compile_args = {'cxx': []}
if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
define_macros += [('WITH_CUDA', None)]
extension = CUDAExtension
extra_compile_args['nvcc'] = [
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
]
sources += sources_cuda
else:
print(f'Compiling {name} without CUDA')
extension = CppExtension
return extension(
name=f'{module}.{name}',
sources=[os.path.join(*module.split('.'), p) for p in sources],
define_macros=define_macros,
extra_compile_args=extra_compile_args)
def add_mim_extension():
"""Add extra files that are required to support MIM into the package.
These files will be added by creating a symlink to the originals if the
package is installed in `editable` mode (e.g. pip install -e .), or by
copying from the originals otherwise.
"""
# parse installment mode
if 'develop' in sys.argv:
# installed by `pip install -e .`
mode = 'symlink'
elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv:
# installed by `pip install .`
# or create source distribution by `python setup.py sdist`
mode = 'copy'
else:
return
filenames = ['tools', 'configs', 'demo', 'model-index.yml']
repo_path = osp.dirname(__file__)
mim_path = osp.join(repo_path, 'mmgen', '.mim')
os.makedirs(mim_path, exist_ok=True)
for filename in filenames:
if osp.exists(filename):
src_path = osp.join(repo_path, filename)
tar_path = osp.join(mim_path, filename)
if osp.isfile(tar_path) or osp.islink(tar_path):
os.remove(tar_path)
elif osp.isdir(tar_path):
shutil.rmtree(tar_path)
if mode == 'symlink':
src_relpath = osp.relpath(src_path, osp.dirname(tar_path))
os.symlink(src_relpath, tar_path)
elif mode == 'copy':
if osp.isfile(src_path):
shutil.copyfile(src_path, tar_path)
elif osp.isdir(src_path):
shutil.copytree(src_path, tar_path)
else:
warnings.warn(f'Cannot copy file {src_path}.')
else:
raise ValueError(f'Invalid mode {mode}')
pass
if __name__ == '__main__':
add_mim_extension()
setup(
name='mmgen',
version=get_version(),
description='MMGeneration',
long_description=readme(),
long_description_content_type='text/markdown',
packages=find_packages(exclude=('configs', 'tools', 'demo')),
package_data={'mmgen.ops': ['*/*.so']},
classifiers=[
'Development Status :: 4 - Beta',
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
],
url='https://github.com/open-mmlab/mmgen',
author='MMGeneration Contributors',
author_email='openmmlab@gmail.com',
license='Apache License 2.0',
include_package_data=True,
install_requires=parse_requirements('requirements.txt'),
cmdclass={'build_ext': BuildExtension},
extras_require={
'all': parse_requirements('requirements.txt'),
'tests': parse_requirements('requirements/tests.txt'),
'mim': parse_requirements('requirements/mminstall.txt'),
},
zip_safe=False)
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser, Namespace
from pathlib import Path
from tempfile import TemporaryDirectory
import mmcv
try:
from model_archiver.model_packaging import package_model
from model_archiver.model_packaging_utils import ModelExportUtils
except ImportError:
package_model = None
def mmgen2torchserver(config_file: str,
checkpoint_file: str,
output_folder: str,
model_name: str,
model_version: str = '1.0',
model_type: str = 'unconditional',
force: bool = False):
"""Converts MMGeneration model (config + checkpoint) to TorchServe `.mar`.
Args:
config_file (str): Path of config file. The config should in
MMGeneration format.
checkpoint_file (str): Path of checkpoint. The checkpoint should in
MMGeneration checkpoint format.
output_folder (str): Folder where `{model_name}.mar` will be created.
The file created will be in TorchServe archive format.
model_name (str): Name of the generated ``'mar'`` file. If not None,
used for naming the `{model_name}.mar` file that will be created
under `output_folder`. If None, `{Path(checkpoint_file).stem}`
will be used.
model_version (str, optional): Model's version. Defaults to '1.0'.
model_type (str, optional): Type of the model to be convert. Handler
named ``{model_type}_handler`` would be used to generate ``mar``
file. Defaults to 'unconditional'.
force (bool, optional): If True, existing `{model_name}.mar` will be
overwritten. Default to False.
"""
mmcv.mkdir_or_exist(output_folder)
config = mmcv.Config.fromfile(config_file)
with TemporaryDirectory() as tmpdir:
config.dump(f'{tmpdir}/config.py')
args = Namespace(
**{
'model_file': f'{tmpdir}/config.py',
'serialized_file': checkpoint_file,
'handler':
f'{Path(__file__).parent}/mmgen_{model_type}_handler.py',
'model_name': model_name or Path(checkpoint_file).stem,
'version': model_version,
'export_path': output_folder,
'force': force,
'requirements_file': None,
'extra_files': None,
'runtime': 'python',
'archive_format': 'default'
})
manifest = ModelExportUtils.generate_manifest_json(args)
package_model(args, manifest)
def parse_args():
parser = ArgumentParser(
description='Convert MMGeneration models to TorchServe `.mar` format.')
parser.add_argument('config', type=str, help='config file path')
parser.add_argument('checkpoint', type=str, help='checkpoint file path')
parser.add_argument(
'--output-folder',
type=str,
required=True,
help='Folder where `{model_name}.mar` will be created.')
parser.add_argument(
'--model-name',
type=str,
default=None,
help='If not None, used for naming the `{model_name}.mar`'
'file that will be created under `output_folder`.'
'If None, `{Path(checkpoint_file).stem}` will be used.')
parser.add_argument(
'--model-type',
type=str,
default='unconditional',
help='Which model type and handler to be used.')
parser.add_argument(
'--model-version',
type=str,
default='1.0',
help='Number used for versioning.')
parser.add_argument(
'-f',
'--force',
action='store_true',
help='overwrite the existing `{model_name}.mar`')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if package_model is None:
raise ImportError('`torch-model-archiver` is required.'
'Try: pip install torch-model-archiver')
mmgen2torchserver(args.config, args.checkpoint, args.output_folder,
args.model_name, args.model_version, args.model_type,
args.force)
# Copyright (c) OpenMMLab. All rights reserved.
import os
import numpy as np
import torch
from ts.torch_handler.base_handler import BaseHandler
from mmgen.apis import init_model
class MMGenUnconditionalHandler(BaseHandler):
def initialize(self, context):
properties = context.system_properties
self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(self.map_location + ':' +
str(properties.get('gpu_id')) if torch.cuda.
is_available() else self.map_location)
self.manifest = context.manifest
model_dir = properties.get('model_dir')
serialized_file = self.manifest['model']['serializedFile']
checkpoint = os.path.join(model_dir, serialized_file)
self.config_file = os.path.join(model_dir, 'config.py')
self.model = init_model(self.config_file, checkpoint, self.device)
self.initialized = True
def preprocess(self, data, *args, **kwargs):
data_decode = dict()
# `data` type is `list[dict]`
for k, v in data[0].items():
# decode strings
if isinstance(v, bytearray):
data_decode[k] = v.decode()
return data_decode
def inference(self, data, *args, **kwargs):
sample_model = data['sample_model']
print(sample_model)
results = self.model.sample_from_noise(
None, num_batches=1, sample_model=sample_model, **kwargs)
return results
def postprocess(self, data):
# convert torch tensor to numpy and then convert to bytes
output_list = []
for data_ in data:
data_ = (data_ + 1) / 2
data_ = data_[[2, 1, 0], ...]
data_ = data_.clamp_(0, 1)
data_ = (data_ * 255).permute(1, 2, 0)
data_np = data_.detach().cpu().numpy().astype(np.uint8)
data_byte = data_np.tobytes()
output_list.append(data_byte)
return output_list
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser
import numpy as np
import requests
from PIL import Image
def parse_args():
parser = ArgumentParser()
parser.add_argument('model_name', help='The model name in the server')
parser.add_argument(
'--inference-addr',
default='127.0.0.1:8080',
help='Address and port of the inference server')
parser.add_argument(
'--img-path',
type=str,
default='demo.png',
help='Path to save generated image.')
parser.add_argument(
'--img-size', type=int, default=128, help='Size of the output image.')
parser.add_argument(
'--sample-model',
type=str,
default='ema/orig',
help='Which model you want to use.')
args = parser.parse_args()
return args
def save_results(contents, img_path, img_size):
if not isinstance(contents, list):
Image.frombytes('RGB', (img_size, img_size), contents).save(img_path)
return
imgs = []
for content in contents:
imgs.append(
np.array(Image.frombytes('RGB', (img_size, img_size), content)))
Image.fromarray(np.concatenate(imgs, axis=1)).save(img_path)
def main(args):
url = 'http://' + args.inference_addr + '/predictions/' + args.model_name
if args.sample_model == 'ema/orig':
cont_ema = requests.post(url, {'sample_model': 'ema'}).content
cont_orig = requests.post(url, {'sample_model': 'orig'}).content
save_results([cont_ema, cont_orig], args.img_path, args.img_size)
return
response = requests.post(url, {'sample_model': args.sample_model})
save_results(response.content, args.img_path, args.img_size)
if __name__ == '__main__':
args = parse_args()
main(args)
#!/usr/bin/env bash
CONFIG=$1
CHECKPOINT=$2
GPUS=$3
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
PORT=${PORT:-29500}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch \
--nnodes=$NNODES \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--nproc_per_node=$GPUS \
--master_port=$PORT \
$(dirname "$0")/test.py \
$CONFIG \
$CHECKPOINT \
--launcher pytorch \
${@:4}
#!/usr/bin/env bash
CONFIG=$1
GPUS=$2
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
PORT=${PORT:-29500}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch \
--nnodes=$NNODES \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--nproc_per_node=$GPUS \
--master_port=$PORT \
$(dirname "$0")/train.py \
$CONFIG \
--seed 0 \
--launcher pytorch ${@:3}
#!/usr/bin/env bash
set -x
CONFIG=$1
CKPT=$2
PY_ARGS=${@:3}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -u tools/evaluation.py ${CONFIG} ${CKPT} --launcher="none" ${PY_ARGS}
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import warnings
import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from mmgen.apis import set_random_seed
from mmgen.core import build_metric, offline_evaluation, online_evaluation
from mmgen.datasets import build_dataloader, build_dataset
from mmgen.models import build_model
from mmgen.utils import get_root_logger
_distributed_metrics = ['FID', 'IS']
def parse_args():
parser = argparse.ArgumentParser(description='Evaluate a Generation model')
parser.add_argument('config', help='evaluation config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
type=int,
help='number of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument(
'--gpu-id',
type=int,
default=0,
help='id of gpu to use '
'(only applicable to non-distributed testing)')
parser.add_argument('--seed', type=int, default=2021, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--batch-size', type=int, default=10, help='batch size of dataloader')
parser.add_argument(
'--samples-path',
type=str,
default=None,
help='path to store images. If not given, remove it after evaluation\
finished')
parser.add_argument(
'--sample-model',
type=str,
default='ema',
choices=['ema', 'orig'],
help='use which mode (ema/orig) in sampling')
parser.add_argument(
'--eval',
nargs='*',
type=str,
default=None,
help='select the metrics you want to access')
parser.add_argument(
'--online',
action='store_true',
help='whether to use online mode for evaluation')
parser.add_argument(
'--num-samples',
type=int,
default=-1,
help='The number of images to be sampled for evaluation.')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file.')
parser.add_argument(
'--sample-cfg',
nargs='+',
action=DictAction,
help='Other customized kwargs for sampling function')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids[0:1]
warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
'Because we only support single GPU mode in '
'non-distributed testing. Use the first GPU '
'in `gpu_ids` now.')
else:
cfg.gpu_ids = [args.gpu_id]
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
rank = 0
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
rank, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
assert args.online or world_size == 1, (
'We only support online mode for distrbuted evaluation.')
dirname = os.path.dirname(args.checkpoint)
ckpt = os.path.basename(args.checkpoint)
if 'http' in args.checkpoint:
log_path = None
else:
log_name = ckpt.split('.')[0] + '_eval_log' + '.txt'
log_path = os.path.join(dirname, log_name)
logger = get_root_logger(
log_file=log_path, log_level=cfg.log_level, file_mode='a')
logger.info('evaluation')
# set random seeds
if args.seed is not None:
if rank == 0:
mmcv.print_log(f'set random seed to {args.seed}', 'mmgen')
set_random_seed(args.seed, deterministic=args.deterministic)
# build the model and load checkpoint
model = build_model(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
# sanity check for models without ema
if not model.use_ema:
args.sample_model = 'orig'
mmcv.print_log(f'Sampling model: {args.sample_model}', 'mmgen')
model.eval()
if args.eval:
if args.eval[0] == 'none':
# only sample images
metrics = []
assert args.num_samples is not None and args.num_samples > 0
else:
metrics = [
build_metric(cfg.metrics[metric]) for metric in args.eval
]
else:
metrics = [build_metric(cfg.metrics[metric]) for metric in cfg.metrics]
# check metrics for dist evaluation
if distributed and metrics:
for metric in metrics:
assert metric.name in _distributed_metrics, (
f'We only support {_distributed_metrics} for multi gpu '
f'evaluation, but receive {args.eval}.')
_ = load_checkpoint(model, args.checkpoint, map_location='cpu')
basic_table_info = dict(
train_cfg=os.path.basename(cfg._filename),
ckpt=ckpt,
sample_model=args.sample_model)
if len(metrics) == 0:
basic_table_info['num_samples'] = args.num_samples
data_loader = None
else:
basic_table_info['num_samples'] = -1
# build the dataloader
if cfg.data.get('test', None) and cfg.data.test.get('imgs_root', None):
dataset = build_dataset(cfg.data.test)
elif cfg.data.get('val', None) and cfg.data.val.get('imgs_root', None):
dataset = build_dataset(cfg.data.val)
elif cfg.data.get('train', None):
# we assume that the train part should work well
dataset = build_dataset(cfg.data.train)
else:
raise RuntimeError('There is no valid dataset config to run, '
'please check your dataset configs.')
# The default loader config
loader_cfg = dict(
samples_per_gpu=args.batch_size,
workers_per_gpu=cfg.data.get('val_workers_per_gpu',
cfg.data.workers_per_gpu),
num_gpus=len(cfg.gpu_ids),
dist=distributed,
shuffle=True)
# The overall dataloader settings
loader_cfg.update({
k: v
for k, v in cfg.data.items() if k not in [
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
'test_dataloader'
]
})
# specific config for test loader
test_loader_cfg = {**loader_cfg, **cfg.data.get('test_dataloader', {})}
data_loader = build_dataloader(dataset, **test_loader_cfg)
if args.sample_cfg is None:
args.sample_cfg = dict()
if not distributed:
model = MMDataParallel(model, device_ids=[0])
else:
find_unused_parameters = cfg.get('find_unused_parameters', False)
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
# online mode will not save samples
if args.online and len(metrics) > 0:
online_evaluation(model, data_loader, metrics, logger,
basic_table_info, args.batch_size, **args.sample_cfg)
else:
offline_evaluation(model, data_loader, metrics, logger,
basic_table_info, args.batch_size,
args.samples_path, **args.sample_cfg)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from mmcv import Config, DictAction
def parse_args():
parser = argparse.ArgumentParser(description='Print the whole config')
parser.add_argument('config', help='config file path')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# import modules from string list.
if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports'])
print(f'Config:\n{cfg.pretty_text}')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import subprocess
from datetime import datetime
import torch
def parse_args():
parser = argparse.ArgumentParser(
description='Process a checkpoint to be published')
parser.add_argument('in_file', help='input checkpoint filename')
parser.add_argument('out_file', help='output checkpoint filename')
args = parser.parse_args()
return args
def process_checkpoint(in_file, out_file):
checkpoint = torch.load(in_file, map_location='cpu')
# remove optimizer for smaller file size
if 'optimizer' in checkpoint:
del checkpoint['optimizer']
# if it is necessary to remove some sensitive data in checkpoint['meta'],
# add the code here.
torch.save(checkpoint, out_file)
now = datetime.now()
time = now.strftime('%Y%m%d_%H%M%S')
sha = subprocess.check_output(['sha256sum', out_file]).decode()
final_file = out_file.rstrip('.pth') + f'_{time}-{sha[:8]}.pth'
subprocess.Popen(['mv', out_file, final_file])
def main():
args = parse_args()
process_checkpoint(args.in_file, args.out_file)
if __name__ == '__main__':
main()
#!/usr/bin/env bash
set -x
PARTITION=$1
JOB_NAME=$2
CONFIG=$3
CKPT=$4
GPUS=${GPUS:-1}
GPUS_PER_NODE=${GPUS_PER_NODE:-1}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
PY_ARGS=${@:5}
SRUN_ARGS=${SRUN_ARGS:-""}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python -u tools/evaluation.py ${CONFIG} ${CKPT} --launcher="none" ${PY_ARGS}
#!/usr/bin/env bash
set -x
PARTITION=$1
JOB_NAME=$2
CONFIG=$3
CKPT=$4
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
PY_ARGS=${@:5}
SRUN_ARGS=${SRUN_ARGS:-""}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python -u tools/evaluation.py ${CONFIG} ${CKPT} --launcher="slurm" ${PY_ARGS}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment