Commit 4040dbda authored by zhangwenwei's avatar zhangwenwei
Browse files

Refactor anchor generator and box coder

parent 148fea12
......@@ -8,8 +8,8 @@ import torch
import torch.utils.data as torch_data
from mmdet.datasets import DATASETS
from mmdet.datasets.pipelines import Compose
from ..core.bbox import box_np_ops
from .pipelines import Compose
from .utils import remove_dontcare
......
from pycocotools.coco import COCO
from mmdet3d.core.evaluation.coco_utils import getImgIds
from mmdet.datasets import DATASETS, CocoDataset
@DATASETS.register_module
class NuScenes2DDataset(CocoDataset):
CLASSES = ('car', 'truck', 'trailer', 'bus', 'construction_vehicle',
'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone',
'barrier')
def load_annotations(self, ann_file):
if not self.class_names:
self.class_names = self.CLASSES
self.coco = COCO(ann_file)
# send class_names into the get id
# in case we only need to train on several classes
# by default self.class_names = CLASSES
self.cat_ids = self.coco.getCatIds(catNms=self.class_names)
self.cat2label = {
cat_id: i # + 1 rm +1 here thus the 0-79 are fg, 80 is bg
for i, cat_id in enumerate(self.cat_ids)
}
# send cat ids to the get img id
# in case we only need to train on several classes
if len(self.cat_ids) < len(self.CLASSES):
self.img_ids = getImgIds(self.coco, catIds=self.cat_ids)
else:
self.img_ids = self.coco.getImgIds()
img_infos = []
for i in self.img_ids:
info = self.coco.loadImgs([i])[0]
info['filename'] = info['file_name']
img_infos.append(info)
return img_infos
......@@ -9,8 +9,8 @@ import torch.utils.data as torch_data
from nuscenes.utils.data_classes import Box as NuScenesBox
from mmdet.datasets import DATASETS
from mmdet.datasets.pipelines import Compose
from ..core.bbox import box_np_ops
from .pipelines import Compose
@DATASETS.register_module
......
from mmdet.datasets.pipelines import Compose
from .dbsampler import DataBaseSampler, MMDataBaseSampler
from .formating import DefaultFormatBundle, DefaultFormatBundle3D
from .loading import LoadMultiViewImageFromFiles, LoadPointsFromFile
from .train_aug import (GlobalRotScale, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D)
__all__ = [
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScale',
'PointShuffle', 'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D'
'PointShuffle', 'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D',
'Compose', 'LoadMultiViewImageFromFiles', 'LoadPointsFromFile',
'DefaultFormatBundle', 'DefaultFormatBundle3D', 'DataBaseSampler',
'MMDataBaseSampler'
]
......@@ -68,7 +68,7 @@ class DataBaseSampler(object):
db_infos = pickle.load(f)
# filter database infos
from mmdet3d.apis import get_root_logger
from mmdet.apis import get_root_logger
logger = get_root_logger()
for k, v in db_infos.items():
logger.info(f'load {len(v)} {k} database infos')
......
import numpy as np
from mmcv.parallel import DataContainer as DC
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import to_tensor
from mmdet.datasets.registry import PIPELINES
PIPELINES._module_dict.pop('DefaultFormatBundle')
......
......@@ -3,7 +3,7 @@ import os.path as osp
import mmcv
import numpy as np
from mmdet.datasets.registry import PIPELINES
from mmdet.datasets.builder import PIPELINES
@PIPELINES.register_module
......
import numpy as np
from mmcv.utils import build_from_cfg
from mmdet3d.core.bbox import box_np_ops
from mmdet3d.utils import build_from_cfg
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import RandomFlip
from mmdet.datasets.registry import PIPELINES
from ..registry import OBJECTSAMPLERS
from .data_augment_utils import noise_per_object_v3_
......
from mmdet.utils import Registry
from mmcv.utils import Registry
OBJECTSAMPLERS = Registry('Object sampler')
import numpy as np
import torch
from mmcv.cnn import normal_init
from mmcv.cnn import bias_init_with_prob, normal_init
from mmdet3d.core import box_torch_ops, boxes3d_to_bev_torch_lidar
from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu, nms_normal_gpu
from mmdet.models import HEADS
from ..utils import bias_init_with_prob
from .second_head import SECONDHead
......@@ -31,25 +30,25 @@ class Anchor3DVeloHead(SECONDHead):
in_channels,
train_cfg,
test_cfg,
cache_anchor=False,
feat_channels=256,
use_direction_classifier=True,
encode_bg_as_zeros=False,
box_code_size=9,
anchor_generator=dict(type='AnchorGeneratorRange', ),
anchor_range=[0, -39.68, -1.78, 69.12, 39.68, -1.78],
anchor_strides=[2],
anchor_sizes=[[1.6, 3.9, 1.56]],
anchor_rotations=[0, 1.57],
anchor_custom_values=[0, 0],
anchor_generator=dict(
type='Anchor3DRangeGenerator',
range=[0, -39.68, -1.78, 69.12, 39.68, -1.78],
strides=[2],
sizes=[[1.6, 3.9, 1.56]],
rotations=[0, 1.57],
custom_values=[0, 0],
reshape_out=True,
),
assigner_per_size=False,
assign_per_class=False,
diff_rad_by_sin=True,
dir_offset=0,
dir_limit_offset=1,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0),
bbox_coder=dict(type='Residual3DBoxCoder', ),
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
......@@ -58,14 +57,11 @@ class Anchor3DVeloHead(SECONDHead):
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
loss_dir=dict(type='CrossEntropyLoss', loss_weight=0.2)):
super().__init__(class_names, in_channels, train_cfg, test_cfg,
cache_anchor, feat_channels, use_direction_classifier,
feat_channels, use_direction_classifier,
encode_bg_as_zeros, box_code_size, anchor_generator,
anchor_range, anchor_strides, anchor_sizes,
anchor_rotations, anchor_custom_values,
assigner_per_size, assign_per_class, diff_rad_by_sin,
dir_offset, dir_limit_offset, target_means,
target_stds, bbox_coder, loss_cls, loss_bbox,
loss_dir)
dir_offset, dir_limit_offset, bbox_coder, loss_cls,
loss_bbox, loss_dir)
self.num_classes = num_classes
# build head layers & losses
if not self.use_sigmoid_cls:
......@@ -131,7 +127,7 @@ class Anchor3DVeloHead(SECONDHead):
scores = scores[topk_inds, :]
dir_cls_score = dir_cls_score[topk_inds]
bboxes = self.bbox_coder.decode_torch(anchors, bbox_pred,
bboxes = self.bbox_coder.decode(anchors, bbox_pred,
self.target_means,
self.target_stds)
mlvl_bboxes.append(bboxes)
......
from __future__ import division
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import normal_init
from mmcv.cnn import bias_init_with_prob, normal_init
from mmdet3d.core import (PseudoSampler, box_torch_ops,
boxes3d_to_bev_torch_lidar, build_anchor_generator,
......@@ -12,7 +10,6 @@ from mmdet3d.core import (PseudoSampler, box_torch_ops,
from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu, nms_normal_gpu
from mmdet.models import HEADS
from ..builder import build_loss
from ..utils import bias_init_with_prob
from .train_mixins import AnchorTrainMixin
......@@ -37,25 +34,24 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
in_channels,
train_cfg,
test_cfg,
cache_anchor=False,
feat_channels=256,
use_direction_classifier=True,
encode_bg_as_zeros=False,
box_code_size=7,
anchor_generator=dict(type='AnchorGeneratorRange'),
anchor_range=[0, -39.68, -1.78, 69.12, 39.68, -1.78],
anchor_strides=[2],
anchor_sizes=[[1.6, 3.9, 1.56]],
anchor_rotations=[0, 1.57],
anchor_custom_values=[],
anchor_generator=dict(
type='Anchor3DRangeGenerator',
range=[0, -39.68, -1.78, 69.12, 39.68, -1.78],
strides=[2],
sizes=[[1.6, 3.9, 1.56]],
rotations=[0, 1.57],
custom_values=[],
reshape_out=False),
assigner_per_size=False,
assign_per_class=False,
diff_rad_by_sin=True,
dir_offset=0,
dir_limit_offset=1,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0),
bbox_coder=dict(type='Residual3DBoxCoder'),
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
......@@ -94,29 +90,9 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
]
# build anchor generator
self.anchor_range = anchor_range
self.anchor_rotations = anchor_rotations
self.anchor_strides = anchor_strides
self.anchor_sizes = anchor_sizes
self.target_means = target_means
self.target_stds = target_stds
self.anchor_generators = []
self.anchor_generator = build_anchor_generator(anchor_generator)
# In 3D detection, the anchor stride is connected with anchor size
self.num_anchors = (
len(self.anchor_rotations) * len(self.anchor_sizes))
# if len(self.anchor_sizes) != self.anchor_strides:
# # this means different anchor in the same anchor strides
# anchor_sizes = [self.anchor_sizes]
for anchor_stride in self.anchor_strides:
anchor_generator.update(
anchor_ranges=anchor_range,
sizes=self.anchor_sizes,
stride=anchor_stride,
rotations=anchor_rotations,
custom_values=anchor_custom_values,
cache_anchor=cache_anchor)
self.anchor_generators.append(
build_anchor_generator(anchor_generator))
self.num_anchors = self.anchor_generator.num_base_anchors
self._init_layers()
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
......@@ -152,7 +128,7 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
def forward(self, feats):
return multi_apply(self.forward_single, feats)
def get_anchors(self, featmap_sizes, input_metas):
def get_anchors(self, featmap_sizes, input_metas, device='cuda'):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
......@@ -161,16 +137,10 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
tuple: anchors of each image, valid flags of each image
"""
num_imgs = len(input_metas)
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = []
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(featmap_sizes[i])
if not self.assigner_per_size:
anchors = anchors.reshape(-1, anchors.size(-1))
multi_level_anchors.append(anchors)
multi_level_anchors = self.anchor_generator.grid_anchors(
featmap_sizes, device=device)
anchor_list = [multi_level_anchors for _ in range(num_imgs)]
return anchor_list
......@@ -237,9 +207,10 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
input_metas,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators)
anchor_list = self.get_anchors(featmap_sizes, input_metas)
assert len(featmap_sizes) == self.anchor_generator.num_levels
device = cls_scores[0].device
anchor_list = self.get_anchors(
featmap_sizes, input_metas, device=device)
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = self.anchor_target_3d(
anchor_list,
......@@ -288,12 +259,14 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
assert len(cls_scores) == len(bbox_preds)
assert len(cls_scores) == len(dir_cls_preds)
num_levels = len(cls_scores)
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
device = cls_scores[0].device
mlvl_anchors = self.anchor_generators.grid_anchors(
featmap_sizes, device=device)
mlvl_anchors = [
self.anchor_generators[i].grid_anchors(
cls_scores[i].size()[-2:]).reshape(-1, self.box_code_size)
for i in range(num_levels)
anchor.reshape(-1, self.box_code_size) for anchor in mlvl_anchors
]
result_list = []
for img_id in range(len(input_metas)):
cls_score_list = [
......@@ -353,9 +326,7 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
bbox_pred = bbox_pred[thr_inds]
scores = scores[thr_inds]
dir_cls_scores = dir_cls_score[thr_inds]
bboxes = self.bbox_coder.decode_torch(anchors, bbox_pred,
self.target_means,
self.target_stds)
bboxes = self.bbox_coder.decode(anchors, bbox_pred)
bboxes_for_nms = boxes3d_to_bev_torch_lidar(bboxes)
mlvl_bboxes_for_nms.append(bboxes_for_nms)
mlvl_bboxes.append(bboxes)
......@@ -383,6 +354,7 @@ class SECONDHead(nn.Module, AnchorTrainMixin):
selected_scores = mlvl_scores[selected]
selected_label_preds = mlvl_label_preds[selected]
selected_dir_scores = mlvl_dir_scores[selected]
# TODO: move dir_offset to box coder
dir_rot = box_torch_ops.limit_period(
selected_bboxes[..., -1] - self.dir_offset,
self.dir_limit_offset, np.pi)
......
......@@ -197,9 +197,8 @@ class AnchorTrainMixin(object):
if gt_labels is not None:
labels += num_classes
if len(pos_inds) > 0:
pos_bbox_targets = self.bbox_coder.encode_torch(
sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes,
target_means, target_stds)
pos_bbox_targets = self.bbox_coder.encode(
sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
pos_dir_targets = get_direction_target(
sampling_result.pos_bboxes,
pos_bbox_targets,
......
from mmdet.models.builder import build
from mmdet.models.registry import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
ROI_EXTRACTORS, SHARED_HEADS)
from mmdet.models.builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
ROI_EXTRACTORS, SHARED_HEADS, build)
from .registry import FUSION_LAYERS, MIDDLE_ENCODERS, VOXEL_ENCODERS
......
from mmdet.utils import Registry
from mmcv.utils import Registry
VOXEL_ENCODERS = Registry('voxel_encoder')
MIDDLE_ENCODERS = Registry('middle_encoder')
......
from mmdet.models.utils import ResLayer, bias_init_with_prob
__all__ = ['bias_init_with_prob', 'ResLayer']
import numpy as np
import torch.nn as nn
def xavier_init(module, gain=1, bias=0, distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
nn.init.xavier_uniform_(module.weight, gain=gain)
else:
nn.init.xavier_normal_(module.weight, gain=gain)
if hasattr(module, 'bias'):
nn.init.constant_(module.bias, bias)
def normal_init(module, mean=0, std=1, bias=0):
nn.init.normal_(module.weight, mean, std)
if hasattr(module, 'bias'):
nn.init.constant_(module.bias, bias)
def uniform_init(module, a=0, b=1, bias=0):
nn.init.uniform_(module.weight, a, b)
if hasattr(module, 'bias'):
nn.init.constant_(module.bias, bias)
def kaiming_init(module,
mode='fan_out',
nonlinearity='relu',
bias=0,
distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
nn.init.kaiming_uniform_(
module.weight, mode=mode, nonlinearity=nonlinearity)
else:
nn.init.kaiming_normal_(
module.weight, mode=mode, nonlinearity=nonlinearity)
if hasattr(module, 'bias'):
nn.init.constant_(module.bias, bias)
def bias_init_with_prob(prior_prob):
""" initialize conv/fc bias value according to giving probablity"""
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init
from mmdet.utils import (Registry, build_from_cfg, get_model_complexity_info,
get_root_logger, print_log)
from mmcv.utils import Registry, build_from_cfg
from mmdet.utils import get_model_complexity_info, get_root_logger, print_log
from .collect_env import collect_env
__all__ = [
......
"""
CommandLine:
pytest tests/test_anchor.py
xdoctest tests/test_anchor.py zero
"""
import torch
def test_aligned_anchor_generator():
from mmdet3d.core.anchor import build_anchor_generator
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
anchor_generator_cfg = dict(
type='AlignedAnchor3DRangeGenerator',
ranges=[[-51.2, -51.2, -1.80, 51.2, 51.2, -1.80]],
strides=[1, 2, 4],
sizes=[
[0.8660, 2.5981, 1.], # 1.5/sqrt(3)
[0.5774, 1.7321, 1.], # 1/sqrt(3)
[1., 1., 1.],
[0.4, 0.4, 1],
],
custom_values=[0, 0],
rotations=[0, 1.57],
size_per_range=False,
reshape_out=True)
featmap_sizes = [(256, 256), (128, 128), (64, 64)]
anchor_generator = build_anchor_generator(anchor_generator_cfg)
assert anchor_generator.num_base_anchors == 8
# check base anchors
expected_grid_anchors = [
torch.tensor([[
-51.0000, -51.0000, -1.8000, 0.8660, 2.5981, 1.0000, 0.0000,
0.0000, 0.0000
],
[
-51.0000, -51.0000, -1.8000, 0.4000, 0.4000, 1.0000,
1.5700, 0.0000, 0.0000
],
[
-50.6000, -51.0000, -1.8000, 0.4000, 0.4000, 1.0000,
0.0000, 0.0000, 0.0000
],
[
-50.2000, -51.0000, -1.8000, 1.0000, 1.0000, 1.0000,
1.5700, 0.0000, 0.0000
],
[
-49.8000, -51.0000, -1.8000, 1.0000, 1.0000, 1.0000,
0.0000, 0.0000, 0.0000
],
[
-49.4000, -51.0000, -1.8000, 0.5774, 1.7321, 1.0000,
1.5700, 0.0000, 0.0000
],
[
-49.0000, -51.0000, -1.8000, 0.5774, 1.7321, 1.0000,
0.0000, 0.0000, 0.0000
],
[
-48.6000, -51.0000, -1.8000, 0.8660, 2.5981, 1.0000,
1.5700, 0.0000, 0.0000
]],
device=device),
torch.tensor([[
-50.8000, -50.8000, -1.8000, 1.7320, 5.1962, 2.0000, 0.0000,
0.0000, 0.0000
],
[
-50.8000, -50.8000, -1.8000, 0.8000, 0.8000, 2.0000,
1.5700, 0.0000, 0.0000
],
[
-50.0000, -50.8000, -1.8000, 0.8000, 0.8000, 2.0000,
0.0000, 0.0000, 0.0000
],
[
-49.2000, -50.8000, -1.8000, 2.0000, 2.0000, 2.0000,
1.5700, 0.0000, 0.0000
],
[
-48.4000, -50.8000, -1.8000, 2.0000, 2.0000, 2.0000,
0.0000, 0.0000, 0.0000
],
[
-47.6000, -50.8000, -1.8000, 1.1548, 3.4642, 2.0000,
1.5700, 0.0000, 0.0000
],
[
-46.8000, -50.8000, -1.8000, 1.1548, 3.4642, 2.0000,
0.0000, 0.0000, 0.0000
],
[
-46.0000, -50.8000, -1.8000, 1.7320, 5.1962, 2.0000,
1.5700, 0.0000, 0.0000
]],
device=device),
torch.tensor([[
-50.4000, -50.4000, -1.8000, 3.4640, 10.3924, 4.0000, 0.0000,
0.0000, 0.0000
],
[
-50.4000, -50.4000, -1.8000, 1.6000, 1.6000, 4.0000,
1.5700, 0.0000, 0.0000
],
[
-48.8000, -50.4000, -1.8000, 1.6000, 1.6000, 4.0000,
0.0000, 0.0000, 0.0000
],
[
-47.2000, -50.4000, -1.8000, 4.0000, 4.0000, 4.0000,
1.5700, 0.0000, 0.0000
],
[
-45.6000, -50.4000, -1.8000, 4.0000, 4.0000, 4.0000,
0.0000, 0.0000, 0.0000
],
[
-44.0000, -50.4000, -1.8000, 2.3096, 6.9284, 4.0000,
1.5700, 0.0000, 0.0000
],
[
-42.4000, -50.4000, -1.8000, 2.3096, 6.9284, 4.0000,
0.0000, 0.0000, 0.0000
],
[
-40.8000, -50.4000, -1.8000, 3.4640, 10.3924, 4.0000,
1.5700, 0.0000, 0.0000
]],
device=device)
]
multi_level_anchors = anchor_generator.grid_anchors(
featmap_sizes, device=device)
expected_multi_level_shapes = [
torch.Size([524288, 9]),
torch.Size([131072, 9]),
torch.Size([32768, 9])
]
for i, single_level_anchor in enumerate(multi_level_anchors):
assert single_level_anchor.shape == expected_multi_level_shapes[i]
# set [:56:7] thus it could cover 8 (len(size) * len(rotations))
# anchors on 8 location
assert single_level_anchor[:56:7].allclose(expected_grid_anchors[i])
......@@ -70,6 +70,34 @@ def test_config_build_detector():
# _check_bbox_head(head_config, detector.bbox_head)
def test_config_build_pipeline():
"""
Test that all detection models defined in the configs can be initialized.
"""
from mmcv import Config
from mmdet3d.datasets.pipelines import Compose
config_dpath = _get_config_directory()
print('Found config_dpath = {!r}'.format(config_dpath))
import glob
config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
config_names = [relpath(p, config_dpath) for p in config_fpaths]
print('Using {} config files'.format(len(config_names)))
for config_fname in config_names:
config_fpath = join(config_dpath, config_fname)
config_mod = Config.fromfile(config_fpath)
# build train_pipeline
train_pipeline = Compose(config_mod.train_pipeline)
test_pipeline = Compose(config_mod.test_pipeline)
assert train_pipeline is not None
assert test_pipeline is not None
def test_config_data_pipeline():
"""
Test whether the data pipeline is valid and can process corner cases.
......@@ -77,7 +105,7 @@ def test_config_data_pipeline():
xdoctest -m tests/test_config.py test_config_build_data_pipeline
"""
from mmcv import Config
from mmdet.datasets.pipelines import Compose
from mmdet3d.datasets.pipelines import Compose
import numpy as np
config_dpath = _get_config_directory()
......
......@@ -27,12 +27,18 @@ def parse_args():
'--validate',
action='store_true',
help='whether to evaluate the checkpoint during training')
parser.add_argument(
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
type=int,
default=1,
help='number of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='ids of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument(
'--deterministic',
......@@ -73,11 +79,14 @@ def main():
osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.gpus = args.gpus
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids
else:
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
if args.autoscale_lr:
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
cfg.optimizer['lr'] = cfg.optimizer['lr'] * cfg.gpus / 8
cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment