Commit 41b18fd8 authored by zhe chen's avatar zhe chen
Browse files

Use pre-commit to reformat code


Use pre-commit to reformat code
parent ff20ea39
...@@ -4,15 +4,15 @@ ...@@ -4,15 +4,15 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import print_function
from __future__ import division
import warnings import warnings
import torch import torch
from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_ from torch import nn
from torch.nn.init import constant_, xavier_uniform_
from ..functions import DCNv3Function, dcnv3_core_pytorch from ..functions import DCNv3Function, dcnv3_core_pytorch
...@@ -72,7 +72,7 @@ def build_act_layer(act_layer): ...@@ -72,7 +72,7 @@ def build_act_layer(act_layer):
def _is_power_of_2(n): def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0): if (not isinstance(n, int)) or (n < 0):
raise ValueError( raise ValueError(
"invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 'invalid input for _is_power_of_2: {} (type: {})'.format(n, type(n)))
return (n & (n - 1) == 0) and n != 0 return (n & (n - 1) == 0) and n != 0
...@@ -124,7 +124,7 @@ class DCNv3_pytorch(nn.Module): ...@@ -124,7 +124,7 @@ class DCNv3_pytorch(nn.Module):
if not _is_power_of_2(_d_per_group): if not _is_power_of_2(_d_per_group):
warnings.warn( warnings.warn(
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.") 'which is more efficient in our CUDA implementation.')
self.offset_scale = offset_scale self.offset_scale = offset_scale
self.channels = channels self.channels = channels
...@@ -251,7 +251,7 @@ class DCNv3(nn.Module): ...@@ -251,7 +251,7 @@ class DCNv3(nn.Module):
if not _is_power_of_2(_d_per_group): if not _is_power_of_2(_d_per_group):
warnings.warn( warnings.warn(
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.") 'which is more efficient in our CUDA implementation.')
self.offset_scale = offset_scale self.offset_scale = offset_scale
self.channels = channels self.channels = channels
......
...@@ -4,39 +4,34 @@ ...@@ -4,39 +4,34 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
import os
import glob import glob
import os
import torch import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
from torch.utils.cpp_extension import CUDA_HOME requirements = ['torch', 'torchvision']
from torch.utils.cpp_extension import CppExtension
from torch.utils.cpp_extension import CUDAExtension
from setuptools import find_packages
from setuptools import setup
requirements = ["torch", "torchvision"]
def get_extensions(): def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "src") extensions_dir = os.path.join(this_dir, 'src')
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))
sources = main_file + source_cpu sources = main_file + source_cpu
extension = CppExtension extension = CppExtension
extra_compile_args = {"cxx": []} extra_compile_args = {'cxx': []}
define_macros = [] define_macros = []
if torch.cuda.is_available() and CUDA_HOME is not None: if torch.cuda.is_available() and CUDA_HOME is not None:
extension = CUDAExtension extension = CUDAExtension
sources += source_cuda sources += source_cuda
define_macros += [("WITH_CUDA", None)] define_macros += [('WITH_CUDA', None)]
extra_compile_args["nvcc"] = [ extra_compile_args['nvcc'] = [
# "-DCUDA_HAS_FP16=1", # "-DCUDA_HAS_FP16=1",
# "-D__CUDA_NO_HALF_OPERATORS__", # "-D__CUDA_NO_HALF_OPERATORS__",
# "-D__CUDA_NO_HALF_CONVERSIONS__", # "-D__CUDA_NO_HALF_CONVERSIONS__",
...@@ -49,7 +44,7 @@ def get_extensions(): ...@@ -49,7 +44,7 @@ def get_extensions():
include_dirs = [extensions_dir] include_dirs = [extensions_dir]
ext_modules = [ ext_modules = [
extension( extension(
"DCNv3", 'DCNv3',
sources, sources,
include_dirs=include_dirs, include_dirs=include_dirs,
define_macros=define_macros, define_macros=define_macros,
...@@ -60,16 +55,16 @@ def get_extensions(): ...@@ -60,16 +55,16 @@ def get_extensions():
setup( setup(
name="DCNv3", name='DCNv3',
version="1.0", version='1.0',
author="InternImage", author='InternImage',
url="https://github.com/OpenGVLab/InternImage", url='https://github.com/OpenGVLab/InternImage',
description= description=
"PyTorch Wrapper for CUDA Functions of DCNv3", 'PyTorch Wrapper for CUDA Functions of DCNv3',
packages=find_packages(exclude=( packages=find_packages(exclude=(
"configs", 'configs',
"tests", 'tests',
)), )),
ext_modules=get_extensions(), ext_modules=get_extensions(),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension},
) )
...@@ -4,16 +4,11 @@ ...@@ -4,16 +4,11 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import print_function
from __future__ import division
import time import time
import torch
import torch.nn as nn
import math
from torch.autograd import gradcheck
import torch
from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch
H_in, W_in = 8, 8 H_in, W_in = 8, 8
...@@ -32,11 +27,11 @@ torch.manual_seed(3) ...@@ -32,11 +27,11 @@ torch.manual_seed(3)
@torch.no_grad() @torch.no_grad()
def check_forward_equal_with_pytorch_double(): def check_forward_equal_with_pytorch_double():
input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask /= mask.sum(-1, keepdim=True) mask /= mask.sum(-1, keepdim=True)
mask = mask.reshape(N, H_out, W_out, M*P) mask = mask.reshape(N, H_out, W_out, M * P)
output_pytorch = dcnv3_core_pytorch( output_pytorch = dcnv3_core_pytorch(
input.double(), input.double(),
...@@ -57,16 +52,17 @@ def check_forward_equal_with_pytorch_double(): ...@@ -57,16 +52,17 @@ def check_forward_equal_with_pytorch_double():
max_rel_err = ((output_cuda - output_pytorch).abs() / max_rel_err = ((output_cuda - output_pytorch).abs() /
output_pytorch.abs()).max() output_pytorch.abs()).max()
print('>>> forward double') print('>>> forward double')
print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') print(
f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
@torch.no_grad() @torch.no_grad()
def check_forward_equal_with_pytorch_float(): def check_forward_equal_with_pytorch_float():
input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask /= mask.sum(-1, keepdim=True) mask /= mask.sum(-1, keepdim=True)
mask = mask.reshape(N, H_out, W_out, M*P) mask = mask.reshape(N, H_out, W_out, M * P)
output_pytorch = dcnv3_core_pytorch( output_pytorch = dcnv3_core_pytorch(
input, input,
...@@ -87,7 +83,8 @@ def check_forward_equal_with_pytorch_float(): ...@@ -87,7 +83,8 @@ def check_forward_equal_with_pytorch_float():
max_rel_err = ((output_cuda - output_pytorch).abs() / max_rel_err = ((output_cuda - output_pytorch).abs() /
output_pytorch.abs()).max() output_pytorch.abs()).max()
print('>>> forward float') print('>>> forward float')
print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') print(
f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_offset=True, grad_mask=True): def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_offset=True, grad_mask=True):
...@@ -98,11 +95,11 @@ def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_o ...@@ -98,11 +95,11 @@ def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_o
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
D = channels D = channels
input0 = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 input0 = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
offset0 = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 offset0 = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask0 /= mask0.sum(-1, keepdim=True) mask0 /= mask0.sum(-1, keepdim=True)
mask0 = mask0.reshape(N, H_out, W_out, M*P) mask0 = mask0.reshape(N, H_out, W_out, M * P)
input0.requires_grad = grad_input input0.requires_grad = grad_input
offset0.requires_grad = grad_offset offset0.requires_grad = grad_offset
mask0.requires_grad = grad_mask mask0.requires_grad = grad_mask
...@@ -161,11 +158,11 @@ def check_backward_equal_with_pytorch_float(channels=4, grad_input=True, grad_of ...@@ -161,11 +158,11 @@ def check_backward_equal_with_pytorch_float(channels=4, grad_input=True, grad_of
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
D = channels D = channels
input0 = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 input0 = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
offset0 = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 offset0 = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask0 /= mask0.sum(-1, keepdim=True) mask0 /= mask0.sum(-1, keepdim=True)
mask0 = mask0.reshape(N, H_out, W_out, M*P) mask0 = mask0.reshape(N, H_out, W_out, M * P)
input0.requires_grad = grad_input input0.requires_grad = grad_input
offset0.requires_grad = grad_offset offset0.requires_grad = grad_offset
mask0.requires_grad = grad_mask mask0.requires_grad = grad_mask
...@@ -223,11 +220,11 @@ def check_time_cost(im2col_step=128): ...@@ -223,11 +220,11 @@ def check_time_cost(im2col_step=128):
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask /= mask.sum(-1, keepdim=True) mask /= mask.sum(-1, keepdim=True)
mask = mask.reshape(N, H_out, W_out, M*P) mask = mask.reshape(N, H_out, W_out, M * P)
print( print(
f'>>> time cost: im2col_step {im2col_step}; input {input.shape}; points {P} ') f'>>> time cost: im2col_step {im2col_step}; input {input.shape}; points {P} ')
repeat = 100 repeat = 100
......
...@@ -5,29 +5,19 @@ ...@@ -5,29 +5,19 @@
# --------------------------------------------- # ---------------------------------------------
import copy import copy
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmcv.cnn import Linear, bias_init_with_prob from mmcv.runner import BaseModule, auto_fp16, force_fp32
from mmcv.utils import TORCH_VERSION, digit_version
from mmdet.core import (multi_apply, multi_apply, reduce_mean)
from mmdet.models.utils.transformer import inverse_sigmoid
from mmdet.models import HEADS from mmdet.models import HEADS
from mmdet.models.builder import build_loss
from mmdet.models.dense_heads import DETRHead from mmdet.models.dense_heads import DETRHead
from mmdet3d.core.bbox.coders import build_bbox_coder from mmdet.models.utils import build_transformer
from projects.mmdet3d_plugin.core.bbox.util import normalize_bbox from mmdet.models.utils.transformer import inverse_sigmoid
from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmcv.runner import force_fp32, auto_fp16
from projects.mmdet3d_plugin.models.utils.bricks import run_time from projects.mmdet3d_plugin.models.utils.bricks import run_time
import numpy as np
import mmcv
import cv2 as cv
from projects.mmdet3d_plugin.models.utils.visual import save_tensor from projects.mmdet3d_plugin.models.utils.visual import save_tensor
from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmdet.models.utils import build_transformer
from mmdet.models.builder import build_loss
from mmcv.runner import BaseModule, force_fp32
@HEADS.register_module() @HEADS.register_module()
class BEVFormerOccHead(BaseModule): class BEVFormerOccHead(BaseModule):
...@@ -61,15 +51,14 @@ class BEVFormerOccHead(BaseModule): ...@@ -61,15 +51,14 @@ class BEVFormerOccHead(BaseModule):
self.bev_h = bev_h self.bev_h = bev_h
self.bev_w = bev_w self.bev_w = bev_w
self.fp16_enabled = False self.fp16_enabled = False
self.num_classes=kwargs['num_classes'] self.num_classes = kwargs['num_classes']
self.use_mask=use_mask self.use_mask = use_mask
self.with_box_refine = with_box_refine self.with_box_refine = with_box_refine
self.as_two_stage = as_two_stage self.as_two_stage = as_two_stage
if self.as_two_stage: if self.as_two_stage:
transformer['as_two_stage'] = self.as_two_stage transformer['as_two_stage'] = self.as_two_stage
self.pc_range = pc_range self.pc_range = pc_range
self.real_w = self.pc_range[3] - self.pc_range[0] self.real_w = self.pc_range[3] - self.pc_range[0]
self.real_h = self.pc_range[4] - self.pc_range[1] self.real_h = self.pc_range[4] - self.pc_range[1]
...@@ -151,7 +140,7 @@ class BEVFormerOccHead(BaseModule): ...@@ -151,7 +140,7 @@ class BEVFormerOccHead(BaseModule):
outs = { outs = {
'bev_embed': bev_embed, 'bev_embed': bev_embed,
'occ':occ_outs, 'occ': occ_outs,
} }
return outs return outs
...@@ -166,25 +155,25 @@ class BEVFormerOccHead(BaseModule): ...@@ -166,25 +155,25 @@ class BEVFormerOccHead(BaseModule):
gt_bboxes_ignore=None, gt_bboxes_ignore=None,
img_metas=None): img_metas=None):
loss_dict=dict() loss_dict = dict()
occ=preds_dicts['occ'] occ = preds_dicts['occ']
assert voxel_semantics.min()>=0 and voxel_semantics.max()<=17 assert voxel_semantics.min() >= 0 and voxel_semantics.max() <= 17
losses = self.loss_single(voxel_semantics,mask_camera,occ) losses = self.loss_single(voxel_semantics, mask_camera, occ)
loss_dict['loss_occ']=losses loss_dict['loss_occ'] = losses
return loss_dict return loss_dict
def loss_single(self,voxel_semantics,mask_camera,preds): def loss_single(self, voxel_semantics, mask_camera, preds):
voxel_semantics=voxel_semantics.long() voxel_semantics = voxel_semantics.long()
if self.use_mask: if self.use_mask:
voxel_semantics=voxel_semantics.reshape(-1) voxel_semantics = voxel_semantics.reshape(-1)
preds=preds.reshape(-1,self.num_classes) preds = preds.reshape(-1, self.num_classes)
mask_camera=mask_camera.reshape(-1) mask_camera = mask_camera.reshape(-1)
num_total_samples=mask_camera.sum() num_total_samples = mask_camera.sum()
loss_occ=self.loss_occ(preds,voxel_semantics,mask_camera, avg_factor=num_total_samples) loss_occ = self.loss_occ(preds, voxel_semantics, mask_camera, avg_factor=num_total_samples)
else: else:
voxel_semantics = voxel_semantics.reshape(-1) voxel_semantics = voxel_semantics.reshape(-1)
preds = preds.reshape(-1, self.num_classes) preds = preds.reshape(-1, self.num_classes)
loss_occ = self.loss_occ(preds, voxel_semantics,) loss_occ = self.loss_occ(preds, voxel_semantics, )
return loss_occ return loss_occ
@force_fp32(apply_to=('preds')) @force_fp32(apply_to=('preds'))
...@@ -199,9 +188,8 @@ class BEVFormerOccHead(BaseModule): ...@@ -199,9 +188,8 @@ class BEVFormerOccHead(BaseModule):
# return self.transformer.get_occ( # return self.transformer.get_occ(
# preds_dicts, img_metas, rescale=rescale) # preds_dicts, img_metas, rescale=rescale)
# print(img_metas[0].keys()) # print(img_metas[0].keys())
occ_out=preds_dicts['occ'] occ_out = preds_dicts['occ']
occ_score=occ_out.softmax(-1) occ_score = occ_out.softmax(-1)
occ_score=occ_score.argmax(-1) occ_score = occ_score.argmax(-1)
return occ_score return occ_score
...@@ -4,17 +4,13 @@ ...@@ -4,17 +4,13 @@
# Modified by Xiaoyu Tian # Modified by Xiaoyu Tian
# --------------------------------------------- # ---------------------------------------------
import copy
import torch import torch
from mmcv.runner import force_fp32, auto_fp16 from mmcv.runner import auto_fp16
from mmdet.models import DETECTORS
from mmdet3d.core import bbox3d2result
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
from mmdet.models import DETECTORS
from projects.mmdet3d_plugin.models.utils.grid_mask import GridMask from projects.mmdet3d_plugin.models.utils.grid_mask import GridMask
import time
import copy
import numpy as np
import mmdet3d
from projects.mmdet3d_plugin.models.utils.bricks import run_time
@DETECTORS.register_module() @DETECTORS.register_module()
......
from mmcv.runner.hooks.hook import HOOKS, Hook from mmcv.runner.hooks.hook import HOOKS, Hook
from projects.mmdet3d_plugin.models.utils import run_time
@HOOKS.register_module() @HOOKS.register_module()
class TransferWeight(Hook): class TransferWeight(Hook):
def __init__(self, every_n_inters=1): def __init__(self, every_n_inters=1):
self.every_n_inters=every_n_inters self.every_n_inters = every_n_inters
def after_train_iter(self, runner): def after_train_iter(self, runner):
if self.every_n_inner_iters(runner, self.every_n_inters): if self.every_n_inner_iters(runner, self.every_n_inters):
runner.eval_model.load_state_dict(runner.model.state_dict()) runner.eval_model.load_state_dict(runner.model.state_dict())
from .transformer import PerceptionTransformer
from .spatial_cross_attention import SpatialCrossAttention, MSDeformableAttention3D
from .temporal_self_attention import TemporalSelfAttention
from .encoder import BEVFormerEncoder, BEVFormerLayer
from .decoder import DetectionTransformerDecoder from .decoder import DetectionTransformerDecoder
from .encoder import BEVFormerEncoder, BEVFormerLayer
from .spatial_cross_attention import (MSDeformableAttention3D,
SpatialCrossAttention)
from .temporal_self_attention import TemporalSelfAttention
from .transformer import PerceptionTransformer
from .transformer_occ import TransformerOcc from .transformer_occ import TransformerOcc
...@@ -8,18 +8,16 @@ import copy ...@@ -8,18 +8,16 @@ import copy
import warnings import warnings
import torch import torch
import torch.nn as nn from mmcv import ConfigDict
from mmcv.cnn import build_norm_layer
from mmcv import ConfigDict, deprecated_api_warning from mmcv.cnn.bricks.registry import TRANSFORMER_LAYER
from mmcv.cnn import Linear, build_activation_layer, build_norm_layer from mmcv.runner.base_module import BaseModule, ModuleList
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.cnn.bricks.registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file # Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
try: try:
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401 from mmcv.ops.multi_scale_deform_attn import \
MultiScaleDeformableAttention # noqa F401
warnings.warn( warnings.warn(
ImportWarning( ImportWarning(
'``MultiScaleDeformableAttention`` has been moved to ' '``MultiScaleDeformableAttention`` has been moved to '
...@@ -31,7 +29,8 @@ except ImportError: ...@@ -31,7 +29,8 @@ except ImportError:
warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from ' warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '
'``mmcv.ops.multi_scale_deform_attn``, ' '``mmcv.ops.multi_scale_deform_attn``, '
'You should install ``mmcv-full`` if you need this module. ') 'You should install ``mmcv-full`` if you need this module. ')
from mmcv.cnn.bricks.transformer import build_feedforward_network, build_attention from mmcv.cnn.bricks.transformer import (build_attention,
build_feedforward_network)
@TRANSFORMER_LAYER.register_module() @TRANSFORMER_LAYER.register_module()
......
...@@ -4,28 +4,21 @@ ...@@ -4,28 +4,21 @@
# Modified by Zhiqi Li # Modified by Zhiqi Li
# --------------------------------------------- # ---------------------------------------------
from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch import math
import mmcv
import cv2 as cv
import copy
import warnings import warnings
from matplotlib import pyplot as plt
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn import xavier_init, constant_init from mmcv.cnn.bricks.registry import ATTENTION, TRANSFORMER_LAYER_SEQUENCE
from mmcv.cnn.bricks.registry import (ATTENTION,
TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import TransformerLayerSequence from mmcv.cnn.bricks.transformer import TransformerLayerSequence
import math from mmcv.ops.multi_scale_deform_attn import \
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential multi_scale_deformable_attn_pytorch
from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning, from mmcv.runner.base_module import BaseModule
to_2tuple) from mmcv.utils import deprecated_api_warning, ext_loader
from mmcv.utils import ext_loader from .multi_scale_deformable_attn_function import \
from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32, \ MultiScaleDeformableAttnFunction_fp32
MultiScaleDeformableAttnFunction_fp16
ext_module = ext_loader.load_ext( ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward']) '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
......
# --------------------------------------------- # ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
# --------------------------------------------- # ---------------------------------------------
# Modified by Zhiqi Li # Modified by Zhiqi Li
# --------------------------------------------- # ---------------------------------------------
from projects.mmdet3d_plugin.models.utils.bricks import run_time
from projects.mmdet3d_plugin.models.utils.visual import save_tensor
from .custom_base_transformer_layer import MyCustomBaseTransformerLayer
import copy import copy
import warnings import warnings
from mmcv.cnn.bricks.registry import (ATTENTION,
TRANSFORMER_LAYER,
TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import TransformerLayerSequence
from mmcv.runner import force_fp32, auto_fp16
import numpy as np import numpy as np
import torch import torch
import cv2 as cv from mmcv.cnn.bricks.registry import (TRANSFORMER_LAYER,
import mmcv TRANSFORMER_LAYER_SEQUENCE)
from mmcv.utils import TORCH_VERSION, digit_version from mmcv.cnn.bricks.transformer import TransformerLayerSequence
from mmcv.utils import ext_loader from mmcv.runner import auto_fp16, force_fp32
from mmcv.utils import TORCH_VERSION, digit_version, ext_loader
from .custom_base_transformer_layer import MyCustomBaseTransformerLayer
ext_module = ext_loader.load_ext( ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward']) '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
@TRANSFORMER_LAYER_SEQUENCE.register_module() @TRANSFORMER_LAYER_SEQUENCE.register_module()
class BEVFormerEncoder(TransformerLayerSequence): class BEVFormerEncoder(TransformerLayerSequence):
""" """
Attention with both self and cross Attention with both self and cross
Implements the decoder in DETR transformer. Implements the decoder in DETR transformer.
...@@ -71,7 +66,7 @@ class BEVFormerEncoder(TransformerLayerSequence): ...@@ -71,7 +66,7 @@ class BEVFormerEncoder(TransformerLayerSequence):
device=device).view(1, H, 1).expand(num_points_in_pillar, H, W) / H device=device).view(1, H, 1).expand(num_points_in_pillar, H, W) / H
ref_3d = torch.stack((xs, ys, zs), -1) ref_3d = torch.stack((xs, ys, zs), -1)
ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1) ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1)
ref_3d = ref_3d[None].repeat(bs, 1, 1, 1) #shape: (bs,num_points_in_pillar,h*w,3) ref_3d = ref_3d[None].repeat(bs, 1, 1, 1) # shape: (bs,num_points_in_pillar,h*w,3)
return ref_3d return ref_3d
# reference points on 2D bev plane, used in temporal self-attention (TSA). # reference points on 2D bev plane, used in temporal self-attention (TSA).
...@@ -91,7 +86,7 @@ class BEVFormerEncoder(TransformerLayerSequence): ...@@ -91,7 +86,7 @@ class BEVFormerEncoder(TransformerLayerSequence):
# This function must use fp32!!! # This function must use fp32!!!
@force_fp32(apply_to=('reference_points', 'img_metas')) @force_fp32(apply_to=('reference_points', 'img_metas'))
def point_sampling(self, reference_points, pc_range, img_metas): def point_sampling(self, reference_points, pc_range, img_metas):
ego2lidar=img_metas[0]['ego2lidar'] ego2lidar = img_metas[0]['ego2lidar']
lidar2img = [] lidar2img = []
for img_meta in img_metas: for img_meta in img_metas:
...@@ -113,17 +108,19 @@ class BEVFormerEncoder(TransformerLayerSequence): ...@@ -113,17 +108,19 @@ class BEVFormerEncoder(TransformerLayerSequence):
reference_points = torch.cat( reference_points = torch.cat(
(reference_points, torch.ones_like(reference_points[..., :1])), -1) (reference_points, torch.ones_like(reference_points[..., :1])), -1)
reference_points = reference_points.permute(1, 0, 2, 3) #shape: (num_points_in_pillar,bs,h*w,4) reference_points = reference_points.permute(1, 0, 2, 3) # shape: (num_points_in_pillar,bs,h*w,4)
D, B, num_query = reference_points.size()[:3] # D=num_points_in_pillar , num_query=h*w D, B, num_query = reference_points.size()[:3] # D=num_points_in_pillar , num_query=h*w
num_cam = lidar2img.size(1) num_cam = lidar2img.size(1)
reference_points = reference_points.view( reference_points = reference_points.view(
D, B, 1, num_query, 4).repeat(1, 1, num_cam, 1, 1).unsqueeze(-1) #shape: (num_points_in_pillar,bs,num_cam,h*w,4) D, B, 1, num_query, 4).repeat(1, 1, num_cam, 1, 1).unsqueeze(
-1) # shape: (num_points_in_pillar,bs,num_cam,h*w,4)
lidar2img = lidar2img.view( lidar2img = lidar2img.view(
1, B, num_cam, 1, 4, 4).repeat(D, 1, 1, num_query, 1, 1) 1, B, num_cam, 1, 4, 4).repeat(D, 1, 1, num_query, 1, 1)
ego2lidar=ego2lidar.view(1,1,1,1,4,4).repeat(D,1,num_cam,num_query,1,1) ego2lidar = ego2lidar.view(1, 1, 1, 1, 4, 4).repeat(D, 1, num_cam, num_query, 1, 1)
reference_points_cam = torch.matmul(torch.matmul(lidar2img.to(torch.float32),ego2lidar.to(torch.float32)),reference_points.to(torch.float32)).squeeze(-1) reference_points_cam = torch.matmul(torch.matmul(lidar2img.to(torch.float32), ego2lidar.to(torch.float32)),
reference_points.to(torch.float32)).squeeze(-1)
eps = 1e-5 eps = 1e-5
bev_mask = (reference_points_cam[..., 2:3] > eps) bev_mask = (reference_points_cam[..., 2:3] > eps)
...@@ -143,8 +140,8 @@ class BEVFormerEncoder(TransformerLayerSequence): ...@@ -143,8 +140,8 @@ class BEVFormerEncoder(TransformerLayerSequence):
bev_mask = bev_mask.new_tensor( bev_mask = bev_mask.new_tensor(
np.nan_to_num(bev_mask.cpu().numpy())) np.nan_to_num(bev_mask.cpu().numpy()))
reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4) #shape: (num_cam,bs,h*w,num_points_in_pillar,2) reference_points_cam = reference_points_cam.permute(2, 1, 3, 0,
4) # shape: (num_cam,bs,h*w,num_points_in_pillar,2)
bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1) bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1)
...@@ -188,7 +185,8 @@ class BEVFormerEncoder(TransformerLayerSequence): ...@@ -188,7 +185,8 @@ class BEVFormerEncoder(TransformerLayerSequence):
intermediate = [] intermediate = []
ref_3d = self.get_reference_points( ref_3d = self.get_reference_points(
bev_h, bev_w, self.pc_range[5]-self.pc_range[2], self.num_points_in_pillar, dim='3d', bs=bev_query.size(1), device=bev_query.device, dtype=bev_query.dtype) bev_h, bev_w, self.pc_range[5] - self.pc_range[2], self.num_points_in_pillar, dim='3d',
bs=bev_query.size(1), device=bev_query.device, dtype=bev_query.dtype)
ref_2d = self.get_reference_points( ref_2d = self.get_reference_points(
bev_h, bev_w, dim='2d', bs=bev_query.size(1), device=bev_query.device, dtype=bev_query.dtype) bev_h, bev_w, dim='2d', bs=bev_query.size(1), device=bev_query.device, dtype=bev_query.dtype)
...@@ -206,12 +204,12 @@ class BEVFormerEncoder(TransformerLayerSequence): ...@@ -206,12 +204,12 @@ class BEVFormerEncoder(TransformerLayerSequence):
if prev_bev is not None: if prev_bev is not None:
prev_bev = prev_bev.permute(1, 0, 2) prev_bev = prev_bev.permute(1, 0, 2)
prev_bev = torch.stack( prev_bev = torch.stack(
[prev_bev, bev_query], 1).reshape(bs*2, len_bev, -1) [prev_bev, bev_query], 1).reshape(bs * 2, len_bev, -1)
hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape( hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape(
bs*2, len_bev, num_bev_level, 2) bs * 2, len_bev, num_bev_level, 2)
else: else:
hybird_ref_2d = torch.stack([ref_2d, ref_2d], 1).reshape( hybird_ref_2d = torch.stack([ref_2d, ref_2d], 1).reshape(
bs*2, len_bev, num_bev_level, 2) bs * 2, len_bev, num_bev_level, 2)
for lid, layer in enumerate(self.layers): for lid, layer in enumerate(self.layers):
output = layer( output = layer(
......
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
# --------------------------------------------- # ---------------------------------------------
import torch import torch
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.autograd.function import Function, once_differentiable
from mmcv.utils import ext_loader from mmcv.utils import ext_loader
from torch.autograd.function import Function, once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
ext_module = ext_loader.load_ext( ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward']) '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
......
# --------------------------------------------- # ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
# --------------------------------------------- # ---------------------------------------------
# Modified by Zhiqi Li # Modified by Zhiqi Li
# --------------------------------------------- # ---------------------------------------------
from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch import math
import warnings import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn import xavier_init, constant_init from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.cnn.bricks.registry import (ATTENTION,
TRANSFORMER_LAYER,
TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import build_attention from mmcv.cnn.bricks.transformer import build_attention
import math from mmcv.ops.multi_scale_deform_attn import \
from mmcv.runner import force_fp32, auto_fp16 multi_scale_deformable_attn_pytorch
from mmcv.runner import force_fp32
from mmcv.runner.base_module import BaseModule
from mmcv.utils import ext_loader
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential from .multi_scale_deformable_attn_function import \
MultiScaleDeformableAttnFunction_fp32
from mmcv.utils import ext_loader
from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32, \
MultiScaleDeformableAttnFunction_fp16
from projects.mmdet3d_plugin.models.utils.bricks import run_time
ext_module = ext_loader.load_ext( ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward']) '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
...@@ -150,7 +147,8 @@ class SpatialCrossAttention(BaseModule): ...@@ -150,7 +147,8 @@ class SpatialCrossAttention(BaseModule):
for i, reference_points_per_img in enumerate(reference_points_cam): for i, reference_points_per_img in enumerate(reference_points_cam):
index_query_per_img = indexes[i] index_query_per_img = indexes[i]
queries_rebatch[j, i, :len(index_query_per_img)] = query[j, index_query_per_img] queries_rebatch[j, i, :len(index_query_per_img)] = query[j, index_query_per_img]
reference_points_rebatch[j, i, :len(index_query_per_img)] = reference_points_per_img[j, index_query_per_img] reference_points_rebatch[j, i, :len(index_query_per_img)] = reference_points_per_img[
j, index_query_per_img]
num_cams, l, bs, embed_dims = key.shape num_cams, l, bs, embed_dims = key.shape
...@@ -159,9 +157,13 @@ class SpatialCrossAttention(BaseModule): ...@@ -159,9 +157,13 @@ class SpatialCrossAttention(BaseModule):
value = value.permute(2, 0, 1, 3).reshape( value = value.permute(2, 0, 1, 3).reshape(
bs * self.num_cams, l, self.embed_dims) bs * self.num_cams, l, self.embed_dims)
queries = self.deformable_attention(query=queries_rebatch.view(bs*self.num_cams, max_len, self.embed_dims), key=key, value=value, queries = self.deformable_attention(query=queries_rebatch.view(bs * self.num_cams, max_len, self.embed_dims),
reference_points=reference_points_rebatch.view(bs*self.num_cams, max_len, D, 2), spatial_shapes=spatial_shapes, key=key, value=value,
level_start_index=level_start_index).view(bs, self.num_cams, max_len, self.embed_dims) reference_points=reference_points_rebatch.view(bs * self.num_cams, max_len,
D, 2),
spatial_shapes=spatial_shapes,
level_start_index=level_start_index).view(bs, self.num_cams, max_len,
self.embed_dims)
for j in range(bs): for j in range(bs):
for i, index_query_per_img in enumerate(indexes): for i, index_query_per_img in enumerate(indexes):
slots[j, index_query_per_img] += queries[j, i, :len(index_query_per_img)] slots[j, index_query_per_img] += queries[j, i, :len(index_query_per_img)]
......
...@@ -4,20 +4,21 @@ ...@@ -4,20 +4,21 @@
# Modified by Zhiqi Li # Modified by Zhiqi Li
# --------------------------------------------- # ---------------------------------------------
from projects.mmdet3d_plugin.models.utils.bricks import run_time import math
from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32
from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch
import warnings import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import xavier_init, constant_init from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn.bricks.registry import ATTENTION from mmcv.cnn.bricks.registry import ATTENTION
import math from mmcv.ops.multi_scale_deform_attn import \
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential multi_scale_deformable_attn_pytorch
from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning, from mmcv.runner.base_module import BaseModule
to_2tuple)
from mmcv.utils import ext_loader from mmcv.utils import ext_loader
from .multi_scale_deformable_attn_function import \
MultiScaleDeformableAttnFunction_fp32
ext_module = ext_loader.load_ext( ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward']) '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
...@@ -97,9 +98,9 @@ class TemporalSelfAttention(BaseModule): ...@@ -97,9 +98,9 @@ class TemporalSelfAttention(BaseModule):
self.num_points = num_points self.num_points = num_points
self.num_bev_queue = num_bev_queue self.num_bev_queue = num_bev_queue
self.sampling_offsets = nn.Linear( self.sampling_offsets = nn.Linear(
embed_dims*self.num_bev_queue, num_bev_queue*num_heads * num_levels * num_points * 2) embed_dims * self.num_bev_queue, num_bev_queue * num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dims*self.num_bev_queue, self.attention_weights = nn.Linear(embed_dims * self.num_bev_queue,
num_bev_queue*num_heads * num_levels * num_points) num_bev_queue * num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dims, embed_dims) self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims) self.output_proj = nn.Linear(embed_dims, embed_dims)
self.init_weights() self.init_weights()
...@@ -114,7 +115,7 @@ class TemporalSelfAttention(BaseModule): ...@@ -114,7 +115,7 @@ class TemporalSelfAttention(BaseModule):
grid_init = (grid_init / grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view( grid_init.abs().max(-1, keepdim=True)[0]).view(
self.num_heads, 1, 1, self.num_heads, 1, 1,
2).repeat(1, self.num_levels*self.num_bev_queue, self.num_points, 1) 2).repeat(1, self.num_levels * self.num_bev_queue, self.num_points, 1)
for i in range(self.num_points): for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1 grid_init[:, :, i, :] *= i + 1
...@@ -177,7 +178,7 @@ class TemporalSelfAttention(BaseModule): ...@@ -177,7 +178,7 @@ class TemporalSelfAttention(BaseModule):
if value is None: if value is None:
assert self.batch_first assert self.batch_first
bs, len_bev, c = query.shape bs, len_bev, c = query.shape
value = torch.stack([query, query], 1).reshape(bs*2, len_bev, c) value = torch.stack([query, query], 1).reshape(bs * 2, len_bev, c)
# value = torch.cat([query, query], 0) # value = torch.cat([query, query], 0)
...@@ -200,7 +201,7 @@ class TemporalSelfAttention(BaseModule): ...@@ -200,7 +201,7 @@ class TemporalSelfAttention(BaseModule):
if key_padding_mask is not None: if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0) value = value.masked_fill(key_padding_mask[..., None], 0.0)
value = value.reshape(bs*self.num_bev_queue, value = value.reshape(bs * self.num_bev_queue,
num_value, self.num_heads, -1) num_value, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query) sampling_offsets = self.sampling_offsets(query)
...@@ -216,10 +217,10 @@ class TemporalSelfAttention(BaseModule): ...@@ -216,10 +217,10 @@ class TemporalSelfAttention(BaseModule):
self.num_levels, self.num_levels,
self.num_points) self.num_points)
attention_weights = attention_weights.permute(0, 3, 1, 2, 4, 5)\ attention_weights = attention_weights.permute(0, 3, 1, 2, 4, 5) \
.reshape(bs*self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points).contiguous() .reshape(bs * self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points).contiguous()
sampling_offsets = sampling_offsets.permute(0, 3, 1, 2, 4, 5, 6)\ sampling_offsets = sampling_offsets.permute(0, 3, 1, 2, 4, 5, 6) \
.reshape(bs*self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points, 2) .reshape(bs * self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points, 2)
if reference_points.shape[-1] == 2: if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack( offset_normalizer = torch.stack(
......
...@@ -9,18 +9,15 @@ import torch ...@@ -9,18 +9,15 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import xavier_init from mmcv.cnn import xavier_init
from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence
from mmcv.runner import auto_fp16
from mmcv.runner.base_module import BaseModule from mmcv.runner.base_module import BaseModule
from mmdet.models.utils.builder import TRANSFORMER from mmdet.models.utils.builder import TRANSFORMER
from torch.nn.init import normal_ from torch.nn.init import normal_
from projects.mmdet3d_plugin.models.utils.visual import save_tensor
from mmcv.runner.base_module import BaseModule
from torchvision.transforms.functional import rotate from torchvision.transforms.functional import rotate
from .temporal_self_attention import TemporalSelfAttention
from .spatial_cross_attention import MSDeformableAttention3D
from .decoder import CustomMSDeformableAttention from .decoder import CustomMSDeformableAttention
from projects.mmdet3d_plugin.models.utils.bricks import run_time from .spatial_cross_attention import MSDeformableAttention3D
from mmcv.runner import force_fp32, auto_fp16 from .temporal_self_attention import TemporalSelfAttention
@TRANSFORMER.register_module() @TRANSFORMER.register_module()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment