Unverified Commit 106b17e7 authored by Sun Jiahao's avatar Sun Jiahao Committed by GitHub
Browse files

[Fix] Fix amp bug (#2452)

* fix amp

* init scale 4096 & fix link

* fix pre-commit

* fix interval

* fix pp & remove fp16
parent 19abb93a
_base_ = './pointpillars_hv_fpn_sbn-all_8xb4-2x_nus-3d.py' _base_ = './pointpillars_hv_fpn_sbn-all_8xb4-2x_nus-3d.py'
train_dataloader = dict(batch_size=2, num_workers=2) train_dataloader = dict(batch_size=2, num_workers=2)
# schedule settings # schedule settings
optim_wrapper = dict(type='AmpOptimWrapper', loss_scale=512.) optim_wrapper = dict(type='AmpOptimWrapper', loss_scale=4096.)
_base_ = './pointpillars_hv_secfpn_sbn-all_8xb4-2x_nus-3d.py' _base_ = './pointpillars_hv_secfpn_sbn-all_8xb4-2x_nus-3d.py'
train_dataloader = dict(batch_size=2, num_workers=2) train_dataloader = dict(batch_size=2, num_workers=2)
# schedule settings # schedule settings
optim_wrapper = dict(type='AmpOptimWrapper', loss_scale=512.) optim_wrapper = dict(type='AmpOptimWrapper', loss_scale=4096.)
_base_ = 'second_hv_secfpn_8xb6-80e_kitti-3d-3class.py' _base_ = 'second_hv_secfpn_8xb6-80e_kitti-3d-3class.py'
# schedule settings # schedule settings
optim_wrapper = dict(type='AmpOptimWrapper', loss_scale=512.) optim_wrapper = dict(type='AmpOptimWrapper', loss_scale=4096.)
_base_ = 'second_hv_secfpn_8xb6-80e_kitti-3d-car.py' _base_ = 'second_hv_secfpn_8xb6-80e_kitti-3d-car.py'
# schedule settings # schedule settings
optim_wrapper = dict(type='AmpOptimWrapper', loss_scale=512.) optim_wrapper = dict(type='AmpOptimWrapper', loss_scale=4096.)
...@@ -10,7 +10,6 @@ class BasePointNet(BaseModule, metaclass=ABCMeta): ...@@ -10,7 +10,6 @@ class BasePointNet(BaseModule, metaclass=ABCMeta):
def __init__(self, init_cfg=None, pretrained=None): def __init__(self, init_cfg=None, pretrained=None):
super(BasePointNet, self).__init__(init_cfg) super(BasePointNet, self).__init__(init_cfg)
self.fp16_enabled = False
assert not (init_cfg and pretrained), \ assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time' 'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str): if isinstance(pretrained, str):
......
...@@ -5,6 +5,8 @@ from typing import List, Tuple ...@@ -5,6 +5,8 @@ from typing import List, Tuple
import numpy as np import numpy as np
import torch import torch
from mmdet.models.utils import multi_apply from mmdet.models.utils import multi_apply
from mmdet.utils.memory import cast_tensor_type
from mmengine.runner import amp
from torch import Tensor from torch import Tensor
from torch import nn as nn from torch import nn as nn
...@@ -92,7 +94,6 @@ class Anchor3DHead(Base3DDenseHead, AnchorTrainMixin): ...@@ -92,7 +94,6 @@ class Anchor3DHead(Base3DDenseHead, AnchorTrainMixin):
warnings.warn( warnings.warn(
'dir_offset and dir_limit_offset will be depressed and be ' 'dir_offset and dir_limit_offset will be depressed and be '
'incorporated into box coder in the future') 'incorporated into box coder in the future')
self.fp16_enabled = False
# build anchor generator # build anchor generator
self.prior_generator = TASK_UTILS.build(anchor_generator) self.prior_generator = TASK_UTILS.build(anchor_generator)
...@@ -112,7 +113,6 @@ class Anchor3DHead(Base3DDenseHead, AnchorTrainMixin): ...@@ -112,7 +113,6 @@ class Anchor3DHead(Base3DDenseHead, AnchorTrainMixin):
self.loss_cls = MODELS.build(loss_cls) self.loss_cls = MODELS.build(loss_cls)
self.loss_bbox = MODELS.build(loss_bbox) self.loss_bbox = MODELS.build(loss_bbox)
self.loss_dir = MODELS.build(loss_dir) self.loss_dir = MODELS.build(loss_dir)
self.fp16_enabled = False
self._init_layers() self._init_layers()
self._init_assigner_sampler() self._init_assigner_sampler()
...@@ -411,11 +411,12 @@ class Anchor3DHead(Base3DDenseHead, AnchorTrainMixin): ...@@ -411,11 +411,12 @@ class Anchor3DHead(Base3DDenseHead, AnchorTrainMixin):
num_total_pos + num_total_neg if self.sampling else num_total_pos) num_total_pos + num_total_neg if self.sampling else num_total_pos)
# num_total_samples = None # num_total_samples = None
with amp.autocast(enabled=False):
losses_cls, losses_bbox, losses_dir = multi_apply( losses_cls, losses_bbox, losses_dir = multi_apply(
self._loss_by_feat_single, self._loss_by_feat_single,
cls_scores, cast_tensor_type(cls_scores, dst_type=torch.float32),
bbox_preds, cast_tensor_type(bbox_preds, dst_type=torch.float32),
dir_cls_preds, cast_tensor_type(dir_cls_preds, dst_type=torch.float32),
labels_list, labels_list,
label_weights_list, label_weights_list,
bbox_targets_list, bbox_targets_list,
......
...@@ -176,7 +176,6 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead): ...@@ -176,7 +176,6 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead):
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.conv_cfg = conv_cfg self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg self.norm_cfg = norm_cfg
self.fp16_enabled = False
self.background_label = ( self.background_label = (
num_classes if background_label is None else background_label) num_classes if background_label is None else background_label)
# background_label should be either 0 or num_classes # background_label should be either 0 or num_classes
......
...@@ -317,7 +317,6 @@ class CenterHead(BaseModule): ...@@ -317,7 +317,6 @@ class CenterHead(BaseModule):
self.loss_bbox = MODELS.build(loss_bbox) self.loss_bbox = MODELS.build(loss_bbox)
self.bbox_coder = TASK_UTILS.build(bbox_coder) self.bbox_coder = TASK_UTILS.build(bbox_coder)
self.num_anchor_per_locs = [n for n in num_classes] self.num_anchor_per_locs = [n for n in num_classes]
self.fp16_enabled = False
# a shared convolution # a shared convolution
self.shared_conv = ConvModule( self.shared_conv = ConvModule(
......
...@@ -214,8 +214,6 @@ class GroupFree3DHead(BaseModule): ...@@ -214,8 +214,6 @@ class GroupFree3DHead(BaseModule):
self.fps_module = Points_Sampler([self.num_proposal]) self.fps_module = Points_Sampler([self.num_proposal])
self.points_obj_cls = PointsObjClsModule(self.in_channels) self.points_obj_cls = PointsObjClsModule(self.in_channels)
self.fp16_enabled = False
# initial candidate prediction # initial candidate prediction
self.conv_pred = BaseConvBboxHead( self.conv_pred = BaseConvBboxHead(
**pred_layer_cfg, **pred_layer_cfg,
......
...@@ -99,7 +99,6 @@ class VoteHead(BaseModule): ...@@ -99,7 +99,6 @@ class VoteHead(BaseModule):
self.vote_module = VoteModule(**vote_module_cfg) self.vote_module = VoteModule(**vote_module_cfg)
self.vote_aggregation = build_sa_module(vote_aggregation_cfg) self.vote_aggregation = build_sa_module(vote_aggregation_cfg)
self.fp16_enabled = False
# Bbox classification and regression # Bbox classification and regression
self.conv_pred = BaseConvBboxHead( self.conv_pred = BaseConvBboxHead(
......
...@@ -57,8 +57,6 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d): ...@@ -57,8 +57,6 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
Returns: Returns:
Tensor: Has shape (N, C) or (N, C, L), same shape as input. Tensor: Has shape (N, C) or (N, C, L), same shape as input.
""" """
assert input.dtype == torch.float32, \
f'input should be in float32 type, got {input.dtype}'
using_dist = dist.is_available() and dist.is_initialized() using_dist = dist.is_available() and dist.is_initialized()
if (not using_dist) or dist.get_world_size() == 1 \ if (not using_dist) or dist.get_world_size() == 1 \
or not self.training: or not self.training:
......
...@@ -22,7 +22,6 @@ class PointPillarsScatter(nn.Module): ...@@ -22,7 +22,6 @@ class PointPillarsScatter(nn.Module):
self.ny = output_shape[0] self.ny = output_shape[0]
self.nx = output_shape[1] self.nx = output_shape[1]
self.in_channels = in_channels self.in_channels = in_channels
self.fp16_enabled = False
def forward(self, voxel_features, coors, batch_size=None): def forward(self, voxel_features, coors, batch_size=None):
"""Foraward function to scatter features.""" """Foraward function to scatter features."""
......
...@@ -4,6 +4,7 @@ from typing import List, Tuple ...@@ -4,6 +4,7 @@ from typing import List, Tuple
import torch import torch
from mmcv.ops import points_in_boxes_all, three_interpolate, three_nn from mmcv.ops import points_in_boxes_all, three_interpolate, three_nn
from mmdet.models.losses import sigmoid_focal_loss, smooth_l1_loss from mmdet.models.losses import sigmoid_focal_loss, smooth_l1_loss
from mmengine.runner import amp
from torch import Tensor from torch import Tensor
from torch import nn as nn from torch import nn as nn
...@@ -68,7 +69,6 @@ class SparseEncoder(nn.Module): ...@@ -68,7 +69,6 @@ class SparseEncoder(nn.Module):
self.encoder_channels = encoder_channels self.encoder_channels = encoder_channels
self.encoder_paddings = encoder_paddings self.encoder_paddings = encoder_paddings
self.stage_num = len(self.encoder_channels) self.stage_num = len(self.encoder_channels)
self.fp16_enabled = False
self.return_middle_feats = return_middle_feats self.return_middle_feats = return_middle_feats
# Spconv init all weight on its own # Spconv init all weight on its own
...@@ -111,6 +111,7 @@ class SparseEncoder(nn.Module): ...@@ -111,6 +111,7 @@ class SparseEncoder(nn.Module):
indice_key='spconv_down2', indice_key='spconv_down2',
conv_type='SparseConv3d') conv_type='SparseConv3d')
@amp.autocast(enabled=False)
def forward(self, voxel_features, coors, batch_size): def forward(self, voxel_features, coors, batch_size):
"""Forward of SparseEncoder. """Forward of SparseEncoder.
......
...@@ -61,7 +61,6 @@ class SparseUNet(BaseModule): ...@@ -61,7 +61,6 @@ class SparseUNet(BaseModule):
self.decoder_channels = decoder_channels self.decoder_channels = decoder_channels
self.decoder_paddings = decoder_paddings self.decoder_paddings = decoder_paddings
self.stage_num = len(self.encoder_channels) self.stage_num = len(self.encoder_channels)
self.fp16_enabled = False
# Spconv init all weight on its own # Spconv init all weight on its own
assert isinstance(order, tuple) and len(order) == 3 assert isinstance(order, tuple) and len(order) == 3
......
...@@ -38,7 +38,6 @@ class SECONDFPN(BaseModule): ...@@ -38,7 +38,6 @@ class SECONDFPN(BaseModule):
assert len(out_channels) == len(upsample_strides) == len(in_channels) assert len(out_channels) == len(upsample_strides) == len(in_channels)
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.fp16_enabled = False
deblocks = [] deblocks = []
for i, out_channel in enumerate(out_channels): for i, out_channel in enumerate(out_channels):
......
...@@ -59,7 +59,6 @@ class PillarFeatureNet(nn.Module): ...@@ -59,7 +59,6 @@ class PillarFeatureNet(nn.Module):
self._with_distance = with_distance self._with_distance = with_distance
self._with_cluster_center = with_cluster_center self._with_cluster_center = with_cluster_center
self._with_voxel_center = with_voxel_center self._with_voxel_center = with_voxel_center
self.fp16_enabled = False
# Create PillarFeatureNet layers # Create PillarFeatureNet layers
self.in_channels = in_channels self.in_channels = in_channels
feat_channels = [in_channels] + list(feat_channels) feat_channels = [in_channels] + list(feat_channels)
...@@ -209,7 +208,6 @@ class DynamicPillarFeatureNet(PillarFeatureNet): ...@@ -209,7 +208,6 @@ class DynamicPillarFeatureNet(PillarFeatureNet):
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
mode=mode, mode=mode,
legacy=legacy) legacy=legacy)
self.fp16_enabled = False
feat_channels = [self.in_channels] + list(feat_channels) feat_channels = [self.in_channels] + list(feat_channels)
pfn_layers = [] pfn_layers = []
# TODO: currently only support one PFNLayer # TODO: currently only support one PFNLayer
......
...@@ -52,7 +52,6 @@ class VFELayer(nn.Module): ...@@ -52,7 +52,6 @@ class VFELayer(nn.Module):
max_out=True, max_out=True,
cat_max=True): cat_max=True):
super(VFELayer, self).__init__() super(VFELayer, self).__init__()
self.fp16_enabled = False
self.cat_max = cat_max self.cat_max = cat_max
self.max_out = max_out self.max_out = max_out
# self.units = int(out_channels / 2) # self.units = int(out_channels / 2)
...@@ -127,7 +126,6 @@ class PFNLayer(nn.Module): ...@@ -127,7 +126,6 @@ class PFNLayer(nn.Module):
mode='max'): mode='max'):
super().__init__() super().__init__()
self.fp16_enabled = False
self.name = 'PFNLayer' self.name = 'PFNLayer'
self.last_vfe = last_layer self.last_vfe = last_layer
if not self.last_vfe: if not self.last_vfe:
......
...@@ -23,7 +23,6 @@ class HardSimpleVFE(nn.Module): ...@@ -23,7 +23,6 @@ class HardSimpleVFE(nn.Module):
def __init__(self, num_features: int = 4) -> None: def __init__(self, num_features: int = 4) -> None:
super(HardSimpleVFE, self).__init__() super(HardSimpleVFE, self).__init__()
self.num_features = num_features self.num_features = num_features
self.fp16_enabled = False
def forward(self, features: Tensor, num_points: Tensor, coors: Tensor, def forward(self, features: Tensor, num_points: Tensor, coors: Tensor,
*args, **kwargs) -> Tensor: *args, **kwargs) -> Tensor:
...@@ -62,7 +61,6 @@ class DynamicSimpleVFE(nn.Module): ...@@ -62,7 +61,6 @@ class DynamicSimpleVFE(nn.Module):
point_cloud_range=(0, -40, -3, 70.4, 40, 1)): point_cloud_range=(0, -40, -3, 70.4, 40, 1)):
super(DynamicSimpleVFE, self).__init__() super(DynamicSimpleVFE, self).__init__()
self.scatter = DynamicScatter(voxel_size, point_cloud_range, True) self.scatter = DynamicScatter(voxel_size, point_cloud_range, True)
self.fp16_enabled = False
@torch.no_grad() @torch.no_grad()
def forward(self, features, coors, *args, **kwargs): def forward(self, features, coors, *args, **kwargs):
...@@ -141,7 +139,6 @@ class DynamicVFE(nn.Module): ...@@ -141,7 +139,6 @@ class DynamicVFE(nn.Module):
self._with_cluster_center = with_cluster_center self._with_cluster_center = with_cluster_center
self._with_voxel_center = with_voxel_center self._with_voxel_center = with_voxel_center
self.return_point_feats = return_point_feats self.return_point_feats = return_point_feats
self.fp16_enabled = False
# Need pillar (voxel) size and x/y offset in order to calculate offset # Need pillar (voxel) size and x/y offset in order to calculate offset
self.vx = voxel_size[0] self.vx = voxel_size[0]
...@@ -340,7 +337,6 @@ class HardVFE(nn.Module): ...@@ -340,7 +337,6 @@ class HardVFE(nn.Module):
self._with_cluster_center = with_cluster_center self._with_cluster_center = with_cluster_center
self._with_voxel_center = with_voxel_center self._with_voxel_center = with_voxel_center
self.return_point_feats = return_point_feats self.return_point_feats = return_point_feats
self.fp16_enabled = False
# Need pillar (voxel) size and x/y offset to calculate pillar offset # Need pillar (voxel) size and x/y offset to calculate pillar offset
self.vx = voxel_size[0] self.vx = voxel_size[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment