Commit 41b18fd8 authored by zhe chen's avatar zhe chen
Browse files

Use pre-commit to reformat code


Use pre-commit to reformat code
parent ff20ea39
...@@ -4,15 +4,15 @@ ...@@ -4,15 +4,15 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import print_function
from __future__ import division
import warnings import warnings
import torch import torch
from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_ from torch import nn
from torch.nn.init import constant_, xavier_uniform_
from ..functions import DCNv3Function, dcnv3_core_pytorch from ..functions import DCNv3Function, dcnv3_core_pytorch
...@@ -72,7 +72,7 @@ def build_act_layer(act_layer): ...@@ -72,7 +72,7 @@ def build_act_layer(act_layer):
def _is_power_of_2(n): def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0): if (not isinstance(n, int)) or (n < 0):
raise ValueError( raise ValueError(
"invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 'invalid input for _is_power_of_2: {} (type: {})'.format(n, type(n)))
return (n & (n - 1) == 0) and n != 0 return (n & (n - 1) == 0) and n != 0
...@@ -124,7 +124,7 @@ class DCNv3_pytorch(nn.Module): ...@@ -124,7 +124,7 @@ class DCNv3_pytorch(nn.Module):
if not _is_power_of_2(_d_per_group): if not _is_power_of_2(_d_per_group):
warnings.warn( warnings.warn(
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.") 'which is more efficient in our CUDA implementation.')
self.offset_scale = offset_scale self.offset_scale = offset_scale
self.channels = channels self.channels = channels
...@@ -161,7 +161,7 @@ class DCNv3_pytorch(nn.Module): ...@@ -161,7 +161,7 @@ class DCNv3_pytorch(nn.Module):
self.input_proj = nn.Linear(channels, channels) self.input_proj = nn.Linear(channels, channels)
self.output_proj = nn.Linear(channels, channels) self.output_proj = nn.Linear(channels, channels)
self._reset_parameters() self._reset_parameters()
if center_feature_scale: if center_feature_scale:
self.center_feature_scale_proj_weight = nn.Parameter( self.center_feature_scale_proj_weight = nn.Parameter(
torch.zeros((group, channels), dtype=torch.float)) torch.zeros((group, channels), dtype=torch.float))
...@@ -251,7 +251,7 @@ class DCNv3(nn.Module): ...@@ -251,7 +251,7 @@ class DCNv3(nn.Module):
if not _is_power_of_2(_d_per_group): if not _is_power_of_2(_d_per_group):
warnings.warn( warnings.warn(
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.") 'which is more efficient in our CUDA implementation.')
self.offset_scale = offset_scale self.offset_scale = offset_scale
self.channels = channels self.channels = channels
...@@ -264,7 +264,7 @@ class DCNv3(nn.Module): ...@@ -264,7 +264,7 @@ class DCNv3(nn.Module):
self.group_channels = channels // group self.group_channels = channels // group
self.offset_scale = offset_scale self.offset_scale = offset_scale
self.center_feature_scale = center_feature_scale self.center_feature_scale = center_feature_scale
self.dw_conv = nn.Sequential( self.dw_conv = nn.Sequential(
nn.Conv2d( nn.Conv2d(
channels, channels,
...@@ -288,7 +288,7 @@ class DCNv3(nn.Module): ...@@ -288,7 +288,7 @@ class DCNv3(nn.Module):
self.input_proj = nn.Linear(channels, channels) self.input_proj = nn.Linear(channels, channels)
self.output_proj = nn.Linear(channels, channels) self.output_proj = nn.Linear(channels, channels)
self._reset_parameters() self._reset_parameters()
if center_feature_scale: if center_feature_scale:
self.center_feature_scale_proj_weight = nn.Parameter( self.center_feature_scale_proj_weight = nn.Parameter(
torch.zeros((group, channels), dtype=torch.float)) torch.zeros((group, channels), dtype=torch.float))
...@@ -332,7 +332,7 @@ class DCNv3(nn.Module): ...@@ -332,7 +332,7 @@ class DCNv3(nn.Module):
self.group, self.group_channels, self.group, self.group_channels,
self.offset_scale, self.offset_scale,
256) 256)
if self.center_feature_scale: if self.center_feature_scale:
center_feature_scale = self.center_feature_scale_module( center_feature_scale = self.center_feature_scale_module(
x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
......
...@@ -4,39 +4,34 @@ ...@@ -4,39 +4,34 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
import os
import glob import glob
import os
import torch import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
from torch.utils.cpp_extension import CUDA_HOME requirements = ['torch', 'torchvision']
from torch.utils.cpp_extension import CppExtension
from torch.utils.cpp_extension import CUDAExtension
from setuptools import find_packages
from setuptools import setup
requirements = ["torch", "torchvision"]
def get_extensions(): def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "src") extensions_dir = os.path.join(this_dir, 'src')
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))
sources = main_file + source_cpu sources = main_file + source_cpu
extension = CppExtension extension = CppExtension
extra_compile_args = {"cxx": []} extra_compile_args = {'cxx': []}
define_macros = [] define_macros = []
if torch.cuda.is_available() and CUDA_HOME is not None: if torch.cuda.is_available() and CUDA_HOME is not None:
extension = CUDAExtension extension = CUDAExtension
sources += source_cuda sources += source_cuda
define_macros += [("WITH_CUDA", None)] define_macros += [('WITH_CUDA', None)]
extra_compile_args["nvcc"] = [ extra_compile_args['nvcc'] = [
# "-DCUDA_HAS_FP16=1", # "-DCUDA_HAS_FP16=1",
# "-D__CUDA_NO_HALF_OPERATORS__", # "-D__CUDA_NO_HALF_OPERATORS__",
# "-D__CUDA_NO_HALF_CONVERSIONS__", # "-D__CUDA_NO_HALF_CONVERSIONS__",
...@@ -49,7 +44,7 @@ def get_extensions(): ...@@ -49,7 +44,7 @@ def get_extensions():
include_dirs = [extensions_dir] include_dirs = [extensions_dir]
ext_modules = [ ext_modules = [
extension( extension(
"DCNv3", 'DCNv3',
sources, sources,
include_dirs=include_dirs, include_dirs=include_dirs,
define_macros=define_macros, define_macros=define_macros,
...@@ -60,16 +55,16 @@ def get_extensions(): ...@@ -60,16 +55,16 @@ def get_extensions():
setup( setup(
name="DCNv3", name='DCNv3',
version="1.0", version='1.0',
author="InternImage", author='InternImage',
url="https://github.com/OpenGVLab/InternImage", url='https://github.com/OpenGVLab/InternImage',
description= description=
"PyTorch Wrapper for CUDA Functions of DCNv3", 'PyTorch Wrapper for CUDA Functions of DCNv3',
packages=find_packages(exclude=( packages=find_packages(exclude=(
"configs", 'configs',
"tests", 'tests',
)), )),
ext_modules=get_extensions(), ext_modules=get_extensions(),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension},
) )
...@@ -171,4 +171,4 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset, ...@@ -171,4 +171,4 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
} else { } else {
return {grad_input, grad_offset, grad_mask}; return {grad_input, grad_offset, grad_mask};
} }
} }
\ No newline at end of file
...@@ -1042,4 +1042,4 @@ void dcnv3_col2im_cuda( ...@@ -1042,4 +1042,4 @@ void dcnv3_col2im_cuda(
if (err != cudaSuccess) { if (err != cudaSuccess) {
printf("error in dcnv3_col2im_cuda: %s\n", cudaGetErrorString(err)); printf("error in dcnv3_col2im_cuda: %s\n", cudaGetErrorString(err));
} }
} }
\ No newline at end of file
...@@ -4,16 +4,11 @@ ...@@ -4,16 +4,11 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import print_function
from __future__ import division
import time import time
import torch
import torch.nn as nn
import math
from torch.autograd import gradcheck
import torch
from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch
H_in, W_in = 8, 8 H_in, W_in = 8, 8
...@@ -32,11 +27,11 @@ torch.manual_seed(3) ...@@ -32,11 +27,11 @@ torch.manual_seed(3)
@torch.no_grad() @torch.no_grad()
def check_forward_equal_with_pytorch_double(): def check_forward_equal_with_pytorch_double():
input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask /= mask.sum(-1, keepdim=True) mask /= mask.sum(-1, keepdim=True)
mask = mask.reshape(N, H_out, W_out, M*P) mask = mask.reshape(N, H_out, W_out, M * P)
output_pytorch = dcnv3_core_pytorch( output_pytorch = dcnv3_core_pytorch(
input.double(), input.double(),
...@@ -57,16 +52,17 @@ def check_forward_equal_with_pytorch_double(): ...@@ -57,16 +52,17 @@ def check_forward_equal_with_pytorch_double():
max_rel_err = ((output_cuda - output_pytorch).abs() / max_rel_err = ((output_cuda - output_pytorch).abs() /
output_pytorch.abs()).max() output_pytorch.abs()).max()
print('>>> forward double') print('>>> forward double')
print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') print(
f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
@torch.no_grad() @torch.no_grad()
def check_forward_equal_with_pytorch_float(): def check_forward_equal_with_pytorch_float():
input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask /= mask.sum(-1, keepdim=True) mask /= mask.sum(-1, keepdim=True)
mask = mask.reshape(N, H_out, W_out, M*P) mask = mask.reshape(N, H_out, W_out, M * P)
output_pytorch = dcnv3_core_pytorch( output_pytorch = dcnv3_core_pytorch(
input, input,
...@@ -87,7 +83,8 @@ def check_forward_equal_with_pytorch_float(): ...@@ -87,7 +83,8 @@ def check_forward_equal_with_pytorch_float():
max_rel_err = ((output_cuda - output_pytorch).abs() / max_rel_err = ((output_cuda - output_pytorch).abs() /
output_pytorch.abs()).max() output_pytorch.abs()).max()
print('>>> forward float') print('>>> forward float')
print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') print(
f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_offset=True, grad_mask=True): def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_offset=True, grad_mask=True):
...@@ -98,11 +95,11 @@ def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_o ...@@ -98,11 +95,11 @@ def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_o
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
D = channels D = channels
input0 = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 input0 = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
offset0 = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 offset0 = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask0 /= mask0.sum(-1, keepdim=True) mask0 /= mask0.sum(-1, keepdim=True)
mask0 = mask0.reshape(N, H_out, W_out, M*P) mask0 = mask0.reshape(N, H_out, W_out, M * P)
input0.requires_grad = grad_input input0.requires_grad = grad_input
offset0.requires_grad = grad_offset offset0.requires_grad = grad_offset
mask0.requires_grad = grad_mask mask0.requires_grad = grad_mask
...@@ -161,11 +158,11 @@ def check_backward_equal_with_pytorch_float(channels=4, grad_input=True, grad_of ...@@ -161,11 +158,11 @@ def check_backward_equal_with_pytorch_float(channels=4, grad_input=True, grad_of
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
D = channels D = channels
input0 = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 input0 = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
offset0 = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 offset0 = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask0 /= mask0.sum(-1, keepdim=True) mask0 /= mask0.sum(-1, keepdim=True)
mask0 = mask0.reshape(N, H_out, W_out, M*P) mask0 = mask0.reshape(N, H_out, W_out, M * P)
input0.requires_grad = grad_input input0.requires_grad = grad_input
offset0.requires_grad = grad_offset offset0.requires_grad = grad_offset
mask0.requires_grad = grad_mask mask0.requires_grad = grad_mask
...@@ -223,11 +220,11 @@ def check_time_cost(im2col_step=128): ...@@ -223,11 +220,11 @@ def check_time_cost(im2col_step=128):
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask /= mask.sum(-1, keepdim=True) mask /= mask.sum(-1, keepdim=True)
mask = mask.reshape(N, H_out, W_out, M*P) mask = mask.reshape(N, H_out, W_out, M * P)
print( print(
f'>>> time cost: im2col_step {im2col_step}; input {input.shape}; points {P} ') f'>>> time cost: im2col_step {im2col_step}; input {input.shape}; points {P} ')
repeat = 100 repeat = 100
......
# ============================================================================== # ==============================================================================
# Binaries and/or source for the following packages or projects # Binaries and/or source for the following packages or projects
# are presented under one or more of the following open source licenses: # are presented under one or more of the following open source licenses:
# baseline.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0 # baseline.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
# #
...@@ -21,8 +21,7 @@ ...@@ -21,8 +21,7 @@
# ============================================================================== # ==============================================================================
import torch import torch
from mmdet3d.models import DETECTORS, build_head, build_neck
from mmdet3d.models import DETECTORS, build_neck, build_head
from mmdet3d.models.detectors import MVXTwoStageDetector from mmdet3d.models.detectors import MVXTwoStageDetector
...@@ -59,13 +58,14 @@ class Baseline(MVXTwoStageDetector): ...@@ -59,13 +58,14 @@ class Baseline(MVXTwoStageDetector):
if type(x) in [list, tuple]: if type(x) in [list, tuple]:
x = x[0] x = x[0]
# view transformation # view transformation
bev_feat = self.img_view_transformer( bev_feat = self.img_view_transformer(
x, x,
torch.cat([ torch.cat([
torch.cat([torch.tensor(l, device=x.device, dtype=torch.float32).unsqueeze(0) for l in img_metas[b]['lidar2img']], dim=0).unsqueeze(0) torch.cat([torch.tensor(l, device=x.device, dtype=torch.float32).unsqueeze(0) for l in
for b in range(B)], dim=0), img_metas[b]['lidar2img']], dim=0).unsqueeze(0)
for b in range(B)], dim=0),
(img_metas[0]['img_shape'][0][0], img_metas[0]['img_shape'][0][1]), (img_metas[0]['img_shape'][0][0], img_metas[0]['img_shape'][0][1]),
) )
_, output_dim, ouput_H, output_W = x.shape _, output_dim, ouput_H, output_W = x.shape
...@@ -76,10 +76,10 @@ class Baseline(MVXTwoStageDetector): ...@@ -76,10 +76,10 @@ class Baseline(MVXTwoStageDetector):
lc_img_metas = [{ lc_img_metas = [{
'batch_input_shape': (bev_feat.shape[-2], bev_feat.shape[-1]), 'batch_input_shape': (bev_feat.shape[-2], bev_feat.shape[-1]),
'img_shape': (bev_feat.shape[-2], bev_feat.shape[-1], None), 'img_shape': (bev_feat.shape[-2], bev_feat.shape[-1], None),
'scale_factor': None, # dummy 'scale_factor': None, # dummy
} for _ in range(B)] } for _ in range(B)]
all_lc_cls_scores_list, all_lc_preds_list, lc_outs_dec_list = self.lc_head( all_lc_cls_scores_list, all_lc_preds_list, lc_outs_dec_list = self.lc_head(
[bev_feat], [bev_feat],
lc_img_metas, lc_img_metas,
) )
...@@ -91,7 +91,7 @@ class Baseline(MVXTwoStageDetector): ...@@ -91,7 +91,7 @@ class Baseline(MVXTwoStageDetector):
'scale_factor': img_metas[b]['scale_factor'], 'scale_factor': img_metas[b]['scale_factor'],
} for b in range(B)] } for b in range(B)]
all_te_cls_scores_list, all_te_preds_list, te_outs_dec_list = self.te_head( all_te_cls_scores_list, all_te_preds_list, te_outs_dec_list = self.te_head(
[pv_feat], [pv_feat],
te_img_metas, te_img_metas,
) )
...@@ -149,7 +149,7 @@ class Baseline(MVXTwoStageDetector): ...@@ -149,7 +149,7 @@ class Baseline(MVXTwoStageDetector):
}) })
# te # te
te_loss_dict, te_assign_results = self.te_head.loss( te_loss_dict, te_assign_results = self.te_head.loss(
outs['all_te_cls_scores_list'], outs['all_te_cls_scores_list'],
outs['all_te_preds_list'], outs['all_te_preds_list'],
...@@ -186,20 +186,20 @@ class Baseline(MVXTwoStageDetector): ...@@ -186,20 +186,20 @@ class Baseline(MVXTwoStageDetector):
}) })
return losses return losses
def forward_test(self, img, img_metas, **kwargs): def forward_test(self, img, img_metas, **kwargs):
outs = self.simple_forward(img, img_metas) outs = self.simple_forward(img, img_metas)
pred_lc = self.lc_head.get_bboxes( pred_lc = self.lc_head.get_bboxes(
outs['all_lc_cls_scores_list'], outs['all_lc_cls_scores_list'],
outs['all_lc_preds_list'], outs['all_lc_preds_list'],
outs['lc_img_metas'], outs['lc_img_metas'],
) )
pred_te = self.te_head.get_bboxes( pred_te = self.te_head.get_bboxes(
outs['all_te_cls_scores_list'], outs['all_te_cls_scores_list'],
outs['all_te_preds_list'], outs['all_te_preds_list'],
outs['te_img_metas'], outs['te_img_metas'],
rescale=True, rescale=True,
) )
......
...@@ -3,17 +3,13 @@ ...@@ -3,17 +3,13 @@
# --------------------------------------------- # ---------------------------------------------
# Modified by Tianyu Li # Modified by Tianyu Li
# --------------------------------------------- # ---------------------------------------------
import time
import copy
import numpy as np
import torch
from mmcv.runner import force_fp32, auto_fp16 import torch
from mmdet.core import bbox2result from mmcv.runner import auto_fp16
from mmdet.models import DETECTORS
from mmdet.models.builder import build_head
from mmdet3d.models.builder import build_neck from mmdet3d.models.builder import build_neck
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
from mmdet.models import DETECTORS
from mmdet.models.builder import build_head
@DETECTORS.register_module() @DETECTORS.register_module()
...@@ -79,7 +75,6 @@ class ROAD_BEVFormer(MVXTwoStageDetector): ...@@ -79,7 +75,6 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
'prev_angle': 0, 'prev_angle': 0,
} }
def extract_img_feat(self, img, img_metas, len_queue=None): def extract_img_feat(self, img, img_metas, len_queue=None):
"""Extract features of images.""" """Extract features of images."""
B = img.size(0) B = img.size(0)
...@@ -108,7 +103,7 @@ class ROAD_BEVFormer(MVXTwoStageDetector): ...@@ -108,7 +103,7 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
for img_feat in img_feats: for img_feat in img_feats:
BN, C, H, W = img_feat.size() BN, C, H, W = img_feat.size()
if len_queue is not None: if len_queue is not None:
img_feats_reshaped.append(img_feat.view(int(B/len_queue), len_queue, int(BN / B), C, H, W)) img_feats_reshaped.append(img_feat.view(int(B / len_queue), len_queue, int(BN / B), C, H, W))
else: else:
img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W)) img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W))
return img_feats_reshaped return img_feats_reshaped
...@@ -118,7 +113,7 @@ class ROAD_BEVFormer(MVXTwoStageDetector): ...@@ -118,7 +113,7 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
"""Extract features from images and points.""" """Extract features from images and points."""
img_feats = self.extract_img_feat(img, img_metas, len_queue=len_queue) img_feats = self.extract_img_feat(img, img_metas, len_queue=len_queue)
return img_feats return img_feats
def forward_dummy(self, img): def forward_dummy(self, img):
...@@ -139,7 +134,7 @@ class ROAD_BEVFormer(MVXTwoStageDetector): ...@@ -139,7 +134,7 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
return self.forward_train(**kwargs) return self.forward_train(**kwargs)
else: else:
return self.forward_test(**kwargs) return self.forward_test(**kwargs)
def obtain_history_bev(self, imgs_queue, img_metas_list): def obtain_history_bev(self, imgs_queue, img_metas_list):
"""Obtain history BEV features iteratively. To save GPU memory, gradients are not calculated. """Obtain history BEV features iteratively. To save GPU memory, gradients are not calculated.
""" """
...@@ -148,7 +143,7 @@ class ROAD_BEVFormer(MVXTwoStageDetector): ...@@ -148,7 +143,7 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
with torch.no_grad(): with torch.no_grad():
prev_bev = None prev_bev = None
bs, len_queue, num_cams, C, H, W = imgs_queue.shape bs, len_queue, num_cams, C, H, W = imgs_queue.shape
imgs_queue = imgs_queue.reshape(bs*len_queue, num_cams, C, H, W) imgs_queue = imgs_queue.reshape(bs * len_queue, num_cams, C, H, W)
img_feats_list = self.extract_feat(img=imgs_queue, len_queue=len_queue) img_feats_list = self.extract_feat(img=imgs_queue, len_queue=len_queue)
for i in range(len_queue): for i in range(len_queue):
img_metas = [each[i] for each in img_metas_list] img_metas = [each[i] for each in img_metas_list]
...@@ -183,7 +178,8 @@ class ROAD_BEVFormer(MVXTwoStageDetector): ...@@ -183,7 +178,8 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
lane_feats = outs['history_states'] lane_feats = outs['history_states']
if self.lclc_head is not None: if self.lclc_head is not None:
lclc_losses = self.lclc_head.forward_train(lane_feats, lane_assign_result, lane_feats, lane_assign_result, gt_topology_lclc) lclc_losses = self.lclc_head.forward_train(lane_feats, lane_assign_result, lane_feats, lane_assign_result,
gt_topology_lclc)
for loss in lclc_losses: for loss in lclc_losses:
losses['lclc_head.' + loss] = lclc_losses[loss] losses['lclc_head.' + loss] = lclc_losses[loss]
...@@ -201,13 +197,15 @@ class ROAD_BEVFormer(MVXTwoStageDetector): ...@@ -201,13 +197,15 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
te_losses = {} te_losses = {}
bbox_outs = self.bbox_head(front_view_img_feats, bbox_img_metas) bbox_outs = self.bbox_head(front_view_img_feats, bbox_img_metas)
bbox_losses, te_assign_result = self.bbox_head.loss(bbox_outs, gt_te, gt_te_labels, bbox_img_metas, gt_bboxes_ignore) bbox_losses, te_assign_result = self.bbox_head.loss(bbox_outs, gt_te, gt_te_labels, bbox_img_metas,
gt_bboxes_ignore)
for loss in bbox_losses: for loss in bbox_losses:
te_losses['bbox_head.' + loss] = bbox_losses[loss] te_losses['bbox_head.' + loss] = bbox_losses[loss]
if self.lcte_head is not None: if self.lcte_head is not None:
te_feats = bbox_outs['history_states'] te_feats = bbox_outs['history_states']
lcte_losses = self.lcte_head.forward_train(lane_feats, lane_assign_result, te_feats, te_assign_result, gt_topology_lcte) lcte_losses = self.lcte_head.forward_train(lane_feats, lane_assign_result, te_feats, te_assign_result,
gt_topology_lcte)
for loss in lcte_losses: for loss in lcte_losses:
te_losses['lcte_head.' + loss] = lcte_losses[loss] te_losses['lcte_head.' + loss] = lcte_losses[loss]
...@@ -263,7 +261,7 @@ class ROAD_BEVFormer(MVXTwoStageDetector): ...@@ -263,7 +261,7 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
bbox_results = self.bbox_head.get_bboxes(bbox_outs, bbox_img_metas, rescale=rescale) bbox_results = self.bbox_head.get_bboxes(bbox_outs, bbox_img_metas, rescale=rescale)
else: else:
bbox_results = [None for _ in range(batchsize)] bbox_results = [None for _ in range(batchsize)]
if self.bbox_head is not None and self.lcte_head is not None: if self.bbox_head is not None and self.lcte_head is not None:
te_feats = bbox_outs['history_states'] te_feats = bbox_outs['history_states']
lcte_results = self.lcte_head.get_relationship(lane_feats, te_feats) lcte_results = self.lcte_head.get_relationship(lane_feats, te_feats)
...@@ -280,7 +278,8 @@ class ROAD_BEVFormer(MVXTwoStageDetector): ...@@ -280,7 +278,8 @@ class ROAD_BEVFormer(MVXTwoStageDetector):
results_list = [dict() for i in range(len(img_metas))] results_list = [dict() for i in range(len(img_metas))]
new_prev_bev, bbox_results, lane_results, lclc_results, lcte_results = self.simple_test_pts( new_prev_bev, bbox_results, lane_results, lclc_results, lcte_results = self.simple_test_pts(
img_feats, img_metas, img, prev_bev, rescale=rescale) img_feats, img_metas, img, prev_bev, rescale=rescale)
for result_dict, bbox, lane, lclc, lcte in zip(results_list, bbox_results, lane_results, lclc_results, lcte_results): for result_dict, bbox, lane, lclc, lcte in zip(results_list, bbox_results, lane_results, lclc_results,
lcte_results):
result_dict['pred_te'] = bbox result_dict['pred_te'] = bbox
result_dict['pred_lc'] = lane result_dict['pred_lc'] = lane
result_dict['pred_topology_lclc'] = lclc result_dict['pred_topology_lclc'] = lclc
......
from .custom_detr_head import *
from .topology_head import *
from .lc_deformable_detr_head import LCDeformableDETRHead
from .te_deformable_detr_head import TEDeformableDETRHead
from .relationship_head import RelationshipHead
\ No newline at end of file
# ============================================================================== # ==============================================================================
# Binaries and/or source for the following packages or projects # Binaries and/or source for the following packages or projects
# are presented under one or more of the following open source licenses: # are presented under one or more of the following open source licenses:
# custom_detr_head.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0 # custom_detr_head.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
# #
...@@ -23,17 +23,17 @@ ...@@ -23,17 +23,17 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import Linear from mmcv.cnn import Linear
from mmdet.core import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh, multi_apply, reduce_mean from mmdet.core import (bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh, multi_apply,
reduce_mean)
from mmdet.models import HEADS, DETRHead from mmdet.models import HEADS, DETRHead
@HEADS.register_module() @HEADS.register_module()
class CustomDETRHead(DETRHead): class CustomDETRHead(DETRHead):
def __init__(self, def __init__(self,
num_classes, num_classes,
in_channels, in_channels,
num_query, num_query,
object_type, object_type,
...@@ -46,7 +46,7 @@ class CustomDETRHead(DETRHead): ...@@ -46,7 +46,7 @@ class CustomDETRHead(DETRHead):
ffn_dropout=0.1, ffn_dropout=0.1,
**kwargs): **kwargs):
self.object_type = object_type self.object_type = object_type
if self.object_type == 'lane': if self.object_type == 'lane':
self.num_reg_dim = num_reg_dim self.num_reg_dim = num_reg_dim
assert self.num_reg_dim % 3 == 0 assert self.num_reg_dim % 3 == 0
...@@ -57,7 +57,7 @@ class CustomDETRHead(DETRHead): ...@@ -57,7 +57,7 @@ class CustomDETRHead(DETRHead):
else: else:
raise NotImplementedError raise NotImplementedError
transformer=dict( transformer = dict(
type='Transformer', type='Transformer',
encoder=dict( encoder=dict(
type='DetrTransformerEncoder', type='DetrTransformerEncoder',
...@@ -106,13 +106,13 @@ class CustomDETRHead(DETRHead): ...@@ -106,13 +106,13 @@ class CustomDETRHead(DETRHead):
operation_order=('self_attn', 'norm', 'cross_attn', 'norm', operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm')), 'ffn', 'norm')),
)) ))
positional_encoding=dict( positional_encoding = dict(
type='SinePositionalEncoding', num_feats=embed_dims//2, normalize=True) type='SinePositionalEncoding', num_feats=embed_dims // 2, normalize=True)
super().__init__( super().__init__(
num_classes=num_classes, num_classes=num_classes,
in_channels=in_channels, in_channels=in_channels,
num_query=num_query, num_query=num_query,
transformer=transformer, transformer=transformer,
positional_encoding=positional_encoding, positional_encoding=positional_encoding,
**kwargs, **kwargs,
) )
...@@ -135,7 +135,7 @@ class CustomDETRHead(DETRHead): ...@@ -135,7 +135,7 @@ class CustomDETRHead(DETRHead):
for img_id in range(batch_size): for img_id in range(batch_size):
img_h, img_w, _ = img_metas[img_id]['img_shape'] img_h, img_w, _ = img_metas[img_id]['img_shape']
masks[img_id, :img_h, :img_w] = 0 masks[img_id, :img_h, :img_w] = 0
x = self.input_proj(x) x = self.input_proj(x)
# interpolate masks to have the same spatial shape with x # interpolate masks to have the same spatial shape with x
masks = F.interpolate( masks = F.interpolate(
...@@ -221,7 +221,7 @@ class CustomDETRHead(DETRHead): ...@@ -221,7 +221,7 @@ class CustomDETRHead(DETRHead):
cls_scores = cls_scores.reshape(-1, self.cls_out_channels) cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
# construct weighted avg_factor to match with the official DETR repo # construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \ cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor: if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean( cls_avg_factor = reduce_mean(
cls_scores.new_tensor([cls_avg_factor])) cls_scores.new_tensor([cls_avg_factor]))
...@@ -244,8 +244,8 @@ class CustomDETRHead(DETRHead): ...@@ -244,8 +244,8 @@ class CustomDETRHead(DETRHead):
for img_meta, bbox_pred in zip(img_metas, bbox_preds): for img_meta, bbox_pred in zip(img_metas, bbox_preds):
img_h, img_w, _ = img_meta['img_shape'] img_h, img_w, _ = img_meta['img_shape']
factor = bbox_pred.new_tensor([img_w, img_h, img_w, factor = bbox_pred.new_tensor([img_w, img_h, img_w,
img_h]).unsqueeze(0).repeat( img_h]).unsqueeze(0).repeat(
bbox_pred.size(0), 1) bbox_pred.size(0), 1)
factors.append(factor) factors.append(factor)
factors = torch.cat(factors, 0) factors = torch.cat(factors, 0)
...@@ -282,8 +282,8 @@ class CustomDETRHead(DETRHead): ...@@ -282,8 +282,8 @@ class CustomDETRHead(DETRHead):
(labels_list, label_weights_list, bbox_targets_list, (labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, pos_inds_list, neg_inds_list, pos_assigned_gt_inds_list) = multi_apply( bbox_weights_list, pos_inds_list, neg_inds_list, pos_assigned_gt_inds_list) = multi_apply(
self._get_target_single, cls_scores_list, bbox_preds_list, self._get_target_single, cls_scores_list, bbox_preds_list,
gt_bboxes_list, gt_labels_list, img_metas, gt_bboxes_ignore_list) gt_bboxes_list, gt_labels_list, img_metas, gt_bboxes_ignore_list)
num_total_pos = sum((inds.numel() for inds in pos_inds_list)) num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list)) num_total_neg = sum((inds.numel() for inds in neg_inds_list))
assign_result = dict( assign_result = dict(
...@@ -312,7 +312,7 @@ class CustomDETRHead(DETRHead): ...@@ -312,7 +312,7 @@ class CustomDETRHead(DETRHead):
pos_assigned_gt_inds = sampling_result.pos_assigned_gt_inds pos_assigned_gt_inds = sampling_result.pos_assigned_gt_inds
# label targets # label targets
labels = gt_bboxes.new_full((num_bboxes, ), labels = gt_bboxes.new_full((num_bboxes,),
self.num_classes, self.num_classes,
dtype=torch.long) dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
...@@ -327,9 +327,12 @@ class CustomDETRHead(DETRHead): ...@@ -327,9 +327,12 @@ class CustomDETRHead(DETRHead):
pos_gt_bboxes = sampling_result.pos_gt_bboxes pos_gt_bboxes = sampling_result.pos_gt_bboxes
pos_gt_bboxes_normalized = torch.zeros_like(pos_gt_bboxes) pos_gt_bboxes_normalized = torch.zeros_like(pos_gt_bboxes)
for p in range(self.num_reg_dim // 3): for p in range(self.num_reg_dim // 3):
pos_gt_bboxes_normalized[..., 3*p] = (pos_gt_bboxes[..., 3*p] - self.bev_range[0]) / (self.bev_range[3] - self.bev_range[0]) pos_gt_bboxes_normalized[..., 3 * p] = (pos_gt_bboxes[..., 3 * p] - self.bev_range[0]) / (
pos_gt_bboxes_normalized[..., 3*p+1] = (pos_gt_bboxes[..., 3*p+1] - self.bev_range[1]) / (self.bev_range[4] - self.bev_range[1]) self.bev_range[3] - self.bev_range[0])
pos_gt_bboxes_normalized[..., 3*p+2] = (pos_gt_bboxes[..., 3*p+2] - self.bev_range[2]) / (self.bev_range[5] - self.bev_range[2]) pos_gt_bboxes_normalized[..., 3 * p + 1] = (pos_gt_bboxes[..., 3 * p + 1] - self.bev_range[1]) / (
self.bev_range[4] - self.bev_range[1])
pos_gt_bboxes_normalized[..., 3 * p + 2] = (pos_gt_bboxes[..., 3 * p + 2] - self.bev_range[2]) / (
self.bev_range[5] - self.bev_range[2])
pos_gt_bboxes_targets = pos_gt_bboxes_normalized pos_gt_bboxes_targets = pos_gt_bboxes_normalized
else: else:
img_h, img_w, _ = img_meta['img_shape'] img_h, img_w, _ = img_meta['img_shape']
...@@ -338,10 +341,10 @@ class CustomDETRHead(DETRHead): ...@@ -338,10 +341,10 @@ class CustomDETRHead(DETRHead):
# Thus the learning target should be normalized by the image size, also # Thus the learning target should be normalized by the image size, also
# the box format should be converted from defaultly x1y1x2y2 to cxcywh. # the box format should be converted from defaultly x1y1x2y2 to cxcywh.
factor = bbox_pred.new_tensor([img_w, img_h, img_w, factor = bbox_pred.new_tensor([img_w, img_h, img_w,
img_h]).unsqueeze(0) img_h]).unsqueeze(0)
pos_gt_bboxes_normalized = sampling_result.pos_gt_bboxes / factor pos_gt_bboxes_normalized = sampling_result.pos_gt_bboxes / factor
pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized) pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized)
bbox_targets[pos_inds] = pos_gt_bboxes_targets bbox_targets[pos_inds] = pos_gt_bboxes_targets
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds, pos_assigned_gt_inds) neg_inds, pos_assigned_gt_inds)
...@@ -360,12 +363,15 @@ class CustomDETRHead(DETRHead): ...@@ -360,12 +363,15 @@ class CustomDETRHead(DETRHead):
det_bboxes = bbox_pred det_bboxes = bbox_pred
for p in range(self.num_reg_dim // 3): for p in range(self.num_reg_dim // 3):
det_bboxes[..., 3*p] = det_bboxes[..., 3*p] * (self.bev_range[3] - self.bev_range[0]) + self.bev_range[0] det_bboxes[..., 3 * p] = det_bboxes[..., 3 * p] * (self.bev_range[3] - self.bev_range[0]) + \
det_bboxes[..., 3*p+1] = det_bboxes[..., 3*p+1] * (self.bev_range[4] - self.bev_range[1]) + self.bev_range[1] self.bev_range[0]
det_bboxes[..., 3*p+2] = det_bboxes[..., 3*p+2] * (self.bev_range[5] - self.bev_range[2]) + self.bev_range[2] det_bboxes[..., 3 * p + 1] = det_bboxes[..., 3 * p + 1] * (self.bev_range[4] - self.bev_range[1]) + \
det_bboxes[..., 3*p].clamp_(min=self.bev_range[0], max=self.bev_range[3]) self.bev_range[1]
det_bboxes[..., 3*p+1].clamp_(min=self.bev_range[1], max=self.bev_range[4]) det_bboxes[..., 3 * p + 2] = det_bboxes[..., 3 * p + 2] * (self.bev_range[5] - self.bev_range[2]) + \
det_bboxes[..., 3*p+2].clamp_(min=self.bev_range[2], max=self.bev_range[5]) self.bev_range[2]
det_bboxes[..., 3 * p].clamp_(min=self.bev_range[0], max=self.bev_range[3])
det_bboxes[..., 3 * p + 1].clamp_(min=self.bev_range[1], max=self.bev_range[4])
det_bboxes[..., 3 * p + 2].clamp_(min=self.bev_range[2], max=self.bev_range[5])
else: else:
# exclude background # exclude background
if self.loss_cls.use_sigmoid: if self.loss_cls.use_sigmoid:
......
# ============================================================================== # ==============================================================================
# Binaries and/or source for the following packages or projects # Binaries and/or source for the following packages or projects
# are presented under one or more of the following open source licenses: # are presented under one or more of the following open source licenses:
# custom_detr_head.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0 # custom_detr_head.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
# #
...@@ -22,12 +22,8 @@ ...@@ -22,12 +22,8 @@
import copy import copy
import numpy as np
import cv2
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import mmcv
from mmcv.cnn import Linear, bias_init_with_prob, build_activation_layer from mmcv.cnn import Linear, bias_init_with_prob, build_activation_layer
from mmcv.runner import auto_fp16, force_fp32 from mmcv.runner import auto_fp16, force_fp32
from mmcv.utils import TORCH_VERSION, digit_version from mmcv.utils import TORCH_VERSION, digit_version
...@@ -79,15 +75,15 @@ class LCDeformableDETRHead(AnchorFreeHead): ...@@ -79,15 +75,15 @@ class LCDeformableDETRHead(AnchorFreeHead):
self.bg_cls_weight = 0 self.bg_cls_weight = 0
self.sync_cls_avg_factor = sync_cls_avg_factor self.sync_cls_avg_factor = sync_cls_avg_factor
if train_cfg: if train_cfg:
assert 'assigner' in train_cfg, 'assigner should be provided '\ assert 'assigner' in train_cfg, 'assigner should be provided ' \
'when train_cfg is set.' 'when train_cfg is set.'
assigner = train_cfg['assigner'] assigner = train_cfg['assigner']
assert loss_cls['loss_weight'] == assigner['cls_cost']['weight'], \ assert loss_cls['loss_weight'] == assigner['cls_cost']['weight'], \
'The classification weight for loss and matcher should be' \ 'The classification weight for loss and matcher should be' \
'exactly the same.' 'exactly the same.'
assert loss_bbox['loss_weight'] == assigner['reg_cost'][ assert loss_bbox['loss_weight'] == assigner['reg_cost'][
'weight'], 'The regression L1 weight for loss and matcher ' \ 'weight'], 'The regression L1 weight for loss and matcher ' \
'should be exactly the same.' 'should be exactly the same.'
assert loss_iou['loss_weight'] == assigner['iou_cost']['weight'], \ assert loss_iou['loss_weight'] == assigner['iou_cost']['weight'], \
'The regression iou weight for loss and matcher should be' \ 'The regression iou weight for loss and matcher should be' \
'exactly the same.' 'exactly the same.'
...@@ -195,7 +191,7 @@ class LCDeformableDETRHead(AnchorFreeHead): ...@@ -195,7 +191,7 @@ class LCDeformableDETRHead(AnchorFreeHead):
network, each is a 5D-tensor with shape network, each is a 5D-tensor with shape
(B, N, C, H, W). (B, N, C, H, W).
prev_bev: previous bev featues prev_bev: previous bev featues
only_bev: only compute BEV features with encoder. only_bev: only compute BEV features with encoder.
Returns: Returns:
all_cls_scores (Tensor): Outputs from the classification head, \ all_cls_scores (Tensor): Outputs from the classification head, \
shape [nb_dec, bs, num_query, cls_out_channels]. Note \ shape [nb_dec, bs, num_query, cls_out_channels]. Note \
...@@ -232,12 +228,12 @@ class LCDeformableDETRHead(AnchorFreeHead): ...@@ -232,12 +228,12 @@ class LCDeformableDETRHead(AnchorFreeHead):
assert reference.shape[-1] == 3 assert reference.shape[-1] == 3
for p in range(self.code_size // 3): for p in range(self.code_size // 3):
tmp[..., 3*p:3*p+3] = tmp[..., 3*p:3*p+3] + reference tmp[..., 3 * p:3 * p + 3] = tmp[..., 3 * p:3 * p + 3] + reference
tmp[..., 3*p:3*p+3] = tmp[..., 3*p:3*p+3].sigmoid() tmp[..., 3 * p:3 * p + 3] = tmp[..., 3 * p:3 * p + 3].sigmoid()
tmp[..., 3*p] = tmp[..., 3*p] * (self.pc_range[3] - self.pc_range[0]) + self.pc_range[0] tmp[..., 3 * p] = tmp[..., 3 * p] * (self.pc_range[3] - self.pc_range[0]) + self.pc_range[0]
tmp[..., 3*p+1] = tmp[..., 3*p+1] * (self.pc_range[4] - self.pc_range[1]) + self.pc_range[1] tmp[..., 3 * p + 1] = tmp[..., 3 * p + 1] * (self.pc_range[4] - self.pc_range[1]) + self.pc_range[1]
tmp[..., 3*p+2] = tmp[..., 3*p+2] * (self.pc_range[5] - self.pc_range[2]) + self.pc_range[2] tmp[..., 3 * p + 2] = tmp[..., 3 * p + 2] * (self.pc_range[5] - self.pc_range[2]) + self.pc_range[2]
outputs_coord = tmp outputs_coord = tmp
outputs_classes.append(outputs_class) outputs_classes.append(outputs_class)
outputs_coords.append(outputs_coord) outputs_coords.append(outputs_coord)
...@@ -293,7 +289,7 @@ class LCDeformableDETRHead(AnchorFreeHead): ...@@ -293,7 +289,7 @@ class LCDeformableDETRHead(AnchorFreeHead):
sampling_result = self.sampler.sample(assign_result, lanes_pred, sampling_result = self.sampler.sample(assign_result, lanes_pred,
gt_lanes) gt_lanes)
pos_inds = sampling_result.pos_inds pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds neg_inds = sampling_result.neg_inds
pos_assigned_gt_inds = sampling_result.pos_assigned_gt_inds pos_assigned_gt_inds = sampling_result.pos_assigned_gt_inds
...@@ -415,7 +411,7 @@ class LCDeformableDETRHead(AnchorFreeHead): ...@@ -415,7 +411,7 @@ class LCDeformableDETRHead(AnchorFreeHead):
cls_scores = cls_scores.reshape(-1, self.cls_out_channels) cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
# construct weighted avg_factor to match with the official DETR repo # construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \ cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor: if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean( cls_avg_factor = reduce_mean(
cls_scores.new_tensor([cls_avg_factor])) cls_scores.new_tensor([cls_avg_factor]))
...@@ -436,7 +432,7 @@ class LCDeformableDETRHead(AnchorFreeHead): ...@@ -436,7 +432,7 @@ class LCDeformableDETRHead(AnchorFreeHead):
bbox_weights = bbox_weights * self.code_weights bbox_weights = bbox_weights * self.code_weights
loss_bbox = self.loss_bbox( loss_bbox = self.loss_bbox(
lanes_preds[isnotnan, :self.code_size], lanes_preds[isnotnan, :self.code_size],
bbox_targets[isnotnan, :self.code_size], bbox_targets[isnotnan, :self.code_size],
bbox_weights[isnotnan, :self.code_size], bbox_weights[isnotnan, :self.code_size],
avg_factor=num_total_pos) avg_factor=num_total_pos)
...@@ -544,7 +540,7 @@ class LCDeformableDETRHead(AnchorFreeHead): ...@@ -544,7 +540,7 @@ class LCDeformableDETRHead(AnchorFreeHead):
cls_scores = all_cls_scores[i].sigmoid() cls_scores = all_cls_scores[i].sigmoid()
predictions_list.append([ predictions_list.append([
all_lanes_preds[i].detach().cpu().numpy(), all_lanes_preds[i].detach().cpu().numpy(),
cls_scores.detach().cpu().numpy()]) cls_scores.detach().cpu().numpy()])
return predictions_list return predictions_list
import copy import copy
import mmcv
import numpy as np import numpy as np
import cv2
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import mmcv
from mmcv.cnn import Linear, bias_init_with_prob, build_activation_layer
from mmcv.cnn.bricks.transformer import build_feedforward_network from mmcv.cnn.bricks.transformer import build_feedforward_network
from mmcv.runner import auto_fp16, force_fp32
from mmcv.utils import TORCH_VERSION, digit_version from mmcv.utils import TORCH_VERSION, digit_version
from mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean from mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean
from mmdet.models.builder import HEADS, build_loss from mmdet.models.builder import HEADS, build_loss
from mmdet.models.dense_heads import AnchorFreeHead from mmdet.models.dense_heads import AnchorFreeHead
from mmdet.models.utils import build_transformer from mmdet.models.utils import build_transformer
from mmdet.models.utils.transformer import inverse_sigmoid
class MLP(nn.Module): class MLP(nn.Module):
...@@ -38,10 +34,10 @@ class RelationshipHead(nn.Module): ...@@ -38,10 +34,10 @@ class RelationshipHead(nn.Module):
in_channels_o2=None, in_channels_o2=None,
shared_param=True, shared_param=True,
loss_rel=dict( loss_rel=dict(
type='FocalLoss', type='FocalLoss',
use_sigmoid=True, use_sigmoid=True,
gamma=2.0, gamma=2.0,
alpha=0.25)): alpha=0.25)):
super().__init__() super().__init__()
self.MLP_o1 = MLP(in_channels_o1, in_channels_o1, 128, 3) self.MLP_o1 = MLP(in_channels_o1, in_channels_o1, 128, 3)
......
# ============================================================================== # ==============================================================================
# Binaries and/or source for the following packages or projects # Binaries and/or source for the following packages or projects
# are presented under one or more of the following open source licenses: # are presented under one or more of the following open source licenses:
# custom_detr_head.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0 # custom_detr_head.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
# #
...@@ -27,12 +27,11 @@ import torch.nn as nn ...@@ -27,12 +27,11 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import Linear, bias_init_with_prob, constant_init from mmcv.cnn import Linear, bias_init_with_prob, constant_init
from mmcv.runner import force_fp32 from mmcv.runner import force_fp32
from mmdet.core import (bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh, multi_apply,
from mmdet.core import (bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh, reduce_mean)
multi_apply, reduce_mean) from mmdet.models import HEADS
from mmdet.models.utils.transformer import inverse_sigmoid
from mmdet.models import HEADS, build_loss
from mmdet.models.dense_heads import DETRHead from mmdet.models.dense_heads import DETRHead
from mmdet.models.utils.transformer import inverse_sigmoid
@HEADS.register_module() @HEADS.register_module()
...@@ -163,14 +162,14 @@ class TEDeformableDETRHead(DETRHead): ...@@ -163,14 +162,14 @@ class TEDeformableDETRHead(DETRHead):
if not self.as_two_stage: if not self.as_two_stage:
query_embeds = self.query_embedding.weight query_embeds = self.query_embedding.weight
hs, init_reference, inter_references, \ hs, init_reference, inter_references, \
enc_outputs_class, enc_outputs_coord = self.transformer( enc_outputs_class, enc_outputs_coord = self.transformer(
mlvl_feats, mlvl_feats,
mlvl_masks, mlvl_masks,
query_embeds, query_embeds,
mlvl_positional_encodings, mlvl_positional_encodings,
reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501 reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501
cls_branches=self.cls_branches if self.as_two_stage else None # noqa:E501 cls_branches=self.cls_branches if self.as_two_stage else None # noqa:E501
) )
hs = hs.permute(0, 2, 1, 3) hs = hs.permute(0, 2, 1, 3)
outputs_classes = [] outputs_classes = []
outputs_coords = [] outputs_coords = []
...@@ -199,7 +198,7 @@ class TEDeformableDETRHead(DETRHead): ...@@ -199,7 +198,7 @@ class TEDeformableDETRHead(DETRHead):
'all_cls_scores': outputs_classes, 'all_cls_scores': outputs_classes,
'all_bbox_preds': outputs_coords, 'all_bbox_preds': outputs_coords,
'enc_cls_scores': enc_outputs_class if self.as_two_stage else None, 'enc_cls_scores': enc_outputs_class if self.as_two_stage else None,
'enc_bbox_preds': enc_outputs_coord.sigmoid() if self.as_two_stage else None, 'enc_bbox_preds': enc_outputs_coord.sigmoid() if self.as_two_stage else None,
'history_states': hs 'history_states': hs
} }
...@@ -336,7 +335,7 @@ class TEDeformableDETRHead(DETRHead): ...@@ -336,7 +335,7 @@ class TEDeformableDETRHead(DETRHead):
cls_scores = cls_scores.reshape(-1, self.cls_out_channels) cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
# construct weighted avg_factor to match with the official DETR repo # construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \ cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor: if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean( cls_avg_factor = reduce_mean(
cls_scores.new_tensor([cls_avg_factor])) cls_scores.new_tensor([cls_avg_factor]))
...@@ -356,7 +355,7 @@ class TEDeformableDETRHead(DETRHead): ...@@ -356,7 +355,7 @@ class TEDeformableDETRHead(DETRHead):
img_h, img_w, _ = img_meta['img_shape'] img_h, img_w, _ = img_meta['img_shape']
factor = bbox_pred.new_tensor([img_w, img_h, img_w, factor = bbox_pred.new_tensor([img_w, img_h, img_w,
img_h]).unsqueeze(0).repeat( img_h]).unsqueeze(0).repeat(
bbox_pred.size(0), 1) bbox_pred.size(0), 1)
factors.append(factor) factors.append(factor)
factors = torch.cat(factors, 0) factors = torch.cat(factors, 0)
...@@ -426,8 +425,8 @@ class TEDeformableDETRHead(DETRHead): ...@@ -426,8 +425,8 @@ class TEDeformableDETRHead(DETRHead):
(labels_list, label_weights_list, bbox_targets_list, (labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, pos_inds_list, neg_inds_list, pos_assigned_gt_inds_list) = multi_apply( bbox_weights_list, pos_inds_list, neg_inds_list, pos_assigned_gt_inds_list) = multi_apply(
self._get_target_single, cls_scores_list, bbox_preds_list, self._get_target_single, cls_scores_list, bbox_preds_list,
gt_bboxes_list, gt_labels_list, img_metas, gt_bboxes_ignore_list) gt_bboxes_list, gt_labels_list, img_metas, gt_bboxes_ignore_list)
num_total_pos = sum((inds.numel() for inds in pos_inds_list)) num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list)) num_total_neg = sum((inds.numel() for inds in neg_inds_list))
assign_result = dict( assign_result = dict(
...@@ -484,7 +483,7 @@ class TEDeformableDETRHead(DETRHead): ...@@ -484,7 +483,7 @@ class TEDeformableDETRHead(DETRHead):
pos_assigned_gt_inds = sampling_result.pos_assigned_gt_inds pos_assigned_gt_inds = sampling_result.pos_assigned_gt_inds
# label targets # label targets
labels = gt_bboxes.new_full((num_bboxes, ), labels = gt_bboxes.new_full((num_bboxes,),
self.num_classes, self.num_classes,
dtype=torch.long) dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
......
# ============================================================================== # ==============================================================================
# Binaries and/or source for the following packages or projects # Binaries and/or source for the following packages or projects
# are presented under one or more of the following open source licenses: # are presented under one or more of the following open source licenses:
# topology_head.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0 # topology_head.py The OpenLane-V2 Dataset Authors Apache License, Version 2.0
# #
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmdet.models import HEADS, build_loss from mmdet.models import HEADS, build_loss
...@@ -41,13 +40,14 @@ class MLP(nn.Module): ...@@ -41,13 +40,14 @@ class MLP(nn.Module):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x return x
@HEADS.register_module() @HEADS.register_module()
class TopologyHead(BaseModule): class TopologyHead(BaseModule):
def __init__(self, def __init__(self,
in_channels, in_channels,
hidden_channels, hidden_channels,
out_channels, out_channels,
num_layers, num_layers,
loss_cls): loss_cls):
...@@ -94,10 +94,11 @@ class TopologyHead(BaseModule): ...@@ -94,10 +94,11 @@ class TopologyHead(BaseModule):
target = pred_adj.new_zeros(pred_adj[b].shape[:-1]) target = pred_adj.new_zeros(pred_adj[b].shape[:-1])
rs = row_assign_result['pos_inds'][b].unsqueeze(-1).repeat(1, column_assign_result['pos_inds'][b].shape[0]) rs = row_assign_result['pos_inds'][b].unsqueeze(-1).repeat(1, column_assign_result['pos_inds'][b].shape[0])
cs = column_assign_result['pos_inds'][b].unsqueeze(0).repeat(row_assign_result['pos_inds'][b].shape[0], 1) cs = column_assign_result['pos_inds'][b].unsqueeze(0).repeat(row_assign_result['pos_inds'][b].shape[0], 1)
target[rs, cs] = gt_adj[b][row_assign_result['pos_assigned_gt_inds'][b]][:, column_assign_result['pos_assigned_gt_inds'][b]].float() target[rs, cs] = gt_adj[b][row_assign_result['pos_assigned_gt_inds'][b]][:,
column_assign_result['pos_assigned_gt_inds'][b]].float()
targets.append(target) targets.append(target)
targets = 1 - torch.stack(targets, dim=0) # 0 as positive targets = 1 - torch.stack(targets, dim=0) # 0 as positive
loss_dict = dict() loss_dict = dict()
pred_adj = pred_adj.reshape(-1, self.out_channels) pred_adj = pred_adj.reshape(-1, self.out_channels)
......
from .spatial_cross_attention import SpatialCrossAttention, MSDeformableAttention3D
from .temporal_self_attention import TemporalSelfAttention
from .encoder import BEVFormerEncoder, BEVFormerLayer
from .decoder import LaneDetectionTransformerDecoder
from .bevformer_constructer import BEVFormerConstructer
from .transformer import PerceptionTransformer
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.init import normal_
from torchvision.transforms.functional import rotate
from mmcv.cnn import xavier_init from mmcv.cnn import xavier_init
from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence, build_positional_encoding from mmcv.cnn.bricks.transformer import (build_positional_encoding,
build_transformer_layer_sequence)
from mmcv.runner.base_module import BaseModule from mmcv.runner.base_module import BaseModule
from mmcv.runner import force_fp32, auto_fp16
from mmdet.models.utils.builder import TRANSFORMER
from mmdet3d.models import NECKS from mmdet3d.models import NECKS
from torch.nn.init import normal_
from torchvision.transforms.functional import rotate
from .temporal_self_attention import TemporalSelfAttention
from .spatial_cross_attention import MSDeformableAttention3D
from .decoder import CustomMSDeformableAttention from .decoder import CustomMSDeformableAttention
from .spatial_cross_attention import MSDeformableAttention3D
from .temporal_self_attention import TemporalSelfAttention
@NECKS.register_module() @NECKS.register_module()
...@@ -69,7 +67,7 @@ class BEVFormerConstructer(BaseModule): ...@@ -69,7 +67,7 @@ class BEVFormerConstructer(BaseModule):
def init_layers(self): def init_layers(self):
self.bev_embedding = nn.Embedding( self.bev_embedding = nn.Embedding(
self.bev_h * self.bev_w, self.embed_dims) self.bev_h * self.bev_w, self.embed_dims)
self.level_embeds = nn.Parameter(torch.Tensor( self.level_embeds = nn.Parameter(torch.Tensor(
self.num_feature_levels, self.embed_dims)) self.num_feature_levels, self.embed_dims))
self.cams_embeds = nn.Parameter( self.cams_embeds = nn.Parameter(
...@@ -82,7 +80,7 @@ class BEVFormerConstructer(BaseModule): ...@@ -82,7 +80,7 @@ class BEVFormerConstructer(BaseModule):
) )
if self.can_bus_norm: if self.can_bus_norm:
self.can_bus_mlp.add_module('norm', nn.LayerNorm(self.embed_dims)) self.can_bus_mlp.add_module('norm', nn.LayerNorm(self.embed_dims))
def init_weights(self): def init_weights(self):
"""Initialize the transformer weights.""" """Initialize the transformer weights."""
for p in self.parameters(): for p in self.parameters():
...@@ -117,9 +115,9 @@ class BEVFormerConstructer(BaseModule): ...@@ -117,9 +115,9 @@ class BEVFormerConstructer(BaseModule):
# obtain rotation angle and shift with ego motion # obtain rotation angle and shift with ego motion
delta_x = np.array([each['can_bus'][0] delta_x = np.array([each['can_bus'][0]
for each in img_metas]) for each in img_metas])
delta_y = np.array([each['can_bus'][1] delta_y = np.array([each['can_bus'][1]
for each in img_metas]) for each in img_metas])
ego_angle = np.array( ego_angle = np.array(
[each['can_bus'][-2] / np.pi * 180 for each in img_metas]) [each['can_bus'][-2] / np.pi * 180 for each in img_metas])
...@@ -129,9 +127,9 @@ class BEVFormerConstructer(BaseModule): ...@@ -129,9 +127,9 @@ class BEVFormerConstructer(BaseModule):
translation_angle = np.arctan2(delta_y, delta_x) / np.pi * 180 translation_angle = np.arctan2(delta_y, delta_x) / np.pi * 180
bev_angle = ego_angle - translation_angle bev_angle = ego_angle - translation_angle
shift_y = translation_length * \ shift_y = translation_length * \
np.cos(bev_angle / 180 * np.pi) / grid_length_y / self.bev_h np.cos(bev_angle / 180 * np.pi) / grid_length_y / self.bev_h
shift_x = translation_length * \ shift_x = translation_length * \
np.sin(bev_angle / 180 * np.pi) / grid_length_x / self.bev_w np.sin(bev_angle / 180 * np.pi) / grid_length_x / self.bev_w
shift_y = shift_y * self.use_shift shift_y = shift_y * self.use_shift
shift_x = shift_x * self.use_shift shift_x = shift_x * self.use_shift
shift = bev_queries.new_tensor( shift = bev_queries.new_tensor(
...@@ -167,7 +165,7 @@ class BEVFormerConstructer(BaseModule): ...@@ -167,7 +165,7 @@ class BEVFormerConstructer(BaseModule):
if self.use_cams_embeds: if self.use_cams_embeds:
feat = feat + self.cams_embeds[:, None, None, :].to(feat.dtype) feat = feat + self.cams_embeds[:, None, None, :].to(feat.dtype)
feat = feat + self.level_embeds[None, feat = feat + self.level_embeds[None,
None, lvl:lvl + 1, :].to(feat.dtype) None, lvl:lvl + 1, :].to(feat.dtype)
spatial_shapes.append(spatial_shape) spatial_shapes.append(spatial_shape)
feat_flatten.append(feat) feat_flatten.append(feat)
...@@ -196,4 +194,3 @@ class BEVFormerConstructer(BaseModule): ...@@ -196,4 +194,3 @@ class BEVFormerConstructer(BaseModule):
) )
return bev_embed return bev_embed
# --------------------------------------------- # ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
# --------------------------------------------- # ---------------------------------------------
# Modified by Zhiqi Li # Modified by Zhiqi Li
# --------------------------------------------- # ---------------------------------------------
import copy import copy
import warnings import warnings
import torch import torch
import torch.nn as nn from mmcv import ConfigDict
from mmcv.cnn import build_norm_layer
from mmcv import ConfigDict, deprecated_api_warning from mmcv.cnn.bricks.registry import TRANSFORMER_LAYER
from mmcv.cnn import Linear, build_activation_layer, build_norm_layer from mmcv.runner.base_module import BaseModule, ModuleList
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
from mmcv.cnn.bricks.registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING, try:
TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE) from mmcv.ops.multi_scale_deform_attn import \
MultiScaleDeformableAttention # noqa F401
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
try: warnings.warn(
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401 ImportWarning(
warnings.warn( '``MultiScaleDeformableAttention`` has been moved to '
ImportWarning( '``mmcv.ops.multi_scale_deform_attn``, please change original path ' # noqa E501
'``MultiScaleDeformableAttention`` has been moved to ' '``from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' # noqa E501
'``mmcv.ops.multi_scale_deform_attn``, please change original path ' # noqa E501 'to ``from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' # noqa E501
'``from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' # noqa E501 ))
'to ``from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' # noqa E501 except ImportError:
)) warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '
except ImportError: '``mmcv.ops.multi_scale_deform_attn``, '
warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from ' 'You should install ``mmcv-full`` if you need this module. ')
'``mmcv.ops.multi_scale_deform_attn``, ' from mmcv.cnn.bricks.transformer import (build_attention,
'You should install ``mmcv-full`` if you need this module. ') build_feedforward_network)
from mmcv.cnn.bricks.transformer import build_feedforward_network, build_attention
@TRANSFORMER_LAYER.register_module()
@TRANSFORMER_LAYER.register_module() class MyCustomBaseTransformerLayer(BaseModule):
class MyCustomBaseTransformerLayer(BaseModule): """Base `TransformerLayer` for vision transformer.
"""Base `TransformerLayer` for vision transformer. It can be built from `mmcv.ConfigDict` and support more flexible
It can be built from `mmcv.ConfigDict` and support more flexible customization, for example, using any number of `FFN or LN ` and
customization, for example, using any number of `FFN or LN ` and use different kinds of `attention` by specifying a list of `ConfigDict`
use different kinds of `attention` by specifying a list of `ConfigDict` named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
named `attn_cfgs`. It is worth mentioning that it supports `prenorm` when you specifying `norm` as the first element of `operation_order`.
when you specifying `norm` as the first element of `operation_order`. More details about the `prenorm`: `On Layer Normalization in the
More details about the `prenorm`: `On Layer Normalization in the Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ .
Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ . Args:
Args: attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): Configs for `self_attention` or `cross_attention` modules,
Configs for `self_attention` or `cross_attention` modules, The order of the configs in the list should be consistent with
The order of the configs in the list should be consistent with corresponding attentions in operation_order.
corresponding attentions in operation_order. If it is a dict, all of the attention modules in operation_order
If it is a dict, all of the attention modules in operation_order will be built with this config. Default: None.
will be built with this config. Default: None. ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): Configs for FFN, The order of the configs in the list should be
Configs for FFN, The order of the configs in the list should be consistent with corresponding ffn in operation_order.
consistent with corresponding ffn in operation_order. If it is a dict, all of the attention modules in operation_order
If it is a dict, all of the attention modules in operation_order will be built with this config.
will be built with this config. operation_order (tuple[str]): The execution order of operation
operation_order (tuple[str]): The execution order of operation in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). Support `prenorm` when you specifying first element as `norm`.
Support `prenorm` when you specifying first element as `norm`. Default:None.
Default:None. norm_cfg (dict): Config dict for normalization layer.
norm_cfg (dict): Config dict for normalization layer. Default: dict(type='LN').
Default: dict(type='LN'). init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None.
Default: None. batch_first (bool): Key, Query and Value are shape
batch_first (bool): Key, Query and Value are shape of (batch, n, embed_dim)
of (batch, n, embed_dim) or (n, batch, embed_dim). Default to False.
or (n, batch, embed_dim). Default to False. """
"""
def __init__(self,
def __init__(self, attn_cfgs=None,
attn_cfgs=None, ffn_cfgs=dict(
ffn_cfgs=dict( type='FFN',
type='FFN', embed_dims=256,
embed_dims=256, feedforward_channels=1024,
feedforward_channels=1024, num_fcs=2,
num_fcs=2, ffn_drop=0.,
ffn_drop=0., act_cfg=dict(type='ReLU', inplace=True),
act_cfg=dict(type='ReLU', inplace=True), ),
), operation_order=None,
operation_order=None, norm_cfg=dict(type='LN'),
norm_cfg=dict(type='LN'), init_cfg=None,
init_cfg=None, batch_first=True,
batch_first=True, **kwargs):
**kwargs):
deprecated_args = dict(
deprecated_args = dict( feedforward_channels='feedforward_channels',
feedforward_channels='feedforward_channels', ffn_dropout='ffn_drop',
ffn_dropout='ffn_drop', ffn_num_fcs='num_fcs')
ffn_num_fcs='num_fcs') for ori_name, new_name in deprecated_args.items():
for ori_name, new_name in deprecated_args.items(): if ori_name in kwargs:
if ori_name in kwargs: warnings.warn(
warnings.warn( f'The arguments `{ori_name}` in BaseTransformerLayer '
f'The arguments `{ori_name}` in BaseTransformerLayer ' f'has been deprecated, now you should set `{new_name}` '
f'has been deprecated, now you should set `{new_name}` ' f'and other FFN related arguments '
f'and other FFN related arguments ' f'to a dict named `ffn_cfgs`. ')
f'to a dict named `ffn_cfgs`. ') ffn_cfgs[new_name] = kwargs[ori_name]
ffn_cfgs[new_name] = kwargs[ori_name]
super(MyCustomBaseTransformerLayer, self).__init__(init_cfg)
super(MyCustomBaseTransformerLayer, self).__init__(init_cfg)
self.batch_first = batch_first
self.batch_first = batch_first
assert set(operation_order) & set(
assert set(operation_order) & set( ['self_attn', 'norm', 'ffn', 'cross_attn']) == \
['self_attn', 'norm', 'ffn', 'cross_attn']) == \ set(operation_order), f'The operation_order of' \
set(operation_order), f'The operation_order of' \ f' {self.__class__.__name__} should ' \
f' {self.__class__.__name__} should ' \ f'contains all four operation type ' \
f'contains all four operation type ' \ f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
num_attn = operation_order.count('self_attn') + operation_order.count(
num_attn = operation_order.count('self_attn') + operation_order.count( 'cross_attn')
'cross_attn') if isinstance(attn_cfgs, dict):
if isinstance(attn_cfgs, dict): attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)] else:
else: assert num_attn == len(attn_cfgs), f'The length ' \
assert num_attn == len(attn_cfgs), f'The length ' \ f'of attn_cfg {num_attn} is ' \
f'of attn_cfg {num_attn} is ' \ f'not consistent with the number of attention' \
f'not consistent with the number of attention' \ f'in operation_order {operation_order}.'
f'in operation_order {operation_order}.'
self.num_attn = num_attn
self.num_attn = num_attn self.operation_order = operation_order
self.operation_order = operation_order self.norm_cfg = norm_cfg
self.norm_cfg = norm_cfg self.pre_norm = operation_order[0] == 'norm'
self.pre_norm = operation_order[0] == 'norm' self.attentions = ModuleList()
self.attentions = ModuleList()
index = 0
index = 0 for operation_name in operation_order:
for operation_name in operation_order: if operation_name in ['self_attn', 'cross_attn']:
if operation_name in ['self_attn', 'cross_attn']: if 'batch_first' in attn_cfgs[index]:
if 'batch_first' in attn_cfgs[index]: assert self.batch_first == attn_cfgs[index]['batch_first']
assert self.batch_first == attn_cfgs[index]['batch_first'] else:
else: attn_cfgs[index]['batch_first'] = self.batch_first
attn_cfgs[index]['batch_first'] = self.batch_first attention = build_attention(attn_cfgs[index])
attention = build_attention(attn_cfgs[index]) # Some custom attentions used as `self_attn`
# Some custom attentions used as `self_attn` # or `cross_attn` can have different behavior.
# or `cross_attn` can have different behavior. attention.operation_name = operation_name
attention.operation_name = operation_name self.attentions.append(attention)
self.attentions.append(attention) index += 1
index += 1
self.embed_dims = self.attentions[0].embed_dims
self.embed_dims = self.attentions[0].embed_dims
self.ffns = ModuleList()
self.ffns = ModuleList() num_ffns = operation_order.count('ffn')
num_ffns = operation_order.count('ffn') if isinstance(ffn_cfgs, dict):
if isinstance(ffn_cfgs, dict): ffn_cfgs = ConfigDict(ffn_cfgs)
ffn_cfgs = ConfigDict(ffn_cfgs) if isinstance(ffn_cfgs, dict):
if isinstance(ffn_cfgs, dict): ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)] assert len(ffn_cfgs) == num_ffns
assert len(ffn_cfgs) == num_ffns for ffn_index in range(num_ffns):
for ffn_index in range(num_ffns): if 'embed_dims' not in ffn_cfgs[ffn_index]:
if 'embed_dims' not in ffn_cfgs[ffn_index]: ffn_cfgs['embed_dims'] = self.embed_dims
ffn_cfgs['embed_dims'] = self.embed_dims else:
else: assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
self.ffns.append(
self.ffns.append( build_feedforward_network(ffn_cfgs[ffn_index]))
build_feedforward_network(ffn_cfgs[ffn_index]))
self.norms = ModuleList()
self.norms = ModuleList() num_norms = operation_order.count('norm')
num_norms = operation_order.count('norm') for _ in range(num_norms):
for _ in range(num_norms): self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
def forward(self,
def forward(self, query,
query, key=None,
key=None, value=None,
value=None, query_pos=None,
query_pos=None, key_pos=None,
key_pos=None, attn_masks=None,
attn_masks=None, query_key_padding_mask=None,
query_key_padding_mask=None, key_padding_mask=None,
key_padding_mask=None, **kwargs):
**kwargs): """Forward function for `TransformerDecoderLayer`.
"""Forward function for `TransformerDecoderLayer`. **kwargs contains some specific arguments of attentions.
**kwargs contains some specific arguments of attentions. Args:
Args: query (Tensor): The input query with shape
query (Tensor): The input query with shape [num_queries, bs, embed_dims] if
[num_queries, bs, embed_dims] if self.batch_first is False, else
self.batch_first is False, else [bs, num_queries embed_dims].
[bs, num_queries embed_dims]. key (Tensor): The key tensor with shape [num_keys, bs,
key (Tensor): The key tensor with shape [num_keys, bs, embed_dims] if self.batch_first is False, else
embed_dims] if self.batch_first is False, else [bs, num_keys, embed_dims] .
[bs, num_keys, embed_dims] . value (Tensor): The value tensor with same shape as `key`.
value (Tensor): The value tensor with same shape as `key`. query_pos (Tensor): The positional encoding for `query`.
query_pos (Tensor): The positional encoding for `query`. Default: None.
Default: None. key_pos (Tensor): The positional encoding for `key`.
key_pos (Tensor): The positional encoding for `key`. Default: None.
Default: None. attn_masks (List[Tensor] | None): 2D Tensor used in
attn_masks (List[Tensor] | None): 2D Tensor used in calculation of corresponding attention. The length of
calculation of corresponding attention. The length of it should equal to the number of `attention` in
it should equal to the number of `attention` in `operation_order`. Default: None.
`operation_order`. Default: None. query_key_padding_mask (Tensor): ByteTensor for `query`, with
query_key_padding_mask (Tensor): ByteTensor for `query`, with shape [bs, num_queries]. Only used in `self_attn` layer.
shape [bs, num_queries]. Only used in `self_attn` layer. Defaults to None.
Defaults to None. key_padding_mask (Tensor): ByteTensor for `query`, with
key_padding_mask (Tensor): ByteTensor for `query`, with shape [bs, num_keys]. Default: None.
shape [bs, num_keys]. Default: None. Returns:
Returns: Tensor: forwarded results with shape [num_queries, bs, embed_dims].
Tensor: forwarded results with shape [num_queries, bs, embed_dims]. """
"""
norm_index = 0
norm_index = 0 attn_index = 0
attn_index = 0 ffn_index = 0
ffn_index = 0 identity = query
identity = query if attn_masks is None:
if attn_masks is None: attn_masks = [None for _ in range(self.num_attn)]
attn_masks = [None for _ in range(self.num_attn)] elif isinstance(attn_masks, torch.Tensor):
elif isinstance(attn_masks, torch.Tensor): attn_masks = [
attn_masks = [ copy.deepcopy(attn_masks) for _ in range(self.num_attn)
copy.deepcopy(attn_masks) for _ in range(self.num_attn) ]
] warnings.warn(f'Use same attn_mask in all attentions in '
warnings.warn(f'Use same attn_mask in all attentions in ' f'{self.__class__.__name__} ')
f'{self.__class__.__name__} ') else:
else: assert len(attn_masks) == self.num_attn, f'The length of ' \
assert len(attn_masks) == self.num_attn, f'The length of ' \ f'attn_masks {len(attn_masks)} must be equal ' \
f'attn_masks {len(attn_masks)} must be equal ' \ f'to the number of attention in ' \
f'to the number of attention in ' \ f'operation_order {self.num_attn}'
f'operation_order {self.num_attn}'
for layer in self.operation_order:
for layer in self.operation_order: if layer == 'self_attn':
if layer == 'self_attn': temp_key = temp_value = query
temp_key = temp_value = query query = self.attentions[attn_index](
query = self.attentions[attn_index]( query,
query, temp_key,
temp_key, temp_value,
temp_value, identity if self.pre_norm else None,
identity if self.pre_norm else None, query_pos=query_pos,
query_pos=query_pos, key_pos=query_pos,
key_pos=query_pos, attn_mask=attn_masks[attn_index],
attn_mask=attn_masks[attn_index], key_padding_mask=query_key_padding_mask,
key_padding_mask=query_key_padding_mask, **kwargs)
**kwargs) attn_index += 1
attn_index += 1 identity = query
identity = query
elif layer == 'norm':
elif layer == 'norm': query = self.norms[norm_index](query)
query = self.norms[norm_index](query) norm_index += 1
norm_index += 1
elif layer == 'cross_attn':
elif layer == 'cross_attn': query = self.attentions[attn_index](
query = self.attentions[attn_index]( query,
query, key,
key, value,
value, identity if self.pre_norm else None,
identity if self.pre_norm else None, query_pos=query_pos,
query_pos=query_pos, key_pos=key_pos,
key_pos=key_pos, attn_mask=attn_masks[attn_index],
attn_mask=attn_masks[attn_index], key_padding_mask=key_padding_mask,
key_padding_mask=key_padding_mask, **kwargs)
**kwargs) attn_index += 1
attn_index += 1 identity = query
identity = query
elif layer == 'ffn':
elif layer == 'ffn': query = self.ffns[ffn_index](
query = self.ffns[ffn_index]( query, identity if self.pre_norm else None)
query, identity if self.pre_norm else None) ffn_index += 1
ffn_index += 1
return query
return query
# --------------------------------------------- # ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
# --------------------------------------------- # ---------------------------------------------
# Modified by Zhiqi Li # Modified by Zhiqi Li
# --------------------------------------------- # ---------------------------------------------
from cmath import pi import math
from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch import warnings
import mmcv
import cv2 as cv import torch
import copy import torch.nn as nn
import warnings from mmcv.cnn import constant_init, xavier_init
from matplotlib import pyplot as plt from mmcv.cnn.bricks.registry import (ATTENTION, TRANSFORMER_LAYER,
import numpy as np TRANSFORMER_LAYER_SEQUENCE)
import torch from mmcv.cnn.bricks.transformer import (BaseTransformerLayer,
import torch.nn as nn TransformerLayerSequence)
import torch.nn.functional as F from mmcv.ops.multi_scale_deform_attn import \
from mmcv.cnn import xavier_init, constant_init multi_scale_deformable_attn_pytorch
from mmcv.cnn.bricks.registry import (ATTENTION, TRANSFORMER_LAYER, from mmcv.runner.base_module import BaseModule
TRANSFORMER_LAYER_SEQUENCE) from mmcv.utils import deprecated_api_warning, ext_loader
from mmcv.cnn.bricks.transformer import BaseTransformerLayer, TransformerLayerSequence
import math from .multi_scale_deformable_attn_function import \
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential MultiScaleDeformableAttnFunction_fp32
from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning,
to_2tuple) ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
from mmcv.utils import ext_loader
from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32, \
MultiScaleDeformableAttnFunction_fp16 def inverse_sigmoid(x, eps=1e-5):
"""Inverse function of sigmoid.
ext_module = ext_loader.load_ext( Args:
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward']) x (Tensor): The tensor to do the
inverse.
eps (float): EPS avoid numerical
def inverse_sigmoid(x, eps=1e-5): overflow. Defaults 1e-5.
"""Inverse function of sigmoid. Returns:
Args: Tensor: The x has passed the inverse
x (Tensor): The tensor to do the function of sigmoid, has same
inverse. shape with input.
eps (float): EPS avoid numerical """
overflow. Defaults 1e-5. x = x.clamp(min=0, max=1)
Returns: x1 = x.clamp(min=eps)
Tensor: The x has passed the inverse x2 = (1 - x).clamp(min=eps)
function of sigmoid, has same return torch.log(x1 / x2)
shape with input.
"""
x = x.clamp(min=0, max=1) @TRANSFORMER_LAYER_SEQUENCE.register_module()
x1 = x.clamp(min=eps) class LaneDetectionTransformerDecoder(TransformerLayerSequence):
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2) def __init__(self, *args, return_intermediate=False, **kwargs):
super(LaneDetectionTransformerDecoder, self).__init__(*args, **kwargs)
self.return_intermediate = return_intermediate
@TRANSFORMER_LAYER_SEQUENCE.register_module() self.fp16_enabled = False
class LaneDetectionTransformerDecoder(TransformerLayerSequence):
def forward(self,
def __init__(self, *args, return_intermediate=False, **kwargs): query,
super(LaneDetectionTransformerDecoder, self).__init__(*args, **kwargs) *args,
self.return_intermediate = return_intermediate reference_points=None,
self.fp16_enabled = False reg_branches=None,
key_padding_mask=None,
def forward(self, **kwargs):
query, """Forward function for `Detr3DTransformerDecoder`.
*args, Args:
reference_points=None, query (Tensor): Input query with shape
reg_branches=None, `(num_query, bs, embed_dims)`.
key_padding_mask=None, reference_points (Tensor): The reference
**kwargs): points of offset. has shape
"""Forward function for `Detr3DTransformerDecoder`. (bs, num_query, 4) when as_two_stage,
Args: otherwise has shape ((bs, num_query, 2).
query (Tensor): Input query with shape reg_branch: (obj:`nn.ModuleList`): Used for
`(num_query, bs, embed_dims)`. refining the regression results. Only would
reference_points (Tensor): The reference be passed when with_box_refine is True,
points of offset. has shape otherwise would be passed a `None`.
(bs, num_query, 4) when as_two_stage, Returns:
otherwise has shape ((bs, num_query, 2). Tensor: Results with shape [1, num_query, bs, embed_dims] when
reg_branch: (obj:`nn.ModuleList`): Used for return_intermediate is `False`, otherwise it has shape
refining the regression results. Only would [num_layers, num_query, bs, embed_dims].
be passed when with_box_refine is True, """
otherwise would be passed a `None`.
Returns: output = query
Tensor: Results with shape [1, num_query, bs, embed_dims] when intermediate = []
return_intermediate is `False`, otherwise it has shape intermediate_reference_points = []
[num_layers, num_query, bs, embed_dims]. for lid, layer in enumerate(self.layers):
""" reference_points_input = reference_points[..., :2].unsqueeze(
2) # BS NUM_QUERY NUM_LEVEL 2
output = query output = layer(
intermediate = [] output,
intermediate_reference_points = [] *args,
for lid, layer in enumerate(self.layers): reference_points=reference_points_input,
reference_points_input = reference_points[..., :2].unsqueeze( key_padding_mask=key_padding_mask,
2) # BS NUM_QUERY NUM_LEVEL 2 **kwargs)
output = layer( output = output.permute(1, 0, 2)
output,
*args, if reg_branches is not None:
reference_points=reference_points_input, tmp = reg_branches[lid](output)
key_padding_mask=key_padding_mask,
**kwargs) assert reference_points.shape[-1] == 3
output = output.permute(1, 0, 2)
new_reference_points = torch.zeros_like(reference_points)
if reg_branches is not None: ref_center = (tmp[..., :3] + tmp[..., -3:]) / 2
new_reference_points = ref_center + inverse_sigmoid(reference_points)
tmp = reg_branches[lid](output) new_reference_points = new_reference_points.sigmoid()
assert reference_points.shape[-1] == 3 reference_points = new_reference_points.detach()
new_reference_points = torch.zeros_like(reference_points) output = output.permute(1, 0, 2)
ref_center = (tmp[..., :3] + tmp[..., -3:]) / 2 if self.return_intermediate:
new_reference_points = ref_center + inverse_sigmoid(reference_points) intermediate.append(output)
new_reference_points = new_reference_points.sigmoid() intermediate_reference_points.append(reference_points)
reference_points = new_reference_points.detach() if self.return_intermediate:
return torch.stack(intermediate), torch.stack(
output = output.permute(1, 0, 2) intermediate_reference_points)
if self.return_intermediate:
intermediate.append(output) return output, reference_points
intermediate_reference_points.append(reference_points)
if self.return_intermediate: @TRANSFORMER_LAYER.register_module()
return torch.stack(intermediate), torch.stack( class CustomDetrTransformerDecoderLayer(BaseTransformerLayer):
intermediate_reference_points) """Implements decoder layer in DETR transformer.
return output, reference_points Args:
attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
Configs for self_attention or cross_attention, the order
@TRANSFORMER_LAYER.register_module() should be consistent with it in `operation_order`. If it is
class CustomDetrTransformerDecoderLayer(BaseTransformerLayer): a dict, it would be expand to the number of attention in
"""Implements decoder layer in DETR transformer. `operation_order`.
feedforward_channels (int): The hidden dimension for FFNs.
Args: ffn_dropout (float): Probability of an element to be zeroed
attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )): in ffn. Default 0.0.
Configs for self_attention or cross_attention, the order operation_order (tuple[str]): The execution order of operation
should be consistent with it in `operation_order`. If it is in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
a dict, it would be expand to the number of attention in Default:None
`operation_order`. act_cfg (dict): The activation config for FFNs. Default: `LN`
feedforward_channels (int): The hidden dimension for FFNs. norm_cfg (dict): Config dict for normalization layer.
ffn_dropout (float): Probability of an element to be zeroed Default: `LN`.
in ffn. Default 0.0. ffn_num_fcs (int): The number of fully-connected layers in FFNs.
operation_order (tuple[str]): The execution order of operation Default:2.
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). """
Default:None
act_cfg (dict): The activation config for FFNs. Default: `LN` def __init__(self,
norm_cfg (dict): Config dict for normalization layer. attn_cfgs,
Default: `LN`. ffn_cfgs,
ffn_num_fcs (int): The number of fully-connected layers in FFNs. operation_order=None,
Default:2. norm_cfg=dict(type='LN'),
""" **kwargs):
super(CustomDetrTransformerDecoderLayer, self).__init__(
def __init__(self, attn_cfgs=attn_cfgs,
attn_cfgs, ffn_cfgs=ffn_cfgs,
ffn_cfgs, operation_order=operation_order,
operation_order=None, norm_cfg=norm_cfg,
norm_cfg=dict(type='LN'), **kwargs)
**kwargs): assert len(operation_order) == 6
super(CustomDetrTransformerDecoderLayer, self).__init__( assert set(operation_order) == set(
attn_cfgs=attn_cfgs, ['self_attn', 'norm', 'cross_attn', 'ffn'])
ffn_cfgs=ffn_cfgs,
operation_order=operation_order,
norm_cfg=norm_cfg, @ATTENTION.register_module()
**kwargs) class CustomMSDeformableAttention(BaseModule):
assert len(operation_order) == 6 """An attention module used in Deformable-Detr.
assert set(operation_order) == set(
['self_attn', 'norm', 'cross_attn', 'ffn']) `Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
@ATTENTION.register_module() Args:
class CustomMSDeformableAttention(BaseModule): embed_dims (int): The embedding dimension of Attention.
"""An attention module used in Deformable-Detr. Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
`Deformable DETR: Deformable Transformers for End-to-End Object Detection. num_levels (int): The number of feature map used in
<https://arxiv.org/pdf/2010.04159.pdf>`_. Attention. Default: 4.
num_points (int): The number of sampling points for
Args: each query in each head. Default: 4.
embed_dims (int): The embedding dimension of Attention. im2col_step (int): The step used in image_to_column.
Default: 256. Default: 64.
num_heads (int): Parallel attention heads. Default: 64. dropout (float): A Dropout layer on `inp_identity`.
num_levels (int): The number of feature map used in Default: 0.1.
Attention. Default: 4. batch_first (bool): Key, Query and Value are shape of
num_points (int): The number of sampling points for (batch, n, embed_dim)
each query in each head. Default: 4. or (n, batch, embed_dim). Default to False.
im2col_step (int): The step used in image_to_column. norm_cfg (dict): Config dict for normalization layer.
Default: 64. Default: None.
dropout (float): A Dropout layer on `inp_identity`. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: 0.1. Default: None.
batch_first (bool): Key, Query and Value are shape of """
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False. def __init__(self,
norm_cfg (dict): Config dict for normalization layer. embed_dims=256,
Default: None. num_heads=8,
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. num_levels=4,
Default: None. num_points=4,
""" im2col_step=64,
dropout=0.1,
def __init__(self, batch_first=False,
embed_dims=256, norm_cfg=None,
num_heads=8, init_cfg=None):
num_levels=4, super().__init__(init_cfg)
num_points=4, if embed_dims % num_heads != 0:
im2col_step=64, raise ValueError(f'embed_dims must be divisible by num_heads, '
dropout=0.1, f'but got {embed_dims} and {num_heads}')
batch_first=False, dim_per_head = embed_dims // num_heads
norm_cfg=None, self.norm_cfg = norm_cfg
init_cfg=None): self.dropout = nn.Dropout(dropout)
super().__init__(init_cfg) self.batch_first = batch_first
if embed_dims % num_heads != 0: self.fp16_enabled = False
raise ValueError(f'embed_dims must be divisible by num_heads, '
f'but got {embed_dims} and {num_heads}') # you'd better set dim_per_head to a power of 2
dim_per_head = embed_dims // num_heads # which is more efficient in the CUDA implementation
self.norm_cfg = norm_cfg def _is_power_of_2(n):
self.dropout = nn.Dropout(dropout) if (not isinstance(n, int)) or (n < 0):
self.batch_first = batch_first raise ValueError(
self.fp16_enabled = False 'invalid input for _is_power_of_2: {} (type: {})'.format(
n, type(n)))
# you'd better set dim_per_head to a power of 2 return (n & (n - 1) == 0) and n != 0
# which is more efficient in the CUDA implementation
def _is_power_of_2(n): if not _is_power_of_2(dim_per_head):
if (not isinstance(n, int)) or (n < 0): warnings.warn(
raise ValueError( "You'd better set embed_dims in "
'invalid input for _is_power_of_2: {} (type: {})'.format( 'MultiScaleDeformAttention to make '
n, type(n))) 'the dimension of each attention head a power of 2 '
return (n & (n - 1) == 0) and n != 0 'which is more efficient in our CUDA implementation.')
if not _is_power_of_2(dim_per_head): self.im2col_step = im2col_step
warnings.warn( self.embed_dims = embed_dims
"You'd better set embed_dims in " self.num_levels = num_levels
'MultiScaleDeformAttention to make ' self.num_heads = num_heads
'the dimension of each attention head a power of 2 ' self.num_points = num_points
'which is more efficient in our CUDA implementation.') self.sampling_offsets = nn.Linear(
embed_dims, num_heads * num_levels * num_points * 2)
self.im2col_step = im2col_step self.attention_weights = nn.Linear(embed_dims,
self.embed_dims = embed_dims num_heads * num_levels * num_points)
self.num_levels = num_levels self.value_proj = nn.Linear(embed_dims, embed_dims)
self.num_heads = num_heads self.output_proj = nn.Linear(embed_dims, embed_dims)
self.num_points = num_points self.init_weights()
self.sampling_offsets = nn.Linear(
embed_dims, num_heads * num_levels * num_points * 2) def init_weights(self):
self.attention_weights = nn.Linear(embed_dims, """Default initialization for Parameters of Module."""
num_heads * num_levels * num_points) constant_init(self.sampling_offsets, 0.)
self.value_proj = nn.Linear(embed_dims, embed_dims) thetas = torch.arange(
self.output_proj = nn.Linear(embed_dims, embed_dims) self.num_heads,
self.init_weights() dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
def init_weights(self): grid_init = (grid_init /
"""Default initialization for Parameters of Module.""" grid_init.abs().max(-1, keepdim=True)[0]).view(
constant_init(self.sampling_offsets, 0.) self.num_heads, 1, 1,
thetas = torch.arange( 2).repeat(1, self.num_levels, self.num_points, 1)
self.num_heads, for i in range(self.num_points):
dtype=torch.float32) * (2.0 * math.pi / self.num_heads) grid_init[:, :, i, :] *= i + 1
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init / self.sampling_offsets.bias.data = grid_init.view(-1)
grid_init.abs().max(-1, keepdim=True)[0]).view( constant_init(self.attention_weights, val=0., bias=0.)
self.num_heads, 1, 1, xavier_init(self.value_proj, distribution='uniform', bias=0.)
2).repeat(1, self.num_levels, self.num_points, 1) xavier_init(self.output_proj, distribution='uniform', bias=0.)
for i in range(self.num_points): self._is_init = True
grid_init[:, :, i, :] *= i + 1
@deprecated_api_warning({'residual': 'identity'},
self.sampling_offsets.bias.data = grid_init.view(-1) cls_name='MultiScaleDeformableAttention')
constant_init(self.attention_weights, val=0., bias=0.) def forward(self,
xavier_init(self.value_proj, distribution='uniform', bias=0.) query,
xavier_init(self.output_proj, distribution='uniform', bias=0.) key=None,
self._is_init = True value=None,
identity=None,
@deprecated_api_warning({'residual': 'identity'}, query_pos=None,
cls_name='MultiScaleDeformableAttention') key_padding_mask=None,
def forward(self, reference_points=None,
query, spatial_shapes=None,
key=None, level_start_index=None,
value=None, flag='decoder',
identity=None, **kwargs):
query_pos=None, """Forward Function of MultiScaleDeformAttention.
key_padding_mask=None,
reference_points=None, Args:
spatial_shapes=None, query (Tensor): Query of Transformer with shape
level_start_index=None, (num_query, bs, embed_dims).
flag='decoder', key (Tensor): The key tensor with shape
**kwargs): `(num_key, bs, embed_dims)`.
"""Forward Function of MultiScaleDeformAttention. value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
Args: identity (Tensor): The tensor used for addition, with the
query (Tensor): Query of Transformer with shape same shape as `query`. Default None. If None,
(num_query, bs, embed_dims). `query` will be used.
key (Tensor): The key tensor with shape query_pos (Tensor): The positional encoding for `query`.
`(num_key, bs, embed_dims)`. Default: None.
value (Tensor): The value tensor with shape key_pos (Tensor): The positional encoding for `key`. Default
`(num_key, bs, embed_dims)`. None.
identity (Tensor): The tensor used for addition, with the reference_points (Tensor): The normalized reference
same shape as `query`. Default None. If None, points with shape (bs, num_query, num_levels, 2),
`query` will be used. all elements is range in [0, 1], top-left (0,0),
query_pos (Tensor): The positional encoding for `query`. bottom-right (1, 1), including padding area.
Default: None. or (N, Length_{query}, num_levels, 4), add
key_pos (Tensor): The positional encoding for `key`. Default additional two dimensions is (w, h) to
None. form reference boxes.
reference_points (Tensor): The normalized reference key_padding_mask (Tensor): ByteTensor for `query`, with
points with shape (bs, num_query, num_levels, 2), shape [bs, num_key].
all elements is range in [0, 1], top-left (0,0), spatial_shapes (Tensor): Spatial shape of features in
bottom-right (1, 1), including padding area. different levels. With shape (num_levels, 2),
or (N, Length_{query}, num_levels, 4), add last dimension represents (h, w).
additional two dimensions is (w, h) to level_start_index (Tensor): The start index of each level.
form reference boxes. A tensor has shape ``(num_levels, )`` and can be represented
key_padding_mask (Tensor): ByteTensor for `query`, with as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in Returns:
different levels. With shape (num_levels, 2), Tensor: forwarded results with shape [num_query, bs, embed_dims].
last dimension represents (h, w). """
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented if value is None:
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. value = query
Returns: if identity is None:
Tensor: forwarded results with shape [num_query, bs, embed_dims]. identity = query
""" if query_pos is not None:
query = query + query_pos
if value is None: if not self.batch_first:
value = query # change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
if identity is None: value = value.permute(1, 0, 2)
identity = query
if query_pos is not None: bs, num_query, _ = query.shape
query = query + query_pos bs, num_value, _ = value.shape
if not self.batch_first: assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2) value = self.value_proj(value)
value = value.permute(1, 0, 2) if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0)
bs, num_query, _ = query.shape value = value.view(bs, num_value, self.num_heads, -1)
bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
value = self.value_proj(value) attention_weights = self.attention_weights(query).view(
if key_padding_mask is not None: bs, num_query, self.num_heads, self.num_levels * self.num_points)
value = value.masked_fill(key_padding_mask[..., None], 0.0) attention_weights = attention_weights.softmax(-1)
value = value.view(bs, num_value, self.num_heads, -1)
attention_weights = attention_weights.view(bs, num_query,
sampling_offsets = self.sampling_offsets(query).view( self.num_heads,
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2) self.num_levels,
attention_weights = self.attention_weights(query).view( self.num_points)
bs, num_query, self.num_heads, self.num_levels * self.num_points) if reference_points.shape[-1] == 2:
attention_weights = attention_weights.softmax(-1) offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
attention_weights = attention_weights.view(bs, num_query, sampling_locations = reference_points[:, :, None, :, None, :] \
self.num_heads, + sampling_offsets \
self.num_levels, / offset_normalizer[None, None, None, :, None, :]
self.num_points) elif reference_points.shape[-1] == 4:
if reference_points.shape[-1] == 2: sampling_locations = reference_points[:, :, None, :, None, :2] \
offset_normalizer = torch.stack( + sampling_offsets / self.num_points \
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) * reference_points[:, :, None, :, None, 2:] \
sampling_locations = reference_points[:, :, None, :, None, :] \ * 0.5
+ sampling_offsets \ else:
/ offset_normalizer[None, None, None, :, None, :] raise ValueError(
elif reference_points.shape[-1] == 4: f'Last dim of reference_points must be'
sampling_locations = reference_points[:, :, None, :, None, :2] \ f' 2 or 4, but get {reference_points.shape[-1]} instead.')
+ sampling_offsets / self.num_points \ if torch.cuda.is_available() and value.is_cuda:
* reference_points[:, :, None, :, None, 2:] \
* 0.5 # using fp16 deformable attention is unstable because it performs many sum operations
else: if value.dtype == torch.float16:
raise ValueError( MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
f'Last dim of reference_points must be' else:
f' 2 or 4, but get {reference_points.shape[-1]} instead.') MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
if torch.cuda.is_available() and value.is_cuda: output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
# using fp16 deformable attention is unstable because it performs many sum operations attention_weights, self.im2col_step)
if value.dtype == torch.float16: else:
MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32 output = multi_scale_deformable_attn_pytorch(
else: value, spatial_shapes, sampling_locations, attention_weights)
MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
output = MultiScaleDeformableAttnFunction.apply( output = self.output_proj(output)
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step) if not self.batch_first:
else: # (num_query, bs ,embed_dims)
output = multi_scale_deformable_attn_pytorch( output = output.permute(1, 0, 2)
value, spatial_shapes, sampling_locations, attention_weights)
return self.dropout(output) + identity
output = self.output_proj(output)
if not self.batch_first:
# (num_query, bs ,embed_dims)
output = output.permute(1, 0, 2)
return self.dropout(output) + identity
# ---------------------------------------------
# --------------------------------------------- # Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved. # ---------------------------------------------
# --------------------------------------------- # Modified by Zhiqi Li
# Modified by Zhiqi Li # ---------------------------------------------
# ---------------------------------------------
import copy
from .custom_base_transformer_layer import MyCustomBaseTransformerLayer import warnings
import copy
import warnings import numpy as np
from mmcv.cnn.bricks.registry import (ATTENTION, import torch
TRANSFORMER_LAYER, from mmcv.cnn.bricks.registry import (TRANSFORMER_LAYER,
TRANSFORMER_LAYER_SEQUENCE) TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import TransformerLayerSequence from mmcv.cnn.bricks.transformer import TransformerLayerSequence
from mmcv.runner import force_fp32, auto_fp16 from mmcv.runner import auto_fp16, force_fp32
import numpy as np from mmcv.utils import TORCH_VERSION, digit_version, ext_loader
import torch
import cv2 as cv from .custom_base_transformer_layer import MyCustomBaseTransformerLayer
import mmcv
from mmcv.utils import TORCH_VERSION, digit_version ext_module = ext_loader.load_ext(
from mmcv.utils import ext_loader '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class BEVFormerEncoder(TransformerLayerSequence):
@TRANSFORMER_LAYER_SEQUENCE.register_module() """
class BEVFormerEncoder(TransformerLayerSequence): Attention with both self and cross
Implements the decoder in DETR transformer.
""" Args:
Attention with both self and cross return_intermediate (bool): Whether to return intermediate outputs.
Implements the decoder in DETR transformer. coder_norm_cfg (dict): Config of last normalization layer. Default:
Args: `LN`.
return_intermediate (bool): Whether to return intermediate outputs. """
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`. def __init__(self, *args, pc_range=None, num_points_in_pillar=4, return_intermediate=False, dataset_type='nuscenes',
""" **kwargs):
def __init__(self, *args, pc_range=None, num_points_in_pillar=4, return_intermediate=False, dataset_type='nuscenes', super(BEVFormerEncoder, self).__init__(*args, **kwargs)
**kwargs): self.return_intermediate = return_intermediate
super(BEVFormerEncoder, self).__init__(*args, **kwargs) self.num_points_in_pillar = num_points_in_pillar
self.return_intermediate = return_intermediate self.pc_range = pc_range
self.fp16_enabled = False
self.num_points_in_pillar = num_points_in_pillar
self.pc_range = pc_range @staticmethod
self.fp16_enabled = False def get_reference_points(H, W, Z=8, num_points_in_pillar=4, dim='3d', bs=1, device='cuda', dtype=torch.float):
"""Get the reference points used in SCA and TSA.
@staticmethod Args:
def get_reference_points(H, W, Z=8, num_points_in_pillar=4, dim='3d', bs=1, device='cuda', dtype=torch.float): H, W: spatial shape of bev.
"""Get the reference points used in SCA and TSA. Z: hight of pillar.
Args: D: sample D points uniformly from each pillar.
H, W: spatial shape of bev. device (obj:`device`): The device where
Z: hight of pillar. reference_points should be.
D: sample D points uniformly from each pillar. Returns:
device (obj:`device`): The device where Tensor: reference points used in decoder, has \
reference_points should be. shape (bs, num_keys, num_levels, 2).
Returns: """
Tensor: reference points used in decoder, has \
shape (bs, num_keys, num_levels, 2). # reference points in 3D space, used in spatial cross-attention (SCA)
""" if dim == '3d':
zs = torch.linspace(0.5, Z - 0.5, num_points_in_pillar, dtype=dtype,
# reference points in 3D space, used in spatial cross-attention (SCA) device=device).view(-1, 1, 1).expand(num_points_in_pillar, H, W) / Z
if dim == '3d': xs = torch.linspace(0.5, W - 0.5, W, dtype=dtype,
zs = torch.linspace(0.5, Z - 0.5, num_points_in_pillar, dtype=dtype, device=device).view(1, 1, W).expand(num_points_in_pillar, H, W) / W
device=device).view(-1, 1, 1).expand(num_points_in_pillar, H, W) / Z ys = torch.linspace(0.5, H - 0.5, H, dtype=dtype,
xs = torch.linspace(0.5, W - 0.5, W, dtype=dtype, device=device).view(1, H, 1).expand(num_points_in_pillar, H, W) / H
device=device).view(1, 1, W).expand(num_points_in_pillar, H, W) / W ref_3d = torch.stack((xs, ys, zs), -1)
ys = torch.linspace(0.5, H - 0.5, H, dtype=dtype, ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1)
device=device).view(1, H, 1).expand(num_points_in_pillar, H, W) / H ref_3d = ref_3d[None].repeat(bs, 1, 1, 1)
ref_3d = torch.stack((xs, ys, zs), -1) return ref_3d
ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1)
ref_3d = ref_3d[None].repeat(bs, 1, 1, 1) # reference points on 2D bev plane, used in temporal self-attention (TSA).
return ref_3d elif dim == '2d':
ref_y, ref_x = torch.meshgrid(
# reference points on 2D bev plane, used in temporal self-attention (TSA). torch.linspace(
elif dim == '2d': 0.5, H - 0.5, H, dtype=dtype, device=device),
ref_y, ref_x = torch.meshgrid( torch.linspace(
torch.linspace( 0.5, W - 0.5, W, dtype=dtype, device=device)
0.5, H - 0.5, H, dtype=dtype, device=device), )
torch.linspace( ref_y = ref_y.reshape(-1)[None] / H
0.5, W - 0.5, W, dtype=dtype, device=device) ref_x = ref_x.reshape(-1)[None] / W
) ref_2d = torch.stack((ref_x, ref_y), -1)
ref_y = ref_y.reshape(-1)[None] / H ref_2d = ref_2d.repeat(bs, 1, 1).unsqueeze(2)
ref_x = ref_x.reshape(-1)[None] / W return ref_2d
ref_2d = torch.stack((ref_x, ref_y), -1)
ref_2d = ref_2d.repeat(bs, 1, 1).unsqueeze(2) # This function must use fp32!!!
return ref_2d @force_fp32(apply_to=('reference_points', 'img_metas'))
def point_sampling(self, reference_points, pc_range, img_metas):
# This function must use fp32!!!
@force_fp32(apply_to=('reference_points', 'img_metas')) lidar2img = []
def point_sampling(self, reference_points, pc_range, img_metas): for img_meta in img_metas:
lidar2img.append(img_meta['lidar2img'])
lidar2img = [] lidar2img = np.asarray(lidar2img)
for img_meta in img_metas: lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4)
lidar2img.append(img_meta['lidar2img']) reference_points = reference_points.clone()
lidar2img = np.asarray(lidar2img)
lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4) reference_points[..., 0:1] = reference_points[..., 0:1] * \
reference_points = reference_points.clone() (pc_range[3] - pc_range[0]) + pc_range[0]
reference_points[..., 1:2] = reference_points[..., 1:2] * \
reference_points[..., 0:1] = reference_points[..., 0:1] * \ (pc_range[4] - pc_range[1]) + pc_range[1]
(pc_range[3] - pc_range[0]) + pc_range[0] reference_points[..., 2:3] = reference_points[..., 2:3] * \
reference_points[..., 1:2] = reference_points[..., 1:2] * \ (pc_range[5] - pc_range[2]) + pc_range[2]
(pc_range[4] - pc_range[1]) + pc_range[1]
reference_points[..., 2:3] = reference_points[..., 2:3] * \ reference_points = torch.cat(
(pc_range[5] - pc_range[2]) + pc_range[2] (reference_points, torch.ones_like(reference_points[..., :1])), -1)
reference_points = torch.cat( reference_points = reference_points.permute(1, 0, 2, 3)
(reference_points, torch.ones_like(reference_points[..., :1])), -1) D, B, num_query = reference_points.size()[:3]
num_cam = lidar2img.size(1)
reference_points = reference_points.permute(1, 0, 2, 3)
D, B, num_query = reference_points.size()[:3] reference_points = reference_points.view(
num_cam = lidar2img.size(1) D, B, 1, num_query, 4).repeat(1, 1, num_cam, 1, 1).unsqueeze(-1)
reference_points = reference_points.view( lidar2img = lidar2img.view(
D, B, 1, num_query, 4).repeat(1, 1, num_cam, 1, 1).unsqueeze(-1) 1, B, num_cam, 1, 4, 4).repeat(D, 1, 1, num_query, 1, 1)
lidar2img = lidar2img.view( reference_points_cam = torch.matmul(lidar2img.to(torch.float32),
1, B, num_cam, 1, 4, 4).repeat(D, 1, 1, num_query, 1, 1) reference_points.to(torch.float32)).squeeze(-1)
eps = 1e-5
reference_points_cam = torch.matmul(lidar2img.to(torch.float32),
reference_points.to(torch.float32)).squeeze(-1) bev_mask = (reference_points_cam[..., 2:3] > eps)
eps = 1e-5 reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3]) * eps)
bev_mask = (reference_points_cam[..., 2:3] > eps)
reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum( reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1]
reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3]) * eps) reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]
reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1] bev_mask = (bev_mask & (reference_points_cam[..., 1:2] > 0.0)
reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0] & (reference_points_cam[..., 1:2] < 1.0)
& (reference_points_cam[..., 0:1] < 1.0)
bev_mask = (bev_mask & (reference_points_cam[..., 1:2] > 0.0) & (reference_points_cam[..., 0:1] > 0.0))
& (reference_points_cam[..., 1:2] < 1.0) if digit_version(TORCH_VERSION) >= digit_version('1.8'):
& (reference_points_cam[..., 0:1] < 1.0) bev_mask = torch.nan_to_num(bev_mask)
& (reference_points_cam[..., 0:1] > 0.0)) else:
if digit_version(TORCH_VERSION) >= digit_version('1.8'): bev_mask = bev_mask.new_tensor(
bev_mask = torch.nan_to_num(bev_mask) np.nan_to_num(bev_mask.cpu().numpy()))
else:
bev_mask = bev_mask.new_tensor( reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4)
np.nan_to_num(bev_mask.cpu().numpy())) bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1)
reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4) return reference_points_cam, bev_mask
bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1)
@auto_fp16()
return reference_points_cam, bev_mask def forward(self,
bev_query,
@auto_fp16() key,
def forward(self, value,
bev_query, *args,
key, bev_h=None,
value, bev_w=None,
*args, bev_pos=None,
bev_h=None, spatial_shapes=None,
bev_w=None, level_start_index=None,
bev_pos=None, valid_ratios=None,
spatial_shapes=None, prev_bev=None,
level_start_index=None, shift=0.,
valid_ratios=None, **kwargs):
prev_bev=None, """Forward function for `TransformerDecoder`.
shift=0., Args:
**kwargs): bev_query (Tensor): Input BEV query with shape
"""Forward function for `TransformerDecoder`. `(num_query, bs, embed_dims)`.
Args: key & value (Tensor): Input multi-cameta features with shape
bev_query (Tensor): Input BEV query with shape (num_cam, num_value, bs, embed_dims)
`(num_query, bs, embed_dims)`. reference_points (Tensor): The reference
key & value (Tensor): Input multi-cameta features with shape points of offset. has shape
(num_cam, num_value, bs, embed_dims) (bs, num_query, 4) when as_two_stage,
reference_points (Tensor): The reference otherwise has shape ((bs, num_query, 2).
points of offset. has shape valid_ratios (Tensor): The radios of valid
(bs, num_query, 4) when as_two_stage, points on the feature map, has shape
otherwise has shape ((bs, num_query, 2). (bs, num_levels, 2)
valid_ratios (Tensor): The radios of valid Returns:
points on the feature map, has shape Tensor: Results with shape [1, num_query, bs, embed_dims] when
(bs, num_levels, 2) return_intermediate is `False`, otherwise it has shape
Returns: [num_layers, num_query, bs, embed_dims].
Tensor: Results with shape [1, num_query, bs, embed_dims] when """
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims]. output = bev_query
""" intermediate = []
output = bev_query ref_3d = self.get_reference_points(
intermediate = [] bev_h, bev_w, self.pc_range[5] - self.pc_range[2], self.num_points_in_pillar, dim='3d',
bs=bev_query.size(1), device=bev_query.device, dtype=bev_query.dtype)
ref_3d = self.get_reference_points( ref_2d = self.get_reference_points(
bev_h, bev_w, self.pc_range[5]-self.pc_range[2], self.num_points_in_pillar, dim='3d', bs=bev_query.size(1), device=bev_query.device, dtype=bev_query.dtype) bev_h, bev_w, dim='2d', bs=bev_query.size(1), device=bev_query.device, dtype=bev_query.dtype)
ref_2d = self.get_reference_points(
bev_h, bev_w, dim='2d', bs=bev_query.size(1), device=bev_query.device, dtype=bev_query.dtype) reference_points_cam, bev_mask = self.point_sampling(
ref_3d, self.pc_range, kwargs['img_metas'])
reference_points_cam, bev_mask = self.point_sampling(
ref_3d, self.pc_range, kwargs['img_metas']) # bug: this code should be 'shift_ref_2d = ref_2d.clone()', we keep this bug for reproducing our results in paper.
shift_ref_2d = ref_2d.clone() # .clone()
# bug: this code should be 'shift_ref_2d = ref_2d.clone()', we keep this bug for reproducing our results in paper. shift_ref_2d += shift[:, None, None, :]
shift_ref_2d = ref_2d.clone() # .clone()
shift_ref_2d += shift[:, None, None, :] # (num_query, bs, embed_dims) -> (bs, num_query, embed_dims)
bev_query = bev_query.permute(1, 0, 2)
# (num_query, bs, embed_dims) -> (bs, num_query, embed_dims) bev_pos = bev_pos.permute(1, 0, 2)
bev_query = bev_query.permute(1, 0, 2) bs, len_bev, num_bev_level, _ = ref_2d.shape
bev_pos = bev_pos.permute(1, 0, 2) if prev_bev is not None:
bs, len_bev, num_bev_level, _ = ref_2d.shape prev_bev = prev_bev.permute(1, 0, 2)
if prev_bev is not None: prev_bev = torch.stack(
prev_bev = prev_bev.permute(1, 0, 2) [prev_bev, bev_query], 1).reshape(bs * 2, len_bev, -1)
prev_bev = torch.stack( hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape(
[prev_bev, bev_query], 1).reshape(bs*2, len_bev, -1) bs * 2, len_bev, num_bev_level, 2)
hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape( else:
bs*2, len_bev, num_bev_level, 2) hybird_ref_2d = torch.stack([ref_2d, ref_2d], 1).reshape(
else: bs * 2, len_bev, num_bev_level, 2)
hybird_ref_2d = torch.stack([ref_2d, ref_2d], 1).reshape(
bs*2, len_bev, num_bev_level, 2) for lid, layer in enumerate(self.layers):
output = layer(
for lid, layer in enumerate(self.layers): bev_query,
output = layer( key,
bev_query, value,
key, *args,
value, bev_pos=bev_pos,
*args, ref_2d=hybird_ref_2d,
bev_pos=bev_pos, ref_3d=ref_3d,
ref_2d=hybird_ref_2d, bev_h=bev_h,
ref_3d=ref_3d, bev_w=bev_w,
bev_h=bev_h, spatial_shapes=spatial_shapes,
bev_w=bev_w, level_start_index=level_start_index,
spatial_shapes=spatial_shapes, reference_points_cam=reference_points_cam,
level_start_index=level_start_index, bev_mask=bev_mask,
reference_points_cam=reference_points_cam, prev_bev=prev_bev,
bev_mask=bev_mask, **kwargs)
prev_bev=prev_bev,
**kwargs) bev_query = output
if self.return_intermediate:
bev_query = output intermediate.append(output)
if self.return_intermediate:
intermediate.append(output) if self.return_intermediate:
return torch.stack(intermediate)
if self.return_intermediate:
return torch.stack(intermediate) return output
return output
@TRANSFORMER_LAYER.register_module()
class BEVFormerLayer(MyCustomBaseTransformerLayer):
@TRANSFORMER_LAYER.register_module() """Implements decoder layer in DETR transformer.
class BEVFormerLayer(MyCustomBaseTransformerLayer): Args:
"""Implements decoder layer in DETR transformer. attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
Args: Configs for self_attention or cross_attention, the order
attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )): should be consistent with it in `operation_order`. If it is
Configs for self_attention or cross_attention, the order a dict, it would be expand to the number of attention in
should be consistent with it in `operation_order`. If it is `operation_order`.
a dict, it would be expand to the number of attention in feedforward_channels (int): The hidden dimension for FFNs.
`operation_order`. ffn_dropout (float): Probability of an element to be zeroed
feedforward_channels (int): The hidden dimension for FFNs. in ffn. Default 0.0.
ffn_dropout (float): Probability of an element to be zeroed operation_order (tuple[str]): The execution order of operation
in ffn. Default 0.0. in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
operation_order (tuple[str]): The execution order of operation Default:None
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). act_cfg (dict): The activation config for FFNs. Default: `LN`
Default:None norm_cfg (dict): Config dict for normalization layer.
act_cfg (dict): The activation config for FFNs. Default: `LN` Default: `LN`.
norm_cfg (dict): Config dict for normalization layer. ffn_num_fcs (int): The number of fully-connected layers in FFNs.
Default: `LN`. Default:2.
ffn_num_fcs (int): The number of fully-connected layers in FFNs. """
Default:2.
""" def __init__(self,
attn_cfgs,
def __init__(self, ffn_cfgs,
attn_cfgs, operation_order=None,
ffn_cfgs, act_cfg=dict(type='ReLU', inplace=True),
operation_order=None, norm_cfg=dict(type='LN'),
act_cfg=dict(type='ReLU', inplace=True), **kwargs):
norm_cfg=dict(type='LN'), super(BEVFormerLayer, self).__init__(
**kwargs): attn_cfgs=attn_cfgs,
super(BEVFormerLayer, self).__init__( ffn_cfgs=ffn_cfgs,
attn_cfgs=attn_cfgs, operation_order=operation_order,
ffn_cfgs=ffn_cfgs, act_cfg=act_cfg,
operation_order=operation_order, norm_cfg=norm_cfg,
act_cfg=act_cfg, **kwargs)
norm_cfg=norm_cfg, self.fp16_enabled = False
**kwargs) assert len(operation_order) == 6
self.fp16_enabled = False assert set(operation_order) == set(
assert len(operation_order) == 6 ['self_attn', 'norm', 'cross_attn', 'ffn'])
assert set(operation_order) == set(
['self_attn', 'norm', 'cross_attn', 'ffn']) def forward(self,
query,
def forward(self, key=None,
query, value=None,
key=None, bev_pos=None,
value=None, query_pos=None,
bev_pos=None, key_pos=None,
query_pos=None, attn_masks=None,
key_pos=None, query_key_padding_mask=None,
attn_masks=None, key_padding_mask=None,
query_key_padding_mask=None, ref_2d=None,
key_padding_mask=None, ref_3d=None,
ref_2d=None, bev_h=None,
ref_3d=None, bev_w=None,
bev_h=None, reference_points_cam=None,
bev_w=None, mask=None,
reference_points_cam=None, spatial_shapes=None,
mask=None, level_start_index=None,
spatial_shapes=None, prev_bev=None,
level_start_index=None, **kwargs):
prev_bev=None, """Forward function for `TransformerDecoderLayer`.
**kwargs):
"""Forward function for `TransformerDecoderLayer`. **kwargs contains some specific arguments of attentions.
**kwargs contains some specific arguments of attentions. Args:
query (Tensor): The input query with shape
Args: [num_queries, bs, embed_dims] if
query (Tensor): The input query with shape self.batch_first is False, else
[num_queries, bs, embed_dims] if [bs, num_queries embed_dims].
self.batch_first is False, else key (Tensor): The key tensor with shape [num_keys, bs,
[bs, num_queries embed_dims]. embed_dims] if self.batch_first is False, else
key (Tensor): The key tensor with shape [num_keys, bs, [bs, num_keys, embed_dims] .
embed_dims] if self.batch_first is False, else value (Tensor): The value tensor with same shape as `key`.
[bs, num_keys, embed_dims] . query_pos (Tensor): The positional encoding for `query`.
value (Tensor): The value tensor with same shape as `key`. Default: None.
query_pos (Tensor): The positional encoding for `query`. key_pos (Tensor): The positional encoding for `key`.
Default: None. Default: None.
key_pos (Tensor): The positional encoding for `key`. attn_masks (List[Tensor] | None): 2D Tensor used in
Default: None. calculation of corresponding attention. The length of
attn_masks (List[Tensor] | None): 2D Tensor used in it should equal to the number of `attention` in
calculation of corresponding attention. The length of `operation_order`. Default: None.
it should equal to the number of `attention` in query_key_padding_mask (Tensor): ByteTensor for `query`, with
`operation_order`. Default: None. shape [bs, num_queries]. Only used in `self_attn` layer.
query_key_padding_mask (Tensor): ByteTensor for `query`, with Defaults to None.
shape [bs, num_queries]. Only used in `self_attn` layer. key_padding_mask (Tensor): ByteTensor for `query`, with
Defaults to None. shape [bs, num_keys]. Default: None.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_keys]. Default: None. Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
Returns: """
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
""" norm_index = 0
attn_index = 0
norm_index = 0 ffn_index = 0
attn_index = 0 identity = query
ffn_index = 0 if attn_masks is None:
identity = query attn_masks = [None for _ in range(self.num_attn)]
if attn_masks is None: elif isinstance(attn_masks, torch.Tensor):
attn_masks = [None for _ in range(self.num_attn)] attn_masks = [
elif isinstance(attn_masks, torch.Tensor): copy.deepcopy(attn_masks) for _ in range(self.num_attn)
attn_masks = [ ]
copy.deepcopy(attn_masks) for _ in range(self.num_attn) warnings.warn(f'Use same attn_mask in all attentions in '
] f'{self.__class__.__name__} ')
warnings.warn(f'Use same attn_mask in all attentions in ' else:
f'{self.__class__.__name__} ') assert len(attn_masks) == self.num_attn, f'The length of ' \
else: f'attn_masks {len(attn_masks)} must be equal ' \
assert len(attn_masks) == self.num_attn, f'The length of ' \ f'to the number of attention in ' \
f'attn_masks {len(attn_masks)} must be equal ' \ f'operation_order {self.num_attn}'
f'to the number of attention in ' \
f'operation_order {self.num_attn}' for layer in self.operation_order:
# temporal self attention
for layer in self.operation_order: if layer == 'self_attn':
# temporal self attention
if layer == 'self_attn': query = self.attentions[attn_index](
query,
query = self.attentions[attn_index]( prev_bev,
query, prev_bev,
prev_bev, identity if self.pre_norm else None,
prev_bev, query_pos=bev_pos,
identity if self.pre_norm else None, key_pos=bev_pos,
query_pos=bev_pos, attn_mask=attn_masks[attn_index],
key_pos=bev_pos, key_padding_mask=query_key_padding_mask,
attn_mask=attn_masks[attn_index], reference_points=ref_2d,
key_padding_mask=query_key_padding_mask, spatial_shapes=torch.tensor(
reference_points=ref_2d, [[bev_h, bev_w]], device=query.device),
spatial_shapes=torch.tensor( level_start_index=torch.tensor([0], device=query.device),
[[bev_h, bev_w]], device=query.device), **kwargs)
level_start_index=torch.tensor([0], device=query.device), attn_index += 1
**kwargs) identity = query
attn_index += 1
identity = query elif layer == 'norm':
query = self.norms[norm_index](query)
elif layer == 'norm': norm_index += 1
query = self.norms[norm_index](query)
norm_index += 1 # spaital cross attention
elif layer == 'cross_attn':
# spaital cross attention query = self.attentions[attn_index](
elif layer == 'cross_attn': query,
query = self.attentions[attn_index]( key,
query, value,
key, identity if self.pre_norm else None,
value, query_pos=query_pos,
identity if self.pre_norm else None, key_pos=key_pos,
query_pos=query_pos, reference_points=ref_3d,
key_pos=key_pos, reference_points_cam=reference_points_cam,
reference_points=ref_3d, mask=mask,
reference_points_cam=reference_points_cam, attn_mask=attn_masks[attn_index],
mask=mask, key_padding_mask=key_padding_mask,
attn_mask=attn_masks[attn_index], spatial_shapes=spatial_shapes,
key_padding_mask=key_padding_mask, level_start_index=level_start_index,
spatial_shapes=spatial_shapes, **kwargs)
level_start_index=level_start_index, attn_index += 1
**kwargs) identity = query
attn_index += 1
identity = query elif layer == 'ffn':
query = self.ffns[ffn_index](
elif layer == 'ffn': query, identity if self.pre_norm else None)
query = self.ffns[ffn_index]( ffn_index += 1
query, identity if self.pre_norm else None)
ffn_index += 1 return query
return query
# --------------------------------------------- # ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
# --------------------------------------------- # ---------------------------------------------
# Modified by Zhiqi Li # Modified by Zhiqi Li
# --------------------------------------------- # ---------------------------------------------
import torch import torch
from torch.cuda.amp import custom_bwd, custom_fwd from mmcv.utils import ext_loader
from torch.autograd.function import Function, once_differentiable from torch.autograd.function import Function, once_differentiable
from mmcv.utils import ext_loader from torch.cuda.amp import custom_bwd, custom_fwd
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward']) ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
class MultiScaleDeformableAttnFunction_fp16(Function):
class MultiScaleDeformableAttnFunction_fp16(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16) @staticmethod
def forward(ctx, value, value_spatial_shapes, value_level_start_index, @custom_fwd(cast_inputs=torch.float16)
sampling_locations, attention_weights, im2col_step): def forward(ctx, value, value_spatial_shapes, value_level_start_index,
"""GPU version of multi-scale deformable attention. sampling_locations, attention_weights, im2col_step):
"""GPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape Args:
(bs, num_keys, mum_heads, embed_dims//num_heads) value (Tensor): The value has shape
value_spatial_shapes (Tensor): Spatial shape of (bs, num_keys, mum_heads, embed_dims//num_heads)
each feature map, has shape (num_levels, 2), value_spatial_shapes (Tensor): Spatial shape of
last dimension 2 represent (h, w) each feature map, has shape (num_levels, 2),
sampling_locations (Tensor): The location of sampling points, last dimension 2 represent (h, w)
has shape sampling_locations (Tensor): The location of sampling points,
(bs ,num_queries, num_heads, num_levels, num_points, 2), has shape
the last dimension 2 represent (x, y). (bs ,num_queries, num_heads, num_levels, num_points, 2),
attention_weights (Tensor): The weight of sampling points used the last dimension 2 represent (x, y).
when calculate the attention, has shape attention_weights (Tensor): The weight of sampling points used
(bs ,num_queries, num_heads, num_levels, num_points), when calculate the attention, has shape
im2col_step (Tensor): The step used in image to column. (bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column.
Returns:
Tensor: has shape (bs, num_queries, embed_dims) Returns:
""" Tensor: has shape (bs, num_queries, embed_dims)
ctx.im2col_step = im2col_step """
output = ext_module.ms_deform_attn_forward( ctx.im2col_step = im2col_step
value, output = ext_module.ms_deform_attn_forward(
value_spatial_shapes, value,
value_level_start_index, value_spatial_shapes,
sampling_locations, value_level_start_index,
attention_weights, sampling_locations,
im2col_step=ctx.im2col_step) attention_weights,
ctx.save_for_backward(value, value_spatial_shapes, im2col_step=ctx.im2col_step)
value_level_start_index, sampling_locations, ctx.save_for_backward(value, value_spatial_shapes,
attention_weights) value_level_start_index, sampling_locations,
return output attention_weights)
return output
@staticmethod
@once_differentiable @staticmethod
@custom_bwd @once_differentiable
def backward(ctx, grad_output): @custom_bwd
"""GPU version of backward function. def backward(ctx, grad_output):
"""GPU version of backward function.
Args:
grad_output (Tensor): Gradient Args:
of output tensor of forward. grad_output (Tensor): Gradient
of output tensor of forward.
Returns:
Tuple[Tensor]: Gradient Returns:
of input tensors in forward. Tuple[Tensor]: Gradient
""" of input tensors in forward.
value, value_spatial_shapes, value_level_start_index, \ """
sampling_locations, attention_weights = ctx.saved_tensors value, value_spatial_shapes, value_level_start_index, \
grad_value = torch.zeros_like(value) sampling_locations, attention_weights = ctx.saved_tensors
grad_sampling_loc = torch.zeros_like(sampling_locations) grad_value = torch.zeros_like(value)
grad_attn_weight = torch.zeros_like(attention_weights) grad_sampling_loc = torch.zeros_like(sampling_locations)
grad_attn_weight = torch.zeros_like(attention_weights)
ext_module.ms_deform_attn_backward(
value, ext_module.ms_deform_attn_backward(
value_spatial_shapes, value,
value_level_start_index, value_spatial_shapes,
sampling_locations, value_level_start_index,
attention_weights, sampling_locations,
grad_output.contiguous(), attention_weights,
grad_value, grad_output.contiguous(),
grad_sampling_loc, grad_value,
grad_attn_weight, grad_sampling_loc,
im2col_step=ctx.im2col_step) grad_attn_weight,
im2col_step=ctx.im2col_step)
return grad_value, None, None, \
grad_sampling_loc, grad_attn_weight, None return grad_value, None, None, \
grad_sampling_loc, grad_attn_weight, None
class MultiScaleDeformableAttnFunction_fp32(Function):
class MultiScaleDeformableAttnFunction_fp32(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) @staticmethod
def forward(ctx, value, value_spatial_shapes, value_level_start_index, @custom_fwd(cast_inputs=torch.float32)
sampling_locations, attention_weights, im2col_step): def forward(ctx, value, value_spatial_shapes, value_level_start_index,
"""GPU version of multi-scale deformable attention. sampling_locations, attention_weights, im2col_step):
"""GPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape Args:
(bs, num_keys, mum_heads, embed_dims//num_heads) value (Tensor): The value has shape
value_spatial_shapes (Tensor): Spatial shape of (bs, num_keys, mum_heads, embed_dims//num_heads)
each feature map, has shape (num_levels, 2), value_spatial_shapes (Tensor): Spatial shape of
last dimension 2 represent (h, w) each feature map, has shape (num_levels, 2),
sampling_locations (Tensor): The location of sampling points, last dimension 2 represent (h, w)
has shape sampling_locations (Tensor): The location of sampling points,
(bs ,num_queries, num_heads, num_levels, num_points, 2), has shape
the last dimension 2 represent (x, y). (bs ,num_queries, num_heads, num_levels, num_points, 2),
attention_weights (Tensor): The weight of sampling points used the last dimension 2 represent (x, y).
when calculate the attention, has shape attention_weights (Tensor): The weight of sampling points used
(bs ,num_queries, num_heads, num_levels, num_points), when calculate the attention, has shape
im2col_step (Tensor): The step used in image to column. (bs ,num_queries, num_heads, num_levels, num_points),
im2col_step (Tensor): The step used in image to column.
Returns:
Tensor: has shape (bs, num_queries, embed_dims) Returns:
""" Tensor: has shape (bs, num_queries, embed_dims)
"""
ctx.im2col_step = im2col_step
output = ext_module.ms_deform_attn_forward( ctx.im2col_step = im2col_step
value, output = ext_module.ms_deform_attn_forward(
value_spatial_shapes, value,
value_level_start_index, value_spatial_shapes,
sampling_locations, value_level_start_index,
attention_weights, sampling_locations,
im2col_step=ctx.im2col_step) attention_weights,
ctx.save_for_backward(value, value_spatial_shapes, im2col_step=ctx.im2col_step)
value_level_start_index, sampling_locations, ctx.save_for_backward(value, value_spatial_shapes,
attention_weights) value_level_start_index, sampling_locations,
return output attention_weights)
return output
@staticmethod
@once_differentiable @staticmethod
@custom_bwd @once_differentiable
def backward(ctx, grad_output): @custom_bwd
"""GPU version of backward function. def backward(ctx, grad_output):
"""GPU version of backward function.
Args:
grad_output (Tensor): Gradient Args:
of output tensor of forward. grad_output (Tensor): Gradient
of output tensor of forward.
Returns:
Tuple[Tensor]: Gradient Returns:
of input tensors in forward. Tuple[Tensor]: Gradient
""" of input tensors in forward.
value, value_spatial_shapes, value_level_start_index, \ """
sampling_locations, attention_weights = ctx.saved_tensors value, value_spatial_shapes, value_level_start_index, \
grad_value = torch.zeros_like(value) sampling_locations, attention_weights = ctx.saved_tensors
grad_sampling_loc = torch.zeros_like(sampling_locations) grad_value = torch.zeros_like(value)
grad_attn_weight = torch.zeros_like(attention_weights) grad_sampling_loc = torch.zeros_like(sampling_locations)
grad_attn_weight = torch.zeros_like(attention_weights)
ext_module.ms_deform_attn_backward(
value, ext_module.ms_deform_attn_backward(
value_spatial_shapes, value,
value_level_start_index, value_spatial_shapes,
sampling_locations, value_level_start_index,
attention_weights, sampling_locations,
grad_output.contiguous(), attention_weights,
grad_value, grad_output.contiguous(),
grad_sampling_loc, grad_value,
grad_attn_weight, grad_sampling_loc,
im2col_step=ctx.im2col_step) grad_attn_weight,
im2col_step=ctx.im2col_step)
return grad_value, None, None, \
grad_sampling_loc, grad_attn_weight, None return grad_value, None, None, \
grad_sampling_loc, grad_attn_weight, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment