Unverified Commit e67b3f81 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Support to train using FP16 (#132)

* Support to train using FP16

* fix type inconsistency error on naive syncBN

* resolve comments

* clean nan check
parent e4320fb4
# Mixed Precision Training
## Introduction
We implement mixed precision training and apply it to VoxelNets (e.g., SECOND and PointPillars).
The results are in the following tables.
**Note**: For mixed precision training, we currently do not support PointNet-based methods (e.g., VoteNet).
Mixed precision training for PointNet-based methods will be supported in the future release.
## Results
### SECOND on KITTI dataset
| Backbone |Class| Lr schd | FP32 Mem (GB) | FP16 Mem (GB) | FP32 mAP | FP16 mAP |Download |
| :---------: | :-----: | :------: | :------------: | :----: |:----: | :------: | :------: |
| [SECFPN](./hv_second_secfpn_fp16_6x8_80e_kitti-3d-car.py)| Car |cyclic 80e|5.4|2.9|79.07|78.72||
| [SECFPN](./hv_second_secfpn_fp16_6x8_80e_kitti-3d-3class.py)| 3 Class |cyclic 80e|5.4|2.9|64.41|67.4||
### PointPillars on nuScenes dataset
**Note**: With mixed precision training, we can train PointPillars with RegNet-400mf on 8 Titan XP GPUS with batch size of 2.
This will cause OOM error without mixed precision training.
_base_ = '../pointpillars/hv_pointpillars_fpn_sbn-all_4x8_2x_nus-3d.py'
data = dict(samples_per_gpu=2, workers_per_gpu=2)
# fp16 settings
fp16 = dict(loss_scale=512.)
_base_ = '../regnet/hv_pointpillars_regnet-400mf_fpn_sbn-all_4x8_2x_nus-3d.py'
data = dict(samples_per_gpu=2, workers_per_gpu=2)
# fp16 settings
fp16 = dict(loss_scale=512.)
_base_ = '../pointpillars/hv_pointpillars_secfpn_sbn-all_4x8_2x_nus-3d.py'
data = dict(samples_per_gpu=2, workers_per_gpu=2)
# fp16 settings
fp16 = dict(loss_scale=512.)
_base_ = '../second/hv_second_secfpn_6x8_80e_kitti-3d-3class.py'
# fp16 settings
fp16 = dict(loss_scale=512.)
_base_ = '../second/hv_second_secfpn_6x8_80e_kitti-3d-car.py'
# fp16 settings
fp16 = dict(loss_scale=512.)
......@@ -8,6 +8,7 @@ class BasePointNet(nn.Module, metaclass=ABCMeta):
def __init__(self):
super(BasePointNet, self).__init__()
self.fp16_enabled = False
def init_weights(self, pretrained=None):
"""Initialize the weights of PointNet backbone."""
......
import copy
import torch
from mmcv.cnn import ConvModule
from mmcv.runner import load_checkpoint
from mmcv.runner import auto_fp16, load_checkpoint
from torch import nn as nn
from mmdet.models import BACKBONES, build_backbone
......@@ -86,6 +86,7 @@ class MultiBackbone(nn.Module):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
@auto_fp16()
def forward(self, points):
"""Forward pass.
......
import torch
from mmcv.cnn import ConvModule
from mmcv.runner import auto_fp16
from torch import nn as nn
from mmdet3d.ops import build_sa_module
......@@ -111,6 +112,7 @@ class PointNet2SAMSG(BasePointNet):
bias=True))
sa_in_channel = aggregation_channels[sa_index]
@auto_fp16(apply_to=('points', ))
def forward(self, points):
"""Forward pass.
......
import torch
from mmcv.runner import auto_fp16
from torch import nn as nn
from mmdet3d.ops import PointFPModule, build_sa_module
......@@ -83,6 +84,7 @@ class PointNet2SASSG(BasePointNet):
fp_source_channel = cur_fp_mlps[-1]
fp_target_channel = skip_channel_list.pop()
@auto_fp16(apply_to=('points', ))
def forward(self, points):
"""Forward pass.
......
import numpy as np
import torch
from mmcv.cnn import bias_init_with_prob, normal_init
from mmcv.runner import force_fp32
from torch import nn as nn
from mmdet3d.core import (PseudoSampler, box3d_multiclass_nms, limit_period,
......@@ -79,6 +80,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
self.assign_per_class = assign_per_class
self.dir_offset = dir_offset
self.dir_limit_offset = dir_limit_offset
self.fp16_enabled = False
# build anchor generator
self.anchor_generator = build_anchor_generator(anchor_generator)
......@@ -211,38 +213,60 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
assert labels.max().item() <= self.num_classes
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss
bbox_pred = bbox_pred.permute(0, 2, 3,
1).reshape(-1, self.box_code_size)
bbox_targets = bbox_targets.reshape(-1, self.box_code_size)
bbox_weights = bbox_weights.reshape(-1, self.box_code_size)
code_weight = self.train_cfg.get('code_weight', None)
bg_class_ind = self.num_classes
pos_inds = ((labels >= 0)
& (labels < bg_class_ind)).nonzero().reshape(-1)
num_pos = len(pos_inds)
pos_bbox_pred = bbox_pred[pos_inds]
pos_bbox_targets = bbox_targets[pos_inds]
pos_bbox_weights = bbox_weights[pos_inds]
# dir loss
if self.use_direction_classifier:
dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).reshape(-1, 2)
dir_targets = dir_targets.reshape(-1)
dir_weights = dir_weights.reshape(-1)
pos_dir_cls_preds = dir_cls_preds[pos_inds]
pos_dir_targets = dir_targets[pos_inds]
pos_dir_weights = dir_weights[pos_inds]
if num_pos > 0:
code_weight = self.train_cfg.get('code_weight', None)
if code_weight:
bbox_weights = bbox_weights * bbox_weights.new_tensor(code_weight)
bbox_pred = bbox_pred.permute(0, 2, 3,
1).reshape(-1, self.box_code_size)
bbox_weights = bbox_weights * bbox_weights.new_tensor(
code_weight)
if self.diff_rad_by_sin:
bbox_pred, bbox_targets = self.add_sin_difference(
bbox_pred, bbox_targets)
pos_bbox_pred, pos_bbox_targets = self.add_sin_difference(
pos_bbox_pred, pos_bbox_targets)
loss_bbox = self.loss_bbox(
bbox_pred,
bbox_targets,
bbox_weights,
pos_bbox_pred,
pos_bbox_targets,
pos_bbox_weights,
avg_factor=num_total_samples)
# direction classification loss
loss_dir = None
if self.use_direction_classifier:
dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).reshape(-1, 2)
dir_targets = dir_targets.reshape(-1)
dir_weights = dir_weights.reshape(-1)
loss_dir = self.loss_dir(
dir_cls_preds,
dir_targets,
dir_weights,
pos_dir_cls_preds,
pos_dir_targets,
pos_dir_weights,
avg_factor=num_total_samples)
else:
loss_bbox = pos_bbox_pred.sum()
if self.use_direction_classifier:
loss_dir = pos_dir_cls_preds.sum()
return loss_cls, loss_bbox, loss_dir
......@@ -270,6 +294,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
dim=-1)
return boxes1, boxes2
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds'))
def loss(self,
cls_scores,
bbox_preds,
......
......@@ -2,6 +2,7 @@ import copy
import numpy as np
import torch
from mmcv.cnn import ConvModule, build_conv_layer, kaiming_init
from mmcv.runner import force_fp32
from torch import nn
from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius,
......@@ -228,7 +229,7 @@ class DCNSeperateHead(nn.Module):
return ret
@HEADS.register_module
@HEADS.register_module()
class CenterHead(nn.Module):
"""CenterHead for CenterPoint.
......@@ -292,6 +293,7 @@ class CenterHead(nn.Module):
self.loss_bbox = build_loss(loss_bbox)
self.bbox_coder = build_bbox_coder(bbox_coder)
self.num_anchor_per_locs = [n for n in num_classes]
self.fp16_enabled = False
# a shared convolution
self.shared_conv = ConvModule(
......@@ -548,6 +550,7 @@ class CenterHead(nn.Module):
inds.append(ind)
return heatmaps, anno_boxes, inds, masks
@force_fp32(apply_to=('preds_dicts'))
def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs):
"""Loss function for CenterHead.
......
import torch
from mmcv.runner import force_fp32
from torch.nn import functional as F
from mmdet3d.core.bbox import bbox_overlaps_nearest_3d
......@@ -38,6 +39,7 @@ class FreeAnchor3DHead(Anchor3DHead):
self.gamma = gamma
self.alpha = alpha
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds'))
def loss(self,
cls_scores,
bbox_preds,
......
......@@ -2,6 +2,7 @@ from __future__ import division
import numpy as np
import torch
from mmcv.runner import force_fp32
from mmdet3d.core import limit_period, xywhr2xyxyr
from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu, nms_normal_gpu
......@@ -81,6 +82,7 @@ class PartA2RPNHead(Anchor3DHead):
diff_rad_by_sin, dir_offset, dir_limit_offset,
bbox_coder, loss_cls, loss_bbox, loss_dir)
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds'))
def loss(self,
cls_scores,
bbox_preds,
......
import torch
from mmcv.ops.nms import batched_nms
from mmcv.runner import force_fp32
from torch.nn import functional as F
from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes,
......@@ -108,6 +109,7 @@ class SSD3DHead(VoteHead):
return seed_points, seed_features, seed_indices
@force_fp32(apply_to=('bbox_preds', ))
def loss(self,
bbox_preds,
points,
......
import numpy as np
import torch
from mmcv.runner import force_fp32
from torch import nn as nn
from torch.nn import functional as F
......@@ -78,6 +79,7 @@ class VoteHead(nn.Module):
self.vote_module = VoteModule(**vote_module_cfg)
self.vote_aggregation = build_sa_module(vote_aggregation_cfg)
self.fp16_enabled = False
# Bbox classification and regression
self.conv_pred = BaseConvBboxHead(
......@@ -204,6 +206,7 @@ class VoteHead(nn.Module):
return results
@force_fp32(apply_to=('bbox_preds', ))
def loss(self,
bbox_preds,
points,
......
......@@ -2,6 +2,7 @@ import copy
import mmcv
import torch
from mmcv.parallel import DataContainer as DC
from mmcv.runner import auto_fp16
from os import path as osp
from mmdet3d.core import Box3DMode, show_result
......@@ -42,6 +43,7 @@ class Base3DDetector(BaseDetector):
else:
return self.aug_test(points, img_metas, img, **kwargs)
@auto_fp16(apply_to=('img', 'points'))
def forward(self, return_loss=True, **kwargs):
"""Calls either forward_train or forward_test depending on whether
return_loss=True.
......
import torch
from mmcv.runner import force_fp32
from torch.nn import functional as F
from mmdet.models import DETECTORS
......@@ -44,6 +45,7 @@ class DynamicVoxelNet(VoxelNet):
return x
@torch.no_grad()
@force_fp32()
def voxelize(self, points):
"""Apply dynamic voxelization to points.
......
import torch
from mmcv.runner import force_fp32
from torch.nn import functional as F
from mmdet.models import DETECTORS
......@@ -21,6 +22,7 @@ class DynamicMVXFasterRCNN(MVXTwoStageDetector):
super(DynamicMVXFasterRCNN, self).__init__(**kwargs)
@torch.no_grad()
@force_fp32()
def voxelize(self, points):
"""Apply dynamic voxelization to points.
......
......@@ -2,6 +2,7 @@ import copy
import mmcv
import torch
from mmcv.parallel import DataContainer as DC
from mmcv.runner import force_fp32
from os import path as osp
from torch import nn as nn
from torch.nn import functional as F
......@@ -203,6 +204,7 @@ class MVXTwoStageDetector(Base3DDetector):
return (img_feats, pts_feats)
@torch.no_grad()
@force_fp32()
def voxelize(self, points):
"""Apply dynamic voxelization to points.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment