"vscode:/vscode.git/clone" did not exist on "cce49ba92a37c417425a81016dc607e4ea2036cf"
Commit 4dc18496 authored by chenshi3's avatar chenshi3
Browse files

Add support for TransFusion-Lidar Head

parent ad9c25c0
...@@ -46,7 +46,7 @@ class BaseBEVBackbone(nn.Module): ...@@ -46,7 +46,7 @@ class BaseBEVBackbone(nn.Module):
self.blocks.append(nn.Sequential(*cur_layers)) self.blocks.append(nn.Sequential(*cur_layers))
if len(upsample_strides) > 0: if len(upsample_strides) > 0:
stride = upsample_strides[idx] stride = upsample_strides[idx]
if stride >= 1: if stride > 1 or (stride == 1 and not self.model_cfg.get('USE_CONV_FOR_NO_STRIDE', False)):
self.deblocks.append(nn.Sequential( self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d( nn.ConvTranspose2d(
num_filters[idx], num_upsample_filters[idx], num_filters[idx], num_upsample_filters[idx],
......
...@@ -30,10 +30,11 @@ def post_act_block(in_channels, out_channels, kernel_size, indice_key=None, stri ...@@ -30,10 +30,11 @@ def post_act_block(in_channels, out_channels, kernel_size, indice_key=None, stri
class SparseBasicBlock(spconv.SparseModule): class SparseBasicBlock(spconv.SparseModule):
expansion = 1 expansion = 1
def __init__(self, inplanes, planes, stride=1, norm_fn=None, downsample=None, indice_key=None): def __init__(self, inplanes, planes, stride=1, bias=None, norm_fn=None, downsample=None, indice_key=None):
super(SparseBasicBlock, self).__init__() super(SparseBasicBlock, self).__init__()
assert norm_fn is not None assert norm_fn is not None
if bias is None:
bias = norm_fn is not None bias = norm_fn is not None
self.conv1 = spconv.SubMConv3d( self.conv1 = spconv.SubMConv3d(
inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=bias, indice_key=indice_key inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=bias, indice_key=indice_key
...@@ -184,6 +185,7 @@ class VoxelResBackBone8x(nn.Module): ...@@ -184,6 +185,7 @@ class VoxelResBackBone8x(nn.Module):
def __init__(self, model_cfg, input_channels, grid_size, **kwargs): def __init__(self, model_cfg, input_channels, grid_size, **kwargs):
super().__init__() super().__init__()
self.model_cfg = model_cfg self.model_cfg = model_cfg
use_bias = self.model_cfg.get('USE_BIAS', None)
norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01) norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01)
self.sparse_shape = grid_size[::-1] + [1, 0, 0] self.sparse_shape = grid_size[::-1] + [1, 0, 0]
...@@ -196,29 +198,29 @@ class VoxelResBackBone8x(nn.Module): ...@@ -196,29 +198,29 @@ class VoxelResBackBone8x(nn.Module):
block = post_act_block block = post_act_block
self.conv1 = spconv.SparseSequential( self.conv1 = spconv.SparseSequential(
SparseBasicBlock(16, 16, norm_fn=norm_fn, indice_key='res1'), SparseBasicBlock(16, 16, bias=use_bias, norm_fn=norm_fn, indice_key='res1'),
SparseBasicBlock(16, 16, norm_fn=norm_fn, indice_key='res1'), SparseBasicBlock(16, 16, bias=use_bias, norm_fn=norm_fn, indice_key='res1'),
) )
self.conv2 = spconv.SparseSequential( self.conv2 = spconv.SparseSequential(
# [1600, 1408, 41] <- [800, 704, 21] # [1600, 1408, 41] <- [800, 704, 21]
block(16, 32, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv2', conv_type='spconv'), block(16, 32, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv2', conv_type='spconv'),
SparseBasicBlock(32, 32, norm_fn=norm_fn, indice_key='res2'), SparseBasicBlock(32, 32, bias=use_bias, norm_fn=norm_fn, indice_key='res2'),
SparseBasicBlock(32, 32, norm_fn=norm_fn, indice_key='res2'), SparseBasicBlock(32, 32, bias=use_bias, norm_fn=norm_fn, indice_key='res2'),
) )
self.conv3 = spconv.SparseSequential( self.conv3 = spconv.SparseSequential(
# [800, 704, 21] <- [400, 352, 11] # [800, 704, 21] <- [400, 352, 11]
block(32, 64, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv3', conv_type='spconv'), block(32, 64, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv3', conv_type='spconv'),
SparseBasicBlock(64, 64, norm_fn=norm_fn, indice_key='res3'), SparseBasicBlock(64, 64, bias=use_bias, norm_fn=norm_fn, indice_key='res3'),
SparseBasicBlock(64, 64, norm_fn=norm_fn, indice_key='res3'), SparseBasicBlock(64, 64, bias=use_bias, norm_fn=norm_fn, indice_key='res3'),
) )
self.conv4 = spconv.SparseSequential( self.conv4 = spconv.SparseSequential(
# [400, 352, 11] <- [200, 176, 5] # [400, 352, 11] <- [200, 176, 5]
block(64, 128, 3, norm_fn=norm_fn, stride=2, padding=(0, 1, 1), indice_key='spconv4', conv_type='spconv'), block(64, 128, 3, norm_fn=norm_fn, stride=2, padding=(0, 1, 1), indice_key='spconv4', conv_type='spconv'),
SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res4'), SparseBasicBlock(128, 128, bias=use_bias, norm_fn=norm_fn, indice_key='res4'),
SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res4'), SparseBasicBlock(128, 128, bias=use_bias, norm_fn=norm_fn, indice_key='res4'),
) )
last_pad = 0 last_pad = 0
......
import torch
from scipy.optimize import linear_sum_assignment
from pcdet.ops.iou3d_nms import iou3d_nms_cuda
def height_overlaps(boxes1, boxes2):
"""
Calculate height overlaps of two boxes.
"""
boxes1_top_height = (boxes1[:,2]+ boxes1[:,5]).view(-1, 1)
boxes1_bottom_height = boxes1[:,2].view(-1, 1)
boxes2_top_height = (boxes2[:,2]+boxes2[:,5]).view(1, -1)
boxes2_bottom_height = boxes2[:,2].view(1, -1)
heighest_of_bottom = torch.max(boxes1_bottom_height, boxes2_bottom_height)
lowest_of_top = torch.min(boxes1_top_height, boxes2_top_height)
overlaps_h = torch.clamp(lowest_of_top - heighest_of_bottom, min=0)
return overlaps_h
def overlaps(boxes1, boxes2):
"""
Calculate 3D overlaps of two boxes.
"""
rows = len(boxes1)
cols = len(boxes2)
if rows * cols == 0:
return boxes1.new(rows, cols)
# height overlap
overlaps_h = height_overlaps(boxes1, boxes2)
boxes1_bev = boxes1[:,:7]
boxes2_bev = boxes2[:,:7]
# bev overlap
overlaps_bev = boxes1_bev.new_zeros(
(boxes1_bev.shape[0], boxes2_bev.shape[0])
).cuda() # (N, M)
iou3d_nms_cuda.boxes_overlap_bev_gpu(
boxes1_bev.contiguous().cuda(), boxes2_bev.contiguous().cuda(), overlaps_bev
)
# 3d overlaps
overlaps_3d = overlaps_bev.to(boxes1.device) * overlaps_h
volume1 = (boxes1[:, 3] * boxes1[:, 4] * boxes1[:, 5]).view(-1, 1)
volume2 = (boxes2[:, 3] * boxes2[:, 4] * boxes2[:, 5]).view(1, -1)
iou3d = overlaps_3d / torch.clamp(volume1 + volume2 - overlaps_3d, min=1e-8)
return iou3d
class HungarianAssigner3D:
def __init__(self, cls_cost, reg_cost, iou_cost):
self.cls_cost = cls_cost
self.reg_cost = reg_cost
self.iou_cost = iou_cost
def focal_loss_cost(self, cls_pred, gt_labels):
weight = self.cls_cost.get('weight', 0.15)
alpha = self.cls_cost.get('alpha', 0.25)
gamma = self.cls_cost.get('gamma', 2.0)
eps = self.cls_cost.get('eps', 1e-12)
cls_pred = cls_pred.sigmoid()
neg_cost = -(1 - cls_pred + eps).log() * (
1 - alpha) * cls_pred.pow(gamma)
pos_cost = -(cls_pred + eps).log() * alpha * (
1 - cls_pred).pow(gamma)
cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels]
return cls_cost * weight
def bevbox_cost(self, bboxes, gt_bboxes, point_cloud_range):
weight = self.reg_cost.get('weight', 0.25)
pc_start = bboxes.new(point_cloud_range[0:2])
pc_range = bboxes.new(point_cloud_range[3:5]) - bboxes.new(point_cloud_range[0:2])
# normalize the box center to [0, 1]
normalized_bboxes_xy = (bboxes[:, :2] - pc_start) / pc_range
normalized_gt_bboxes_xy = (gt_bboxes[:, :2] - pc_start) / pc_range
reg_cost = torch.cdist(normalized_bboxes_xy, normalized_gt_bboxes_xy, p=1)
return reg_cost * weight
def iou3d_cost(self, bboxes, gt_bboxes):
iou = overlaps(bboxes, gt_bboxes)
weight = self.iou_cost.get('weight', 0.25)
iou_cost = - iou
return iou_cost * weight, iou
def assign(self, bboxes, gt_bboxes, gt_labels, cls_pred, point_cloud_range):
num_gts, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
# 1. assign -1 by default
assigned_gt_inds = bboxes.new_full((num_bboxes,), -1, dtype=torch.long)
assigned_labels = bboxes.new_full((num_bboxes,), -1, dtype=torch.long)
if num_gts == 0 or num_bboxes == 0:
# No ground truth or boxes, return empty assignment
if num_gts == 0:
# No ground truth, assign all to background
assigned_gt_inds[:] = 0
return num_gts, assigned_gt_inds, max_overlaps, assigned_labels
# 2. compute the weighted costs
cls_cost = self.focal_loss_cost(cls_pred[0].T, gt_labels)
reg_cost = self.bevbox_cost(bboxes, gt_bboxes, point_cloud_range)
iou_cost, iou = self.iou3d_cost(bboxes, gt_bboxes)
# weighted sum of above three costs
cost = cls_cost + reg_cost + iou_cost
# 3. do Hungarian matching on CPU using linear_sum_assignment
cost = cost.detach().cpu()
matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
matched_row_inds = torch.from_numpy(matched_row_inds).to(bboxes.device)
matched_col_inds = torch.from_numpy(matched_col_inds).to(bboxes.device)
# 4. assign backgrounds and foregrounds
# assign all indices to backgrounds first
assigned_gt_inds[:] = 0
# assign foregrounds based on matching results
assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
max_overlaps = torch.zeros_like(iou.max(1).values)
max_overlaps[matched_row_inds] = iou[matched_row_inds, matched_col_inds]
return assigned_gt_inds, max_overlaps
\ No newline at end of file
import copy
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.init import kaiming_normal_
from ..model_utils.transfusion_utils import clip_sigmoid
from ..model_utils.basic_block_2d import BasicBlock2D
from ..model_utils.transfusion_utils import PositionEmbeddingLearned, TransformerDecoderLayer
from .target_assigner.hungarian_assigner import HungarianAssigner3D
from ...utils import loss_utils
from ..model_utils import centernet_utils
class SeparateHead_Transfusion(nn.Module):
def __init__(self, input_channels, head_channels, kernel_size, sep_head_dict, init_bias=-2.19, use_bias=False):
super().__init__()
self.sep_head_dict = sep_head_dict
for cur_name in self.sep_head_dict:
output_channels = self.sep_head_dict[cur_name]['out_channels']
num_conv = self.sep_head_dict[cur_name]['num_conv']
fc_list = []
for k in range(num_conv - 1):
fc_list.append(nn.Sequential(
nn.Conv1d(input_channels, head_channels, kernel_size, stride=1, padding=kernel_size//2, bias=use_bias),
nn.BatchNorm1d(head_channels),
nn.ReLU()
))
fc_list.append(nn.Conv1d(head_channels, output_channels, kernel_size, stride=1, padding=kernel_size//2, bias=True))
fc = nn.Sequential(*fc_list)
if 'hm' in cur_name:
fc[-1].bias.data.fill_(init_bias)
else:
for m in fc.modules():
if isinstance(m, nn.Conv2d):
kaiming_normal_(m.weight.data)
if hasattr(m, "bias") and m.bias is not None:
nn.init.constant_(m.bias, 0)
self.__setattr__(cur_name, fc)
def forward(self, x):
ret_dict = {}
for cur_name in self.sep_head_dict:
ret_dict[cur_name] = self.__getattr__(cur_name)(x)
return ret_dict
class TransFusionHead(nn.Module):
"""
This module implements TransFusionHead.
The code is adapted from https://github.com/mit-han-lab/bevfusion/ with minimal modifications.
"""
def __init__(
self,
model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range, voxel_size, predict_boxes_when_training=True,
):
super(TransFusionHead, self).__init__()
self.grid_size = grid_size
self.point_cloud_range = point_cloud_range
self.voxel_size = voxel_size
self.num_classes = num_class
self.model_cfg = model_cfg
self.feature_map_stride = self.model_cfg.TARGET_ASSIGNER_CONFIG.get('FEATURE_MAP_STRIDE', None)
self.dataset_name = self.model_cfg.TARGET_ASSIGNER_CONFIG.get('DATASET', 'nuScenes')
hidden_channel=self.model_cfg.HIDDEN_CHANNEL
self.num_proposals = self.model_cfg.NUM_PROPOSALS
self.bn_momentum = self.model_cfg.BN_MOMENTUM
self.nms_kernel_size = self.model_cfg.NMS_KERNEL_SIZE
num_heads = self.model_cfg.NUM_HEADS
dropout = self.model_cfg.DROPOUT
activation = self.model_cfg.ACTIVATION
ffn_channel = self.model_cfg.FFN_CHANNEL
bias = self.model_cfg.get('USE_BIAS_BEFORE_NORM', False)
loss_cls = self.model_cfg.LOSS_CONFIG.LOSS_CLS
self.use_sigmoid_cls = loss_cls.get("use_sigmoid", False)
if not self.use_sigmoid_cls:
self.num_classes += 1
self.loss_cls = loss_utils.SigmoidFocalClassificationLoss(gamma=loss_cls.gamma,alpha=loss_cls.alpha)
self.loss_cls_weight = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['cls_weight']
self.loss_bbox = loss_utils.L1Loss()
self.loss_bbox_weight = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['bbox_weight']
self.loss_heatmap = loss_utils.GaussianFocalLoss()
self.loss_heatmap_weight = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['hm_weight']
self.code_size = 10
# a shared convolution
self.shared_conv = nn.Conv2d(in_channels=input_channels,out_channels=hidden_channel,kernel_size=3,padding=1)
layers = []
layers.append(BasicBlock2D(hidden_channel,hidden_channel, kernel_size=3,padding=1,bias=bias))
layers.append(nn.Conv2d(in_channels=hidden_channel,out_channels=num_class,kernel_size=3,padding=1))
self.heatmap_head = nn.Sequential(*layers)
self.class_encoding = nn.Conv1d(num_class, hidden_channel, 1)
# transformer decoder layers for object query with LiDAR feature
self.decoder = TransformerDecoderLayer(hidden_channel, num_heads, ffn_channel, dropout, activation,
self_posembed=PositionEmbeddingLearned(2, hidden_channel),
cross_posembed=PositionEmbeddingLearned(2, hidden_channel),
)
# Prediction Head
heads = copy.deepcopy(self.model_cfg.SEPARATE_HEAD_CFG.HEAD_DICT)
heads['heatmap'] = dict(out_channels=self.num_classes, num_conv=self.model_cfg.NUM_HM_CONV)
self.prediction_head = SeparateHead_Transfusion(hidden_channel, 64, 1, heads, use_bias=bias)
self.init_weights()
self.bbox_assigner = HungarianAssigner3D(**self.model_cfg.TARGET_ASSIGNER_CONFIG.HUNGARIAN_ASSIGNER)
# Position Embedding for Cross-Attention, which is re-used during training
x_size = self.grid_size[0] // self.feature_map_stride
y_size = self.grid_size[1] // self.feature_map_stride
self.bev_pos = self.create_2D_grid(x_size, y_size)
self.forward_ret_dict = {}
def create_2D_grid(self, x_size, y_size):
meshgrid = [[0, x_size - 1, x_size], [0, y_size - 1, y_size]]
# NOTE: modified
batch_x, batch_y = torch.meshgrid(
*[torch.linspace(it[0], it[1], it[2]) for it in meshgrid]
)
batch_x = batch_x + 0.5
batch_y = batch_y + 0.5
coord_base = torch.cat([batch_x[None], batch_y[None]], dim=0)[None]
coord_base = coord_base.view(1, 2, -1).permute(0, 2, 1)
return coord_base
def init_weights(self):
# initialize transformer
for m in self.decoder.parameters():
if m.dim() > 1:
nn.init.xavier_uniform_(m)
if hasattr(self, "query"):
nn.init.xavier_normal_(self.query)
self.init_bn_momentum()
def init_bn_momentum(self):
for m in self.modules():
if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
m.momentum = self.bn_momentum
def predict(self, inputs):
batch_size = inputs.shape[0]
lidar_feat = self.shared_conv(inputs)
lidar_feat_flatten = lidar_feat.view(
batch_size, lidar_feat.shape[1], -1
)
bev_pos = self.bev_pos.repeat(batch_size, 1, 1).to(lidar_feat.device)
# query initialization
dense_heatmap = self.heatmap_head(lidar_feat)
heatmap = dense_heatmap.detach().sigmoid()
padding = self.nms_kernel_size // 2
local_max = torch.zeros_like(heatmap)
local_max_inner = F.max_pool2d(
heatmap, kernel_size=self.nms_kernel_size, stride=1, padding=0
)
local_max[:, :, padding:(-padding), padding:(-padding)] = local_max_inner
# for Pedestrian & Traffic_cone in nuScenes
if self.dataset_name == "nuScenes":
local_max[ :, 8, ] = F.max_pool2d(heatmap[:, 8], kernel_size=1, stride=1, padding=0)
local_max[ :, 9, ] = F.max_pool2d(heatmap[:, 9], kernel_size=1, stride=1, padding=0)
# for Pedestrian & Cyclist in Waymo
elif self.dataset_name == "Waymo":
local_max[ :, 1, ] = F.max_pool2d(heatmap[:, 1], kernel_size=1, stride=1, padding=0)
local_max[ :, 2, ] = F.max_pool2d(heatmap[:, 2], kernel_size=1, stride=1, padding=0)
heatmap = heatmap * (heatmap == local_max)
heatmap = heatmap.view(batch_size, heatmap.shape[1], -1)
# top num_proposals among all classes
top_proposals = heatmap.view(batch_size, -1).argsort(dim=-1, descending=True)[
..., : self.num_proposals
]
top_proposals_class = top_proposals // heatmap.shape[-1]
top_proposals_index = top_proposals % heatmap.shape[-1]
query_feat = lidar_feat_flatten.gather(
index=top_proposals_index[:, None, :].expand(-1, lidar_feat_flatten.shape[1], -1),
dim=-1,
)
self.query_labels = top_proposals_class
# add category embedding
one_hot = F.one_hot(top_proposals_class, num_classes=self.num_classes).permute(0, 2, 1)
query_cat_encoding = self.class_encoding(one_hot.float())
query_feat += query_cat_encoding
query_pos = bev_pos.gather(
index=top_proposals_index[:, None, :].permute(0, 2, 1).expand(-1, -1, bev_pos.shape[-1]),
dim=1,
)
# convert to xy
query_pos = query_pos.flip(dims=[-1])
bev_pos = bev_pos.flip(dims=[-1])
query_feat = self.decoder(
query_feat, lidar_feat_flatten, query_pos, bev_pos
)
res_layer = self.prediction_head(query_feat)
res_layer["center"] = res_layer["center"] + query_pos.permute(0, 2, 1)
res_layer["query_heatmap_score"] = heatmap.gather(
index=top_proposals_index[:, None, :].expand(-1, self.num_classes, -1),
dim=-1,
)
res_layer["dense_heatmap"] = dense_heatmap
return res_layer
def forward(self, batch_dict):
feats = batch_dict['spatial_features_2d']
res = self.predict(feats)
if not self.training:
bboxes = self.get_bboxes(res)
batch_dict['final_box_dicts'] = bboxes
else:
gt_boxes = batch_dict['gt_boxes']
gt_bboxes_3d = gt_boxes[...,:-1]
gt_labels_3d = gt_boxes[...,-1].long() - 1
loss, tb_dict = self.loss(gt_bboxes_3d, gt_labels_3d, res)
batch_dict['loss'] = loss
batch_dict['tb_dict'] = tb_dict
return batch_dict
def get_targets(self, gt_bboxes_3d, gt_labels_3d, pred_dicts):
assign_results = []
for batch_idx in range(len(gt_bboxes_3d)):
pred_dict = {}
for key in pred_dicts.keys():
pred_dict[key] = pred_dicts[key][batch_idx : batch_idx + 1]
gt_bboxes = gt_bboxes_3d[batch_idx]
valid_idx = []
# filter empty boxes
for i in range(len(gt_bboxes)):
if gt_bboxes[i][3] > 0 and gt_bboxes[i][4] > 0:
valid_idx.append(i)
assign_result = self.get_targets_single(gt_bboxes[valid_idx], gt_labels_3d[batch_idx][valid_idx], pred_dict)
assign_results.append(assign_result)
res_tuple = tuple(map(list, zip(*assign_results)))
labels = torch.cat(res_tuple[0], dim=0)
label_weights = torch.cat(res_tuple[1], dim=0)
bbox_targets = torch.cat(res_tuple[2], dim=0)
bbox_weights = torch.cat(res_tuple[3], dim=0)
num_pos = np.sum(res_tuple[4])
matched_ious = np.mean(res_tuple[5])
heatmap = torch.cat(res_tuple[6], dim=0)
return labels, label_weights, bbox_targets, bbox_weights, num_pos, matched_ious, heatmap
def get_targets_single(self, gt_bboxes_3d, gt_labels_3d, preds_dict):
num_proposals = preds_dict["center"].shape[-1]
score = copy.deepcopy(preds_dict["heatmap"].detach())
center = copy.deepcopy(preds_dict["center"].detach())
height = copy.deepcopy(preds_dict["height"].detach())
dim = copy.deepcopy(preds_dict["dim"].detach())
rot = copy.deepcopy(preds_dict["rot"].detach())
if "vel" in preds_dict.keys():
vel = copy.deepcopy(preds_dict["vel"].detach())
else:
vel = None
boxes_dict = self.decode_bbox(score, rot, dim, center, height, vel)
bboxes_tensor = boxes_dict[0]["pred_boxes"]
gt_bboxes_tensor = gt_bboxes_3d.to(score.device)
assigned_gt_inds, ious = self.bbox_assigner.assign(
bboxes_tensor, gt_bboxes_tensor, gt_labels_3d,
score, self.point_cloud_range,
)
pos_inds = torch.nonzero(assigned_gt_inds > 0, as_tuple=False).squeeze(-1).unique()
neg_inds = torch.nonzero(assigned_gt_inds == 0, as_tuple=False).squeeze(-1).unique()
pos_assigned_gt_inds = assigned_gt_inds[pos_inds] - 1
if gt_bboxes_3d.numel() == 0:
assert pos_inds.numel() == 0
pos_gt_bboxes = torch.empty_like(gt_bboxes_3d).view(-1, 9)
else:
pos_gt_bboxes = gt_bboxes_3d[pos_assigned_gt_inds.long(), :]
# create target for loss computation
bbox_targets = torch.zeros([num_proposals, self.code_size]).to(center.device)
bbox_weights = torch.zeros([num_proposals, self.code_size]).to(center.device)
ious = torch.clamp(ious, min=0.0, max=1.0)
labels = bboxes_tensor.new_zeros(num_proposals, dtype=torch.long)
label_weights = bboxes_tensor.new_zeros(num_proposals, dtype=torch.long)
if gt_labels_3d is not None: # default label is -1
labels += self.num_classes
# both pos and neg have classification loss, only pos has regression and iou loss
if len(pos_inds) > 0:
pos_bbox_targets = self.encode_bbox(pos_gt_bboxes)
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0
if gt_labels_3d is None:
labels[pos_inds] = 1
else:
labels[pos_inds] = gt_labels_3d[pos_assigned_gt_inds]
label_weights[pos_inds] = 1.0
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
# compute dense heatmap targets
device = labels.device
target_assigner_cfg = self.model_cfg.TARGET_ASSIGNER_CONFIG
feature_map_size = (self.grid_size[:2] // self.feature_map_stride)
heatmap = gt_bboxes_3d.new_zeros(self.num_classes, feature_map_size[1], feature_map_size[0])
for idx in range(len(gt_bboxes_3d)):
width = gt_bboxes_3d[idx][3]
length = gt_bboxes_3d[idx][4]
width = width / self.voxel_size[0] / self.feature_map_stride
length = length / self.voxel_size[1] / self.feature_map_stride
if width > 0 and length > 0:
radius = centernet_utils.gaussian_radius(length.view(-1), width.view(-1), target_assigner_cfg.GAUSSIAN_OVERLAP)[0]
radius = max(target_assigner_cfg.MIN_RADIUS, int(radius))
x, y = gt_bboxes_3d[idx][0], gt_bboxes_3d[idx][1]
coor_x = (x - self.point_cloud_range[0]) / self.voxel_size[0] / self.feature_map_stride
coor_y = (y - self.point_cloud_range[1]) / self.voxel_size[1] / self.feature_map_stride
center = torch.tensor([coor_x, coor_y], dtype=torch.float32, device=device)
center_int = center.to(torch.int32)
centernet_utils.draw_gaussian_to_heatmap(heatmap[gt_labels_3d[idx]], center_int, radius)
mean_iou = ious[pos_inds].sum() / max(len(pos_inds), 1)
return (labels[None], label_weights[None], bbox_targets[None], bbox_weights[None], int(pos_inds.shape[0]), float(mean_iou), heatmap[None])
def loss(self, gt_bboxes_3d, gt_labels_3d, pred_dicts, **kwargs):
labels, label_weights, bbox_targets, bbox_weights, num_pos, matched_ious, heatmap = \
self.get_targets(gt_bboxes_3d, gt_labels_3d, pred_dicts)
loss_dict = dict()
loss_all = 0
# compute heatmap loss
loss_heatmap = self.loss_heatmap(
clip_sigmoid(pred_dicts["dense_heatmap"]),
heatmap,
).sum() / max(heatmap.eq(1).float().sum().item(), 1)
loss_dict["loss_heatmap"] = loss_heatmap.item() * self.loss_heatmap_weight
loss_all += loss_heatmap * self.loss_heatmap_weight
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = pred_dicts["heatmap"].permute(0, 2, 1).reshape(-1, self.num_classes)
one_hot_targets = torch.zeros(*list(labels.shape), self.num_classes+1, dtype=cls_score.dtype, device=labels.device)
one_hot_targets.scatter_(-1, labels.unsqueeze(dim=-1).long(), 1.0)
one_hot_targets = one_hot_targets[..., :-1]
loss_cls = self.loss_cls(
cls_score, one_hot_targets, label_weights
).sum() / max(num_pos, 1)
preds = torch.cat([pred_dicts[head_name] for head_name in self.model_cfg.SEPARATE_HEAD_CFG.HEAD_ORDER], dim=1).permute(0, 2, 1)
code_weights = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['code_weights']
reg_weights = bbox_weights * bbox_weights.new_tensor(code_weights)
loss_bbox = self.loss_bbox(preds, bbox_targets)
loss_bbox = (loss_bbox * reg_weights).sum() / max(num_pos, 1)
loss_dict["loss_cls"] = loss_cls.item() * self.loss_cls_weight
loss_dict["loss_bbox"] = loss_bbox.item() * self.loss_bbox_weight
loss_all = loss_all + loss_cls * self.loss_cls_weight + loss_bbox * self.loss_bbox_weight
loss_dict[f"matched_ious"] = loss_cls.new_tensor(matched_ious)
loss_dict['loss_trans'] = loss_all
return loss_all,loss_dict
def encode_bbox(self, bboxes):
code_size = 10
targets = torch.zeros([bboxes.shape[0], code_size]).to(bboxes.device)
targets[:, 0] = (bboxes[:, 0] - self.point_cloud_range[0]) / (self.feature_map_stride * self.voxel_size[0])
targets[:, 1] = (bboxes[:, 1] - self.point_cloud_range[1]) / (self.feature_map_stride * self.voxel_size[1])
targets[:, 3:6] = bboxes[:, 3:6].log()
targets[:, 2] = bboxes[:, 2]
targets[:, 6] = torch.sin(bboxes[:, 6])
targets[:, 7] = torch.cos(bboxes[:, 6])
if code_size == 10:
targets[:, 8:10] = bboxes[:, 7:]
return targets
def decode_bbox(self, heatmap, rot, dim, center, height, vel, filter=False):
post_process_cfg = self.model_cfg.POST_PROCESSING
score_thresh = post_process_cfg.SCORE_THRESH
post_center_range = post_process_cfg.POST_CENTER_RANGE
post_center_range = torch.tensor(post_center_range).cuda().float()
# class label
final_preds = heatmap.max(1, keepdims=False).indices
final_scores = heatmap.max(1, keepdims=False).values
center[:, 0, :] = center[:, 0, :] * self.feature_map_stride * self.voxel_size[0] + self.point_cloud_range[0]
center[:, 1, :] = center[:, 1, :] * self.feature_map_stride * self.voxel_size[1] + self.point_cloud_range[1]
dim = dim.exp()
rots, rotc = rot[:, 0:1, :], rot[:, 1:2, :]
rot = torch.atan2(rots, rotc)
if vel is None:
final_box_preds = torch.cat([center, height, dim, rot], dim=1).permute(0, 2, 1)
else:
final_box_preds = torch.cat([center, height, dim, rot, vel], dim=1).permute(0, 2, 1)
predictions_dicts = []
for i in range(heatmap.shape[0]):
boxes3d = final_box_preds[i]
scores = final_scores[i]
labels = final_preds[i]
predictions_dict = {
'pred_boxes': boxes3d,
'pred_scores': scores,
'pred_labels': labels
}
predictions_dicts.append(predictions_dict)
if filter is False:
return predictions_dicts
thresh_mask = final_scores > score_thresh
mask = (final_box_preds[..., :3] >= post_center_range[:3]).all(2)
mask &= (final_box_preds[..., :3] <= post_center_range[3:]).all(2)
predictions_dicts = []
for i in range(heatmap.shape[0]):
cmask = mask[i, :]
cmask &= thresh_mask[i]
boxes3d = final_box_preds[i, cmask]
scores = final_scores[i, cmask]
labels = final_preds[i, cmask]
predictions_dict = {
'pred_boxes': boxes3d,
'pred_scores': scores,
'pred_labels': labels,
}
predictions_dicts.append(predictions_dict)
return predictions_dicts
def get_bboxes(self, preds_dicts):
batch_size = preds_dicts["heatmap"].shape[0]
batch_score = preds_dicts["heatmap"].sigmoid()
one_hot = F.one_hot(
self.query_labels, num_classes=self.num_classes
).permute(0, 2, 1)
batch_score = batch_score * preds_dicts["query_heatmap_score"] * one_hot
batch_center = preds_dicts["center"]
batch_height = preds_dicts["height"]
batch_dim = preds_dicts["dim"]
batch_rot = preds_dicts["rot"]
batch_vel = None
if "vel" in preds_dicts:
batch_vel = preds_dicts["vel"]
ret_dict = self.decode_bbox(
batch_score, batch_rot, batch_dim,
batch_center, batch_height, batch_vel,
filter=True,
)
for k in range(batch_size):
ret_dict[k]['pred_labels'] = ret_dict[k]['pred_labels'].int() + 1
return ret_dict
...@@ -13,6 +13,7 @@ from .mppnet import MPPNet ...@@ -13,6 +13,7 @@ from .mppnet import MPPNet
from .mppnet_e2e import MPPNetE2E from .mppnet_e2e import MPPNetE2E
from .pillarnet import PillarNet from .pillarnet import PillarNet
from .voxelnext import VoxelNeXt from .voxelnext import VoxelNeXt
from .transfusion import TransFusion
__all__ = { __all__ = {
'Detector3DTemplate': Detector3DTemplate, 'Detector3DTemplate': Detector3DTemplate,
...@@ -30,7 +31,8 @@ __all__ = { ...@@ -30,7 +31,8 @@ __all__ = {
'MPPNet': MPPNet, 'MPPNet': MPPNet,
'MPPNetE2E': MPPNetE2E, 'MPPNetE2E': MPPNetE2E,
'PillarNet': PillarNet, 'PillarNet': PillarNet,
'VoxelNeXt': VoxelNeXt 'VoxelNeXt': VoxelNeXt,
'TransFusion': TransFusion,
} }
......
from .detector3d_template import Detector3DTemplate
class TransFusion(Detector3DTemplate):
def __init__(self, model_cfg, num_class, dataset):
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
self.module_list = self.build_networks()
def forward(self, batch_dict):
for cur_module in self.module_list:
batch_dict = cur_module(batch_dict)
if self.training:
loss, tb_dict, disp_dict = self.get_training_loss(batch_dict)
ret_dict = {
'loss': loss
}
return ret_dict, tb_dict, disp_dict
else:
pred_dicts, recall_dicts = self.post_processing(batch_dict)
return pred_dicts, recall_dicts
def get_training_loss(self,batch_dict):
disp_dict = {}
loss_trans, tb_dict = batch_dict['loss'],batch_dict['tb_dict']
tb_dict = {
'loss_trans': loss_trans.item(),
**tb_dict
}
loss = loss_trans
return loss, tb_dict, disp_dict
def post_processing(self, batch_dict):
post_process_cfg = self.model_cfg.POST_PROCESSING
batch_size = batch_dict['batch_size']
final_pred_dict = batch_dict['final_box_dicts']
recall_dict = {}
for index in range(batch_size):
pred_boxes = final_pred_dict[index]['pred_boxes']
recall_dict = self.generate_recall_record(
box_preds=pred_boxes,
recall_dict=recall_dict, batch_index=index, data_dict=batch_dict,
thresh_list=post_process_cfg.RECALL_THRESH_LIST
)
return final_pred_dict, recall_dict
import torch
from torch import nn
import torch.nn.functional as F
def clip_sigmoid(x, eps=1e-4):
y = torch.clamp(x.sigmoid_(), min=eps, max=1 - eps)
return y
class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, input_channel, num_pos_feats=288):
super().__init__()
self.position_embedding_head = nn.Sequential(
nn.Conv1d(input_channel, num_pos_feats, kernel_size=1),
nn.BatchNorm1d(num_pos_feats),
nn.ReLU(inplace=True),
nn.Conv1d(num_pos_feats, num_pos_feats, kernel_size=1))
def forward(self, xyz):
xyz = xyz.transpose(1, 2).contiguous()
position_embedding = self.position_embedding_head(xyz)
return position_embedding
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
self_posembed=None, cross_posembed=None, cross_only=False):
super().__init__()
self.cross_only = cross_only
if not self.cross_only:
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
self.activation = _get_activation_fn(activation)
self.self_posembed = self_posembed
self.cross_posembed = cross_posembed
def with_pos_embed(self, tensor, pos_embed):
return tensor if pos_embed is None else tensor + pos_embed
def forward(self, query, key, query_pos, key_pos, key_padding_mask=None, attn_mask=None):
# NxCxP to PxNxC
if self.self_posembed is not None:
query_pos_embed = self.self_posembed(query_pos).permute(2, 0, 1)
else:
query_pos_embed = None
if self.cross_posembed is not None:
key_pos_embed = self.cross_posembed(key_pos).permute(2, 0, 1)
else:
key_pos_embed = None
query = query.permute(2, 0, 1)
key = key.permute(2, 0, 1)
if not self.cross_only:
q = k = v = self.with_pos_embed(query, query_pos_embed)
query2 = self.self_attn(q, k, value=v)[0]
query = query + self.dropout1(query2)
query = self.norm1(query)
query2 = self.multihead_attn(query=self.with_pos_embed(query, query_pos_embed),
key=self.with_pos_embed(key, key_pos_embed),
value=self.with_pos_embed(key, key_pos_embed),
key_padding_mask=key_padding_mask, attn_mask=attn_mask)[0]
query = query + self.dropout2(query2)
query = self.norm2(query)
query2 = self.linear2(self.dropout(self.activation(self.linear1(query))))
query = query + self.dropout3(query2)
query = self.norm3(query)
# NxCxP to PxNxC
query = query.permute(1, 2, 0)
return query
...@@ -561,3 +561,48 @@ class IouRegLossSparse(nn.Module): ...@@ -561,3 +561,48 @@ class IouRegLossSparse(nn.Module):
loss = loss / (mask.sum() + 1e-4) loss = loss / (mask.sum() + 1e-4)
return loss return loss
class L1Loss(nn.Module):
def __init__(self):
super(L1Loss, self).__init__()
def forward(self, pred, target):
if target.numel() == 0:
return pred.sum() * 0
assert pred.size() == target.size()
loss = torch.abs(pred - target)
return loss
class GaussianFocalLoss(nn.Module):
"""GaussianFocalLoss is a variant of focal loss.
More details can be found in the `paper
<https://arxiv.org/abs/1808.01244>`_
Code is modified from `kp_utils.py
<https://github.com/princeton-vl/CornerNet/blob/master/models/py_utils/kp_utils.py#L152>`_ # noqa: E501
Please notice that the target in GaussianFocalLoss is a gaussian heatmap,
not 0/1 binary target.
Args:
alpha (float): Power of prediction.
gamma (float): Power of target for negative samples.
reduction (str): Options are "none", "mean" and "sum".
loss_weight (float): Loss weight of current loss.
"""
def __init__(self,
alpha=2.0,
gamma=4.0):
super(GaussianFocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, pred, target):
eps = 1e-12
pos_weights = target.eq(1)
neg_weights = (1 - target).pow(self.gamma)
pos_loss = -(pred + eps).log() * (1 - pred).pow(self.alpha) * pos_weights
neg_loss = -(1 - pred + eps).log() * pred.pow(self.alpha) * neg_weights
return pos_loss + neg_loss
\ No newline at end of file
CLASS_NAMES: ['car','truck', 'construction_vehicle', 'bus', 'trailer',
'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone']
DATA_CONFIG:
_BASE_CONFIG_: cfgs/dataset_configs/nuscenes_dataset.yaml
POINT_CLOUD_RANGE: [-54.0, -54.0, -5.0, 54.0, 54.0, 3.0]
# sc TODO: just for debug
INFO_PATH: {
'train': [nuscenes_infos_10sweeps_train_with_cam_2d.pkl],
'test': [nuscenes_infos_10sweeps_val_with_cam_2d.pkl],
}
DATA_AUGMENTOR:
DISABLE_AUG_LIST: ['placeholder']
AUG_CONFIG_LIST:
- NAME: gt_sampling
DB_INFO_PATH:
- nuscenes_dbinfos_10sweeps_withvelo.pkl
PREPARE: {
filter_by_min_points: [
'car:5','truck:5', 'construction_vehicle:5', 'bus:5', 'trailer:5',
'barrier:5', 'motorcycle:5', 'bicycle:5', 'pedestrian:5', 'traffic_cone:5'
],
}
SAMPLE_GROUPS: [
'car:2','truck:3', 'construction_vehicle:7', 'bus:4', 'trailer:6',
'barrier:2', 'motorcycle:6', 'bicycle:6', 'pedestrian:2', 'traffic_cone:2'
]
NUM_POINT_FEATURES: 5
DATABASE_WITH_FAKELIDAR: False
REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0]
LIMIT_WHOLE_SCENE: True
- NAME: random_world_flip
ALONG_AXIS_LIST: ['x', 'y']
- NAME: random_world_rotation
WORLD_ROT_ANGLE: [-0.78539816, 0.78539816]
- NAME: random_world_scaling
WORLD_SCALE_RANGE: [0.9, 1.1]
- NAME: random_world_translation
NOISE_TRANSLATE_STD: [0.5, 0.5, 0.5]
DATA_PROCESSOR:
- NAME: mask_points_and_boxes_outside_range
REMOVE_OUTSIDE_BOXES: True
- NAME: shuffle_points
SHUFFLE_ENABLED: {
'train': True,
'test': True
}
- NAME: transform_points_to_voxels
VOXEL_SIZE: [0.075, 0.075, 0.2]
MAX_POINTS_PER_VOXEL: 10
MAX_NUMBER_OF_VOXELS: {
'train': 120000,
'test': 160000
}
MODEL:
NAME: TransFusion
VFE:
NAME: MeanVFE
BACKBONE_3D:
NAME: VoxelResBackBone8x
USE_BIAS: False
MAP_TO_BEV:
NAME: HeightCompression
NUM_BEV_FEATURES: 256
BACKBONE_2D:
NAME: BaseBEVBackbone
LAYER_NUMS: [5, 5]
LAYER_STRIDES: [1, 2]
NUM_FILTERS: [128, 256]
UPSAMPLE_STRIDES: [1, 2]
NUM_UPSAMPLE_FILTERS: [256, 256]
USE_CONV_FOR_NO_STRIDE: True
DENSE_HEAD:
CLASS_AGNOSTIC: False
NAME: TransFusionHead
USE_BIAS_BEFORE_NORM: False
NUM_PROPOSALS: 200
HIDDEN_CHANNEL: 128
NUM_CLASSES: 10
NUM_HEADS: 8
NMS_KERNEL_SIZE: 3
FFN_CHANNEL: 256
DROPOUT: 0.1
BN_MOMENTUM: 0.1
ACTIVATION: relu
NUM_HM_CONV: 2
SEPARATE_HEAD_CFG:
HEAD_ORDER: ['center', 'height', 'dim', 'rot', 'vel']
HEAD_DICT: {
'center': {'out_channels': 2, 'num_conv': 2},
'height': {'out_channels': 1, 'num_conv': 2},
'dim': {'out_channels': 3, 'num_conv': 2},
'rot': {'out_channels': 2, 'num_conv': 2},
'vel': {'out_channels': 2, 'num_conv': 2},
}
TARGET_ASSIGNER_CONFIG:
FEATURE_MAP_STRIDE: 8
DATASET: nuScenes
GAUSSIAN_OVERLAP: 0.1
MIN_RADIUS: 2
HUNGARIAN_ASSIGNER:
cls_cost: {'gamma': 2.0, 'alpha': 0.25, 'weight': 0.15}
reg_cost: {'weight': 0.25}
iou_cost: {'weight': 0.25}
LOSS_CONFIG:
LOSS_WEIGHTS: {
'cls_weight': 1.0,
'bbox_weight': 0.25,
'hm_weight': 1.0,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
}
LOSS_CLS:
use_sigmoid: true
gamma: 2.0
alpha: 0.25
POST_PROCESSING:
SCORE_THRESH: 0.0
POST_CENTER_RANGE: [-61.2, -61.2, -10.0, 61.2, 61.2, 10.0]
POST_PROCESSING:
RECALL_THRESH_LIST: [0.3, 0.5, 0.7]
SCORE_THRESH: 0.1
OUTPUT_RAW_SCORE: False
EVAL_METRIC: kitti
OPTIMIZATION:
BATCH_SIZE_PER_GPU: 4
NUM_EPOCHS: 20
OPTIMIZER: adam_onecycle
LR: 0.001
WEIGHT_DECAY: 0.01
MOMENTUM: 0.9
BETAS: [0.9, 0.999]
MOMS: [0.9, 0.8052631]
PCT_START: 0.4
DIV_FACTOR: 10
DECAY_STEP_LIST: [35, 45]
LR_DECAY: 0.1
LR_CLIP: 0.0000001
LR_WARMUP: False
WARMUP_EPOCH: 1
GRAD_NORM_CLIP: 35
HOOK:
DisableAugmentationHook:
DISABLE_AUG_LIST: ['gt_sampling']
NUM_LAST_EPOCHS: 5
\ No newline at end of file
...@@ -195,7 +195,8 @@ def main(): ...@@ -195,7 +195,8 @@ def main():
ckpt_save_time_interval=args.ckpt_save_time_interval, ckpt_save_time_interval=args.ckpt_save_time_interval,
use_logger_to_record=not args.use_tqdm_to_record, use_logger_to_record=not args.use_tqdm_to_record,
show_gpu_stat=not args.wo_gpu_stat, show_gpu_stat=not args.wo_gpu_stat,
use_amp=args.use_amp use_amp=args.use_amp,
cfg=cfg
) )
if hasattr(train_set, 'use_shared_memory') and train_set.use_shared_memory: if hasattr(train_set, 'use_shared_memory') and train_set.use_shared_memory:
......
...@@ -25,8 +25,9 @@ def build_optimizer(model, optim_cfg): ...@@ -25,8 +25,9 @@ def build_optimizer(model, optim_cfg):
flatten_model = lambda m: sum(map(flatten_model, m.children()), []) if num_children(m) else [m] flatten_model = lambda m: sum(map(flatten_model, m.children()), []) if num_children(m) else [m]
get_layer_groups = lambda m: [nn.Sequential(*flatten_model(m))] get_layer_groups = lambda m: [nn.Sequential(*flatten_model(m))]
betas = optim_cfg.get('BETAS', (0.9, 0.99))
optimizer_func = partial(optim.Adam, betas=(0.9, 0.99)) betas = tuple(betas)
optimizer_func = partial(optim.Adam, betas=betas)
optimizer = OptimWrapper.create( optimizer = OptimWrapper.create(
optimizer_func, 3e-3, get_layer_groups(model), wd=optim_cfg.WEIGHT_DECAY, true_wd=True, bn_wd=True optimizer_func, 3e-3, get_layer_groups(model), wd=optim_cfg.WEIGHT_DECAY, true_wd=True, bn_wd=True
) )
......
...@@ -151,8 +151,13 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_ ...@@ -151,8 +151,13 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_
start_epoch, total_epochs, start_iter, rank, tb_log, ckpt_save_dir, train_sampler=None, start_epoch, total_epochs, start_iter, rank, tb_log, ckpt_save_dir, train_sampler=None,
lr_warmup_scheduler=None, ckpt_save_interval=1, max_ckpt_save_num=50, lr_warmup_scheduler=None, ckpt_save_interval=1, max_ckpt_save_num=50,
merge_all_iters_to_one_epoch=False, use_amp=False, merge_all_iters_to_one_epoch=False, use_amp=False,
use_logger_to_record=False, logger=None, logger_iter_interval=None, ckpt_save_time_interval=None, show_gpu_stat=False): use_logger_to_record=False, logger=None, logger_iter_interval=None, ckpt_save_time_interval=None, show_gpu_stat=False, cfg=None):
accumulated_iter = start_iter accumulated_iter = start_iter
# use for disable data augmentation hook
hook_config = cfg.get('HOOK', None)
augment_disable_flag = False
with tqdm.trange(start_epoch, total_epochs, desc='epochs', dynamic_ncols=True, leave=(rank == 0)) as tbar: with tqdm.trange(start_epoch, total_epochs, desc='epochs', dynamic_ncols=True, leave=(rank == 0)) as tbar:
total_it_each_epoch = len(train_loader) total_it_each_epoch = len(train_loader)
if merge_all_iters_to_one_epoch: if merge_all_iters_to_one_epoch:
...@@ -170,6 +175,8 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_ ...@@ -170,6 +175,8 @@ def train_model(model, optimizer, train_loader, model_func, lr_scheduler, optim_
cur_scheduler = lr_warmup_scheduler cur_scheduler = lr_warmup_scheduler
else: else:
cur_scheduler = lr_scheduler cur_scheduler = lr_scheduler
augment_disable_flag = disable_augmentation_hook(hook_config, dataloader_iter, total_epochs, cur_epoch, cfg, augment_disable_flag, logger)
accumulated_iter = train_one_epoch( accumulated_iter = train_one_epoch(
model, optimizer, train_loader, model_func, model, optimizer, train_loader, model_func,
lr_scheduler=cur_scheduler, lr_scheduler=cur_scheduler,
...@@ -245,3 +252,21 @@ def save_checkpoint(state, filename='checkpoint'): ...@@ -245,3 +252,21 @@ def save_checkpoint(state, filename='checkpoint'):
torch.save(state, filename, _use_new_zipfile_serialization=False) torch.save(state, filename, _use_new_zipfile_serialization=False)
else: else:
torch.save(state, filename) torch.save(state, filename)
def disable_augmentation_hook(hook_config, dataloader, total_epochs, cur_epoch, cfg, flag, logger):
"""
This hook turns off the data augmentation during training.
"""
if hook_config is not None:
DisableAugmentationHook = hook_config.get('DisableAugmentationHook', None)
if DisableAugmentationHook is not None:
num_last_epochs = DisableAugmentationHook.NUM_LAST_EPOCHS
if (total_epochs - num_last_epochs) <= cur_epoch and not flag:
DISABLE_AUG_LIST = DisableAugmentationHook.DISABLE_AUG_LIST
dataset_cfg=cfg.DATA_CONFIG
logger.info(f'Disable augmentations: {DISABLE_AUG_LIST}')
dataset_cfg.DATA_AUGMENTOR.DISABLE_AUG_LIST = DISABLE_AUG_LIST
dataloader._dataset.data_augmentor.disableAugmentation(dataset_cfg.DATA_AUGMENTOR)
flag = True
return flag
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment