Unverified Commit ac5ed37f authored by Zhiqi Li's avatar Zhiqi Li Committed by GitHub
Browse files

Update: support DCNv4 in InternImage! (#277)



* Update: support DCNv4 in InternImage!

* fix mask softmax bug.

---------
Co-authored-by: default avatarZhiqi Li <zhiqil@nvidia>
parent aaac6990
...@@ -54,6 +54,7 @@ The official implementation of ...@@ -54,6 +54,7 @@ The official implementation of
## News ## News
- `Jan 22, 2024`: 🚀 Support [DCNv4](https://github.com/OpenGVLab/DCNv4) in InternImage!
- `Mar 14, 2023`: 🚀 "INTERN-2.5" is released! - `Mar 14, 2023`: 🚀 "INTERN-2.5" is released!
- `Feb 28, 2023`: 🚀 InternImage is accepted to CVPR 2023! - `Feb 28, 2023`: 🚀 InternImage is accepted to CVPR 2023!
- `Nov 18, 2022`: 🚀 InternImage-XL merged into [BEVFormer v2](https://arxiv.org/abs/2211.10439) achieves state-of-the-art performance of `63.4 NDS` on nuScenes Camera Only. - `Nov 18, 2022`: 🚀 InternImage-XL merged into [BEVFormer v2](https://arxiv.org/abs/2211.10439) achieves state-of-the-art performance of `63.4 NDS` on nuScenes Camera Only.
......
# --------------------------------------------------------
# InternImage
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
_base_ = [
'../_base_/models/mask_rcnn_r50_fpn.py',
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_1x.py',
'../_base_/default_runtime.py'
]
pretrained = 'https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_t_1k_224.pth'
model = dict(
backbone=dict(
_delete_=True,
type='InternImage',
core_op='DCNv3',
channels=64,
depths=[4, 4, 18, 4],
groups=[4, 8, 16, 32],
mlp_ratio=4.,
drop_path_rate=0.2,
norm_layer='LN',
layer_scale=1.0,
offset_scale=1.0,
post_norm=False,
with_cp=False,
out_indices=(0, 1, 2, 3),
use_dcn_v4_op=True,
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
neck=dict(
type='FPN',
in_channels=[64, 128, 256, 512],
out_channels=256,
num_outs=5))
# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)
optimizer = dict(
_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.05,
constructor='CustomLayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=30, layer_decay_rate=1.0,
depths=[4, 4, 18, 4]))
optimizer_config = dict(grad_clip=None)
# fp16 = dict(loss_scale=dict(init_scale=512))
evaluation = dict(save_best='auto')
checkpoint_config = dict(
interval=1,
max_keep_ckpts=3,
save_last=True,
)
\ No newline at end of file
...@@ -15,7 +15,7 @@ from mmdet.utils import get_root_logger ...@@ -15,7 +15,7 @@ from mmdet.utils import get_root_logger
from mmdet.models.builder import BACKBONES from mmdet.models.builder import BACKBONES
import torch.nn.functional as F import torch.nn.functional as F
from ops_dcnv3 import modules as opsm from ops_dcnv3 import modules as dcnv3
class to_channels_first(nn.Module): class to_channels_first(nn.Module):
...@@ -365,7 +365,8 @@ class InternImageLayer(nn.Module): ...@@ -365,7 +365,8 @@ class InternImageLayer(nn.Module):
with_cp=False, with_cp=False,
dw_kernel_size=None, # for InternImage-H/G dw_kernel_size=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G res_post_norm=False, # for InternImage-H/G
center_feature_scale=False): # for InternImage-H/G center_feature_scale=False,
use_dcn_v4_op=False): # for InternImage-H/G
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.groups = groups self.groups = groups
...@@ -385,7 +386,8 @@ class InternImageLayer(nn.Module): ...@@ -385,7 +386,8 @@ class InternImageLayer(nn.Module):
act_layer=act_layer, act_layer=act_layer,
norm_layer=norm_layer, norm_layer=norm_layer,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G dw_kernel_size=dw_kernel_size, # for InternImage-H/G
center_feature_scale=center_feature_scale) # for InternImage-H/G center_feature_scale=center_feature_scale,
use_dcn_v4_op=use_dcn_v4_op) # for InternImage-H/G
self.drop_path = DropPath(drop_path) if drop_path > 0. \ self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity() else nn.Identity()
self.norm2 = build_norm_layer(channels, 'LN') self.norm2 = build_norm_layer(channels, 'LN')
...@@ -469,7 +471,8 @@ class InternImageBlock(nn.Module): ...@@ -469,7 +471,8 @@ class InternImageBlock(nn.Module):
dw_kernel_size=None, # for InternImage-H/G dw_kernel_size=None, # for InternImage-H/G
post_norm_block_ids=None, # for InternImage-H/G post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G res_post_norm=False, # for InternImage-H/G
center_feature_scale=False): # for InternImage-H/G center_feature_scale=False, # for InternImage-H/G
use_dcn_v4_op=False):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.depth = depth self.depth = depth
...@@ -493,7 +496,8 @@ class InternImageBlock(nn.Module): ...@@ -493,7 +496,8 @@ class InternImageBlock(nn.Module):
with_cp=with_cp, with_cp=with_cp,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G dw_kernel_size=dw_kernel_size, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale # for InternImage-H/G center_feature_scale=center_feature_scale, # for InternImage-H/G
use_dcn_v4_op=use_dcn_v4_op
) for i in range(depth) ) for i in range(depth)
]) ])
if not self.post_norm or center_feature_scale: if not self.post_norm or center_feature_scale:
...@@ -569,6 +573,7 @@ class InternImage(nn.Module): ...@@ -569,6 +573,7 @@ class InternImage(nn.Module):
level2_post_norm_block_ids=None, # for InternImage-H/G level2_post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G res_post_norm=False, # for InternImage-H/G
center_feature_scale=False, # for InternImage-H/G center_feature_scale=False, # for InternImage-H/G
use_dcn_v4_op=False,
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
init_cfg=None, init_cfg=None,
**kwargs): **kwargs):
...@@ -591,6 +596,7 @@ class InternImage(nn.Module): ...@@ -591,6 +596,7 @@ class InternImage(nn.Module):
logger.info(f"level2_post_norm: {level2_post_norm}") logger.info(f"level2_post_norm: {level2_post_norm}")
logger.info(f"level2_post_norm_block_ids: {level2_post_norm_block_ids}") logger.info(f"level2_post_norm_block_ids: {level2_post_norm_block_ids}")
logger.info(f"res_post_norm: {res_post_norm}") logger.info(f"res_post_norm: {res_post_norm}")
logger.info(f"use_dcn_v4_op: {use_dcn_v4_op}")
in_chans = 3 in_chans = 3
self.patch_embed = StemLayer(in_chans=in_chans, self.patch_embed = StemLayer(in_chans=in_chans,
...@@ -611,7 +617,7 @@ class InternImage(nn.Module): ...@@ -611,7 +617,7 @@ class InternImage(nn.Module):
post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and ( post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and (
i == 2) else None # for InternImage-H/G i == 2) else None # for InternImage-H/G
level = InternImageBlock( level = InternImageBlock(
core_op=getattr(opsm, core_op), core_op=getattr(dcnv3, core_op),
channels=int(channels * 2**i), channels=int(channels * 2**i),
depth=depths[i], depth=depths[i],
groups=groups[i], groups=groups[i],
...@@ -628,7 +634,8 @@ class InternImage(nn.Module): ...@@ -628,7 +634,8 @@ class InternImage(nn.Module):
dw_kernel_size=dw_kernel_size, # for InternImage-H/G dw_kernel_size=dw_kernel_size, # for InternImage-H/G
post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale # for InternImage-H/G center_feature_scale=center_feature_scale, # for InternImage-H/G
use_dcn_v4_op=use_dcn_v4_op,
) )
self.levels.append(level) self.levels.append(level)
...@@ -687,7 +694,7 @@ class InternImage(nn.Module): ...@@ -687,7 +694,7 @@ class InternImage(nn.Module):
nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.weight, 1.0)
def _init_deform_weights(self, m): def _init_deform_weights(self, m):
if isinstance(m, getattr(opsm, self.core_op)): if isinstance(m, getattr(dcnv3, self.core_op)):
m._reset_parameters() m._reset_parameters()
def forward(self, x): def forward(self, x):
......
...@@ -14,7 +14,11 @@ from torch import nn ...@@ -14,7 +14,11 @@ from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_ from torch.nn.init import xavier_uniform_, constant_
from ..functions import DCNv3Function, dcnv3_core_pytorch from ..functions import DCNv3Function, dcnv3_core_pytorch
try:
from DCNv4.functions import DCNv4Function
except:
warnings.warn('Now, we support DCNv4 in InternImage.')
import math
class to_channels_first(nn.Module): class to_channels_first(nn.Module):
...@@ -228,7 +232,9 @@ class DCNv3(nn.Module): ...@@ -228,7 +232,9 @@ class DCNv3(nn.Module):
offset_scale=1.0, offset_scale=1.0,
act_layer='GELU', act_layer='GELU',
norm_layer='LN', norm_layer='LN',
center_feature_scale=False): center_feature_scale=False,
use_dcn_v4_op=False,
):
""" """
DCNv3 Module DCNv3 Module
:param channels :param channels
...@@ -264,7 +270,9 @@ class DCNv3(nn.Module): ...@@ -264,7 +270,9 @@ class DCNv3(nn.Module):
self.group_channels = channels // group self.group_channels = channels // group
self.offset_scale = offset_scale self.offset_scale = offset_scale
self.center_feature_scale = center_feature_scale self.center_feature_scale = center_feature_scale
self.use_dcn_v4_op = use_dcn_v4_op
self.dw_conv = nn.Sequential( self.dw_conv = nn.Sequential(
nn.Conv2d( nn.Conv2d(
channels, channels,
...@@ -321,18 +329,44 @@ class DCNv3(nn.Module): ...@@ -321,18 +329,44 @@ class DCNv3(nn.Module):
x1 = self.dw_conv(x1) x1 = self.dw_conv(x1)
offset = self.offset(x1) offset = self.offset(x1)
mask = self.mask(x1).reshape(N, H, W, self.group, -1) mask = self.mask(x1).reshape(N, H, W, self.group, -1)
mask = F.softmax(mask, -1).reshape(N, H, W, -1).type(dtype)
x = DCNv3Function.apply(
x, offset, mask,
self.kernel_size, self.kernel_size,
self.stride, self.stride,
self.pad, self.pad,
self.dilation, self.dilation,
self.group, self.group_channels,
self.offset_scale,
256)
if not self.use_dcn_v4_op:
mask = F.softmax(mask, -1).reshape(N, H, W, -1).type(dtype)
x = DCNv3Function.apply(
x, offset, mask,
self.kernel_size, self.kernel_size,
self.stride, self.stride,
self.pad, self.pad,
self.dilation, self.dilation,
self.group, self.group_channels,
self.offset_scale,
256)
else:
# DCNv4 combines offset and weight mask into one tensor `offset_mask`.
# The following code is to align DCNv3 and DCNv4
offset = offset.view(N, H, W, self.group, -1)
mask = F.softmax(mask, -1)
mask = mask.view(N, H, W, self.group, -1)
offset_mask = torch.cat([offset, mask], -1).view(N, H, W, -1).contiguous()
# For efficiency, the last dimension of the offset_mask tensor in dcnv4 is a multiple of 8.
K3 = offset_mask.size(-1)
K3_pad = int(math.ceil(K3/8)*8)
pad_dim = K3_pad - K3
offset_mask = torch.cat([offset_mask, offset_mask.new_zeros([*offset_mask.size()[:3], pad_dim])], -1)
x = DCNv4Function.apply(
x, offset_mask,
self.kernel_size, self.kernel_size,
self.stride, self.stride,
self.pad, self.pad,
self.dilation, self.dilation,
self.group, self.group_channels,
self.offset_scale,
256,
False
)
if self.center_feature_scale: if self.center_feature_scale:
center_feature_scale = self.center_feature_scale_module( center_feature_scale = self.center_feature_scale_module(
x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
......
...@@ -15,7 +15,7 @@ from mmseg.utils import get_root_logger ...@@ -15,7 +15,7 @@ from mmseg.utils import get_root_logger
from mmseg.models.builder import BACKBONES from mmseg.models.builder import BACKBONES
import torch.nn.functional as F import torch.nn.functional as F
from ops_dcnv3 import modules as opsm from ops_dcnv3 import modules as dcnv3
class to_channels_first(nn.Module): class to_channels_first(nn.Module):
...@@ -365,7 +365,8 @@ class InternImageLayer(nn.Module): ...@@ -365,7 +365,8 @@ class InternImageLayer(nn.Module):
with_cp=False, with_cp=False,
dw_kernel_size=None, # for InternImage-H/G dw_kernel_size=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G res_post_norm=False, # for InternImage-H/G
center_feature_scale=False): # for InternImage-H/G center_feature_scale=False,
use_dcn_v4_op=False): # for InternImage-H/G
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.groups = groups self.groups = groups
...@@ -385,7 +386,8 @@ class InternImageLayer(nn.Module): ...@@ -385,7 +386,8 @@ class InternImageLayer(nn.Module):
act_layer=act_layer, act_layer=act_layer,
norm_layer=norm_layer, norm_layer=norm_layer,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G dw_kernel_size=dw_kernel_size, # for InternImage-H/G
center_feature_scale=center_feature_scale) # for InternImage-H/G center_feature_scale=center_feature_scale,
use_dcn_v4_op=use_dcn_v4_op) # for InternImage-H/G
self.drop_path = DropPath(drop_path) if drop_path > 0. \ self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity() else nn.Identity()
self.norm2 = build_norm_layer(channels, 'LN') self.norm2 = build_norm_layer(channels, 'LN')
...@@ -469,7 +471,8 @@ class InternImageBlock(nn.Module): ...@@ -469,7 +471,8 @@ class InternImageBlock(nn.Module):
dw_kernel_size=None, # for InternImage-H/G dw_kernel_size=None, # for InternImage-H/G
post_norm_block_ids=None, # for InternImage-H/G post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G res_post_norm=False, # for InternImage-H/G
center_feature_scale=False): # for InternImage-H/G center_feature_scale=False, # for InternImage-H/G
use_dcn_v4_op=False):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.depth = depth self.depth = depth
...@@ -493,7 +496,8 @@ class InternImageBlock(nn.Module): ...@@ -493,7 +496,8 @@ class InternImageBlock(nn.Module):
with_cp=with_cp, with_cp=with_cp,
dw_kernel_size=dw_kernel_size, # for InternImage-H/G dw_kernel_size=dw_kernel_size, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale # for InternImage-H/G center_feature_scale=center_feature_scale, # for InternImage-H/G
use_dcn_v4_op=use_dcn_v4_op
) for i in range(depth) ) for i in range(depth)
]) ])
if not self.post_norm or center_feature_scale: if not self.post_norm or center_feature_scale:
...@@ -569,6 +573,7 @@ class InternImage(nn.Module): ...@@ -569,6 +573,7 @@ class InternImage(nn.Module):
level2_post_norm_block_ids=None, # for InternImage-H/G level2_post_norm_block_ids=None, # for InternImage-H/G
res_post_norm=False, # for InternImage-H/G res_post_norm=False, # for InternImage-H/G
center_feature_scale=False, # for InternImage-H/G center_feature_scale=False, # for InternImage-H/G
use_dcn_v4_op=False,
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
init_cfg=None, init_cfg=None,
**kwargs): **kwargs):
...@@ -591,6 +596,7 @@ class InternImage(nn.Module): ...@@ -591,6 +596,7 @@ class InternImage(nn.Module):
logger.info(f"level2_post_norm: {level2_post_norm}") logger.info(f"level2_post_norm: {level2_post_norm}")
logger.info(f"level2_post_norm_block_ids: {level2_post_norm_block_ids}") logger.info(f"level2_post_norm_block_ids: {level2_post_norm_block_ids}")
logger.info(f"res_post_norm: {res_post_norm}") logger.info(f"res_post_norm: {res_post_norm}")
logger.info(f"use_dcn_v4_op: {use_dcn_v4_op}")
in_chans = 3 in_chans = 3
self.patch_embed = StemLayer(in_chans=in_chans, self.patch_embed = StemLayer(in_chans=in_chans,
...@@ -611,7 +617,7 @@ class InternImage(nn.Module): ...@@ -611,7 +617,7 @@ class InternImage(nn.Module):
post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and ( post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and (
i == 2) else None # for InternImage-H/G i == 2) else None # for InternImage-H/G
level = InternImageBlock( level = InternImageBlock(
core_op=getattr(opsm, core_op), core_op=getattr(dcnv3, core_op),
channels=int(channels * 2**i), channels=int(channels * 2**i),
depth=depths[i], depth=depths[i],
groups=groups[i], groups=groups[i],
...@@ -628,7 +634,8 @@ class InternImage(nn.Module): ...@@ -628,7 +634,8 @@ class InternImage(nn.Module):
dw_kernel_size=dw_kernel_size, # for InternImage-H/G dw_kernel_size=dw_kernel_size, # for InternImage-H/G
post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G
res_post_norm=res_post_norm, # for InternImage-H/G res_post_norm=res_post_norm, # for InternImage-H/G
center_feature_scale=center_feature_scale # for InternImage-H/G center_feature_scale=center_feature_scale, # for InternImage-H/G
use_dcn_v4_op=use_dcn_v4_op,
) )
self.levels.append(level) self.levels.append(level)
...@@ -687,7 +694,7 @@ class InternImage(nn.Module): ...@@ -687,7 +694,7 @@ class InternImage(nn.Module):
nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.weight, 1.0)
def _init_deform_weights(self, m): def _init_deform_weights(self, m):
if isinstance(m, getattr(opsm, self.core_op)): if isinstance(m, getattr(dcnv3, self.core_op)):
m._reset_parameters() m._reset_parameters()
def forward(self, x): def forward(self, x):
......
...@@ -14,7 +14,11 @@ from torch import nn ...@@ -14,7 +14,11 @@ from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_ from torch.nn.init import xavier_uniform_, constant_
from ..functions import DCNv3Function, dcnv3_core_pytorch from ..functions import DCNv3Function, dcnv3_core_pytorch
try:
from DCNv4.functions import DCNv4Function
except:
warnings.warn('Now, we support DCNv4 in InternImage.')
import math
class to_channels_first(nn.Module): class to_channels_first(nn.Module):
...@@ -228,7 +232,9 @@ class DCNv3(nn.Module): ...@@ -228,7 +232,9 @@ class DCNv3(nn.Module):
offset_scale=1.0, offset_scale=1.0,
act_layer='GELU', act_layer='GELU',
norm_layer='LN', norm_layer='LN',
center_feature_scale=False): center_feature_scale=False,
use_dcn_v4_op=False,
):
""" """
DCNv3 Module DCNv3 Module
:param channels :param channels
...@@ -264,7 +270,9 @@ class DCNv3(nn.Module): ...@@ -264,7 +270,9 @@ class DCNv3(nn.Module):
self.group_channels = channels // group self.group_channels = channels // group
self.offset_scale = offset_scale self.offset_scale = offset_scale
self.center_feature_scale = center_feature_scale self.center_feature_scale = center_feature_scale
self.use_dcn_v4_op = use_dcn_v4_op
self.dw_conv = nn.Sequential( self.dw_conv = nn.Sequential(
nn.Conv2d( nn.Conv2d(
channels, channels,
...@@ -321,18 +329,44 @@ class DCNv3(nn.Module): ...@@ -321,18 +329,44 @@ class DCNv3(nn.Module):
x1 = self.dw_conv(x1) x1 = self.dw_conv(x1)
offset = self.offset(x1) offset = self.offset(x1)
mask = self.mask(x1).reshape(N, H, W, self.group, -1) mask = self.mask(x1).reshape(N, H, W, self.group, -1)
mask = F.softmax(mask, -1).reshape(N, H, W, -1).type(dtype)
x = DCNv3Function.apply(
x, offset, mask,
self.kernel_size, self.kernel_size,
self.stride, self.stride,
self.pad, self.pad,
self.dilation, self.dilation,
self.group, self.group_channels,
self.offset_scale,
256)
if not self.use_dcn_v4_op:
mask = F.softmax(mask, -1).reshape(N, H, W, -1).type(dtype)
x = DCNv3Function.apply(
x, offset, mask,
self.kernel_size, self.kernel_size,
self.stride, self.stride,
self.pad, self.pad,
self.dilation, self.dilation,
self.group, self.group_channels,
self.offset_scale,
256)
else:
# DCNv4 combines offset and weight mask into one tensor `offset_mask`.
# The following code is to align DCNv3 and DCNv4
offset = offset.view(N, H, W, self.group, -1)
mask = F.softmax(mask, -1)
mask = mask.view(N, H, W, self.group, -1)
offset_mask = torch.cat([offset, mask], -1).view(N, H, W, -1).contiguous()
# For efficiency, the last dimension of the offset_mask tensor in dcnv4 is a multiple of 8.
K3 = offset_mask.size(-1)
K3_pad = int(math.ceil(K3/8)*8)
pad_dim = K3_pad - K3
offset_mask = torch.cat([offset_mask, offset_mask.new_zeros([*offset_mask.size()[:3], pad_dim])], -1)
x = DCNv4Function.apply(
x, offset_mask,
self.kernel_size, self.kernel_size,
self.stride, self.stride,
self.pad, self.pad,
self.dilation, self.dilation,
self.group, self.group_channels,
self.offset_scale,
256,
False
)
if self.center_feature_scale: if self.center_feature_scale:
center_feature_scale = self.center_feature_scale_module( center_feature_scale = self.center_feature_scale_module(
x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment