Unverified Commit e67b3f81 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Support to train using FP16 (#132)

* Support to train using FP16

* fix type inconsistency error on naive syncBN

* resolve comments

* clean nan check
parent e4320fb4
# Mixed Precision Training
## Introduction
We implement mixed precision training and apply it to VoxelNets (e.g., SECOND and PointPillars).
The results are in the following tables.
**Note**: For mixed precision training, we currently do not support PointNet-based methods (e.g., VoteNet).
Mixed precision training for PointNet-based methods will be supported in the future release.
## Results
### SECOND on KITTI dataset
| Backbone |Class| Lr schd | FP32 Mem (GB) | FP16 Mem (GB) | FP32 mAP | FP16 mAP |Download |
| :---------: | :-----: | :------: | :------------: | :----: |:----: | :------: | :------: |
| [SECFPN](./hv_second_secfpn_fp16_6x8_80e_kitti-3d-car.py)| Car |cyclic 80e|5.4|2.9|79.07|78.72||
| [SECFPN](./hv_second_secfpn_fp16_6x8_80e_kitti-3d-3class.py)| 3 Class |cyclic 80e|5.4|2.9|64.41|67.4||
### PointPillars on nuScenes dataset
**Note**: With mixed precision training, we can train PointPillars with RegNet-400mf on 8 Titan XP GPUS with batch size of 2.
This will cause OOM error without mixed precision training.
_base_ = '../pointpillars/hv_pointpillars_fpn_sbn-all_4x8_2x_nus-3d.py'
data = dict(samples_per_gpu=2, workers_per_gpu=2)
# fp16 settings
fp16 = dict(loss_scale=512.)
_base_ = '../regnet/hv_pointpillars_regnet-400mf_fpn_sbn-all_4x8_2x_nus-3d.py'
data = dict(samples_per_gpu=2, workers_per_gpu=2)
# fp16 settings
fp16 = dict(loss_scale=512.)
_base_ = '../pointpillars/hv_pointpillars_secfpn_sbn-all_4x8_2x_nus-3d.py'
data = dict(samples_per_gpu=2, workers_per_gpu=2)
# fp16 settings
fp16 = dict(loss_scale=512.)
_base_ = '../second/hv_second_secfpn_6x8_80e_kitti-3d-3class.py'
# fp16 settings
fp16 = dict(loss_scale=512.)
_base_ = '../second/hv_second_secfpn_6x8_80e_kitti-3d-car.py'
# fp16 settings
fp16 = dict(loss_scale=512.)
...@@ -8,6 +8,7 @@ class BasePointNet(nn.Module, metaclass=ABCMeta): ...@@ -8,6 +8,7 @@ class BasePointNet(nn.Module, metaclass=ABCMeta):
def __init__(self): def __init__(self):
super(BasePointNet, self).__init__() super(BasePointNet, self).__init__()
self.fp16_enabled = False
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
"""Initialize the weights of PointNet backbone.""" """Initialize the weights of PointNet backbone."""
......
import copy import copy
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import load_checkpoint from mmcv.runner import auto_fp16, load_checkpoint
from torch import nn as nn from torch import nn as nn
from mmdet.models import BACKBONES, build_backbone from mmdet.models import BACKBONES, build_backbone
...@@ -86,6 +86,7 @@ class MultiBackbone(nn.Module): ...@@ -86,6 +86,7 @@ class MultiBackbone(nn.Module):
logger = get_root_logger() logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger) load_checkpoint(self, pretrained, strict=False, logger=logger)
@auto_fp16()
def forward(self, points): def forward(self, points):
"""Forward pass. """Forward pass.
......
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import auto_fp16
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import build_sa_module from mmdet3d.ops import build_sa_module
...@@ -111,6 +112,7 @@ class PointNet2SAMSG(BasePointNet): ...@@ -111,6 +112,7 @@ class PointNet2SAMSG(BasePointNet):
bias=True)) bias=True))
sa_in_channel = aggregation_channels[sa_index] sa_in_channel = aggregation_channels[sa_index]
@auto_fp16(apply_to=('points', ))
def forward(self, points): def forward(self, points):
"""Forward pass. """Forward pass.
......
import torch import torch
from mmcv.runner import auto_fp16
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import PointFPModule, build_sa_module from mmdet3d.ops import PointFPModule, build_sa_module
...@@ -83,6 +84,7 @@ class PointNet2SASSG(BasePointNet): ...@@ -83,6 +84,7 @@ class PointNet2SASSG(BasePointNet):
fp_source_channel = cur_fp_mlps[-1] fp_source_channel = cur_fp_mlps[-1]
fp_target_channel = skip_channel_list.pop() fp_target_channel = skip_channel_list.pop()
@auto_fp16(apply_to=('points', ))
def forward(self, points): def forward(self, points):
"""Forward pass. """Forward pass.
......
import numpy as np import numpy as np
import torch import torch
from mmcv.cnn import bias_init_with_prob, normal_init from mmcv.cnn import bias_init_with_prob, normal_init
from mmcv.runner import force_fp32
from torch import nn as nn from torch import nn as nn
from mmdet3d.core import (PseudoSampler, box3d_multiclass_nms, limit_period, from mmdet3d.core import (PseudoSampler, box3d_multiclass_nms, limit_period,
...@@ -79,6 +80,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -79,6 +80,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
self.assign_per_class = assign_per_class self.assign_per_class = assign_per_class
self.dir_offset = dir_offset self.dir_offset = dir_offset
self.dir_limit_offset = dir_limit_offset self.dir_limit_offset = dir_limit_offset
self.fp16_enabled = False
# build anchor generator # build anchor generator
self.anchor_generator = build_anchor_generator(anchor_generator) self.anchor_generator = build_anchor_generator(anchor_generator)
...@@ -211,39 +213,61 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -211,39 +213,61 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
labels = labels.reshape(-1) labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1) label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.num_classes) cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
assert labels.max().item() <= self.num_classes
loss_cls = self.loss_cls( loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=num_total_samples) cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss # regression loss
bbox_pred = bbox_pred.permute(0, 2, 3,
1).reshape(-1, self.box_code_size)
bbox_targets = bbox_targets.reshape(-1, self.box_code_size) bbox_targets = bbox_targets.reshape(-1, self.box_code_size)
bbox_weights = bbox_weights.reshape(-1, self.box_code_size) bbox_weights = bbox_weights.reshape(-1, self.box_code_size)
code_weight = self.train_cfg.get('code_weight', None)
if code_weight: bg_class_ind = self.num_classes
bbox_weights = bbox_weights * bbox_weights.new_tensor(code_weight) pos_inds = ((labels >= 0)
bbox_pred = bbox_pred.permute(0, 2, 3, & (labels < bg_class_ind)).nonzero().reshape(-1)
1).reshape(-1, self.box_code_size) num_pos = len(pos_inds)
if self.diff_rad_by_sin:
bbox_pred, bbox_targets = self.add_sin_difference( pos_bbox_pred = bbox_pred[pos_inds]
bbox_pred, bbox_targets) pos_bbox_targets = bbox_targets[pos_inds]
loss_bbox = self.loss_bbox( pos_bbox_weights = bbox_weights[pos_inds]
bbox_pred,
bbox_targets, # dir loss
bbox_weights,
avg_factor=num_total_samples)
# direction classification loss
loss_dir = None
if self.use_direction_classifier: if self.use_direction_classifier:
dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).reshape(-1, 2) dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).reshape(-1, 2)
dir_targets = dir_targets.reshape(-1) dir_targets = dir_targets.reshape(-1)
dir_weights = dir_weights.reshape(-1) dir_weights = dir_weights.reshape(-1)
loss_dir = self.loss_dir( pos_dir_cls_preds = dir_cls_preds[pos_inds]
dir_cls_preds, pos_dir_targets = dir_targets[pos_inds]
dir_targets, pos_dir_weights = dir_weights[pos_inds]
dir_weights,
if num_pos > 0:
code_weight = self.train_cfg.get('code_weight', None)
if code_weight:
bbox_weights = bbox_weights * bbox_weights.new_tensor(
code_weight)
if self.diff_rad_by_sin:
pos_bbox_pred, pos_bbox_targets = self.add_sin_difference(
pos_bbox_pred, pos_bbox_targets)
loss_bbox = self.loss_bbox(
pos_bbox_pred,
pos_bbox_targets,
pos_bbox_weights,
avg_factor=num_total_samples) avg_factor=num_total_samples)
# direction classification loss
loss_dir = None
if self.use_direction_classifier:
loss_dir = self.loss_dir(
pos_dir_cls_preds,
pos_dir_targets,
pos_dir_weights,
avg_factor=num_total_samples)
else:
loss_bbox = pos_bbox_pred.sum()
if self.use_direction_classifier:
loss_dir = pos_dir_cls_preds.sum()
return loss_cls, loss_bbox, loss_dir return loss_cls, loss_bbox, loss_dir
@staticmethod @staticmethod
...@@ -270,6 +294,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -270,6 +294,7 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
dim=-1) dim=-1)
return boxes1, boxes2 return boxes1, boxes2
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds'))
def loss(self, def loss(self,
cls_scores, cls_scores,
bbox_preds, bbox_preds,
......
...@@ -2,6 +2,7 @@ import copy ...@@ -2,6 +2,7 @@ import copy
import numpy as np import numpy as np
import torch import torch
from mmcv.cnn import ConvModule, build_conv_layer, kaiming_init from mmcv.cnn import ConvModule, build_conv_layer, kaiming_init
from mmcv.runner import force_fp32
from torch import nn from torch import nn
from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius, from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius,
...@@ -228,7 +229,7 @@ class DCNSeperateHead(nn.Module): ...@@ -228,7 +229,7 @@ class DCNSeperateHead(nn.Module):
return ret return ret
@HEADS.register_module @HEADS.register_module()
class CenterHead(nn.Module): class CenterHead(nn.Module):
"""CenterHead for CenterPoint. """CenterHead for CenterPoint.
...@@ -292,6 +293,7 @@ class CenterHead(nn.Module): ...@@ -292,6 +293,7 @@ class CenterHead(nn.Module):
self.loss_bbox = build_loss(loss_bbox) self.loss_bbox = build_loss(loss_bbox)
self.bbox_coder = build_bbox_coder(bbox_coder) self.bbox_coder = build_bbox_coder(bbox_coder)
self.num_anchor_per_locs = [n for n in num_classes] self.num_anchor_per_locs = [n for n in num_classes]
self.fp16_enabled = False
# a shared convolution # a shared convolution
self.shared_conv = ConvModule( self.shared_conv = ConvModule(
...@@ -548,6 +550,7 @@ class CenterHead(nn.Module): ...@@ -548,6 +550,7 @@ class CenterHead(nn.Module):
inds.append(ind) inds.append(ind)
return heatmaps, anno_boxes, inds, masks return heatmaps, anno_boxes, inds, masks
@force_fp32(apply_to=('preds_dicts'))
def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs): def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs):
"""Loss function for CenterHead. """Loss function for CenterHead.
......
import torch import torch
from mmcv.runner import force_fp32
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core.bbox import bbox_overlaps_nearest_3d from mmdet3d.core.bbox import bbox_overlaps_nearest_3d
...@@ -38,6 +39,7 @@ class FreeAnchor3DHead(Anchor3DHead): ...@@ -38,6 +39,7 @@ class FreeAnchor3DHead(Anchor3DHead):
self.gamma = gamma self.gamma = gamma
self.alpha = alpha self.alpha = alpha
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds'))
def loss(self, def loss(self,
cls_scores, cls_scores,
bbox_preds, bbox_preds,
......
...@@ -2,6 +2,7 @@ from __future__ import division ...@@ -2,6 +2,7 @@ from __future__ import division
import numpy as np import numpy as np
import torch import torch
from mmcv.runner import force_fp32
from mmdet3d.core import limit_period, xywhr2xyxyr from mmdet3d.core import limit_period, xywhr2xyxyr
from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu, nms_normal_gpu from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu, nms_normal_gpu
...@@ -81,6 +82,7 @@ class PartA2RPNHead(Anchor3DHead): ...@@ -81,6 +82,7 @@ class PartA2RPNHead(Anchor3DHead):
diff_rad_by_sin, dir_offset, dir_limit_offset, diff_rad_by_sin, dir_offset, dir_limit_offset,
bbox_coder, loss_cls, loss_bbox, loss_dir) bbox_coder, loss_cls, loss_bbox, loss_dir)
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds'))
def loss(self, def loss(self,
cls_scores, cls_scores,
bbox_preds, bbox_preds,
......
import torch import torch
from mmcv.ops.nms import batched_nms from mmcv.ops.nms import batched_nms
from mmcv.runner import force_fp32
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes, from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes,
...@@ -108,6 +109,7 @@ class SSD3DHead(VoteHead): ...@@ -108,6 +109,7 @@ class SSD3DHead(VoteHead):
return seed_points, seed_features, seed_indices return seed_points, seed_features, seed_indices
@force_fp32(apply_to=('bbox_preds', ))
def loss(self, def loss(self,
bbox_preds, bbox_preds,
points, points,
......
import numpy as np import numpy as np
import torch import torch
from mmcv.runner import force_fp32
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
...@@ -78,6 +79,7 @@ class VoteHead(nn.Module): ...@@ -78,6 +79,7 @@ class VoteHead(nn.Module):
self.vote_module = VoteModule(**vote_module_cfg) self.vote_module = VoteModule(**vote_module_cfg)
self.vote_aggregation = build_sa_module(vote_aggregation_cfg) self.vote_aggregation = build_sa_module(vote_aggregation_cfg)
self.fp16_enabled = False
# Bbox classification and regression # Bbox classification and regression
self.conv_pred = BaseConvBboxHead( self.conv_pred = BaseConvBboxHead(
...@@ -204,6 +206,7 @@ class VoteHead(nn.Module): ...@@ -204,6 +206,7 @@ class VoteHead(nn.Module):
return results return results
@force_fp32(apply_to=('bbox_preds', ))
def loss(self, def loss(self,
bbox_preds, bbox_preds,
points, points,
......
...@@ -2,6 +2,7 @@ import copy ...@@ -2,6 +2,7 @@ import copy
import mmcv import mmcv
import torch import torch
from mmcv.parallel import DataContainer as DC from mmcv.parallel import DataContainer as DC
from mmcv.runner import auto_fp16
from os import path as osp from os import path as osp
from mmdet3d.core import Box3DMode, show_result from mmdet3d.core import Box3DMode, show_result
...@@ -42,6 +43,7 @@ class Base3DDetector(BaseDetector): ...@@ -42,6 +43,7 @@ class Base3DDetector(BaseDetector):
else: else:
return self.aug_test(points, img_metas, img, **kwargs) return self.aug_test(points, img_metas, img, **kwargs)
@auto_fp16(apply_to=('img', 'points'))
def forward(self, return_loss=True, **kwargs): def forward(self, return_loss=True, **kwargs):
"""Calls either forward_train or forward_test depending on whether """Calls either forward_train or forward_test depending on whether
return_loss=True. return_loss=True.
......
import torch import torch
from mmcv.runner import force_fp32
from torch.nn import functional as F from torch.nn import functional as F
from mmdet.models import DETECTORS from mmdet.models import DETECTORS
...@@ -44,6 +45,7 @@ class DynamicVoxelNet(VoxelNet): ...@@ -44,6 +45,7 @@ class DynamicVoxelNet(VoxelNet):
return x return x
@torch.no_grad() @torch.no_grad()
@force_fp32()
def voxelize(self, points): def voxelize(self, points):
"""Apply dynamic voxelization to points. """Apply dynamic voxelization to points.
......
import torch import torch
from mmcv.runner import force_fp32
from torch.nn import functional as F from torch.nn import functional as F
from mmdet.models import DETECTORS from mmdet.models import DETECTORS
...@@ -21,6 +22,7 @@ class DynamicMVXFasterRCNN(MVXTwoStageDetector): ...@@ -21,6 +22,7 @@ class DynamicMVXFasterRCNN(MVXTwoStageDetector):
super(DynamicMVXFasterRCNN, self).__init__(**kwargs) super(DynamicMVXFasterRCNN, self).__init__(**kwargs)
@torch.no_grad() @torch.no_grad()
@force_fp32()
def voxelize(self, points): def voxelize(self, points):
"""Apply dynamic voxelization to points. """Apply dynamic voxelization to points.
......
...@@ -2,6 +2,7 @@ import copy ...@@ -2,6 +2,7 @@ import copy
import mmcv import mmcv
import torch import torch
from mmcv.parallel import DataContainer as DC from mmcv.parallel import DataContainer as DC
from mmcv.runner import force_fp32
from os import path as osp from os import path as osp
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
...@@ -203,6 +204,7 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -203,6 +204,7 @@ class MVXTwoStageDetector(Base3DDetector):
return (img_feats, pts_feats) return (img_feats, pts_feats)
@torch.no_grad() @torch.no_grad()
@force_fp32()
def voxelize(self, points): def voxelize(self, points):
"""Apply dynamic voxelization to points. """Apply dynamic voxelization to points.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment