Unverified Commit 02ac3e17 authored by Shaoshuai Shi's avatar Shaoshuai Shi Committed by GitHub
Browse files

Support multi-modal 3D detection on NuScenes #1339

Add support for multi-modal NuScenes Detection
parents ad9c25c0 fcfa0773
......@@ -13,6 +13,8 @@ from .mppnet import MPPNet
from .mppnet_e2e import MPPNetE2E
from .pillarnet import PillarNet
from .voxelnext import VoxelNeXt
from .transfusion import TransFusion
from .bevfusion import BevFusion
__all__ = {
'Detector3DTemplate': Detector3DTemplate,
......@@ -30,7 +32,9 @@ __all__ = {
'MPPNet': MPPNet,
'MPPNetE2E': MPPNetE2E,
'PillarNet': PillarNet,
'VoxelNeXt': VoxelNeXt
'VoxelNeXt': VoxelNeXt,
'TransFusion': TransFusion,
'BevFusion': BevFusion,
}
......
from .detector3d_template import Detector3DTemplate
from .. import backbones_image, view_transforms
from ..backbones_image import img_neck
from ..backbones_2d import fuser
class BevFusion(Detector3DTemplate):
def __init__(self, model_cfg, num_class, dataset):
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
self.module_topology = [
'vfe', 'backbone_3d', 'map_to_bev_module', 'pfe',
'image_backbone','neck','vtransform','fuser',
'backbone_2d', 'dense_head', 'point_head', 'roi_head'
]
self.module_list = self.build_networks()
def build_neck(self,model_info_dict):
if self.model_cfg.get('NECK', None) is None:
return None, model_info_dict
neck_module = img_neck.__all__[self.model_cfg.NECK.NAME](
model_cfg=self.model_cfg.NECK
)
model_info_dict['module_list'].append(neck_module)
return neck_module, model_info_dict
def build_vtransform(self,model_info_dict):
if self.model_cfg.get('VTRANSFORM', None) is None:
return None, model_info_dict
vtransform_module = view_transforms.__all__[self.model_cfg.VTRANSFORM.NAME](
model_cfg=self.model_cfg.VTRANSFORM
)
model_info_dict['module_list'].append(vtransform_module)
return vtransform_module, model_info_dict
def build_image_backbone(self, model_info_dict):
if self.model_cfg.get('IMAGE_BACKBONE', None) is None:
return None, model_info_dict
image_backbone_module = backbones_image.__all__[self.model_cfg.IMAGE_BACKBONE.NAME](
model_cfg=self.model_cfg.IMAGE_BACKBONE
)
image_backbone_module.init_weights()
model_info_dict['module_list'].append(image_backbone_module)
return image_backbone_module, model_info_dict
def build_fuser(self, model_info_dict):
if self.model_cfg.get('FUSER', None) is None:
return None, model_info_dict
fuser_module = fuser.__all__[self.model_cfg.FUSER.NAME](
model_cfg=self.model_cfg.FUSER
)
model_info_dict['module_list'].append(fuser_module)
model_info_dict['num_bev_features'] = self.model_cfg.FUSER.OUT_CHANNEL
return fuser_module, model_info_dict
def forward(self, batch_dict):
for i,cur_module in enumerate(self.module_list):
batch_dict = cur_module(batch_dict)
if self.training:
loss, tb_dict, disp_dict = self.get_training_loss(batch_dict)
ret_dict = {
'loss': loss
}
return ret_dict, tb_dict, disp_dict
else:
pred_dicts, recall_dicts = self.post_processing(batch_dict)
return pred_dicts, recall_dicts
def get_training_loss(self,batch_dict):
disp_dict = {}
loss_trans, tb_dict = batch_dict['loss'],batch_dict['tb_dict']
tb_dict = {
'loss_trans': loss_trans.item(),
**tb_dict
}
loss = loss_trans
return loss, tb_dict, disp_dict
def post_processing(self, batch_dict):
post_process_cfg = self.model_cfg.POST_PROCESSING
batch_size = batch_dict['batch_size']
final_pred_dict = batch_dict['final_box_dicts']
recall_dict = {}
for index in range(batch_size):
pred_boxes = final_pred_dict[index]['pred_boxes']
recall_dict = self.generate_recall_record(
box_preds=pred_boxes,
recall_dict=recall_dict, batch_index=index, data_dict=batch_dict,
thresh_list=post_process_cfg.RECALL_THRESH_LIST
)
return final_pred_dict, recall_dict
from .detector3d_template import Detector3DTemplate
class TransFusion(Detector3DTemplate):
def __init__(self, model_cfg, num_class, dataset):
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
self.module_list = self.build_networks()
def forward(self, batch_dict):
for cur_module in self.module_list:
batch_dict = cur_module(batch_dict)
if self.training:
loss, tb_dict, disp_dict = self.get_training_loss(batch_dict)
ret_dict = {
'loss': loss
}
return ret_dict, tb_dict, disp_dict
else:
pred_dicts, recall_dicts = self.post_processing(batch_dict)
return pred_dicts, recall_dicts
def get_training_loss(self,batch_dict):
disp_dict = {}
loss_trans, tb_dict = batch_dict['loss'],batch_dict['tb_dict']
tb_dict = {
'loss_trans': loss_trans.item(),
**tb_dict
}
loss = loss_trans
return loss, tb_dict, disp_dict
def post_processing(self, batch_dict):
post_process_cfg = self.model_cfg.POST_PROCESSING
batch_size = batch_dict['batch_size']
final_pred_dict = batch_dict['final_box_dicts']
recall_dict = {}
for index in range(batch_size):
pred_boxes = final_pred_dict[index]['pred_boxes']
recall_dict = self.generate_recall_record(
box_preds=pred_boxes,
recall_dict=recall_dict, batch_index=index, data_dict=batch_dict,
thresh_list=post_process_cfg.RECALL_THRESH_LIST
)
return final_pred_dict, recall_dict
This diff is collapsed.
import torch
from torch import nn
import torch.nn.functional as F
def clip_sigmoid(x, eps=1e-4):
y = torch.clamp(x.sigmoid_(), min=eps, max=1 - eps)
return y
class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, input_channel, num_pos_feats=288):
super().__init__()
self.position_embedding_head = nn.Sequential(
nn.Conv1d(input_channel, num_pos_feats, kernel_size=1),
nn.BatchNorm1d(num_pos_feats),
nn.ReLU(inplace=True),
nn.Conv1d(num_pos_feats, num_pos_feats, kernel_size=1))
def forward(self, xyz):
xyz = xyz.transpose(1, 2).contiguous()
position_embedding = self.position_embedding_head(xyz)
return position_embedding
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
self_posembed=None, cross_posembed=None, cross_only=False):
super().__init__()
self.cross_only = cross_only
if not self.cross_only:
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
self.activation = _get_activation_fn(activation)
self.self_posembed = self_posembed
self.cross_posembed = cross_posembed
def with_pos_embed(self, tensor, pos_embed):
return tensor if pos_embed is None else tensor + pos_embed
def forward(self, query, key, query_pos, key_pos, key_padding_mask=None, attn_mask=None):
# NxCxP to PxNxC
if self.self_posembed is not None:
query_pos_embed = self.self_posembed(query_pos).permute(2, 0, 1)
else:
query_pos_embed = None
if self.cross_posembed is not None:
key_pos_embed = self.cross_posembed(key_pos).permute(2, 0, 1)
else:
key_pos_embed = None
query = query.permute(2, 0, 1)
key = key.permute(2, 0, 1)
if not self.cross_only:
q = k = v = self.with_pos_embed(query, query_pos_embed)
query2 = self.self_attn(q, k, value=v)[0]
query = query + self.dropout1(query2)
query = self.norm1(query)
query2 = self.multihead_attn(query=self.with_pos_embed(query, query_pos_embed),
key=self.with_pos_embed(key, key_pos_embed),
value=self.with_pos_embed(key, key_pos_embed),
key_padding_mask=key_padding_mask, attn_mask=attn_mask)[0]
query = query + self.dropout2(query2)
query = self.norm2(query)
query2 = self.linear2(self.dropout(self.activation(self.linear1(query))))
query = query + self.dropout3(query2)
query = self.norm3(query)
# NxCxP to PxNxC
query = query.permute(1, 2, 0)
return query
from .depth_lss import DepthLSSTransform
__all__ = {
'DepthLSSTransform': DepthLSSTransform,
}
\ No newline at end of file
import torch
from torch import nn
from pcdet.ops.bev_pool import bev_pool
def gen_dx_bx(xbound, ybound, zbound):
dx = torch.Tensor([row[2] for row in [xbound, ybound, zbound]])
bx = torch.Tensor([row[0] + row[2] / 2.0 for row in [xbound, ybound, zbound]])
nx = torch.LongTensor(
[(row[1] - row[0]) / row[2] for row in [xbound, ybound, zbound]]
)
return dx, bx, nx
class DepthLSSTransform(nn.Module):
"""
This module implements LSS, which lists images into 3D and then splats onto bev features.
This code is adapted from https://github.com/mit-han-lab/bevfusion/ with minimal modifications.
"""
def __init__(self, model_cfg):
super().__init__()
self.model_cfg = model_cfg
in_channel = self.model_cfg.IN_CHANNEL
out_channel = self.model_cfg.OUT_CHANNEL
self.image_size = self.model_cfg.IMAGE_SIZE
self.feature_size = self.model_cfg.FEATURE_SIZE
xbound = self.model_cfg.XBOUND
ybound = self.model_cfg.YBOUND
zbound = self.model_cfg.ZBOUND
self.dbound = self.model_cfg.DBOUND
downsample = self.model_cfg.DOWNSAMPLE
dx, bx, nx = gen_dx_bx(xbound, ybound, zbound)
self.dx = nn.Parameter(dx, requires_grad=False)
self.bx = nn.Parameter(bx, requires_grad=False)
self.nx = nn.Parameter(nx, requires_grad=False)
self.C = out_channel
self.frustum = self.create_frustum()
self.D = self.frustum.shape[0]
self.dtransform = nn.Sequential(
nn.Conv2d(1, 8, 1),
nn.BatchNorm2d(8),
nn.ReLU(True),
nn.Conv2d(8, 32, 5, stride=4, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.Conv2d(32, 64, 5, stride=2, padding=2),
nn.BatchNorm2d(64),
nn.ReLU(True),
)
self.depthnet = nn.Sequential(
nn.Conv2d(in_channel + 64, in_channel, 3, padding=1),
nn.BatchNorm2d(in_channel),
nn.ReLU(True),
nn.Conv2d(in_channel, in_channel, 3, padding=1),
nn.BatchNorm2d(in_channel),
nn.ReLU(True),
nn.Conv2d(in_channel, self.D + self.C, 1),
)
if downsample > 1:
assert downsample == 2, downsample
self.downsample = nn.Sequential(
nn.Conv2d(out_channel, out_channel, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channel),
nn.ReLU(True),
nn.Conv2d(out_channel, out_channel, 3, stride=downsample, padding=1, bias=False),
nn.BatchNorm2d(out_channel),
nn.ReLU(True),
nn.Conv2d(out_channel, out_channel, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channel),
nn.ReLU(True),
)
else:
self.downsample = nn.Identity()
def create_frustum(self):
iH, iW = self.image_size
fH, fW = self.feature_size
ds = torch.arange(*self.dbound, dtype=torch.float).view(-1, 1, 1).expand(-1, fH, fW)
D, _, _ = ds.shape
xs = torch.linspace(0, iW - 1, fW, dtype=torch.float).view(1, 1, fW).expand(D, fH, fW)
ys = torch.linspace(0, iH - 1, fH, dtype=torch.float).view(1, fH, 1).expand(D, fH, fW)
frustum = torch.stack((xs, ys, ds), -1)
return nn.Parameter(frustum, requires_grad=False)
def get_geometry(self, camera2lidar_rots, camera2lidar_trans, intrins, post_rots, post_trans, **kwargs):
camera2lidar_rots = camera2lidar_rots.to(torch.float)
camera2lidar_trans = camera2lidar_trans.to(torch.float)
intrins = intrins.to(torch.float)
post_rots = post_rots.to(torch.float)
post_trans = post_trans.to(torch.float)
B, N, _ = camera2lidar_trans.shape
# undo post-transformation
# B x N x D x H x W x 3
points = self.frustum - post_trans.view(B, N, 1, 1, 1, 3)
points = torch.inverse(post_rots).view(B, N, 1, 1, 1, 3, 3).matmul(points.unsqueeze(-1))
# cam_to_lidar
points = torch.cat((points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3], points[:, :, :, :, :, 2:3]), 5)
combine = camera2lidar_rots.matmul(torch.inverse(intrins))
points = combine.view(B, N, 1, 1, 1, 3, 3).matmul(points).squeeze(-1)
points += camera2lidar_trans.view(B, N, 1, 1, 1, 3)
if "extra_rots" in kwargs:
extra_rots = kwargs["extra_rots"]
points = extra_rots.view(B, 1, 1, 1, 1, 3, 3).repeat(1, N, 1, 1, 1, 1, 1) \
.matmul(points.unsqueeze(-1)).squeeze(-1)
if "extra_trans" in kwargs:
extra_trans = kwargs["extra_trans"]
points += extra_trans.view(B, 1, 1, 1, 1, 3).repeat(1, N, 1, 1, 1, 1)
return points
def bev_pool(self, geom_feats, x):
geom_feats = geom_feats.to(torch.float)
x = x.to(torch.float)
B, N, D, H, W, C = x.shape
Nprime = B * N * D * H * W
# flatten x
x = x.reshape(Nprime, C)
# flatten indices
geom_feats = ((geom_feats - (self.bx - self.dx / 2.0)) / self.dx).long()
geom_feats = geom_feats.view(Nprime, 3)
batch_ix = torch.cat([torch.full([Nprime // B, 1], ix, device=x.device, dtype=torch.long) for ix in range(B)])
geom_feats = torch.cat((geom_feats, batch_ix), 1)
# filter out points that are outside box
kept = (
(geom_feats[:, 0] >= 0)
& (geom_feats[:, 0] < self.nx[0])
& (geom_feats[:, 1] >= 0)
& (geom_feats[:, 1] < self.nx[1])
& (geom_feats[:, 2] >= 0)
& (geom_feats[:, 2] < self.nx[2])
)
x = x[kept]
geom_feats = geom_feats[kept]
x = bev_pool(x, geom_feats, B, self.nx[2], self.nx[0], self.nx[1])
# collapse Z
final = torch.cat(x.unbind(dim=2), 1)
return final
def get_cam_feats(self, x, d):
B, N, C, fH, fW = x.shape
d = d.view(B * N, *d.shape[2:])
x = x.view(B * N, C, fH, fW)
d = self.dtransform(d)
x = torch.cat([d, x], dim=1)
x = self.depthnet(x)
depth = x[:, : self.D].softmax(dim=1)
x = depth.unsqueeze(1) * x[:, self.D : (self.D + self.C)].unsqueeze(2)
x = x.view(B, N, self.C, self.D, fH, fW)
x = x.permute(0, 1, 3, 4, 5, 2)
return x
def forward(self, batch_dict):
"""
Args:
batch_dict:
image_fpn (list[tensor]): image features after image neck
Returns:
batch_dict:
spatial_features_img (tensor): bev features from image modality
"""
x = batch_dict['image_fpn']
x = x[0]
BN, C, H, W = x.size()
img = x.view(int(BN/6), 6, C, H, W)
camera_intrinsics = batch_dict['camera_intrinsics']
camera2lidar = batch_dict['camera2lidar']
img_aug_matrix = batch_dict['img_aug_matrix']
lidar_aug_matrix = batch_dict['lidar_aug_matrix']
lidar2image = batch_dict['lidar2image']
intrins = camera_intrinsics[..., :3, :3]
post_rots = img_aug_matrix[..., :3, :3]
post_trans = img_aug_matrix[..., :3, 3]
camera2lidar_rots = camera2lidar[..., :3, :3]
camera2lidar_trans = camera2lidar[..., :3, 3]
points = batch_dict['points']
batch_size = BN // 6
depth = torch.zeros(batch_size, img.shape[1], 1, *self.image_size).to(points[0].device)
for b in range(batch_size):
batch_mask = points[:,0] == b
cur_coords = points[batch_mask][:, 1:4]
cur_img_aug_matrix = img_aug_matrix[b]
cur_lidar_aug_matrix = lidar_aug_matrix[b]
cur_lidar2image = lidar2image[b]
# inverse aug
cur_coords -= cur_lidar_aug_matrix[:3, 3]
cur_coords = torch.inverse(cur_lidar_aug_matrix[:3, :3]).matmul(
cur_coords.transpose(1, 0)
)
# lidar2image
cur_coords = cur_lidar2image[:, :3, :3].matmul(cur_coords)
cur_coords += cur_lidar2image[:, :3, 3].reshape(-1, 3, 1)
# get 2d coords
dist = cur_coords[:, 2, :]
cur_coords[:, 2, :] = torch.clamp(cur_coords[:, 2, :], 1e-5, 1e5)
cur_coords[:, :2, :] /= cur_coords[:, 2:3, :]
# do image aug
cur_coords = cur_img_aug_matrix[:, :3, :3].matmul(cur_coords)
cur_coords += cur_img_aug_matrix[:, :3, 3].reshape(-1, 3, 1)
cur_coords = cur_coords[:, :2, :].transpose(1, 2)
# normalize coords for grid sample
cur_coords = cur_coords[..., [1, 0]]
# filter points outside of images
on_img = (
(cur_coords[..., 0] < self.image_size[0])
& (cur_coords[..., 0] >= 0)
& (cur_coords[..., 1] < self.image_size[1])
& (cur_coords[..., 1] >= 0)
)
for c in range(on_img.shape[0]):
masked_coords = cur_coords[c, on_img[c]].long()
masked_dist = dist[c, on_img[c]]
depth[b, c, 0, masked_coords[:, 0], masked_coords[:, 1]] = masked_dist
extra_rots = lidar_aug_matrix[..., :3, :3]
extra_trans = lidar_aug_matrix[..., :3, 3]
geom = self.get_geometry(
camera2lidar_rots, camera2lidar_trans, intrins, post_rots,
post_trans, extra_rots=extra_rots, extra_trans=extra_trans,
)
# use points depth to assist the depth prediction in images
x = self.get_cam_feats(img, depth)
x = self.bev_pool(geom, x)
x = self.downsample(x)
# convert bev features from (b, c, x, y) to (b, c, y, x)
x = x.permute(0, 1, 3, 2)
batch_dict['spatial_features_img'] = x
return batch_dict
\ No newline at end of file
from .bev_pool import bev_pool
\ No newline at end of file
import torch
from . import bev_pool_ext
__all__ = ["bev_pool"]
class QuickCumsum(torch.autograd.Function):
@staticmethod
def forward(ctx, x, geom_feats, ranks):
x = x.cumsum(0)
kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool)
kept[:-1] = ranks[1:] != ranks[:-1]
x, geom_feats = x[kept], geom_feats[kept]
x = torch.cat((x[:1], x[1:] - x[:-1]))
# save kept for backward
ctx.save_for_backward(kept)
# no gradient for geom_feats
ctx.mark_non_differentiable(geom_feats)
return x, geom_feats
@staticmethod
def backward(ctx, gradx, gradgeom):
(kept,) = ctx.saved_tensors
back = torch.cumsum(kept, 0)
back[kept] -= 1
val = gradx[back]
return val, None, None
class QuickCumsumCuda(torch.autograd.Function):
@staticmethod
def forward(ctx, x, geom_feats, ranks, B, D, H, W):
kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool)
kept[1:] = ranks[1:] != ranks[:-1]
interval_starts = torch.where(kept)[0].int()
interval_lengths = torch.zeros_like(interval_starts)
interval_lengths[:-1] = interval_starts[1:] - interval_starts[:-1]
interval_lengths[-1] = x.shape[0] - interval_starts[-1]
geom_feats = geom_feats.int()
out = bev_pool_ext.bev_pool_forward(
x,
geom_feats,
interval_lengths,
interval_starts,
B,
D,
H,
W,
)
ctx.save_for_backward(interval_starts, interval_lengths, geom_feats)
ctx.saved_shapes = B, D, H, W
return out
@staticmethod
def backward(ctx, out_grad):
interval_starts, interval_lengths, geom_feats = ctx.saved_tensors
B, D, H, W = ctx.saved_shapes
out_grad = out_grad.contiguous()
x_grad = bev_pool_ext.bev_pool_backward(
out_grad,
geom_feats,
interval_lengths,
interval_starts,
B,
D,
H,
W,
)
return x_grad, None, None, None, None, None, None
def bev_pool(feats, coords, B, D, H, W):
assert feats.shape[0] == coords.shape[0]
ranks = (
coords[:, 0] * (W * D * B)
+ coords[:, 1] * (D * B)
+ coords[:, 2] * B
+ coords[:, 3]
)
indices = ranks.argsort()
feats, coords, ranks = feats[indices], coords[indices], ranks[indices]
x = QuickCumsumCuda.apply(feats, coords, ranks, B, D, H, W)
x = x.permute(0, 4, 1, 2, 3).contiguous()
return x
#include <torch/torch.h>
#include <c10/cuda/CUDAGuard.h>
// CUDA function declarations
void bev_pool(int b, int d, int h, int w, int n, int c, int n_intervals, const float* x,
const int* geom_feats, const int* interval_starts, const int* interval_lengths, float* out);
void bev_pool_grad(int b, int d, int h, int w, int n, int c, int n_intervals, const float* out_grad,
const int* geom_feats, const int* interval_starts, const int* interval_lengths, float* x_grad);
/*
Function: pillar pooling (forward, cuda)
Args:
x : input features, FloatTensor[n, c]
geom_feats : input coordinates, IntTensor[n, 4]
interval_lengths : starting position for pooled point, IntTensor[n_intervals]
interval_starts : how many points in each pooled point, IntTensor[n_intervals]
Return:
out : output features, FloatTensor[b, d, h, w, c]
*/
at::Tensor bev_pool_forward(
const at::Tensor _x,
const at::Tensor _geom_feats,
const at::Tensor _interval_lengths,
const at::Tensor _interval_starts,
int b, int d, int h, int w
) {
int n = _x.size(0);
int c = _x.size(1);
int n_intervals = _interval_lengths.size(0);
const at::cuda::OptionalCUDAGuard device_guard(device_of(_x));
const float* x = _x.data_ptr<float>();
const int* geom_feats = _geom_feats.data_ptr<int>();
const int* interval_lengths = _interval_lengths.data_ptr<int>();
const int* interval_starts = _interval_starts.data_ptr<int>();
auto options =
torch::TensorOptions().dtype(_x.dtype()).device(_x.device());
at::Tensor _out = torch::zeros({b, d, h, w, c}, options);
float* out = _out.data_ptr<float>();
bev_pool(
b, d, h, w, n, c, n_intervals, x,
geom_feats, interval_starts, interval_lengths, out
);
return _out;
}
/*
Function: pillar pooling (backward, cuda)
Args:
out_grad : input features, FloatTensor[b, d, h, w, c]
geom_feats : input coordinates, IntTensor[n, 4]
interval_lengths : starting position for pooled point, IntTensor[n_intervals]
interval_starts : how many points in each pooled point, IntTensor[n_intervals]
Return:
x_grad : output features, FloatTensor[n, 4]
*/
at::Tensor bev_pool_backward(
const at::Tensor _out_grad,
const at::Tensor _geom_feats,
const at::Tensor _interval_lengths,
const at::Tensor _interval_starts,
int b, int d, int h, int w
) {
int n = _geom_feats.size(0);
int c = _out_grad.size(4);
int n_intervals = _interval_lengths.size(0);
const at::cuda::OptionalCUDAGuard device_guard(device_of(_out_grad));
const float* out_grad = _out_grad.data_ptr<float>();
const int* geom_feats = _geom_feats.data_ptr<int>();
const int* interval_lengths = _interval_lengths.data_ptr<int>();
const int* interval_starts = _interval_starts.data_ptr<int>();
auto options =
torch::TensorOptions().dtype(_out_grad.dtype()).device(_out_grad.device());
at::Tensor _x_grad = torch::zeros({n, c}, options);
float* x_grad = _x_grad.data_ptr<float>();
bev_pool_grad(
b, d, h, w, n, c, n_intervals, out_grad,
geom_feats, interval_starts, interval_lengths, x_grad
);
return _x_grad;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("bev_pool_forward", &bev_pool_forward,
"bev_pool_forward");
m.def("bev_pool_backward", &bev_pool_backward,
"bev_pool_backward");
}
#include <stdio.h>
#include <stdlib.h>
/*
Function: pillar pooling
Args:
b : batch size
d : depth of the feature map
h : height of pooled feature map
w : width of pooled feature map
n : number of input points
c : number of channels
n_intervals : number of unique points
x : input features, FloatTensor[n, c]
geom_feats : input coordinates, IntTensor[n, 4]
interval_lengths : starting position for pooled point, IntTensor[n_intervals]
interval_starts : how many points in each pooled point, IntTensor[n_intervals]
out : output features, FloatTensor[b, d, h, w, c]
*/
__global__ void bev_pool_kernel(int b, int d, int h, int w, int n, int c, int n_intervals,
const float *__restrict__ x,
const int *__restrict__ geom_feats,
const int *__restrict__ interval_starts,
const int *__restrict__ interval_lengths,
float* __restrict__ out) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int index = idx / c;
int cur_c = idx % c;
if (index >= n_intervals) return;
int interval_start = interval_starts[index];
int interval_length = interval_lengths[index];
const int* cur_geom_feats = geom_feats + interval_start * 4;
const float* cur_x = x + interval_start * c + cur_c;
float* cur_out = out + cur_geom_feats[3] * d * h * w * c +
cur_geom_feats[2] * h * w * c + cur_geom_feats[0] * w * c +
cur_geom_feats[1] * c + cur_c;
float psum = 0;
for(int i = 0; i < interval_length; i++){
psum += cur_x[i * c];
}
*cur_out = psum;
}
/*
Function: pillar pooling backward
Args:
b : batch size
d : depth of the feature map
h : height of pooled feature map
w : width of pooled feature map
n : number of input points
c : number of channels
n_intervals : number of unique points
out_grad : gradient of the BEV fmap from top, FloatTensor[b, d, h, w, c]
geom_feats : input coordinates, IntTensor[n, 4]
interval_lengths : starting position for pooled point, IntTensor[n_intervals]
interval_starts : how many points in each pooled point, IntTensor[n_intervals]
x_grad : gradient of the image fmap, FloatTensor
*/
__global__ void bev_pool_grad_kernel(int b, int d, int h, int w, int n, int c, int n_intervals,
const float *__restrict__ out_grad,
const int *__restrict__ geom_feats,
const int *__restrict__ interval_starts,
const int *__restrict__ interval_lengths,
float* __restrict__ x_grad) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int index = idx / c;
int cur_c = idx % c;
if (index >= n_intervals) return;
int interval_start = interval_starts[index];
int interval_length = interval_lengths[index];
const int* cur_geom_feats = geom_feats + interval_start * 4;
float* cur_x_grad = x_grad + interval_start * c + cur_c;
const float* cur_out_grad = out_grad + cur_geom_feats[3] * d * h * w * c +
cur_geom_feats[2] * h * w * c + cur_geom_feats[0] * w * c +
cur_geom_feats[1] * c + cur_c;
for(int i = 0; i < interval_length; i++){
cur_x_grad[i * c] = *cur_out_grad;
}
}
void bev_pool(int b, int d, int h, int w, int n, int c, int n_intervals, const float* x,
const int* geom_feats, const int* interval_starts, const int* interval_lengths, float* out) {
bev_pool_kernel<<<(int)ceil(((double)n_intervals * c / 256)), 256>>>(
b, d, h, w, n, c, n_intervals, x, geom_feats, interval_starts, interval_lengths, out
);
}
void bev_pool_grad(int b, int d, int h, int w, int n, int c, int n_intervals, const float* out_grad,
const int* geom_feats, const int* interval_starts, const int* interval_lengths, float* x_grad) {
bev_pool_grad_kernel<<<(int)ceil(((double)n_intervals * c / 256)), 256>>>(
b, d, h, w, n, c, n_intervals, out_grad, geom_feats, interval_starts, interval_lengths, x_grad
);
}
......@@ -57,6 +57,24 @@ def rotate_points_along_z(points, angle):
return points_rot.numpy() if is_numpy else points_rot
def angle2matrix(angle):
"""
Args:
angle: angle along z-axis, angle increases x ==> y
Returns:
rot_matrix: (3x3 Tensor) rotation matrix
"""
cosa = torch.cos(angle)
sina = torch.sin(angle)
rot_matrix = torch.tensor([
[cosa, -sina, 0],
[sina, cosa, 0],
[ 0, 0, 1]
])
return rot_matrix
def mask_points_by_range(points, limit_range):
mask = (points[:, 0] >= limit_range[0]) & (points[:, 0] <= limit_range[3]) \
& (points[:, 1] >= limit_range[1]) & (points[:, 1] <= limit_range[4])
......
......@@ -560,4 +560,49 @@ class IouRegLossSparse(nn.Module):
loss += (1. - iou).sum()
loss = loss / (mask.sum() + 1e-4)
return loss
\ No newline at end of file
return loss
class L1Loss(nn.Module):
def __init__(self):
super(L1Loss, self).__init__()
def forward(self, pred, target):
if target.numel() == 0:
return pred.sum() * 0
assert pred.size() == target.size()
loss = torch.abs(pred - target)
return loss
class GaussianFocalLoss(nn.Module):
"""GaussianFocalLoss is a variant of focal loss.
More details can be found in the `paper
<https://arxiv.org/abs/1808.01244>`_
Code is modified from `kp_utils.py
<https://github.com/princeton-vl/CornerNet/blob/master/models/py_utils/kp_utils.py#L152>`_ # noqa: E501
Please notice that the target in GaussianFocalLoss is a gaussian heatmap,
not 0/1 binary target.
Args:
alpha (float): Power of prediction.
gamma (float): Power of target for negative samples.
reduction (str): Options are "none", "mean" and "sum".
loss_weight (float): Loss weight of current loss.
"""
def __init__(self,
alpha=2.0,
gamma=4.0):
super(GaussianFocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, pred, target):
eps = 1e-12
pos_weights = target.eq(1)
neg_weights = (1 - target).pow(self.gamma)
pos_loss = -(pred + eps).log() * (1 - pred).pow(self.alpha) * pos_weights
neg_loss = -(1 - pred + eps).log() * pred.pow(self.alpha) * neg_weights
return pos_loss + neg_loss
\ No newline at end of file
......@@ -10,3 +10,4 @@ tqdm
torchvision
SharedArray
opencv-python
pyquaternion
\ No newline at end of file
......@@ -117,5 +117,13 @@ if __name__ == '__main__':
],
),
make_cuda_ext(
name="bev_pool_ext",
module="pcdet.ops.bev_pool",
sources=[
"src/bev_pool.cpp",
"src/bev_pool_cuda.cu",
],
),
],
)
CLASS_NAMES: ['car','truck', 'construction_vehicle', 'bus', 'trailer',
'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone']
DATA_CONFIG:
_BASE_CONFIG_: cfgs/dataset_configs/nuscenes_dataset.yaml
POINT_CLOUD_RANGE: [-54.0, -54.0, -5.0, 54.0, 54.0, 3.0]
CAMERA_CONFIG:
USE_CAMERA: True
IMAGE:
FINAL_DIM: [256,704]
RESIZE_LIM_TRAIN: [0.38, 0.55]
RESIZE_LIM_TEST: [0.48, 0.48]
DATA_AUGMENTOR:
DISABLE_AUG_LIST: ['placeholder']
AUG_CONFIG_LIST:
- NAME: random_world_flip
ALONG_AXIS_LIST: ['x', 'y']
- NAME: random_world_rotation
WORLD_ROT_ANGLE: [-0.78539816, 0.78539816]
- NAME: random_world_scaling
WORLD_SCALE_RANGE: [0.9, 1.1]
- NAME: random_world_translation
NOISE_TRANSLATE_STD: [0.5, 0.5, 0.5]
- NAME: imgaug
ROT_LIM: [-5.4, 5.4]
RAND_FLIP: True
DATA_PROCESSOR:
- NAME: mask_points_and_boxes_outside_range
REMOVE_OUTSIDE_BOXES: True
- NAME: shuffle_points
SHUFFLE_ENABLED: {
'train': True,
'test': True
}
- NAME: transform_points_to_voxels
VOXEL_SIZE: [0.075, 0.075, 0.2]
MAX_POINTS_PER_VOXEL: 10
MAX_NUMBER_OF_VOXELS: {
'train': 120000,
'test': 160000
}
- NAME: image_calibrate
- NAME: image_normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
MODEL:
NAME: BevFusion
VFE:
NAME: MeanVFE
BACKBONE_3D:
NAME: VoxelResBackBone8x
USE_BIAS: False
MAP_TO_BEV:
NAME: HeightCompression
NUM_BEV_FEATURES: 256
IMAGE_BACKBONE:
NAME: SwinTransformer
EMBED_DIMS: 96
DEPTHS: [2, 2, 6, 2]
NUM_HEADS: [3, 6, 12, 24]
WINDOW_SIZE: 7
MLP_RATIO: 4
DROP_RATE: 0.
ATTN_DROP_RATE: 0.
DROP_PATH_RATE: 0.2
PATCH_NORM: True
OUT_INDICES: [1, 2, 3]
WITH_CP: False
CONVERT_WEIGHTS: True
INIT_CFG:
type: Pretrained
checkpoint: swint-nuimages-pretrained.pth
NECK:
NAME: GeneralizedLSSFPN
IN_CHANNELS: [192, 384, 768]
OUT_CHANNELS: 256
START_LEVEL: 0
END_LEVEL: -1
NUM_OUTS: 3
VTRANSFORM:
NAME: DepthLSSTransform
IMAGE_SIZE: [256, 704]
IN_CHANNEL: 256
OUT_CHANNEL: 80
FEATURE_SIZE: [32, 88]
XBOUND: [-54.0, 54.0, 0.3]
YBOUND: [-54.0, 54.0, 0.3]
ZBOUND: [-10.0, 10.0, 20.0]
DBOUND: [1.0, 60.0, 0.5]
DOWNSAMPLE: 2
FUSER:
NAME: ConvFuser
IN_CHANNEL: 336
OUT_CHANNEL: 256
BACKBONE_2D:
NAME: BaseBEVBackbone
LAYER_NUMS: [5, 5]
LAYER_STRIDES: [1, 2]
NUM_FILTERS: [128, 256]
UPSAMPLE_STRIDES: [1, 2]
NUM_UPSAMPLE_FILTERS: [256, 256]
USE_CONV_FOR_NO_STRIDE: True
DENSE_HEAD:
CLASS_AGNOSTIC: False
NAME: TransFusionHead
USE_BIAS_BEFORE_NORM: False
NUM_PROPOSALS: 200
HIDDEN_CHANNEL: 128
NUM_CLASSES: 10
NUM_HEADS: 8
NMS_KERNEL_SIZE: 3
FFN_CHANNEL: 256
DROPOUT: 0.1
BN_MOMENTUM: 0.1
ACTIVATION: relu
NUM_HM_CONV: 2
SEPARATE_HEAD_CFG:
HEAD_ORDER: ['center', 'height', 'dim', 'rot', 'vel']
HEAD_DICT: {
'center': {'out_channels': 2, 'num_conv': 2},
'height': {'out_channels': 1, 'num_conv': 2},
'dim': {'out_channels': 3, 'num_conv': 2},
'rot': {'out_channels': 2, 'num_conv': 2},
'vel': {'out_channels': 2, 'num_conv': 2},
}
TARGET_ASSIGNER_CONFIG:
FEATURE_MAP_STRIDE: 8
DATASET: nuScenes
GAUSSIAN_OVERLAP: 0.1
MIN_RADIUS: 2
HUNGARIAN_ASSIGNER:
cls_cost: {'gamma': 2.0, 'alpha': 0.25, 'weight': 0.15}
reg_cost: {'weight': 0.25}
iou_cost: {'weight': 0.25}
LOSS_CONFIG:
LOSS_WEIGHTS: {
'cls_weight': 1.0,
'bbox_weight': 0.25,
'hm_weight': 1.0,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
}
LOSS_CLS:
use_sigmoid: True
gamma: 2.0
alpha: 0.25
POST_PROCESSING:
SCORE_THRESH: 0.0
POST_CENTER_RANGE: [-61.2, -61.2, -10.0, 61.2, 61.2, 10.0]
POST_PROCESSING:
RECALL_THRESH_LIST: [0.3, 0.5, 0.7]
SCORE_THRESH: 0.1
OUTPUT_RAW_SCORE: False
EVAL_METRIC: kitti
OPTIMIZATION:
BATCH_SIZE_PER_GPU: 3
NUM_EPOCHS: 6
OPTIMIZER: adam_cosineanneal
LR: 0.0001
WEIGHT_DECAY: 0.01
MOMENTUM: 0.9
BETAS: [0.9, 0.999]
MOMS: [0.9, 0.8052631]
PCT_START: 0.4
WARMUP_ITER: 500
DECAY_STEP_LIST: [35, 45]
LR_WARMUP: False
WARMUP_EPOCH: 1
GRAD_NORM_CLIP: 35
LOSS_SCALE_FP16: 32
\ No newline at end of file
CLASS_NAMES: ['car','truck', 'construction_vehicle', 'bus', 'trailer',
'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone']
DATA_CONFIG:
_BASE_CONFIG_: cfgs/dataset_configs/nuscenes_dataset.yaml
POINT_CLOUD_RANGE: [-54.0, -54.0, -5.0, 54.0, 54.0, 3.0]
DATA_AUGMENTOR:
DISABLE_AUG_LIST: ['placeholder']
AUG_CONFIG_LIST:
- NAME: gt_sampling
DB_INFO_PATH:
- nuscenes_dbinfos_10sweeps_withvelo.pkl
PREPARE: {
filter_by_min_points: [
'car:5','truck:5', 'construction_vehicle:5', 'bus:5', 'trailer:5',
'barrier:5', 'motorcycle:5', 'bicycle:5', 'pedestrian:5', 'traffic_cone:5'
],
}
SAMPLE_GROUPS: [
'car:2','truck:3', 'construction_vehicle:7', 'bus:4', 'trailer:6',
'barrier:2', 'motorcycle:6', 'bicycle:6', 'pedestrian:2', 'traffic_cone:2'
]
NUM_POINT_FEATURES: 5
DATABASE_WITH_FAKELIDAR: False
REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0]
LIMIT_WHOLE_SCENE: True
- NAME: random_world_flip
ALONG_AXIS_LIST: ['x', 'y']
- NAME: random_world_rotation
WORLD_ROT_ANGLE: [-0.78539816, 0.78539816]
- NAME: random_world_scaling
WORLD_SCALE_RANGE: [0.9, 1.1]
- NAME: random_world_translation
NOISE_TRANSLATE_STD: [0.5, 0.5, 0.5]
DATA_PROCESSOR:
- NAME: mask_points_and_boxes_outside_range
REMOVE_OUTSIDE_BOXES: True
- NAME: shuffle_points
SHUFFLE_ENABLED: {
'train': True,
'test': True
}
- NAME: transform_points_to_voxels
VOXEL_SIZE: [0.075, 0.075, 0.2]
MAX_POINTS_PER_VOXEL: 10
MAX_NUMBER_OF_VOXELS: {
'train': 120000,
'test': 160000
}
MODEL:
NAME: TransFusion
VFE:
NAME: MeanVFE
BACKBONE_3D:
NAME: VoxelResBackBone8x
USE_BIAS: False
MAP_TO_BEV:
NAME: HeightCompression
NUM_BEV_FEATURES: 256
BACKBONE_2D:
NAME: BaseBEVBackbone
LAYER_NUMS: [5, 5]
LAYER_STRIDES: [1, 2]
NUM_FILTERS: [128, 256]
UPSAMPLE_STRIDES: [1, 2]
NUM_UPSAMPLE_FILTERS: [256, 256]
USE_CONV_FOR_NO_STRIDE: True
DENSE_HEAD:
CLASS_AGNOSTIC: False
NAME: TransFusionHead
USE_BIAS_BEFORE_NORM: False
NUM_PROPOSALS: 200
HIDDEN_CHANNEL: 128
NUM_CLASSES: 10
NUM_HEADS: 8
NMS_KERNEL_SIZE: 3
FFN_CHANNEL: 256
DROPOUT: 0.1
BN_MOMENTUM: 0.1
ACTIVATION: relu
NUM_HM_CONV: 2
SEPARATE_HEAD_CFG:
HEAD_ORDER: ['center', 'height', 'dim', 'rot', 'vel']
HEAD_DICT: {
'center': {'out_channels': 2, 'num_conv': 2},
'height': {'out_channels': 1, 'num_conv': 2},
'dim': {'out_channels': 3, 'num_conv': 2},
'rot': {'out_channels': 2, 'num_conv': 2},
'vel': {'out_channels': 2, 'num_conv': 2},
}
TARGET_ASSIGNER_CONFIG:
FEATURE_MAP_STRIDE: 8
DATASET: nuScenes
GAUSSIAN_OVERLAP: 0.1
MIN_RADIUS: 2
HUNGARIAN_ASSIGNER:
cls_cost: {'gamma': 2.0, 'alpha': 0.25, 'weight': 0.15}
reg_cost: {'weight': 0.25}
iou_cost: {'weight': 0.25}
LOSS_CONFIG:
LOSS_WEIGHTS: {
'cls_weight': 1.0,
'bbox_weight': 0.25,
'hm_weight': 1.0,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
}
LOSS_CLS:
use_sigmoid: True
gamma: 2.0
alpha: 0.25
POST_PROCESSING:
SCORE_THRESH: 0.0
POST_CENTER_RANGE: [-61.2, -61.2, -10.0, 61.2, 61.2, 10.0]
POST_PROCESSING:
RECALL_THRESH_LIST: [0.3, 0.5, 0.7]
SCORE_THRESH: 0.1
OUTPUT_RAW_SCORE: False
EVAL_METRIC: kitti
OPTIMIZATION:
BATCH_SIZE_PER_GPU: 4
NUM_EPOCHS: 20
OPTIMIZER: adam_onecycle
LR: 0.001
WEIGHT_DECAY: 0.01
MOMENTUM: 0.9
BETAS: [0.9, 0.999]
MOMS: [0.9, 0.8052631]
PCT_START: 0.4
DIV_FACTOR: 10
DECAY_STEP_LIST: [35, 45]
LR_DECAY: 0.1
LR_CLIP: 0.0000001
LR_WARMUP: False
WARMUP_EPOCH: 1
GRAD_NORM_CLIP: 35
HOOK:
DisableAugmentationHook:
DISABLE_AUG_LIST: ['gt_sampling']
NUM_LAST_EPOCHS: 5
\ No newline at end of file
......@@ -195,7 +195,8 @@ def main():
ckpt_save_time_interval=args.ckpt_save_time_interval,
use_logger_to_record=not args.use_tqdm_to_record,
show_gpu_stat=not args.wo_gpu_stat,
use_amp=args.use_amp
use_amp=args.use_amp,
cfg=cfg
)
if hasattr(train_set, 'use_shared_memory') and train_set.use_shared_memory:
......
......@@ -5,7 +5,7 @@ import torch.optim as optim
import torch.optim.lr_scheduler as lr_sched
from .fastai_optim import OptimWrapper
from .learning_schedules_fastai import CosineWarmupLR, OneCycle
from .learning_schedules_fastai import CosineWarmupLR, OneCycle, CosineAnnealing
def build_optimizer(model, optim_cfg):
......@@ -16,7 +16,7 @@ def build_optimizer(model, optim_cfg):
model.parameters(), lr=optim_cfg.LR, weight_decay=optim_cfg.WEIGHT_DECAY,
momentum=optim_cfg.MOMENTUM
)
elif optim_cfg.OPTIMIZER == 'adam_onecycle':
elif optim_cfg.OPTIMIZER in ['adam_onecycle','adam_cosineanneal']:
def children(m: nn.Module):
return list(m.children())
......@@ -25,8 +25,9 @@ def build_optimizer(model, optim_cfg):
flatten_model = lambda m: sum(map(flatten_model, m.children()), []) if num_children(m) else [m]
get_layer_groups = lambda m: [nn.Sequential(*flatten_model(m))]
optimizer_func = partial(optim.Adam, betas=(0.9, 0.99))
betas = optim_cfg.get('BETAS', (0.9, 0.99))
betas = tuple(betas)
optimizer_func = partial(optim.Adam, betas=betas)
optimizer = OptimWrapper.create(
optimizer_func, 3e-3, get_layer_groups(model), wd=optim_cfg.WEIGHT_DECAY, true_wd=True, bn_wd=True
)
......@@ -51,6 +52,10 @@ def build_scheduler(optimizer, total_iters_each_epoch, total_epochs, last_epoch,
lr_scheduler = OneCycle(
optimizer, total_steps, optim_cfg.LR, list(optim_cfg.MOMS), optim_cfg.DIV_FACTOR, optim_cfg.PCT_START
)
elif optim_cfg.OPTIMIZER == 'adam_cosineanneal':
lr_scheduler = CosineAnnealing(
optimizer, total_steps, total_epochs, optim_cfg.LR, list(optim_cfg.MOMS), optim_cfg.PCT_START, optim_cfg.WARMUP_ITER
)
else:
lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd, last_epoch=last_epoch)
......
......@@ -41,7 +41,7 @@ class LRSchedulerStep(object):
self.mom_phases.append((int(start * total_step), total_step, lambda_func))
assert self.mom_phases[0][0] == 0
def step(self, step):
def step(self, step, epoch=None):
for start, end, func in self.lr_phases:
if step >= start:
self.optimizer.lr = func((step - start) / (end - start))
......@@ -83,12 +83,60 @@ class CosineWarmupLR(lr_sched._LRScheduler):
self.eta_min = eta_min
super(CosineWarmupLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
def get_lr(self, epoch=None):
return [self.eta_min + (base_lr - self.eta_min) *
(1 - math.cos(math.pi * self.last_epoch / self.T_max)) / 2
for base_lr in self.base_lrs]
def linear_warmup(end, lr_max, pct):
k = (1 - pct / end) * (1 - 0.33333333)
warmup_lr = lr_max * (1 - k)
return warmup_lr
class CosineAnnealing(LRSchedulerStep):
def __init__(self, fai_optimizer, total_step, total_epoch, lr_max, moms, pct_start, warmup_iter):
self.lr_max = lr_max
self.moms = moms
self.pct_start = pct_start
mom_phases = ((0, partial(annealing_cos, *self.moms)),
(self.pct_start, partial(annealing_cos,
*self.moms[::-1])))
fai_optimizer.lr, fai_optimizer.mom = lr_max, self.moms[0]
self.optimizer = fai_optimizer
self.total_step = total_step
self.warmup_iter = warmup_iter
self.total_epoch = total_epoch
self.mom_phases = []
for i, (start, lambda_func) in enumerate(mom_phases):
if len(self.mom_phases) != 0:
assert self.mom_phases[-1][0] < start
if isinstance(lambda_func, str):
lambda_func = eval(lambda_func)
if i < len(mom_phases) - 1:
self.mom_phases.append((int(start * total_step), int(mom_phases[i + 1][0] * total_step), lambda_func))
else:
self.mom_phases.append((int(start * total_step), total_step, lambda_func))
assert self.mom_phases[0][0] == 0
def step(self, step, epoch):
# update lr
if step < self.warmup_iter:
self.optimizer.lr = linear_warmup(self.warmup_iter, self.lr_max, step)
else:
target_lr = self.lr_max * 0.001
cos_lr = annealing_cos(self.lr_max, target_lr, epoch / self.total_epoch)
self.optimizer.lr = cos_lr
# update mom
for start, end, func in self.mom_phases:
if step >= start:
self.optimizer.mom = func((step - start) / (end - start))
class FakeOptim:
def __init__(self):
self.lr = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment