"vscode:/vscode.git/clone" did not exist on "432443c2c2bf732d37ddee7a1937bcbefd02edb5"
Commit 088919f6 authored by pangjm's avatar pangjm
Browse files

Update single base version

parent 108fc9e1
import torch import torch
from ._functions import Scatter from ._functions import Scatter
from torch.nn.parallel._functions import Scatter as OrigScatter from torch.nn.parallel._functions import Scatter as OrigScatter
from detkit.datasets.utils import DataContainer from mmdet.datasets.utils import DataContainer
def scatter(inputs, target_gpus, dim=0): def scatter(inputs, target_gpus, dim=0):
......
from argparse import ArgumentParser
from multiprocessing import Pool
import matplotlib.pyplot as plt
import numpy as np
import copy
import os
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
def generate_area_range(splitRng=32, stop_size=128):
areaRng = [[0**2, 1e5**2], [0**2, 32**2], [32**2, 96**2], [96**2, 1e5**2]]
start = 0
while start < stop_size:
end = start + splitRng
areaRng.append([start * start, end * end])
start = end
areaRng.append([start * start, 1e5**2])
return areaRng
def print_summarize(iouThr=None,
iouThrs=None,
precision=None,
recall=None,
areaRng_id=4,
areaRngs=None,
maxDets_id=2,
maxDets=None):
assert (precision is not None) or (recall is not None)
iStr = ' {:<18} {} @[ IoU={:<9} | size={:>5}-{:>5} | maxDets={:>3d} ] = {:0.3f}'
titleStr = 'Average Precision' if precision is not None else 'Average Recall'
typeStr = '(AP)' if precision is not None else '(AR)'
iouStr = '{:0.2f}:{:0.2f}'.format(iouThrs[0], iouThrs[-1]) \
if iouThr is None else '{:0.2f}'.format(iouThr)
aind = [areaRng_id]
mind = [maxDets_id]
if precision is not None:
# dimension of precision: [TxRxKxAxM]
s = precision
# IoU
if iouThr is not None:
t = np.where(iouThr == iouThrs)[0]
s = s[t]
s = s[:, :, :, aind, mind]
else:
# dimension of recall: [TxKxAxM]
s = recall
if iouThr is not None:
t = np.where(iouThr == iouThrs)[0]
s = s[t]
s = s[:, :, aind, mind]
if len(s[s > -1]) == 0:
mean_s = -1
else:
mean_s = np.mean(s[s > -1])
print(
iStr.format(
titleStr, typeStr, iouStr, np.sqrt(areaRngs[areaRng_id][0]),
np.sqrt(areaRngs[areaRng_id][1])
if np.sqrt(areaRngs[areaRng_id][1]) < 999 else 'max',
maxDets[maxDets_id], mean_s))
def eval_results(res_file, ann_file, res_types, splitRng):
for res_type in res_types:
assert res_type in ['proposal', 'bbox', 'segm', 'keypoints']
areaRng = generate_area_range(splitRng)
cocoGt = COCO(ann_file)
cocoDt = cocoGt.loadRes(res_file)
imgIds = cocoGt.getImgIds()
for res_type in res_types:
iou_type = 'bbox' if res_type == 'proposal' else res_type
cocoEval = COCOeval(cocoGt, cocoDt, iou_type)
cocoEval.params.imgIds = imgIds
if res_type == 'proposal':
cocoEval.params.useCats = 0
cocoEval.params.maxDets = [100, 300, 1000]
cocoEval.params.areaRng = areaRng
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
ps = cocoEval.eval['precision']
rc = cocoEval.eval['recall']
for i in range(len(areaRng)):
print_summarize(None, cocoEval.params.iouThrs, ps, None, i,
areaRng, 2, cocoEval.params.maxDets)
def makeplot(rs, ps, outDir, class_name):
cs = np.vstack([
np.ones((2, 3)),
np.array([.31, .51, .74]),
np.array([.75, .31, .30]),
np.array([.36, .90, .38]),
np.array([.50, .39, .64]),
np.array([1, .6, 0])
])
areaNames = ['all', 'small', 'medium', 'large']
types = ['C75', 'C50', 'Loc', 'Sim', 'Oth', 'BG', 'FN']
for i in range(len(areaNames)):
area_ps = ps[..., i, 0]
figure_tile = class_name + '-' + areaNames[i]
aps = [ps_.mean() for ps_ in area_ps]
ps_curve = [
ps_.mean(axis=1) if ps_.ndim > 1 else ps_ for ps_ in area_ps
]
ps_curve.insert(0, np.zeros(ps_curve[0].shape))
fig = plt.figure()
ax = plt.subplot(111)
for k in range(len(types)):
ax.plot(rs, ps_curve[k + 1], color=[0, 0, 0], linewidth=0.5)
ax.fill_between(
rs,
ps_curve[k],
ps_curve[k + 1],
color=cs[k],
label=str('[{:.3f}'.format(aps[k]) + ']' + types[k]))
plt.xlabel('recall')
plt.ylabel('precision')
plt.xlim(0, 1.)
plt.ylim(0, 1.)
plt.title(figure_tile)
plt.legend()
# plt.show()
fig.savefig(outDir + '/{}.png'.format(figure_tile))
plt.close(fig)
def analyze_individual_category(k, cocoDt, cocoGt, catId, iou_type):
nm = cocoGt.loadCats(catId)[0]
print('--------------analyzing {}-{}---------------'.format(
k + 1, nm['name']))
ps_ = {}
dt = copy.deepcopy(cocoDt)
nm = cocoGt.loadCats(catId)[0]
imgIds = cocoGt.getImgIds()
dt_anns = dt.dataset['annotations']
select_dt_anns = []
for ann in dt_anns:
if ann['category_id'] == catId:
select_dt_anns.append(ann)
dt.dataset['annotations'] = select_dt_anns
dt.createIndex()
# compute precision but ignore superclass confusion
gt = copy.deepcopy(cocoGt)
child_catIds = gt.getCatIds(supNms=[nm['supercategory']])
for idx, ann in enumerate(gt.dataset['annotations']):
if (ann['category_id'] in child_catIds
and ann['category_id'] != catId):
gt.dataset['annotations'][idx]['ignore'] = 1
gt.dataset['annotations'][idx]['iscrowd'] = 1
gt.dataset['annotations'][idx]['category_id'] = catId
cocoEval = COCOeval(gt, copy.deepcopy(dt), iou_type)
cocoEval.params.imgIds = imgIds
cocoEval.params.maxDets = [100]
cocoEval.params.iouThrs = [.1]
cocoEval.params.useCats = 1
cocoEval.evaluate()
cocoEval.accumulate()
ps_supercategory = cocoEval.eval['precision'][0, :, k, :, :]
ps_['ps_supercategory'] = ps_supercategory
# compute precision but ignore any class confusion
gt = copy.deepcopy(cocoGt)
for idx, ann in enumerate(gt.dataset['annotations']):
if ann['category_id'] != catId:
gt.dataset['annotations'][idx]['ignore'] = 1
gt.dataset['annotations'][idx]['iscrowd'] = 1
gt.dataset['annotations'][idx]['category_id'] = catId
cocoEval = COCOeval(gt, copy.deepcopy(dt), iou_type)
cocoEval.params.imgIds = imgIds
cocoEval.params.maxDets = [100]
cocoEval.params.iouThrs = [.1]
cocoEval.params.useCats = 1
cocoEval.evaluate()
cocoEval.accumulate()
ps_allcategory = cocoEval.eval['precision'][0, :, k, :, :]
ps_['ps_allcategory'] = ps_allcategory
return k, ps_
def analyze_results(res_file, ann_file, res_types, out_dir):
for res_type in res_types:
assert res_type in ['bbox', 'segm']
directory = os.path.dirname(out_dir + '/')
if not os.path.exists(directory):
print('-------------create {}-----------------'.format(out_dir))
os.makedirs(directory)
cocoGt = COCO(ann_file)
cocoDt = cocoGt.loadRes(res_file)
imgIds = cocoGt.getImgIds()
for res_type in res_types:
iou_type = res_type
cocoEval = COCOeval(
copy.deepcopy(cocoGt), copy.deepcopy(cocoDt), iou_type)
cocoEval.params.imgIds = imgIds
cocoEval.params.iouThrs = [.75, .5, .1]
cocoEval.params.maxDets = [100]
cocoEval.evaluate()
cocoEval.accumulate()
ps = cocoEval.eval['precision']
ps = np.vstack([ps, np.zeros((4, *ps.shape[1:]))])
catIds = cocoGt.getCatIds()
recThrs = cocoEval.params.recThrs
with Pool(processes=48) as pool:
args = [(k, cocoDt, cocoGt, catId, iou_type)
for k, catId in enumerate(catIds)]
analyze_results = pool.starmap(analyze_individual_category, args)
for k, catId in enumerate(catIds):
nm = cocoGt.loadCats(catId)[0]
print('--------------saving {}-{}---------------'.format(
k + 1, nm['name']))
analyze_result = analyze_results[k]
assert k == analyze_result[0]
ps_supercategory = analyze_result[1]['ps_supercategory']
ps_allcategory = analyze_result[1]['ps_allcategory']
# compute precision but ignore superclass confusion
ps[3, :, k, :, :] = ps_supercategory
# compute precision but ignore any class confusion
ps[4, :, k, :, :] = ps_allcategory
# fill in background and false negative errors and plot
ps[ps == -1] = 0
ps[5, :, k, :, :] = (ps[4, :, k, :, :] > 0)
ps[6, :, k, :, :] = 1.0
makeplot(recThrs, ps[:, :, k], out_dir, nm['name'])
makeplot(recThrs, ps, out_dir, 'all')
def main():
parser = ArgumentParser(description='COCO Evaluation')
parser.add_argument('result', help='result file path')
parser.add_argument(
'--ann',
default='/mnt/SSD/dataset/coco/annotations/instances_minival2017.json',
help='annotation file path')
parser.add_argument(
'--types', type=str, nargs='+', default=['bbox'], help='result types')
parser.add_argument(
'--analyze', action='store_true', help='whether to analyze results')
parser.add_argument(
'--out_dir',
type=str,
default=None,
help='dir to save analyze result images')
parser.add_argument(
'--splitRng',
type=int,
default=32,
help='range to split area in evaluation')
args = parser.parse_args()
if not args.analyze:
eval_results(args.result, args.ann, args.types, splitRng=args.splitRng)
else:
assert args.out_dir is not None
analyze_results(
args.result, args.ann, args.types, out_dir=args.out_dir)
if __name__ == '__main__':
main()
# model settings
model = dict(
pretrained=
'/mnt/lustre/pangjiangmiao/initmodel/pytorch/resnet50-19c8e357.pth',
backbone=dict(
type='resnet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='fb'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
coarsest_stride=32,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
use_sigmoid_cls=True),
roi_block=dict(
type='SingleLevelRoI',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCRoIHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False))
meta_params = dict(
rpn_train_cfg = dict(
pos_fraction=0.5,
pos_balance_sampling=False,
neg_pos_ub=256,
allowed_border=0,
anchor_batch_size=256,
pos_iou_thr=0.7,
neg_iou_thr=0.3,
neg_balance_thr=0,
min_pos_iou=1e-3,
pos_weight=-1,
smoothl1_beta=1 / 9.0,
debug=False),
rpn_test_cfg = dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn_train_cfg = dict(
pos_iou_thr=0.5,
neg_iou_thr=0.5,
crowd_thr=1.1,
roi_batch_size=512,
add_gt_as_proposals=True,
pos_fraction=0.25,
pos_balance_sampling=False,
neg_pos_ub=512,
neg_balance_thr=0,
pos_weight=-1,
debug=False),
rcnn_test_cfg = dict(score_thr=1e-3, max_per_img=100, nms_thr=0.5)
)
# dataset settings
data_root = '/mnt/lustre/pangjiangmiao/dataset/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
img_per_gpu = 1
data_workers = 2
train_dataset = dict(
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0.5)
test_dataset = dict(
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32)
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
grad_clip_config = dict(grad_clip=True, max_norm=35, norm_type=2)
# learning policy
lr_policy = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.333,
step=[8, 11])
max_epoch = 12
checkpoint_config = dict(interval=1)
dist_params = dict(backend='nccl', port='29500', master_ip='127.0.0.1')
# logging settings
log_level = 'INFO'
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# ('TensorboardLoggerHook', dict(log_dir=work_dir + '/log')),
])
# yapf:enable
work_dir = './model/r50_fpn_frcnn_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
pretrained=
'/mnt/lustre/pangjiangmiao/initmodel/pytorch/resnet50-19c8e357.pth',
backbone=dict(
type='resnet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='fb'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
coarsest_stride=32,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
use_sigmoid_cls=True),
roi_block=dict(
type='SingleLevelRoI',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCRoIHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False),
mask_block=dict(
type='SingleLevelRoI',
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_head=dict(
type='FCNMaskHead',
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=81))
meta_params = dict(
rpn_train_cfg=dict(
pos_fraction=0.5,
pos_balance_sampling=False,
neg_pos_ub=256,
allowed_border=0,
anchor_batch_size=256,
pos_iou_thr=0.7,
neg_iou_thr=0.3,
neg_balance_thr=0,
min_pos_iou=1e-3,
pos_weight=-1,
smoothl1_beta=1 / 9.0,
debug=False),
rpn_test_cfg=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn_train_cfg=dict(
mask_size=28,
pos_iou_thr=0.5,
neg_iou_thr=0.5,
crowd_thr=1.1,
roi_batch_size=512,
add_gt_as_proposals=True,
pos_fraction=0.25,
pos_balance_sampling=False,
neg_pos_ub=512,
neg_balance_thr=0,
pos_weight=-1,
debug=False),
rcnn_test_cfg=dict(
score_thr=1e-3, max_per_img=100, nms_thr=0.5, mask_thr_binary=0.5))
# dataset settings
data_root = '/mnt/lustre/pangjiangmiao/dataset/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
img_per_gpu = 1
data_workers = 2
train_dataset = dict(
with_mask=True,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0.5)
test_dataset = dict(
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32)
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
grad_clip_config = dict(grad_clip=True, max_norm=35, norm_type=2)
# learning policy
lr_policy = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.333,
step=[8, 11])
max_epoch = 12
checkpoint_config = dict(interval=1)
dist_params = dict(backend='nccl', port='29500', master_ip='127.0.0.1')
# logging settings
log_level = 'INFO'
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# ('TensorboardLoggerHook', dict(log_dir=work_dir + '/log')),
])
# yapf:enable
work_dir = './model/r50_fpn_mask_rcnn_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
pretrained=
'/mnt/lustre/pangjiangmiao/initmodel/pytorch/resnet50-19c8e357.pth',
backbone=dict(
type='resnet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='fb'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
coarsest_stride=32,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
use_sigmoid_cls=True))
meta_params = dict(
rpn_train_cfg=dict(
pos_fraction=0.5,
pos_balance_sampling=False,
neg_pos_ub=256,
allowed_border=0,
anchor_batch_size=256,
pos_iou_thr=0.7,
neg_iou_thr=0.3,
neg_balance_thr=0,
min_pos_iou=1e-3,
pos_weight=-1,
smoothl1_beta=1 / 9.0,
debug=False),
rpn_test_cfg=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0))
# dataset settings
data_root = '/mnt/lustre/pangjiangmiao/dataset/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
img_per_gpu = 1
data_workers = 2
train_dataset = dict(
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0.5)
test_dataset = dict(
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
test_mode=True)
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
grad_clip_config = dict(grad_clip=True, max_norm=35, norm_type=2)
# learning policy
lr_policy = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.333,
step=[8, 11])
max_epoch = 12
checkpoint_config = dict(interval=1)
dist_params = dict(backend='nccl', port='29500', master_ip='127.0.0.1')
# logging settings
log_level = 'INFO'
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# ('TensorboardLoggerHook', dict(log_dir=work_dir + '/log')),
])
# yapf:enable
work_dir = './model/r50_fpn_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
import os.path as osp
import sys
sys.path.append(osp.abspath(osp.join(__file__, '../../')))
sys.path.append('/mnt/lustre/pangjiangmiao/sensenet_folder/mmcv')
import argparse
import numpy as np
import torch
import mmcv
from mmcv import Config
from mmcv.torchpack import load_checkpoint, parallel_test
from mmdet.core import _data_func, results2json
from mmdet.datasets import CocoDataset
from mmdet.datasets.data_engine import build_data
from mmdet.models import Detector
def parse_args():
parser = argparse.ArgumentParser(description='MMDet test detector')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--world_size', default=1, type=int)
parser.add_argument('--out', help='output result file')
parser.add_argument(
'--out_json', action='store_true', help='get json output file')
args = parser.parse_args()
return args
args = parse_args()
def main():
cfg = Config.fromfile(args.config)
cfg.model['pretrained'] = None
# TODO this img_per_gpu
cfg.img_per_gpu == 1
if args.world_size == 1:
# TODO verify this part
args.dist = False
args.img_per_gpu = cfg.img_per_gpu
args.data_workers = cfg.data_workers
model = Detector(**cfg.model, **meta_params)
load_checkpoint(model, args.checkpoint)
test_loader = build_data(cfg.test_dataset, args)
model = torch.nn.DataParallel(model, device_ids=0)
# TODO write single_test
outputs = single_test(test_loader, model)
else:
test_dataset = CocoDataset(**cfg.test_dataset)
model = dict(cfg.model, **cfg.meta_params)
outputs = parallel_test(Detector, model,
args.checkpoint, test_dataset, _data_func,
range(args.world_size))
if args.out:
mmcv.dump(outputs, args.out, protocol=4)
if args.out_json:
results2json(test_dataset, outputs, args.out + '.json')
if __name__ == '__main__':
main()
from __future__ import division
import argparse
import sys
import os.path as osp
sys.path.append(osp.abspath(osp.join(__file__, '../../')))
sys.path.append('/mnt/lustre/pangjiangmiao/sensenet_folder/mmcv')
import torch
import torch.multiprocessing as mp
from mmcv import Config
from mmcv.torchpack import Runner
from mmdet.core import (batch_processor, init_dist, broadcast_params,
DistOptimizerStepperHook, DistSamplerSeedHook)
from mmdet.datasets.data_engine import build_data
from mmdet.models import Detector
from mmdet.nn.parallel import MMDataParallel
def parse_args():
parser = argparse.ArgumentParser(description='MMDet train val detector')
parser.add_argument('config', help='train config file path')
parser.add_argument('--validate', action='store_true', help='validate')
parser.add_argument(
'--dist', action='store_true', help='distributed training or not')
parser.add_argument('--world_size', default=1, type=int)
parser.add_argument('--rank', default=0, type=int)
args = parser.parse_args()
return args
args = parse_args()
def main():
# Enable distributed training or not
if args.dist:
print('Enable distributed training.')
mp.set_start_method("spawn", force=True)
init_dist(
args.world_size,
args.rank,
**cfg.dist_params)
else:
print('Disabled distributed training.')
# Fetch config information
cfg = Config.fromfile(args.config)
# TODO more flexible
args.img_per_gpu = cfg.img_per_gpu
args.data_workers = cfg.data_workers
# prepare training loader
train_loader = [build_data(cfg.train_dataset, args)]
if args.validate:
val_loader = build_data(cfg.val_dataset, args)
train_loader.append(val_loader)
# build model
model = Detector(**cfg.model, **cfg.meta_params)
if args.dist:
model = model.cuda()
broadcast_params(model)
else:
device_ids = args.rank % torch.cuda.device_count()
model = MMDataParallel(model, device_ids=device_ids).cuda()
# register hooks
runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
cfg.log_level)
optimizer_stepper = DistOptimizerStepperHook(
**cfg.grad_clip_config) if args.dist else cfg.grad_clip_config
runner.register_training_hooks(cfg.lr_policy, optimizer_stepper,
cfg.checkpoint_config, cfg.log_config)
if args.dist:
runner.register_hook(DistSamplerSeedHook())
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(train_loader, cfg.workflow, cfg.max_epoch, args=args)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment