Commit 41b18fd8 authored by zhe chen's avatar zhe chen
Browse files

Use pre-commit to reformat code


Use pre-commit to reformat code
parent ff20ea39
......@@ -4,15 +4,15 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import, division, print_function
import warnings
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_
from torch import nn
from torch.nn.init import constant_, xavier_uniform_
from ..functions import DCNv3Function, dcnv3_core_pytorch
......@@ -72,7 +72,7 @@ def build_act_layer(act_layer):
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError(
"invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
'invalid input for _is_power_of_2: {} (type: {})'.format(n, type(n)))
return (n & (n - 1) == 0) and n != 0
......@@ -124,7 +124,7 @@ class DCNv3_pytorch(nn.Module):
if not _is_power_of_2(_d_per_group):
warnings.warn(
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.")
'which is more efficient in our CUDA implementation.')
self.offset_scale = offset_scale
self.channels = channels
......@@ -251,7 +251,7 @@ class DCNv3(nn.Module):
if not _is_power_of_2(_d_per_group):
warnings.warn(
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.")
'which is more efficient in our CUDA implementation.')
self.offset_scale = offset_scale
self.channels = channels
......
......@@ -4,39 +4,34 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import os
import glob
import os
import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
from torch.utils.cpp_extension import CUDA_HOME
from torch.utils.cpp_extension import CppExtension
from torch.utils.cpp_extension import CUDAExtension
from setuptools import find_packages
from setuptools import setup
requirements = ["torch", "torchvision"]
requirements = ['torch', 'torchvision']
def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "src")
extensions_dir = os.path.join(this_dir, 'src')
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))
sources = main_file + source_cpu
extension = CppExtension
extra_compile_args = {"cxx": []}
extra_compile_args = {'cxx': []}
define_macros = []
if torch.cuda.is_available() and CUDA_HOME is not None:
extension = CUDAExtension
sources += source_cuda
define_macros += [("WITH_CUDA", None)]
extra_compile_args["nvcc"] = [
define_macros += [('WITH_CUDA', None)]
extra_compile_args['nvcc'] = [
# "-DCUDA_HAS_FP16=1",
# "-D__CUDA_NO_HALF_OPERATORS__",
# "-D__CUDA_NO_HALF_CONVERSIONS__",
......@@ -49,7 +44,7 @@ def get_extensions():
include_dirs = [extensions_dir]
ext_modules = [
extension(
"DCNv3",
'DCNv3',
sources,
include_dirs=include_dirs,
define_macros=define_macros,
......@@ -60,16 +55,16 @@ def get_extensions():
setup(
name="DCNv3",
version="1.0",
author="InternImage",
url="https://github.com/OpenGVLab/InternImage",
name='DCNv3',
version='1.0',
author='InternImage',
url='https://github.com/OpenGVLab/InternImage',
description=
"PyTorch Wrapper for CUDA Functions of DCNv3",
'PyTorch Wrapper for CUDA Functions of DCNv3',
packages=find_packages(exclude=(
"configs",
"tests",
'configs',
'tests',
)),
ext_modules=get_extensions(),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension},
)
......@@ -4,16 +4,11 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import, division, print_function
import time
import torch
import torch.nn as nn
import math
from torch.autograd import gradcheck
import torch
from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch
H_in, W_in = 8, 8
......@@ -32,11 +27,11 @@ torch.manual_seed(3)
@torch.no_grad()
def check_forward_equal_with_pytorch_double():
input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01
offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10
input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask /= mask.sum(-1, keepdim=True)
mask = mask.reshape(N, H_out, W_out, M*P)
mask = mask.reshape(N, H_out, W_out, M * P)
output_pytorch = dcnv3_core_pytorch(
input.double(),
......@@ -57,16 +52,17 @@ def check_forward_equal_with_pytorch_double():
max_rel_err = ((output_cuda - output_pytorch).abs() /
output_pytorch.abs()).max()
print('>>> forward double')
print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
print(
f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
@torch.no_grad()
def check_forward_equal_with_pytorch_float():
input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01
offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10
input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask /= mask.sum(-1, keepdim=True)
mask = mask.reshape(N, H_out, W_out, M*P)
mask = mask.reshape(N, H_out, W_out, M * P)
output_pytorch = dcnv3_core_pytorch(
input,
......@@ -87,7 +83,8 @@ def check_forward_equal_with_pytorch_float():
max_rel_err = ((output_cuda - output_pytorch).abs() /
output_pytorch.abs()).max()
print('>>> forward float')
print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
print(
f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_offset=True, grad_mask=True):
......@@ -98,11 +95,11 @@ def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_o
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
D = channels
input0 = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01
offset0 = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10
input0 = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
offset0 = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask0 /= mask0.sum(-1, keepdim=True)
mask0 = mask0.reshape(N, H_out, W_out, M*P)
mask0 = mask0.reshape(N, H_out, W_out, M * P)
input0.requires_grad = grad_input
offset0.requires_grad = grad_offset
mask0.requires_grad = grad_mask
......@@ -161,11 +158,11 @@ def check_backward_equal_with_pytorch_float(channels=4, grad_input=True, grad_of
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
D = channels
input0 = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01
offset0 = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10
input0 = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
offset0 = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask0 /= mask0.sum(-1, keepdim=True)
mask0 = mask0.reshape(N, H_out, W_out, M*P)
mask0 = mask0.reshape(N, H_out, W_out, M * P)
input0.requires_grad = grad_input
offset0.requires_grad = grad_offset
mask0.requires_grad = grad_mask
......@@ -223,11 +220,11 @@ def check_time_cost(im2col_step=128):
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01
offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10
input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
mask /= mask.sum(-1, keepdim=True)
mask = mask.reshape(N, H_out, W_out, M*P)
mask = mask.reshape(N, H_out, W_out, M * P)
print(
f'>>> time cost: im2col_step {im2col_step}; input {input.shape}; points {P} ')
repeat = 100
......
......@@ -5,29 +5,19 @@
# ---------------------------------------------
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Linear, bias_init_with_prob
from mmcv.utils import TORCH_VERSION, digit_version
from mmdet.core import (multi_apply, multi_apply, reduce_mean)
from mmdet.models.utils.transformer import inverse_sigmoid
from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmcv.runner import BaseModule, auto_fp16, force_fp32
from mmdet.models import HEADS
from mmdet.models.builder import build_loss
from mmdet.models.dense_heads import DETRHead
from mmdet3d.core.bbox.coders import build_bbox_coder
from projects.mmdet3d_plugin.core.bbox.util import normalize_bbox
from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmcv.runner import force_fp32, auto_fp16
from mmdet.models.utils import build_transformer
from mmdet.models.utils.transformer import inverse_sigmoid
from projects.mmdet3d_plugin.models.utils.bricks import run_time
import numpy as np
import mmcv
import cv2 as cv
from projects.mmdet3d_plugin.models.utils.visual import save_tensor
from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmdet.models.utils import build_transformer
from mmdet.models.builder import build_loss
from mmcv.runner import BaseModule, force_fp32
@HEADS.register_module()
class BEVFormerOccHead(BaseModule):
......@@ -61,15 +51,14 @@ class BEVFormerOccHead(BaseModule):
self.bev_h = bev_h
self.bev_w = bev_w
self.fp16_enabled = False
self.num_classes=kwargs['num_classes']
self.use_mask=use_mask
self.num_classes = kwargs['num_classes']
self.use_mask = use_mask
self.with_box_refine = with_box_refine
self.as_two_stage = as_two_stage
if self.as_two_stage:
transformer['as_two_stage'] = self.as_two_stage
self.pc_range = pc_range
self.real_w = self.pc_range[3] - self.pc_range[0]
self.real_h = self.pc_range[4] - self.pc_range[1]
......@@ -151,7 +140,7 @@ class BEVFormerOccHead(BaseModule):
outs = {
'bev_embed': bev_embed,
'occ':occ_outs,
'occ': occ_outs,
}
return outs
......@@ -166,25 +155,25 @@ class BEVFormerOccHead(BaseModule):
gt_bboxes_ignore=None,
img_metas=None):
loss_dict=dict()
occ=preds_dicts['occ']
assert voxel_semantics.min()>=0 and voxel_semantics.max()<=17
losses = self.loss_single(voxel_semantics,mask_camera,occ)
loss_dict['loss_occ']=losses
loss_dict = dict()
occ = preds_dicts['occ']
assert voxel_semantics.min() >= 0 and voxel_semantics.max() <= 17
losses = self.loss_single(voxel_semantics, mask_camera, occ)
loss_dict['loss_occ'] = losses
return loss_dict
def loss_single(self,voxel_semantics,mask_camera,preds):
voxel_semantics=voxel_semantics.long()
def loss_single(self, voxel_semantics, mask_camera, preds):
voxel_semantics = voxel_semantics.long()
if self.use_mask:
voxel_semantics=voxel_semantics.reshape(-1)
preds=preds.reshape(-1,self.num_classes)
mask_camera=mask_camera.reshape(-1)
num_total_samples=mask_camera.sum()
loss_occ=self.loss_occ(preds,voxel_semantics,mask_camera, avg_factor=num_total_samples)
voxel_semantics = voxel_semantics.reshape(-1)
preds = preds.reshape(-1, self.num_classes)
mask_camera = mask_camera.reshape(-1)
num_total_samples = mask_camera.sum()
loss_occ = self.loss_occ(preds, voxel_semantics, mask_camera, avg_factor=num_total_samples)
else:
voxel_semantics = voxel_semantics.reshape(-1)
preds = preds.reshape(-1, self.num_classes)
loss_occ = self.loss_occ(preds, voxel_semantics,)
loss_occ = self.loss_occ(preds, voxel_semantics, )
return loss_occ
@force_fp32(apply_to=('preds'))
......@@ -199,9 +188,8 @@ class BEVFormerOccHead(BaseModule):
# return self.transformer.get_occ(
# preds_dicts, img_metas, rescale=rescale)
# print(img_metas[0].keys())
occ_out=preds_dicts['occ']
occ_score=occ_out.softmax(-1)
occ_score=occ_score.argmax(-1)
occ_out = preds_dicts['occ']
occ_score = occ_out.softmax(-1)
occ_score = occ_score.argmax(-1)
return occ_score
......@@ -4,17 +4,13 @@
# Modified by Xiaoyu Tian
# ---------------------------------------------
import copy
import torch
from mmcv.runner import force_fp32, auto_fp16
from mmdet.models import DETECTORS
from mmdet3d.core import bbox3d2result
from mmcv.runner import auto_fp16
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
from mmdet.models import DETECTORS
from projects.mmdet3d_plugin.models.utils.grid_mask import GridMask
import time
import copy
import numpy as np
import mmdet3d
from projects.mmdet3d_plugin.models.utils.bricks import run_time
@DETECTORS.register_module()
......
from mmcv.runner.hooks.hook import HOOKS, Hook
from projects.mmdet3d_plugin.models.utils import run_time
@HOOKS.register_module()
class TransferWeight(Hook):
def __init__(self, every_n_inters=1):
self.every_n_inters=every_n_inters
self.every_n_inters = every_n_inters
def after_train_iter(self, runner):
if self.every_n_inner_iters(runner, self.every_n_inters):
runner.eval_model.load_state_dict(runner.model.state_dict())
from .transformer import PerceptionTransformer
from .spatial_cross_attention import SpatialCrossAttention, MSDeformableAttention3D
from .temporal_self_attention import TemporalSelfAttention
from .encoder import BEVFormerEncoder, BEVFormerLayer
from .decoder import DetectionTransformerDecoder
from .encoder import BEVFormerEncoder, BEVFormerLayer
from .spatial_cross_attention import (MSDeformableAttention3D,
SpatialCrossAttention)
from .temporal_self_attention import TemporalSelfAttention
from .transformer import PerceptionTransformer
from .transformer_occ import TransformerOcc
......@@ -8,18 +8,16 @@ import copy
import warnings
import torch
import torch.nn as nn
from mmcv import ConfigDict, deprecated_api_warning
from mmcv.cnn import Linear, build_activation_layer, build_norm_layer
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.cnn.bricks.registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
from mmcv import ConfigDict
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.registry import TRANSFORMER_LAYER
from mmcv.runner.base_module import BaseModule, ModuleList
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
try:
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401
from mmcv.ops.multi_scale_deform_attn import \
MultiScaleDeformableAttention # noqa F401
warnings.warn(
ImportWarning(
'``MultiScaleDeformableAttention`` has been moved to '
......@@ -31,7 +29,8 @@ except ImportError:
warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '
'``mmcv.ops.multi_scale_deform_attn``, '
'You should install ``mmcv-full`` if you need this module. ')
from mmcv.cnn.bricks.transformer import build_feedforward_network, build_attention
from mmcv.cnn.bricks.transformer import (build_attention,
build_feedforward_network)
@TRANSFORMER_LAYER.register_module()
......
......@@ -4,28 +4,21 @@
# Modified by Zhiqi Li
# ---------------------------------------------
from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch
import mmcv
import cv2 as cv
import copy
import math
import warnings
from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import xavier_init, constant_init
from mmcv.cnn.bricks.registry import (ATTENTION,
TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn.bricks.registry import ATTENTION, TRANSFORMER_LAYER_SEQUENCE
from mmcv.cnn.bricks.transformer import TransformerLayerSequence
import math
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning,
to_2tuple)
from mmcv.ops.multi_scale_deform_attn import \
multi_scale_deformable_attn_pytorch
from mmcv.runner.base_module import BaseModule
from mmcv.utils import deprecated_api_warning, ext_loader
from mmcv.utils import ext_loader
from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32, \
MultiScaleDeformableAttnFunction_fp16
from .multi_scale_deformable_attn_function import \
MultiScaleDeformableAttnFunction_fp32
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
......
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
from projects.mmdet3d_plugin.models.utils.bricks import run_time
from projects.mmdet3d_plugin.models.utils.visual import save_tensor
from .custom_base_transformer_layer import MyCustomBaseTransformerLayer
import copy
import warnings
from mmcv.cnn.bricks.registry import (ATTENTION,
TRANSFORMER_LAYER,
TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import TransformerLayerSequence
from mmcv.runner import force_fp32, auto_fp16
import numpy as np
import torch
import cv2 as cv
import mmcv
from mmcv.utils import TORCH_VERSION, digit_version
from mmcv.utils import ext_loader
from mmcv.cnn.bricks.registry import (TRANSFORMER_LAYER,
TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import TransformerLayerSequence
from mmcv.runner import auto_fp16, force_fp32
from mmcv.utils import TORCH_VERSION, digit_version, ext_loader
from .custom_base_transformer_layer import MyCustomBaseTransformerLayer
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class BEVFormerEncoder(TransformerLayerSequence):
"""
Attention with both self and cross
Implements the decoder in DETR transformer.
......@@ -71,7 +66,7 @@ class BEVFormerEncoder(TransformerLayerSequence):
device=device).view(1, H, 1).expand(num_points_in_pillar, H, W) / H
ref_3d = torch.stack((xs, ys, zs), -1)
ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1)
ref_3d = ref_3d[None].repeat(bs, 1, 1, 1) #shape: (bs,num_points_in_pillar,h*w,3)
ref_3d = ref_3d[None].repeat(bs, 1, 1, 1) # shape: (bs,num_points_in_pillar,h*w,3)
return ref_3d
# reference points on 2D bev plane, used in temporal self-attention (TSA).
......@@ -91,7 +86,7 @@ class BEVFormerEncoder(TransformerLayerSequence):
# This function must use fp32!!!
@force_fp32(apply_to=('reference_points', 'img_metas'))
def point_sampling(self, reference_points, pc_range, img_metas):
ego2lidar=img_metas[0]['ego2lidar']
ego2lidar = img_metas[0]['ego2lidar']
lidar2img = []
for img_meta in img_metas:
......@@ -113,17 +108,19 @@ class BEVFormerEncoder(TransformerLayerSequence):
reference_points = torch.cat(
(reference_points, torch.ones_like(reference_points[..., :1])), -1)
reference_points = reference_points.permute(1, 0, 2, 3) #shape: (num_points_in_pillar,bs,h*w,4)
reference_points = reference_points.permute(1, 0, 2, 3) # shape: (num_points_in_pillar,bs,h*w,4)
D, B, num_query = reference_points.size()[:3] # D=num_points_in_pillar , num_query=h*w
num_cam = lidar2img.size(1)
reference_points = reference_points.view(
D, B, 1, num_query, 4).repeat(1, 1, num_cam, 1, 1).unsqueeze(-1) #shape: (num_points_in_pillar,bs,num_cam,h*w,4)
D, B, 1, num_query, 4).repeat(1, 1, num_cam, 1, 1).unsqueeze(
-1) # shape: (num_points_in_pillar,bs,num_cam,h*w,4)
lidar2img = lidar2img.view(
1, B, num_cam, 1, 4, 4).repeat(D, 1, 1, num_query, 1, 1)
ego2lidar=ego2lidar.view(1,1,1,1,4,4).repeat(D,1,num_cam,num_query,1,1)
reference_points_cam = torch.matmul(torch.matmul(lidar2img.to(torch.float32),ego2lidar.to(torch.float32)),reference_points.to(torch.float32)).squeeze(-1)
ego2lidar = ego2lidar.view(1, 1, 1, 1, 4, 4).repeat(D, 1, num_cam, num_query, 1, 1)
reference_points_cam = torch.matmul(torch.matmul(lidar2img.to(torch.float32), ego2lidar.to(torch.float32)),
reference_points.to(torch.float32)).squeeze(-1)
eps = 1e-5
bev_mask = (reference_points_cam[..., 2:3] > eps)
......@@ -143,8 +140,8 @@ class BEVFormerEncoder(TransformerLayerSequence):
bev_mask = bev_mask.new_tensor(
np.nan_to_num(bev_mask.cpu().numpy()))
reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4) #shape: (num_cam,bs,h*w,num_points_in_pillar,2)
reference_points_cam = reference_points_cam.permute(2, 1, 3, 0,
4) # shape: (num_cam,bs,h*w,num_points_in_pillar,2)
bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1)
......@@ -188,7 +185,8 @@ class BEVFormerEncoder(TransformerLayerSequence):
intermediate = []
ref_3d = self.get_reference_points(
bev_h, bev_w, self.pc_range[5]-self.pc_range[2], self.num_points_in_pillar, dim='3d', bs=bev_query.size(1), device=bev_query.device, dtype=bev_query.dtype)
bev_h, bev_w, self.pc_range[5] - self.pc_range[2], self.num_points_in_pillar, dim='3d',
bs=bev_query.size(1), device=bev_query.device, dtype=bev_query.dtype)
ref_2d = self.get_reference_points(
bev_h, bev_w, dim='2d', bs=bev_query.size(1), device=bev_query.device, dtype=bev_query.dtype)
......@@ -206,12 +204,12 @@ class BEVFormerEncoder(TransformerLayerSequence):
if prev_bev is not None:
prev_bev = prev_bev.permute(1, 0, 2)
prev_bev = torch.stack(
[prev_bev, bev_query], 1).reshape(bs*2, len_bev, -1)
[prev_bev, bev_query], 1).reshape(bs * 2, len_bev, -1)
hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape(
bs*2, len_bev, num_bev_level, 2)
bs * 2, len_bev, num_bev_level, 2)
else:
hybird_ref_2d = torch.stack([ref_2d, ref_2d], 1).reshape(
bs*2, len_bev, num_bev_level, 2)
bs * 2, len_bev, num_bev_level, 2)
for lid, layer in enumerate(self.layers):
output = layer(
......
......@@ -5,9 +5,10 @@
# ---------------------------------------------
import torch
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.autograd.function import Function, once_differentiable
from mmcv.utils import ext_loader
from torch.autograd.function import Function, once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
......
# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
# Modified by Zhiqi Li
# ---------------------------------------------
from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch
import math
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import xavier_init, constant_init
from mmcv.cnn.bricks.registry import (ATTENTION,
TRANSFORMER_LAYER,
TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.cnn.bricks.transformer import build_attention
import math
from mmcv.runner import force_fp32, auto_fp16
from mmcv.ops.multi_scale_deform_attn import \
multi_scale_deformable_attn_pytorch
from mmcv.runner import force_fp32
from mmcv.runner.base_module import BaseModule
from mmcv.utils import ext_loader
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from .multi_scale_deformable_attn_function import \
MultiScaleDeformableAttnFunction_fp32
from mmcv.utils import ext_loader
from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32, \
MultiScaleDeformableAttnFunction_fp16
from projects.mmdet3d_plugin.models.utils.bricks import run_time
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
......@@ -150,7 +147,8 @@ class SpatialCrossAttention(BaseModule):
for i, reference_points_per_img in enumerate(reference_points_cam):
index_query_per_img = indexes[i]
queries_rebatch[j, i, :len(index_query_per_img)] = query[j, index_query_per_img]
reference_points_rebatch[j, i, :len(index_query_per_img)] = reference_points_per_img[j, index_query_per_img]
reference_points_rebatch[j, i, :len(index_query_per_img)] = reference_points_per_img[
j, index_query_per_img]
num_cams, l, bs, embed_dims = key.shape
......@@ -159,9 +157,13 @@ class SpatialCrossAttention(BaseModule):
value = value.permute(2, 0, 1, 3).reshape(
bs * self.num_cams, l, self.embed_dims)
queries = self.deformable_attention(query=queries_rebatch.view(bs*self.num_cams, max_len, self.embed_dims), key=key, value=value,
reference_points=reference_points_rebatch.view(bs*self.num_cams, max_len, D, 2), spatial_shapes=spatial_shapes,
level_start_index=level_start_index).view(bs, self.num_cams, max_len, self.embed_dims)
queries = self.deformable_attention(query=queries_rebatch.view(bs * self.num_cams, max_len, self.embed_dims),
key=key, value=value,
reference_points=reference_points_rebatch.view(bs * self.num_cams, max_len,
D, 2),
spatial_shapes=spatial_shapes,
level_start_index=level_start_index).view(bs, self.num_cams, max_len,
self.embed_dims)
for j in range(bs):
for i, index_query_per_img in enumerate(indexes):
slots[j, index_query_per_img] += queries[j, i, :len(index_query_per_img)]
......
......@@ -4,20 +4,21 @@
# Modified by Zhiqi Li
# ---------------------------------------------
from projects.mmdet3d_plugin.models.utils.bricks import run_time
from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32
from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch
import math
import warnings
import torch
import torch.nn as nn
from mmcv.cnn import xavier_init, constant_init
from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn.bricks.registry import ATTENTION
import math
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning,
to_2tuple)
from mmcv.ops.multi_scale_deform_attn import \
multi_scale_deformable_attn_pytorch
from mmcv.runner.base_module import BaseModule
from mmcv.utils import ext_loader
from .multi_scale_deformable_attn_function import \
MultiScaleDeformableAttnFunction_fp32
ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
......@@ -97,9 +98,9 @@ class TemporalSelfAttention(BaseModule):
self.num_points = num_points
self.num_bev_queue = num_bev_queue
self.sampling_offsets = nn.Linear(
embed_dims*self.num_bev_queue, num_bev_queue*num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dims*self.num_bev_queue,
num_bev_queue*num_heads * num_levels * num_points)
embed_dims * self.num_bev_queue, num_bev_queue * num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dims * self.num_bev_queue,
num_bev_queue * num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims)
self.init_weights()
......@@ -114,7 +115,7 @@ class TemporalSelfAttention(BaseModule):
grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view(
self.num_heads, 1, 1,
2).repeat(1, self.num_levels*self.num_bev_queue, self.num_points, 1)
2).repeat(1, self.num_levels * self.num_bev_queue, self.num_points, 1)
for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1
......@@ -177,7 +178,7 @@ class TemporalSelfAttention(BaseModule):
if value is None:
assert self.batch_first
bs, len_bev, c = query.shape
value = torch.stack([query, query], 1).reshape(bs*2, len_bev, c)
value = torch.stack([query, query], 1).reshape(bs * 2, len_bev, c)
# value = torch.cat([query, query], 0)
......@@ -200,7 +201,7 @@ class TemporalSelfAttention(BaseModule):
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0)
value = value.reshape(bs*self.num_bev_queue,
value = value.reshape(bs * self.num_bev_queue,
num_value, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query)
......@@ -216,10 +217,10 @@ class TemporalSelfAttention(BaseModule):
self.num_levels,
self.num_points)
attention_weights = attention_weights.permute(0, 3, 1, 2, 4, 5)\
.reshape(bs*self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points).contiguous()
sampling_offsets = sampling_offsets.permute(0, 3, 1, 2, 4, 5, 6)\
.reshape(bs*self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points, 2)
attention_weights = attention_weights.permute(0, 3, 1, 2, 4, 5) \
.reshape(bs * self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points).contiguous()
sampling_offsets = sampling_offsets.permute(0, 3, 1, 2, 4, 5, 6) \
.reshape(bs * self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points, 2)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack(
......
......@@ -9,18 +9,15 @@ import torch
import torch.nn as nn
from mmcv.cnn import xavier_init
from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence
from mmcv.runner import auto_fp16
from mmcv.runner.base_module import BaseModule
from mmdet.models.utils.builder import TRANSFORMER
from torch.nn.init import normal_
from projects.mmdet3d_plugin.models.utils.visual import save_tensor
from mmcv.runner.base_module import BaseModule
from torchvision.transforms.functional import rotate
from .temporal_self_attention import TemporalSelfAttention
from .spatial_cross_attention import MSDeformableAttention3D
from .decoder import CustomMSDeformableAttention
from projects.mmdet3d_plugin.models.utils.bricks import run_time
from mmcv.runner import force_fp32, auto_fp16
from .spatial_cross_attention import MSDeformableAttention3D
from .temporal_self_attention import TemporalSelfAttention
@TRANSFORMER.register_module()
......
......@@ -7,21 +7,18 @@
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import xavier_init
from mmcv.cnn import ConvModule, xavier_init
from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence
from mmcv.runner import auto_fp16
from mmcv.runner.base_module import BaseModule
from mmdet.models.utils.builder import TRANSFORMER
from torch.nn.init import normal_
from projects.mmdet3d_plugin.models.utils.visual import save_tensor
from mmcv.runner.base_module import BaseModule
from torchvision.transforms.functional import rotate
from .temporal_self_attention import TemporalSelfAttention
from .spatial_cross_attention import MSDeformableAttention3D
from .decoder import CustomMSDeformableAttention
from projects.mmdet3d_plugin.models.utils.bricks import run_time
from mmcv.runner import force_fp32, auto_fp16
from mmcv.cnn import PLUGIN_LAYERS, Conv2d,Conv3d, ConvModule, caffe2_xavier_init
from .spatial_cross_attention import MSDeformableAttention3D
from .temporal_self_attention import TemporalSelfAttention
@TRANSFORMER.register_module()
class TransformerOcc(BaseModule):
......@@ -53,7 +50,7 @@ class TransformerOcc(BaseModule):
num_classes=18,
out_dim=32,
pillar_h=16,
act_cfg=dict(type='ReLU',inplace=True),
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='BN', ),
norm_cfg_3d=dict(type='BN3d', ),
**kwargs):
......@@ -70,10 +67,10 @@ class TransformerOcc(BaseModule):
self.use_can_bus = use_can_bus
self.can_bus_norm = can_bus_norm
self.use_cams_embeds = use_cams_embeds
self.use_3d=use_3d
self.use_conv=use_conv
self.use_3d = use_3d
self.use_conv = use_conv
self.pillar_h = pillar_h
self.out_dim=out_dim
self.out_dim = out_dim
if not use_3d:
if use_conv:
use_bias = norm_cfg is None
......@@ -89,24 +86,24 @@ class TransformerOcc(BaseModule):
act_cfg=act_cfg),
ConvModule(
self.embed_dims,
self.embed_dims*2,
self.embed_dims * 2,
kernel_size=3,
stride=1,
padding=1,
bias=use_bias,
norm_cfg=norm_cfg,
act_cfg=act_cfg),)
act_cfg=act_cfg), )
else:
self.decoder = nn.Sequential(
nn.Linear(self.embed_dims, self.embed_dims * 2),
nn.Softplus(),
nn.Linear(self.embed_dims * 2, self.embed_dims*2),
nn.Linear(self.embed_dims * 2, self.embed_dims * 2),
)
else:
use_bias_3d = norm_cfg_3d is None
self.middle_dims=self.embed_dims//pillar_h
self.middle_dims = self.embed_dims // pillar_h
self.decoder = nn.Sequential(
ConvModule(
self.middle_dims,
......@@ -130,9 +127,9 @@ class TransformerOcc(BaseModule):
act_cfg=act_cfg),
)
self.predicter = nn.Sequential(
nn.Linear(self.out_dim, self.out_dim*2),
nn.Linear(self.out_dim, self.out_dim * 2),
nn.Softplus(),
nn.Linear(self.out_dim*2,num_classes),
nn.Linear(self.out_dim * 2, num_classes),
)
self.two_stage_num_proposals = two_stage_num_proposals
self.init_layers()
......@@ -217,7 +214,7 @@ class TransformerOcc(BaseModule):
prev_bev = prev_bev.permute(1, 0, 2)
elif len(prev_bev.shape) == 4:
prev_bev = prev_bev.view(bs,-1,bev_h * bev_w).permute(2, 0, 1)
prev_bev = prev_bev.view(bs, -1, bev_h * bev_w).permute(2, 0, 1)
if self.rotate_prev_bev:
for i in range(bs):
# num_prev_bev = prev_bev.size(1)
......@@ -337,16 +334,16 @@ class TransformerOcc(BaseModule):
bs = mlvl_feats[0].size(0)
bev_embed = bev_embed.permute(0, 2, 1).view(bs, -1, bev_h, bev_w)
if self.use_3d:
outputs=self.decoder(bev_embed.view(bs,-1,self.pillar_h,bev_h, bev_w))
outputs=outputs.permute(0,4,3,2,1)
outputs = self.decoder(bev_embed.view(bs, -1, self.pillar_h, bev_h, bev_w))
outputs = outputs.permute(0, 4, 3, 2, 1)
elif self.use_conv:
outputs = self.decoder(bev_embed)
outputs = outputs.view(bs, -1,self.pillar_h, bev_h, bev_w).permute(0,3,4,2, 1)
outputs = outputs.view(bs, -1, self.pillar_h, bev_h, bev_w).permute(0, 3, 4, 2, 1)
else:
outputs = self.decoder(bev_embed.permute(0,2,3,1))
outputs = outputs.view(bs, bev_h, bev_w,self.pillar_h,self.out_dim)
outputs = self.decoder(bev_embed.permute(0, 2, 3, 1))
outputs = outputs.view(bs, bev_h, bev_w, self.pillar_h, self.out_dim)
outputs = self.predicter(outputs)
# print('outputs',type(outputs))
return bev_embed, outputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment