Unverified Commit e67b3f81 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Support to train using FP16 (#132)

* Support to train using FP16

* fix type inconsistency error on naive syncBN

* resolve comments

* clean nan check
parent e4320fb4
import torch
from mmcv.runner import force_fp32
from torch.nn import functional as F
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
......@@ -46,6 +47,7 @@ class VoxelNet(SingleStage3DDetector):
return x
@torch.no_grad()
@force_fp32()
def voxelize(self, points):
"""Apply hard voxelization to points."""
voxels, coors, num_points = [], [], []
......
import torch
from mmcv.runner import auto_fp16
from torch import nn
from ..registry import MIDDLE_ENCODERS
......@@ -21,7 +22,9 @@ class PointPillarsScatter(nn.Module):
self.ny = output_shape[0]
self.nx = output_shape[1]
self.in_channels = in_channels
self.fp16_enabled = False
@auto_fp16(apply_to=('voxel_features', ))
def forward(self, voxel_features, coors, batch_size=None):
"""Foraward function to scatter features."""
# TODO: rewrite the function in a batch manner
......
from mmcv.runner import auto_fp16
from torch import nn as nn
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
......@@ -49,6 +50,7 @@ class SparseEncoder(nn.Module):
self.encoder_channels = encoder_channels
self.encoder_paddings = encoder_paddings
self.stage_num = len(self.encoder_channels)
self.fp16_enabled = False
# Spconv init all weight on its own
assert isinstance(order, tuple) and len(order) == 3
......@@ -90,6 +92,7 @@ class SparseEncoder(nn.Module):
indice_key='spconv_down2',
conv_type='SparseConv3d')
@auto_fp16(apply_to=('voxel_features', ))
def forward(self, voxel_features, coors, batch_size):
"""Forward of SparseEncoder.
......
import torch
from mmcv.runner import auto_fp16
from torch import nn as nn
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
......@@ -51,6 +52,7 @@ class SparseUNet(nn.Module):
self.decoder_channels = decoder_channels
self.decoder_paddings = decoder_paddings
self.stage_num = len(self.encoder_channels)
self.fp16_enabled = False
# Spconv init all weight on its own
assert isinstance(order, tuple) and len(order) == 3
......@@ -91,6 +93,7 @@ class SparseUNet(nn.Module):
indice_key='spconv_down2',
conv_type='SparseConv3d')
@auto_fp16(apply_to=('voxel_features', ))
def forward(self, voxel_features, coors, batch_size):
"""Forward of SparseUNet.
......
......@@ -2,6 +2,7 @@ import numpy as np
import torch
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer,
constant_init, is_norm, kaiming_init)
from mmcv.runner import auto_fp16
from torch import nn as nn
from mmdet.models import NECKS
......@@ -36,6 +37,7 @@ class SECONDFPN(nn.Module):
assert len(out_channels) == len(upsample_strides) == len(in_channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.fp16_enabled = False
deblocks = []
for i, out_channel in enumerate(out_channels):
......@@ -70,6 +72,7 @@ class SECONDFPN(nn.Module):
elif is_norm(m):
constant_init(m, 1)
@auto_fp16()
def forward(self, x):
"""Forward function.
......
import torch
from mmcv.cnn import build_norm_layer
from mmcv.runner import force_fp32
from torch import nn
from mmdet3d.ops import DynamicScatter
......@@ -58,7 +59,7 @@ class PillarFeatureNet(nn.Module):
self._with_distance = with_distance
self._with_cluster_center = with_cluster_center
self._with_voxel_center = with_voxel_center
self.fp16_enabled = False
# Create PillarFeatureNet layers
self.in_channels = in_channels
feat_channels = [in_channels] + list(feat_channels)
......@@ -86,6 +87,7 @@ class PillarFeatureNet(nn.Module):
self.y_offset = self.vy / 2 + point_cloud_range[1]
self.point_cloud_range = point_cloud_range
@force_fp32(out_fp16=True)
def forward(self, features, num_points, coors):
"""Forward function.
......@@ -196,7 +198,7 @@ class DynamicPillarFeatureNet(PillarFeatureNet):
point_cloud_range=point_cloud_range,
norm_cfg=norm_cfg,
mode=mode)
self.fp16_enabled = False
feat_channels = [self.in_channels] + list(feat_channels)
pfn_layers = []
# TODO: currently only support one PFNLayer
......@@ -257,6 +259,7 @@ class DynamicPillarFeatureNet(PillarFeatureNet):
center_per_point = canvas[:, voxel_index.long()].t()
return center_per_point
@force_fp32(out_fp16=True)
def forward(self, features, coors):
"""Forward function.
......
import torch
from mmcv.cnn import build_norm_layer
from mmcv.runner import auto_fp16
from torch import nn
from torch.nn import functional as F
......@@ -51,6 +52,7 @@ class VFELayer(nn.Module):
max_out=True,
cat_max=True):
super(VFELayer, self).__init__()
self.fp16_enabled = False
self.cat_max = cat_max
self.max_out = max_out
# self.units = int(out_channels / 2)
......@@ -58,6 +60,7 @@ class VFELayer(nn.Module):
self.norm = build_norm_layer(norm_cfg, out_channels)[1]
self.linear = nn.Linear(in_channels, out_channels, bias=False)
@auto_fp16(apply_to=('inputs'), out_fp32=True)
def forward(self, inputs):
"""Forward function.
......@@ -78,6 +81,7 @@ class VFELayer(nn.Module):
"""
# [K, T, 7] tensordot [7, units] = [K, T, units]
voxel_count = inputs.shape[1]
x = self.linear(inputs)
x = self.norm(x.permute(0, 2, 1).contiguous()).permute(0, 2,
1).contiguous()
......@@ -123,6 +127,7 @@ class PFNLayer(nn.Module):
mode='max'):
super().__init__()
self.fp16_enabled = False
self.name = 'PFNLayer'
self.last_vfe = last_layer
if not self.last_vfe:
......@@ -135,6 +140,7 @@ class PFNLayer(nn.Module):
assert mode in ['max', 'avg']
self.mode = mode
@auto_fp16(apply_to=('inputs'), out_fp32=True)
def forward(self, inputs, num_voxels=None, aligned_distance=None):
"""Forward function.
......
import torch
from mmcv.cnn import build_norm_layer
from mmcv.runner import force_fp32
from torch import nn
from mmdet3d.ops import DynamicScatter
......@@ -21,7 +22,9 @@ class HardSimpleVFE(nn.Module):
def __init__(self, num_features=4):
super(HardSimpleVFE, self).__init__()
self.num_features = num_features
self.fp16_enabled = False
@force_fp32(out_fp16=True)
def forward(self, features, num_points, coors):
"""Forward function.
......@@ -58,8 +61,10 @@ class DynamicSimpleVFE(nn.Module):
point_cloud_range=(0, -40, -3, 70.4, 40, 1)):
super(DynamicSimpleVFE, self).__init__()
self.scatter = DynamicScatter(voxel_size, point_cloud_range, True)
self.fp16_enabled = False
@torch.no_grad()
@force_fp32(out_fp16=True)
def forward(self, features, coors):
"""Forward function.
......@@ -134,6 +139,7 @@ class DynamicVFE(nn.Module):
self._with_cluster_center = with_cluster_center
self._with_voxel_center = with_voxel_center
self.return_point_feats = return_point_feats
self.fp16_enabled = False
# Need pillar (voxel) size and x/y offset in order to calculate offset
self.vx = voxel_size[0]
......@@ -209,6 +215,7 @@ class DynamicVFE(nn.Module):
center_per_point = voxel_mean[voxel_inds, ...]
return center_per_point
@force_fp32(out_fp16=True)
def forward(self,
features,
coors,
......@@ -330,6 +337,7 @@ class HardVFE(nn.Module):
self._with_cluster_center = with_cluster_center
self._with_voxel_center = with_voxel_center
self.return_point_feats = return_point_feats
self.fp16_enabled = False
# Need pillar (voxel) size and x/y offset to calculate pillar offset
self.vx = voxel_size[0]
......@@ -372,6 +380,7 @@ class HardVFE(nn.Module):
if fusion_layer is not None:
self.fusion_layer = builder.build_fusion_layer(fusion_layer)
@force_fp32(out_fp16=True)
def forward(self,
features,
num_points,
......
import torch
from mmcv.runner import force_fp32
from torch import nn as nn
from typing import List
......@@ -59,7 +60,9 @@ class Points_Sampler(nn.Module):
self.samplers = nn.ModuleList()
for fps_mod in fps_mod_list:
self.samplers.append(get_sampler_type(fps_mod)())
self.fp16_enabled = False
@force_fp32()
def forward(self, points_xyz, features):
"""forward.
......
import torch
from mmcv.cnn import NORM_LAYERS
from mmcv.runner import force_fp32
from torch import distributed as dist
from torch import nn as nn
from torch.autograd.function import Function
......@@ -42,10 +43,19 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
It is slower than `nn.SyncBatchNorm`.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fp16_enabled = False
# customized normalization layer still needs this decorator
# to force the input to be fp32 and the output to be fp16
# TODO: make mmcv fp16 utils handle customized norm layers
@force_fp32(out_fp16=True)
def forward(self, input):
assert input.dtype == torch.float32, \
f'input should be in float32 type, got {input.dtype}'
if dist.get_world_size() == 1 or not self.training:
return super().forward(input)
assert input.shape[0] > 0, 'SyncBN does not support empty inputs'
C = input.shape[1]
mean = torch.mean(input, dim=[0, 2])
......@@ -87,7 +97,17 @@ class NaiveSyncBatchNorm2d(nn.BatchNorm2d):
It is slower than `nn.SyncBatchNorm`.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fp16_enabled = False
# customized normalization layer still needs this decorator
# to force the input to be fp32 and the output to be fp16
# TODO: make mmcv fp16 utils handle customized norm layers
@force_fp32(out_fp16=True)
def forward(self, input):
assert input.dtype == torch.float32, \
f'input should be in float32 type, got {input.dtype}'
if dist.get_world_size() == 1 or not self.training:
return super().forward(input)
......
import torch
from mmcv.cnn import ConvModule
from mmcv.runner import force_fp32
from torch import nn as nn
from typing import List
......@@ -21,7 +22,7 @@ class PointFPModule(nn.Module):
mlp_channels: List[int],
norm_cfg: dict = dict(type='BN2d')):
super().__init__()
self.fp16_enabled = False
self.mlps = nn.Sequential()
for i in range(len(mlp_channels) - 1):
self.mlps.add_module(
......@@ -34,6 +35,7 @@ class PointFPModule(nn.Module):
conv_cfg=dict(type='Conv2d'),
norm_cfg=norm_cfg))
@force_fp32()
def forward(self, target: torch.Tensor, source: torch.Tensor,
target_feats: torch.Tensor,
source_feats: torch.Tensor) -> torch.Tensor:
......
......@@ -145,7 +145,6 @@ class PointSAModuleMSG(nn.Module):
"""
new_features_list = []
xyz_flipped = points_xyz.transpose(1, 2).contiguous()
if indices is not None:
assert (indices.shape[1] == self.num_point[0])
new_xyz = gather_points(xyz_flipped, indices).transpose(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment