"vscode:/vscode.git/clone" did not exist on "6985955ddd85e6026c08375337ee356681428d4a"
Commit ce79da2e authored by zhangwenwei's avatar zhangwenwei
Browse files

Merge branch 'add-tta' into 'master'

Support test time augmentation

See merge request open-mmlab/mmdet.3d!70
parents f6e95edd 3c5ff9fa
...@@ -100,7 +100,7 @@ class SECONDFusionFPN(SECONDFPN): ...@@ -100,7 +100,7 @@ class SECONDFusionFPN(SECONDFPN):
coors=None, coors=None,
points=None, points=None,
img_feats=None, img_feats=None,
img_meta=None): img_metas=None):
assert len(x) == len(self.in_channels) assert len(x) == len(self.in_channels)
ups = [deblock(x[i]) for i, deblock in enumerate(self.deblocks)] ups = [deblock(x[i]) for i, deblock in enumerate(self.deblocks)]
...@@ -119,5 +119,5 @@ class SECONDFusionFPN(SECONDFPN): ...@@ -119,5 +119,5 @@ class SECONDFusionFPN(SECONDFPN):
coors[:, 3] / self.downsample_rates[2]) coors[:, 3] / self.downsample_rates[2])
# fusion for each point # fusion for each point
out = self.fusion_layer(img_feats, points, out, out = self.fusion_layer(img_feats, points, out,
downsample_pts_coors, img_meta) downsample_pts_coors, img_metas)
return [out] return [out]
...@@ -51,7 +51,7 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta): ...@@ -51,7 +51,7 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta):
@abstractmethod @abstractmethod
def forward_train(self, def forward_train(self,
x, x,
img_meta, img_metas,
proposal_list, proposal_list,
gt_bboxes, gt_bboxes,
gt_labels, gt_labels,
...@@ -64,7 +64,7 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta): ...@@ -64,7 +64,7 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta):
def simple_test(self, def simple_test(self,
x, x,
proposal_list, proposal_list,
img_meta, img_metas,
proposals=None, proposals=None,
rescale=False, rescale=False,
**kwargs): **kwargs):
......
...@@ -441,7 +441,7 @@ class PartA2BboxHead(nn.Module): ...@@ -441,7 +441,7 @@ class PartA2BboxHead(nn.Module):
bbox_pred, bbox_pred,
class_labels, class_labels,
class_pred, class_pred,
img_meta, img_metas,
cfg=None): cfg=None):
roi_batch_id = rois[..., 0] roi_batch_id = rois[..., 0]
roi_boxes = rois[..., 1:] # boxes without batch id roi_boxes = rois[..., 1:] # boxes without batch id
...@@ -474,8 +474,8 @@ class PartA2BboxHead(nn.Module): ...@@ -474,8 +474,8 @@ class PartA2BboxHead(nn.Module):
selected_scores = cur_cls_score[selected] selected_scores = cur_cls_score[selected]
result_list.append( result_list.append(
(img_meta[batch_id]['box_type_3d'](selected_bboxes, (img_metas[batch_id]['box_type_3d'](selected_bboxes,
self.bbox_coder.code_size), self.bbox_coder.code_size),
selected_scores, selected_label_preds)) selected_scores, selected_label_preds))
return result_list return result_list
......
...@@ -59,7 +59,7 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -59,7 +59,7 @@ class PartAggregationROIHead(Base3DRoIHead):
return hasattr(self, return hasattr(self,
'semantic_head') and self.semantic_head is not None 'semantic_head') and self.semantic_head is not None
def forward_train(self, feats_dict, voxels_dict, img_meta, proposal_list, def forward_train(self, feats_dict, voxels_dict, img_metas, proposal_list,
gt_bboxes_3d, gt_labels_3d): gt_bboxes_3d, gt_labels_3d):
"""Training forward function of PartAggregationROIHead """Training forward function of PartAggregationROIHead
...@@ -97,7 +97,7 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -97,7 +97,7 @@ class PartAggregationROIHead(Base3DRoIHead):
return losses return losses
def simple_test(self, feats_dict, voxels_dict, img_meta, proposal_list, def simple_test(self, feats_dict, voxels_dict, img_metas, proposal_list,
**kwargs): **kwargs):
"""Simple testing forward function of PartAggregationROIHead """Simple testing forward function of PartAggregationROIHead
...@@ -131,7 +131,7 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -131,7 +131,7 @@ class PartAggregationROIHead(Base3DRoIHead):
bbox_results['bbox_pred'], bbox_results['bbox_pred'],
labels_3d, labels_3d,
cls_preds, cls_preds,
img_meta, img_metas,
cfg=self.test_cfg) cfg=self.test_cfg)
bbox_results = [ bbox_results = [
......
...@@ -188,7 +188,7 @@ class DynamicVFE(nn.Module): ...@@ -188,7 +188,7 @@ class DynamicVFE(nn.Module):
coors, coors,
points=None, points=None,
img_feats=None, img_feats=None,
img_meta=None): img_metas=None):
"""Forward functions """Forward functions
Args: Args:
...@@ -198,7 +198,7 @@ class DynamicVFE(nn.Module): ...@@ -198,7 +198,7 @@ class DynamicVFE(nn.Module):
multi-modality fusion. Defaults to None. multi-modality fusion. Defaults to None.
img_feats (list[torch.Tensor], optional): Image fetures used for img_feats (list[torch.Tensor], optional): Image fetures used for
multi-modality fusion. Defaults to None. multi-modality fusion. Defaults to None.
img_meta (dict, optional): [description]. Defaults to None. img_metas (dict, optional): [description]. Defaults to None.
Returns: Returns:
tuple: If `return_point_feats` is False, returns voxel features and tuple: If `return_point_feats` is False, returns voxel features and
...@@ -237,7 +237,7 @@ class DynamicVFE(nn.Module): ...@@ -237,7 +237,7 @@ class DynamicVFE(nn.Module):
if (i == len(self.vfe_layers) - 1 and self.fusion_layer is not None if (i == len(self.vfe_layers) - 1 and self.fusion_layer is not None
and img_feats is not None): and img_feats is not None):
point_feats = self.fusion_layer(img_feats, points, point_feats, point_feats = self.fusion_layer(img_feats, points, point_feats,
img_meta) img_metas)
voxel_feats, voxel_coors = self.vfe_scatter(point_feats, coors) voxel_feats, voxel_coors = self.vfe_scatter(point_feats, coors)
if i != len(self.vfe_layers) - 1: if i != len(self.vfe_layers) - 1:
# need to concat voxel feats if it is not the last vfe # need to concat voxel feats if it is not the last vfe
...@@ -351,7 +351,7 @@ class HardVFE(nn.Module): ...@@ -351,7 +351,7 @@ class HardVFE(nn.Module):
num_points, num_points,
coors, coors,
img_feats=None, img_feats=None,
img_meta=None): img_metas=None):
"""Forward functions """Forward functions
Args: Args:
...@@ -360,7 +360,7 @@ class HardVFE(nn.Module): ...@@ -360,7 +360,7 @@ class HardVFE(nn.Module):
coors (torch.Tensor): Coordinates of voxels, shape is Mx(1+NDim). coors (torch.Tensor): Coordinates of voxels, shape is Mx(1+NDim).
img_feats (list[torch.Tensor], optional): Image fetures used for img_feats (list[torch.Tensor], optional): Image fetures used for
multi-modality fusion. Defaults to None. multi-modality fusion. Defaults to None.
img_meta (dict, optional): [description]. Defaults to None. img_metas (dict, optional): [description]. Defaults to None.
Returns: Returns:
tuple: If `return_point_feats` is False, returns voxel features and tuple: If `return_point_feats` is False, returns voxel features and
...@@ -410,12 +410,12 @@ class HardVFE(nn.Module): ...@@ -410,12 +410,12 @@ class HardVFE(nn.Module):
if (self.fusion_layer is not None and img_feats is not None): if (self.fusion_layer is not None and img_feats is not None):
voxel_feats = self.fusion_with_mask(features, mask, voxel_feats, voxel_feats = self.fusion_with_mask(features, mask, voxel_feats,
coors, img_feats, img_meta) coors, img_feats, img_metas)
return voxel_feats return voxel_feats
def fusion_with_mask(self, features, mask, voxel_feats, coors, img_feats, def fusion_with_mask(self, features, mask, voxel_feats, coors, img_feats,
img_meta): img_metas):
"""Fuse image and point features with mask. """Fuse image and point features with mask.
Args: Args:
...@@ -425,7 +425,7 @@ class HardVFE(nn.Module): ...@@ -425,7 +425,7 @@ class HardVFE(nn.Module):
voxel_feats (torch.Tensor): Features of voxels. voxel_feats (torch.Tensor): Features of voxels.
coors (torch.Tensor): Coordinates of each single voxel. coors (torch.Tensor): Coordinates of each single voxel.
img_feats (list[torch.Tensor]): Multi-scale feature maps of image. img_feats (list[torch.Tensor]): Multi-scale feature maps of image.
img_meta (list(dict)): Meta information of image and points. img_metas (list(dict)): Meta information of image and points.
Returns: Returns:
torch.Tensor: Fused features of each voxel. torch.Tensor: Fused features of each voxel.
...@@ -439,7 +439,7 @@ class HardVFE(nn.Module): ...@@ -439,7 +439,7 @@ class HardVFE(nn.Module):
point_feats = voxel_feats[mask] point_feats = voxel_feats[mask]
point_feats = self.fusion_layer(img_feats, points, point_feats, point_feats = self.fusion_layer(img_feats, points, point_feats,
img_meta) img_metas)
voxel_canvas = voxel_feats.new_zeros( voxel_canvas = voxel_feats.new_zeros(
size=(voxel_feats.size(0), voxel_feats.size(1), size=(voxel_feats.size(0), voxel_feats.size(1),
......
...@@ -27,7 +27,7 @@ def test_getitem(): ...@@ -27,7 +27,7 @@ def test_getitem():
dict(type='IndoorPointSample', num_points=5), dict(type='IndoorPointSample', num_points=5),
dict(type='IndoorFlipData', flip_ratio_yz=1.0, flip_ratio_xz=1.0), dict(type='IndoorFlipData', flip_ratio_yz=1.0, flip_ratio_xz=1.0),
dict( dict(
type='IndoorGlobalRotScale', type='IndoorGlobalRotScaleTrans',
shift_height=True, shift_height=True,
rot_range=[-1 / 36, 1 / 36], rot_range=[-1 / 36, 1 / 36],
scale_range=None), scale_range=None),
...@@ -50,11 +50,11 @@ def test_getitem(): ...@@ -50,11 +50,11 @@ def test_getitem():
gt_labels = data['gt_labels_3d']._data gt_labels = data['gt_labels_3d']._data
pts_semantic_mask = data['pts_semantic_mask']._data pts_semantic_mask = data['pts_semantic_mask']._data
pts_instance_mask = data['pts_instance_mask']._data pts_instance_mask = data['pts_instance_mask']._data
file_name = data['img_meta']._data['file_name'] file_name = data['img_metas']._data['file_name']
flip_xz = data['img_meta']._data['flip_xz'] flip_xz = data['img_metas']._data['flip_xz']
flip_yz = data['img_meta']._data['flip_yz'] flip_yz = data['img_metas']._data['flip_yz']
rot_angle = data['img_meta']._data['rot_angle'] rot_angle = data['img_metas']._data['rot_angle']
sample_idx = data['img_meta']._data['sample_idx'] sample_idx = data['img_metas']._data['sample_idx']
assert file_name == './tests/data/scannet/' \ assert file_name == './tests/data/scannet/' \
'points/scene0000_00.bin' 'points/scene0000_00.bin'
assert flip_xz is True assert flip_xz is True
......
...@@ -19,7 +19,7 @@ def test_getitem(): ...@@ -19,7 +19,7 @@ def test_getitem():
dict(type='LoadAnnotations3D'), dict(type='LoadAnnotations3D'),
dict(type='IndoorFlipData', flip_ratio_yz=1.0), dict(type='IndoorFlipData', flip_ratio_yz=1.0),
dict( dict(
type='IndoorGlobalRotScale', type='IndoorGlobalRotScaleTrans',
shift_height=True, shift_height=True,
rot_range=[-1 / 6, 1 / 6], rot_range=[-1 / 6, 1 / 6],
scale_range=[0.85, 1.15]), scale_range=[0.85, 1.15]),
...@@ -39,12 +39,12 @@ def test_getitem(): ...@@ -39,12 +39,12 @@ def test_getitem():
points = data['points']._data points = data['points']._data
gt_bboxes_3d = data['gt_bboxes_3d']._data gt_bboxes_3d = data['gt_bboxes_3d']._data
gt_labels_3d = data['gt_labels_3d']._data gt_labels_3d = data['gt_labels_3d']._data
file_name = data['img_meta']._data['file_name'] file_name = data['img_metas']._data['file_name']
flip_xz = data['img_meta']._data['flip_xz'] flip_xz = data['img_metas']._data['flip_xz']
flip_yz = data['img_meta']._data['flip_yz'] flip_yz = data['img_metas']._data['flip_yz']
scale_ratio = data['img_meta']._data['scale_ratio'] scale_ratio = data['img_metas']._data['scale_ratio']
rot_angle = data['img_meta']._data['rot_angle'] rot_angle = data['img_metas']._data['rot_angle']
sample_idx = data['img_meta']._data['sample_idx'] sample_idx = data['img_metas']._data['sample_idx']
assert file_name == './tests/data/sunrgbd' \ assert file_name == './tests/data/sunrgbd' \
'/points/000001.bin' '/points/000001.bin'
assert flip_xz is False assert flip_xz is False
......
...@@ -71,7 +71,7 @@ def test_anchor3d_head_loss(): ...@@ -71,7 +71,7 @@ def test_anchor3d_head_loss():
if not torch.cuda.is_available(): if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda') pytest.skip('test requires GPU and torch+cuda')
bbox_head_cfg = _get_head_cfg( bbox_head_cfg = _get_head_cfg(
'second/dv_second_secfpn_2x8_cosine_80e_kitti-3d-3class.py') 'second/hv_second_secfpn_6x8_80e_kitti-3d-3class.py')
from mmdet3d.models.builder import build_head from mmdet3d.models.builder import build_head
self = build_head(bbox_head_cfg) self = build_head(bbox_head_cfg)
...@@ -123,7 +123,7 @@ def test_anchor3d_head_getboxes(): ...@@ -123,7 +123,7 @@ def test_anchor3d_head_getboxes():
if not torch.cuda.is_available(): if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda') pytest.skip('test requires GPU and torch+cuda')
bbox_head_cfg = _get_head_cfg( bbox_head_cfg = _get_head_cfg(
'second/dv_second_secfpn_2x8_cosine_80e_kitti-3d-3class.py') 'second/hv_second_secfpn_6x8_80e_kitti-3d-3class.py')
from mmdet3d.models.builder import build_head from mmdet3d.models.builder import build_head
self = build_head(bbox_head_cfg) self = build_head(bbox_head_cfg)
......
...@@ -2,7 +2,8 @@ import numpy as np ...@@ -2,7 +2,8 @@ import numpy as np
import torch import torch
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.datasets.pipelines import IndoorFlipData, IndoorGlobalRotScale from mmdet3d.datasets.pipelines import (IndoorFlipData,
IndoorGlobalRotScaleTrans)
def test_indoor_flip_data(): def test_indoor_flip_data():
...@@ -69,7 +70,7 @@ def test_indoor_flip_data(): ...@@ -69,7 +70,7 @@ def test_indoor_flip_data():
def test_global_rot_scale(): def test_global_rot_scale():
np.random.seed(0) np.random.seed(0)
sunrgbd_augment = IndoorGlobalRotScale( sunrgbd_augment = IndoorGlobalRotScaleTrans(
True, rot_range=[-1 / 6, 1 / 6], scale_range=[0.85, 1.15]) True, rot_range=[-1 / 6, 1 / 6], scale_range=[0.85, 1.15])
sunrgbd_results = dict() sunrgbd_results = dict()
sunrgbd_results['points'] = np.array( sunrgbd_results['points'] = np.array(
...@@ -100,7 +101,7 @@ def test_global_rot_scale(): ...@@ -100,7 +101,7 @@ def test_global_rot_scale():
expected_sunrgbd_gt_bboxes_3d, 1e-3) expected_sunrgbd_gt_bboxes_3d, 1e-3)
np.random.seed(0) np.random.seed(0)
scannet_augment = IndoorGlobalRotScale( scannet_augment = IndoorGlobalRotScaleTrans(
True, rot_range=[-1 * 1 / 36, 1 / 36], scale_range=None) True, rot_range=[-1 * 1 / 36, 1 / 36], scale_range=None)
scannet_results = dict() scannet_results = dict()
scannet_results['points'] = np.array( scannet_results['points'] = np.array(
......
...@@ -30,7 +30,7 @@ def test_scannet_pipeline(): ...@@ -30,7 +30,7 @@ def test_scannet_pipeline():
dict(type='IndoorPointSample', num_points=5), dict(type='IndoorPointSample', num_points=5),
dict(type='IndoorFlipData', flip_ratio_yz=1.0, flip_ratio_xz=1.0), dict(type='IndoorFlipData', flip_ratio_yz=1.0, flip_ratio_xz=1.0),
dict( dict(
type='IndoorGlobalRotScale', type='IndoorGlobalRotScaleTrans',
shift_height=True, shift_height=True,
rot_range=[-1 / 36, 1 / 36], rot_range=[-1 / 36, 1 / 36],
scale_range=None), scale_range=None),
...@@ -113,7 +113,7 @@ def test_sunrgbd_pipeline(): ...@@ -113,7 +113,7 @@ def test_sunrgbd_pipeline():
dict(type='LoadAnnotations3D'), dict(type='LoadAnnotations3D'),
dict(type='IndoorFlipData', flip_ratio_yz=1.0), dict(type='IndoorFlipData', flip_ratio_yz=1.0),
dict( dict(
type='IndoorGlobalRotScale', type='IndoorGlobalRotScaleTrans',
shift_height=True, shift_height=True,
rot_range=[-1 / 6, 1 / 6], rot_range=[-1 / 6, 1 / 6],
scale_range=[0.85, 1.15]), scale_range=[0.85, 1.15]),
......
...@@ -16,14 +16,14 @@ def test_outdoor_aug_pipeline(): ...@@ -16,14 +16,14 @@ def test_outdoor_aug_pipeline():
dict( dict(
type='ObjectNoise', type='ObjectNoise',
num_try=100, num_try=100,
loc_noise_std=[1.0, 1.0, 0.5], translation_std=[1.0, 1.0, 0.5],
global_rot_range=[0.0, 0.0], global_rot_range=[0.0, 0.0],
rot_uniform_noise=[-0.78539816, 0.78539816]), rot_range=[-0.78539816, 0.78539816]),
dict(type='RandomFlip3D', flip_ratio=0.5), dict(type='RandomFlip3D', flip_ratio=0.5),
dict( dict(
type='GlobalRotScale', type='GlobalRotScaleTrans',
rot_uniform_noise=[-0.78539816, 0.78539816], rot_range=[-0.78539816, 0.78539816],
scaling_uniform_noise=[0.95, 1.05]), scale_ratio_range=[0.95, 1.05]),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'), dict(type='PointShuffle'),
...@@ -98,7 +98,7 @@ def test_outdoor_aug_pipeline(): ...@@ -98,7 +98,7 @@ def test_outdoor_aug_pipeline():
pts_filename='tests/data/kitti/a.bin', pts_filename='tests/data/kitti/a.bin',
ann_info=dict(gt_bboxes_3d=gt_bboxes_3d, gt_labels_3d=gt_labels_3d), ann_info=dict(gt_bboxes_3d=gt_bboxes_3d, gt_labels_3d=gt_labels_3d),
bbox3d_fields=[], bbox3d_fields=[],
) img_fields=[])
output = pipeline(results) output = pipeline(results)
...@@ -133,10 +133,10 @@ def test_outdoor_velocity_aug_pipeline(): ...@@ -133,10 +133,10 @@ def test_outdoor_velocity_aug_pipeline():
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4), dict(type='LoadPointsFromFile', load_dim=4, use_dim=4),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict( dict(
type='GlobalRotScale', type='GlobalRotScaleTrans',
rot_uniform_noise=[-0.3925, 0.3925], rot_range=[-0.3925, 0.3925],
scaling_uniform_noise=[0.95, 1.05], scale_ratio_range=[0.95, 1.05],
trans_normal_noise=[0, 0, 0]), translation_std=[0, 0, 0]),
dict(type='RandomFlip3D', flip_ratio=0.5), dict(type='RandomFlip3D', flip_ratio=0.5),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
...@@ -197,7 +197,7 @@ def test_outdoor_velocity_aug_pipeline(): ...@@ -197,7 +197,7 @@ def test_outdoor_velocity_aug_pipeline():
pts_filename='tests/data/kitti/a.bin', pts_filename='tests/data/kitti/a.bin',
ann_info=dict(gt_bboxes_3d=gt_bboxes_3d, gt_labels_3d=gt_labels_3d), ann_info=dict(gt_bboxes_3d=gt_bboxes_3d, gt_labels_3d=gt_labels_3d),
bbox3d_fields=[], bbox3d_fields=[],
) img_fields=[])
output = pipeline(results) output = pipeline(results)
......
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import mmcv import mmcv
import torch import torch
from mmcv import Config, DictAction
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from tools.fuse_conv_bn import fuse_module from tools.fuse_conv_bn import fuse_module
...@@ -13,37 +14,6 @@ from mmdet.apis import multi_gpu_test, set_random_seed, single_gpu_test ...@@ -13,37 +14,6 @@ from mmdet.apis import multi_gpu_test, set_random_seed, single_gpu_test
from mmdet.core import wrap_fp16_model from mmdet.core import wrap_fp16_model
class MultipleKVAction(argparse.Action):
"""
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary. List options should
be passed as comma separated values, i.e KEY=V1,V2,V3
"""
def _parse_int_float_bool(self, val):
try:
return int(val)
except ValueError:
pass
try:
return float(val)
except ValueError:
pass
if val.lower() in ['true', 'false']:
return True if val.lower() == 'true' else False
return val
def __call__(self, parser, namespace, values, option_string=None):
options = {}
for kv in values:
key, val = kv.split('=', maxsplit=1)
val = [self._parse_int_float_bool(v) for v in val.split(',')]
if len(val) == 1:
val = val[0]
options[key] = val
setattr(namespace, self.dest, options)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='MMDet test (and eval) a model') description='MMDet test (and eval) a model')
...@@ -82,7 +52,7 @@ def parse_args(): ...@@ -82,7 +52,7 @@ def parse_args():
action='store_true', action='store_true',
help='whether to set deterministic options for CUDNN backend.') help='whether to set deterministic options for CUDNN backend.')
parser.add_argument( parser.add_argument(
'--options', nargs='+', action=MultipleKVAction, help='custom options') '--options', nargs='+', action=DictAction, help='custom options')
parser.add_argument( parser.add_argument(
'--launcher', '--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'], choices=['none', 'pytorch', 'slurm', 'mpi'],
...@@ -109,7 +79,7 @@ def main(): ...@@ -109,7 +79,7 @@ def main():
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
raise ValueError('The output file must be a pkl file.') raise ValueError('The output file must be a pkl file.')
cfg = mmcv.Config.fromfile(args.config) cfg = Config.fromfile(args.config)
# set cudnn_benchmark # set cudnn_benchmark
if cfg.get('cudnn_benchmark', False): if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
...@@ -167,7 +137,7 @@ def main(): ...@@ -167,7 +137,7 @@ def main():
rank, _ = get_dist_info() rank, _ = get_dist_info()
if rank == 0: if rank == 0:
if args.out: if args.out:
print('\nwriting results to {}'.format(args.out)) print(f'\nwriting results to {args.out}')
mmcv.dump(outputs, args.out) mmcv.dump(outputs, args.out)
kwargs = {} if args.options is None else args.options kwargs = {} if args.options is None else args.options
if args.format_only: if args.format_only:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment