Unverified Commit e67b3f81 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Support to train using FP16 (#132)

* Support to train using FP16

* fix type inconsistency error on naive syncBN

* resolve comments

* clean nan check
parent e4320fb4
import torch import torch
from mmcv.runner import force_fp32
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
...@@ -46,6 +47,7 @@ class VoxelNet(SingleStage3DDetector): ...@@ -46,6 +47,7 @@ class VoxelNet(SingleStage3DDetector):
return x return x
@torch.no_grad() @torch.no_grad()
@force_fp32()
def voxelize(self, points): def voxelize(self, points):
"""Apply hard voxelization to points.""" """Apply hard voxelization to points."""
voxels, coors, num_points = [], [], [] voxels, coors, num_points = [], [], []
......
import torch import torch
from mmcv.runner import auto_fp16
from torch import nn from torch import nn
from ..registry import MIDDLE_ENCODERS from ..registry import MIDDLE_ENCODERS
...@@ -21,7 +22,9 @@ class PointPillarsScatter(nn.Module): ...@@ -21,7 +22,9 @@ class PointPillarsScatter(nn.Module):
self.ny = output_shape[0] self.ny = output_shape[0]
self.nx = output_shape[1] self.nx = output_shape[1]
self.in_channels = in_channels self.in_channels = in_channels
self.fp16_enabled = False
@auto_fp16(apply_to=('voxel_features', ))
def forward(self, voxel_features, coors, batch_size=None): def forward(self, voxel_features, coors, batch_size=None):
"""Foraward function to scatter features.""" """Foraward function to scatter features."""
# TODO: rewrite the function in a batch manner # TODO: rewrite the function in a batch manner
......
from mmcv.runner import auto_fp16
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
...@@ -49,6 +50,7 @@ class SparseEncoder(nn.Module): ...@@ -49,6 +50,7 @@ class SparseEncoder(nn.Module):
self.encoder_channels = encoder_channels self.encoder_channels = encoder_channels
self.encoder_paddings = encoder_paddings self.encoder_paddings = encoder_paddings
self.stage_num = len(self.encoder_channels) self.stage_num = len(self.encoder_channels)
self.fp16_enabled = False
# Spconv init all weight on its own # Spconv init all weight on its own
assert isinstance(order, tuple) and len(order) == 3 assert isinstance(order, tuple) and len(order) == 3
...@@ -90,6 +92,7 @@ class SparseEncoder(nn.Module): ...@@ -90,6 +92,7 @@ class SparseEncoder(nn.Module):
indice_key='spconv_down2', indice_key='spconv_down2',
conv_type='SparseConv3d') conv_type='SparseConv3d')
@auto_fp16(apply_to=('voxel_features', ))
def forward(self, voxel_features, coors, batch_size): def forward(self, voxel_features, coors, batch_size):
"""Forward of SparseEncoder. """Forward of SparseEncoder.
......
import torch import torch
from mmcv.runner import auto_fp16
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
...@@ -51,6 +52,7 @@ class SparseUNet(nn.Module): ...@@ -51,6 +52,7 @@ class SparseUNet(nn.Module):
self.decoder_channels = decoder_channels self.decoder_channels = decoder_channels
self.decoder_paddings = decoder_paddings self.decoder_paddings = decoder_paddings
self.stage_num = len(self.encoder_channels) self.stage_num = len(self.encoder_channels)
self.fp16_enabled = False
# Spconv init all weight on its own # Spconv init all weight on its own
assert isinstance(order, tuple) and len(order) == 3 assert isinstance(order, tuple) and len(order) == 3
...@@ -91,6 +93,7 @@ class SparseUNet(nn.Module): ...@@ -91,6 +93,7 @@ class SparseUNet(nn.Module):
indice_key='spconv_down2', indice_key='spconv_down2',
conv_type='SparseConv3d') conv_type='SparseConv3d')
@auto_fp16(apply_to=('voxel_features', ))
def forward(self, voxel_features, coors, batch_size): def forward(self, voxel_features, coors, batch_size):
"""Forward of SparseUNet. """Forward of SparseUNet.
......
...@@ -2,6 +2,7 @@ import numpy as np ...@@ -2,6 +2,7 @@ import numpy as np
import torch import torch
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer, from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer,
constant_init, is_norm, kaiming_init) constant_init, is_norm, kaiming_init)
from mmcv.runner import auto_fp16
from torch import nn as nn from torch import nn as nn
from mmdet.models import NECKS from mmdet.models import NECKS
...@@ -36,6 +37,7 @@ class SECONDFPN(nn.Module): ...@@ -36,6 +37,7 @@ class SECONDFPN(nn.Module):
assert len(out_channels) == len(upsample_strides) == len(in_channels) assert len(out_channels) == len(upsample_strides) == len(in_channels)
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.fp16_enabled = False
deblocks = [] deblocks = []
for i, out_channel in enumerate(out_channels): for i, out_channel in enumerate(out_channels):
...@@ -70,6 +72,7 @@ class SECONDFPN(nn.Module): ...@@ -70,6 +72,7 @@ class SECONDFPN(nn.Module):
elif is_norm(m): elif is_norm(m):
constant_init(m, 1) constant_init(m, 1)
@auto_fp16()
def forward(self, x): def forward(self, x):
"""Forward function. """Forward function.
......
import torch import torch
from mmcv.cnn import build_norm_layer from mmcv.cnn import build_norm_layer
from mmcv.runner import force_fp32
from torch import nn from torch import nn
from mmdet3d.ops import DynamicScatter from mmdet3d.ops import DynamicScatter
...@@ -58,7 +59,7 @@ class PillarFeatureNet(nn.Module): ...@@ -58,7 +59,7 @@ class PillarFeatureNet(nn.Module):
self._with_distance = with_distance self._with_distance = with_distance
self._with_cluster_center = with_cluster_center self._with_cluster_center = with_cluster_center
self._with_voxel_center = with_voxel_center self._with_voxel_center = with_voxel_center
self.fp16_enabled = False
# Create PillarFeatureNet layers # Create PillarFeatureNet layers
self.in_channels = in_channels self.in_channels = in_channels
feat_channels = [in_channels] + list(feat_channels) feat_channels = [in_channels] + list(feat_channels)
...@@ -86,6 +87,7 @@ class PillarFeatureNet(nn.Module): ...@@ -86,6 +87,7 @@ class PillarFeatureNet(nn.Module):
self.y_offset = self.vy / 2 + point_cloud_range[1] self.y_offset = self.vy / 2 + point_cloud_range[1]
self.point_cloud_range = point_cloud_range self.point_cloud_range = point_cloud_range
@force_fp32(out_fp16=True)
def forward(self, features, num_points, coors): def forward(self, features, num_points, coors):
"""Forward function. """Forward function.
...@@ -196,7 +198,7 @@ class DynamicPillarFeatureNet(PillarFeatureNet): ...@@ -196,7 +198,7 @@ class DynamicPillarFeatureNet(PillarFeatureNet):
point_cloud_range=point_cloud_range, point_cloud_range=point_cloud_range,
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
mode=mode) mode=mode)
self.fp16_enabled = False
feat_channels = [self.in_channels] + list(feat_channels) feat_channels = [self.in_channels] + list(feat_channels)
pfn_layers = [] pfn_layers = []
# TODO: currently only support one PFNLayer # TODO: currently only support one PFNLayer
...@@ -257,6 +259,7 @@ class DynamicPillarFeatureNet(PillarFeatureNet): ...@@ -257,6 +259,7 @@ class DynamicPillarFeatureNet(PillarFeatureNet):
center_per_point = canvas[:, voxel_index.long()].t() center_per_point = canvas[:, voxel_index.long()].t()
return center_per_point return center_per_point
@force_fp32(out_fp16=True)
def forward(self, features, coors): def forward(self, features, coors):
"""Forward function. """Forward function.
......
import torch import torch
from mmcv.cnn import build_norm_layer from mmcv.cnn import build_norm_layer
from mmcv.runner import auto_fp16
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
...@@ -51,6 +52,7 @@ class VFELayer(nn.Module): ...@@ -51,6 +52,7 @@ class VFELayer(nn.Module):
max_out=True, max_out=True,
cat_max=True): cat_max=True):
super(VFELayer, self).__init__() super(VFELayer, self).__init__()
self.fp16_enabled = False
self.cat_max = cat_max self.cat_max = cat_max
self.max_out = max_out self.max_out = max_out
# self.units = int(out_channels / 2) # self.units = int(out_channels / 2)
...@@ -58,6 +60,7 @@ class VFELayer(nn.Module): ...@@ -58,6 +60,7 @@ class VFELayer(nn.Module):
self.norm = build_norm_layer(norm_cfg, out_channels)[1] self.norm = build_norm_layer(norm_cfg, out_channels)[1]
self.linear = nn.Linear(in_channels, out_channels, bias=False) self.linear = nn.Linear(in_channels, out_channels, bias=False)
@auto_fp16(apply_to=('inputs'), out_fp32=True)
def forward(self, inputs): def forward(self, inputs):
"""Forward function. """Forward function.
...@@ -78,6 +81,7 @@ class VFELayer(nn.Module): ...@@ -78,6 +81,7 @@ class VFELayer(nn.Module):
""" """
# [K, T, 7] tensordot [7, units] = [K, T, units] # [K, T, 7] tensordot [7, units] = [K, T, units]
voxel_count = inputs.shape[1] voxel_count = inputs.shape[1]
x = self.linear(inputs) x = self.linear(inputs)
x = self.norm(x.permute(0, 2, 1).contiguous()).permute(0, 2, x = self.norm(x.permute(0, 2, 1).contiguous()).permute(0, 2,
1).contiguous() 1).contiguous()
...@@ -123,6 +127,7 @@ class PFNLayer(nn.Module): ...@@ -123,6 +127,7 @@ class PFNLayer(nn.Module):
mode='max'): mode='max'):
super().__init__() super().__init__()
self.fp16_enabled = False
self.name = 'PFNLayer' self.name = 'PFNLayer'
self.last_vfe = last_layer self.last_vfe = last_layer
if not self.last_vfe: if not self.last_vfe:
...@@ -135,6 +140,7 @@ class PFNLayer(nn.Module): ...@@ -135,6 +140,7 @@ class PFNLayer(nn.Module):
assert mode in ['max', 'avg'] assert mode in ['max', 'avg']
self.mode = mode self.mode = mode
@auto_fp16(apply_to=('inputs'), out_fp32=True)
def forward(self, inputs, num_voxels=None, aligned_distance=None): def forward(self, inputs, num_voxels=None, aligned_distance=None):
"""Forward function. """Forward function.
......
import torch import torch
from mmcv.cnn import build_norm_layer from mmcv.cnn import build_norm_layer
from mmcv.runner import force_fp32
from torch import nn from torch import nn
from mmdet3d.ops import DynamicScatter from mmdet3d.ops import DynamicScatter
...@@ -21,7 +22,9 @@ class HardSimpleVFE(nn.Module): ...@@ -21,7 +22,9 @@ class HardSimpleVFE(nn.Module):
def __init__(self, num_features=4): def __init__(self, num_features=4):
super(HardSimpleVFE, self).__init__() super(HardSimpleVFE, self).__init__()
self.num_features = num_features self.num_features = num_features
self.fp16_enabled = False
@force_fp32(out_fp16=True)
def forward(self, features, num_points, coors): def forward(self, features, num_points, coors):
"""Forward function. """Forward function.
...@@ -58,8 +61,10 @@ class DynamicSimpleVFE(nn.Module): ...@@ -58,8 +61,10 @@ class DynamicSimpleVFE(nn.Module):
point_cloud_range=(0, -40, -3, 70.4, 40, 1)): point_cloud_range=(0, -40, -3, 70.4, 40, 1)):
super(DynamicSimpleVFE, self).__init__() super(DynamicSimpleVFE, self).__init__()
self.scatter = DynamicScatter(voxel_size, point_cloud_range, True) self.scatter = DynamicScatter(voxel_size, point_cloud_range, True)
self.fp16_enabled = False
@torch.no_grad() @torch.no_grad()
@force_fp32(out_fp16=True)
def forward(self, features, coors): def forward(self, features, coors):
"""Forward function. """Forward function.
...@@ -134,6 +139,7 @@ class DynamicVFE(nn.Module): ...@@ -134,6 +139,7 @@ class DynamicVFE(nn.Module):
self._with_cluster_center = with_cluster_center self._with_cluster_center = with_cluster_center
self._with_voxel_center = with_voxel_center self._with_voxel_center = with_voxel_center
self.return_point_feats = return_point_feats self.return_point_feats = return_point_feats
self.fp16_enabled = False
# Need pillar (voxel) size and x/y offset in order to calculate offset # Need pillar (voxel) size and x/y offset in order to calculate offset
self.vx = voxel_size[0] self.vx = voxel_size[0]
...@@ -209,6 +215,7 @@ class DynamicVFE(nn.Module): ...@@ -209,6 +215,7 @@ class DynamicVFE(nn.Module):
center_per_point = voxel_mean[voxel_inds, ...] center_per_point = voxel_mean[voxel_inds, ...]
return center_per_point return center_per_point
@force_fp32(out_fp16=True)
def forward(self, def forward(self,
features, features,
coors, coors,
...@@ -330,6 +337,7 @@ class HardVFE(nn.Module): ...@@ -330,6 +337,7 @@ class HardVFE(nn.Module):
self._with_cluster_center = with_cluster_center self._with_cluster_center = with_cluster_center
self._with_voxel_center = with_voxel_center self._with_voxel_center = with_voxel_center
self.return_point_feats = return_point_feats self.return_point_feats = return_point_feats
self.fp16_enabled = False
# Need pillar (voxel) size and x/y offset to calculate pillar offset # Need pillar (voxel) size and x/y offset to calculate pillar offset
self.vx = voxel_size[0] self.vx = voxel_size[0]
...@@ -372,6 +380,7 @@ class HardVFE(nn.Module): ...@@ -372,6 +380,7 @@ class HardVFE(nn.Module):
if fusion_layer is not None: if fusion_layer is not None:
self.fusion_layer = builder.build_fusion_layer(fusion_layer) self.fusion_layer = builder.build_fusion_layer(fusion_layer)
@force_fp32(out_fp16=True)
def forward(self, def forward(self,
features, features,
num_points, num_points,
......
import torch import torch
from mmcv.runner import force_fp32
from torch import nn as nn from torch import nn as nn
from typing import List from typing import List
...@@ -59,7 +60,9 @@ class Points_Sampler(nn.Module): ...@@ -59,7 +60,9 @@ class Points_Sampler(nn.Module):
self.samplers = nn.ModuleList() self.samplers = nn.ModuleList()
for fps_mod in fps_mod_list: for fps_mod in fps_mod_list:
self.samplers.append(get_sampler_type(fps_mod)()) self.samplers.append(get_sampler_type(fps_mod)())
self.fp16_enabled = False
@force_fp32()
def forward(self, points_xyz, features): def forward(self, points_xyz, features):
"""forward. """forward.
......
import torch import torch
from mmcv.cnn import NORM_LAYERS from mmcv.cnn import NORM_LAYERS
from mmcv.runner import force_fp32
from torch import distributed as dist from torch import distributed as dist
from torch import nn as nn from torch import nn as nn
from torch.autograd.function import Function from torch.autograd.function import Function
...@@ -42,10 +43,19 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d): ...@@ -42,10 +43,19 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
It is slower than `nn.SyncBatchNorm`. It is slower than `nn.SyncBatchNorm`.
""" """
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fp16_enabled = False
# customized normalization layer still needs this decorator
# to force the input to be fp32 and the output to be fp16
# TODO: make mmcv fp16 utils handle customized norm layers
@force_fp32(out_fp16=True)
def forward(self, input): def forward(self, input):
assert input.dtype == torch.float32, \
f'input should be in float32 type, got {input.dtype}'
if dist.get_world_size() == 1 or not self.training: if dist.get_world_size() == 1 or not self.training:
return super().forward(input) return super().forward(input)
assert input.shape[0] > 0, 'SyncBN does not support empty inputs' assert input.shape[0] > 0, 'SyncBN does not support empty inputs'
C = input.shape[1] C = input.shape[1]
mean = torch.mean(input, dim=[0, 2]) mean = torch.mean(input, dim=[0, 2])
...@@ -87,7 +97,17 @@ class NaiveSyncBatchNorm2d(nn.BatchNorm2d): ...@@ -87,7 +97,17 @@ class NaiveSyncBatchNorm2d(nn.BatchNorm2d):
It is slower than `nn.SyncBatchNorm`. It is slower than `nn.SyncBatchNorm`.
""" """
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fp16_enabled = False
# customized normalization layer still needs this decorator
# to force the input to be fp32 and the output to be fp16
# TODO: make mmcv fp16 utils handle customized norm layers
@force_fp32(out_fp16=True)
def forward(self, input): def forward(self, input):
assert input.dtype == torch.float32, \
f'input should be in float32 type, got {input.dtype}'
if dist.get_world_size() == 1 or not self.training: if dist.get_world_size() == 1 or not self.training:
return super().forward(input) return super().forward(input)
......
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import force_fp32
from torch import nn as nn from torch import nn as nn
from typing import List from typing import List
...@@ -21,7 +22,7 @@ class PointFPModule(nn.Module): ...@@ -21,7 +22,7 @@ class PointFPModule(nn.Module):
mlp_channels: List[int], mlp_channels: List[int],
norm_cfg: dict = dict(type='BN2d')): norm_cfg: dict = dict(type='BN2d')):
super().__init__() super().__init__()
self.fp16_enabled = False
self.mlps = nn.Sequential() self.mlps = nn.Sequential()
for i in range(len(mlp_channels) - 1): for i in range(len(mlp_channels) - 1):
self.mlps.add_module( self.mlps.add_module(
...@@ -34,6 +35,7 @@ class PointFPModule(nn.Module): ...@@ -34,6 +35,7 @@ class PointFPModule(nn.Module):
conv_cfg=dict(type='Conv2d'), conv_cfg=dict(type='Conv2d'),
norm_cfg=norm_cfg)) norm_cfg=norm_cfg))
@force_fp32()
def forward(self, target: torch.Tensor, source: torch.Tensor, def forward(self, target: torch.Tensor, source: torch.Tensor,
target_feats: torch.Tensor, target_feats: torch.Tensor,
source_feats: torch.Tensor) -> torch.Tensor: source_feats: torch.Tensor) -> torch.Tensor:
......
...@@ -145,7 +145,6 @@ class PointSAModuleMSG(nn.Module): ...@@ -145,7 +145,6 @@ class PointSAModuleMSG(nn.Module):
""" """
new_features_list = [] new_features_list = []
xyz_flipped = points_xyz.transpose(1, 2).contiguous() xyz_flipped = points_xyz.transpose(1, 2).contiguous()
if indices is not None: if indices is not None:
assert (indices.shape[1] == self.num_point[0]) assert (indices.shape[1] == self.num_point[0])
new_xyz = gather_points(xyz_flipped, indices).transpose( new_xyz = gather_points(xyz_flipped, indices).transpose(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment