Unverified Commit 42f03d84 authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

Refactor plugins: move from ops to cnn (#380)

* Refator plugins: move from ops to cnn

* minior update

* minior update
parent 82211b40
from .activation import build_activation_layer
from .context_block import ContextBlock
from .conv import build_conv_layer
from .conv_module import ConvModule
from .generalized_attention import GeneralizedAttention
from .hsigmoid import HSigmoid
from .hswish import HSwish
from .non_local import NonLocal1d, NonLocal2d, NonLocal3d
from .norm import build_norm_layer, is_norm
from .padding import build_padding_layer
from .plugin import build_plugin_layer
from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, UPSAMPLE_LAYERS)
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS)
from .scale import Scale
from .upsample import build_upsample_layer
__all__ = [
'ConvModule', 'build_activation_layer', 'build_conv_layer',
'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d', 'NonLocal2d', 'NonLocal3d',
'build_plugin_layer', 'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d',
'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
'UPSAMPLE_LAYERS', 'Scale'
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale'
]
import torch
from torch import nn
from ..cnn import constant_init, kaiming_init
from ..utils import constant_init, kaiming_init
from .registry import PLUGIN_LAYERS
def last_zero_init(m):
......@@ -11,6 +12,7 @@ def last_zero_init(m):
constant_init(m, val=0)
@PLUGIN_LAYERS.register_module()
class ContextBlock(nn.Module):
"""ContextBlock module in GCNet.
......@@ -20,11 +22,16 @@ class ContextBlock(nn.Module):
Args:
in_channels (int): Channels of the input feature map.
ratio (float): Ratio of channels of transform bottleneck
pooling_type (str): Pooling method for context modeling
fusion_types (list[str]|tuple[str]): Fusion method for feature fusion,
options: 'channels_add', 'channel_mul'
pooling_type (str): Pooling method for context modeling.
Options are 'att' and 'avg', stand for attention pooling and
average pooling respectively. Default: 'att'.
fusion_types (Sequence[str]): Fusion method for feature fusion,
Options are 'channels_add', 'channel_mul', stand for channelwise
addition and multiplication respectively. Default: ('channel_add',)
"""
_abbr_ = 'context_block'
def __init__(self,
in_channels,
ratio,
......
......@@ -7,8 +7,10 @@ from .activation import build_activation_layer
from .conv import build_conv_layer
from .norm import build_norm_layer
from .padding import build_padding_layer
from .registry import PLUGIN_LAYERS
@PLUGIN_LAYERS.register_module()
class ConvModule(nn.Module):
"""A conv block that bundles conv/norm/activation layers.
......@@ -54,6 +56,8 @@ class ConvModule(nn.Module):
Default: ('conv', 'norm', 'act').
"""
_abbr_ = 'conv_block'
def __init__(self,
in_channels,
out_channels,
......
......@@ -5,9 +5,11 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ..cnn import kaiming_init
from ..utils import kaiming_init
from .registry import PLUGIN_LAYERS
@PLUGIN_LAYERS.register_module()
class GeneralizedAttention(nn.Module):
"""GeneralizedAttention module.
......@@ -16,13 +18,18 @@ class GeneralizedAttention(nn.Module):
Args:
in_channels (int): Channels of the input feature map.
spatial_range (int): The spatial range.
-1 indicates no spatial range constraint.
spatial_range (int): The spatial range. -1 indicates no spatial range
constraint. Default: -1.
num_heads (int): The head number of empirical_attention module.
Default: 9.
position_embedding_dim (int): The position embedding dimension.
Default: -1.
position_magnitude (int): A multiplier acting on coord difference.
Default: 1.
kv_stride (int): The feature stride acting on key/value feature map.
Default: 2.
q_stride (int): The feature stride acting on query feature map.
Default: 1.
attention_type (str): A binary indicator string for indicating which
items in generalized empirical_attention module are used.
'1000' indicates 'query and key content' (appr - appr) item,
......@@ -30,8 +37,11 @@ class GeneralizedAttention(nn.Module):
(appr - position) item,
'0010' indicates 'key content only' (bias - appr) item,
'0001' indicates 'relative position only' (bias - position) item.
Default: '1111'.
"""
_abbr_ = 'gen_attention_block'
def __init__(self,
in_channels,
spatial_range=-1,
......@@ -161,16 +171,16 @@ class GeneralizedAttention(nn.Module):
device,
feat_dim,
wave_length=1000):
h_idxs = torch.linspace(0, h - 1, h).cuda(device)
h_idxs = torch.linspace(0, h - 1, h).to(device)
h_idxs = h_idxs.view((h, 1)) * q_stride
w_idxs = torch.linspace(0, w - 1, w).cuda(device)
w_idxs = torch.linspace(0, w - 1, w).to(device)
w_idxs = w_idxs.view((w, 1)) * q_stride
h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).cuda(device)
h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(device)
h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride
w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).cuda(device)
w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(device)
w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride
# (h, h_kv, 1)
......@@ -181,9 +191,9 @@ class GeneralizedAttention(nn.Module):
w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0)
w_diff *= self.position_magnitude
feat_range = torch.arange(0, feat_dim / 4).cuda(device)
feat_range = torch.arange(0, feat_dim / 4).to(device)
dim_mat = torch.Tensor([wave_length]).cuda(device)
dim_mat = torch.Tensor([wave_length]).to(device)
dim_mat = dim_mat**((4. / feat_dim) * feat_range)
dim_mat = dim_mat.view((1, 1, -1))
......@@ -370,6 +380,15 @@ class GeneralizedAttention(nn.Module):
view(n, self.v_dim * self.num_heads, h, w)
out = self.proj_conv(out)
# output is downsampled, upsample back to input size
if self.q_downsample is not None:
out = F.interpolate(
out,
size=x_input.shape[2:],
mode='bilinear',
align_corners=False)
out = self.gamma * out + x_input
return out
......
......@@ -5,6 +5,7 @@ import torch.nn as nn
from ..utils import constant_init, normal_init
from .conv_module import ConvModule
from .registry import PLUGIN_LAYERS
class _NonLocalNd(nn.Module, metaclass=ABCMeta):
......@@ -185,6 +186,7 @@ class NonLocal1d(_NonLocalNd):
self.phi = nn.Sequential(self.phi, max_pool_layer)
@PLUGIN_LAYERS.register_module()
class NonLocal2d(_NonLocalNd):
"""2D Non-local module.
......@@ -197,6 +199,8 @@ class NonLocal2d(_NonLocalNd):
Default: dict(type='Conv2d').
"""
_abbr_ = 'nonlocal_block'
def __init__(self,
in_channels,
sub_sample=False,
......
......@@ -26,7 +26,7 @@ def infer_abbr(class_type):
the norm type in variable names, e.g, self.bn1, self.gn. This method will
infer the abbreviation to map class types to abbreviations.
Rule 1: If the class has the property "abbr", return the property.
Rule 1: If the class has the property "_abbr_", return the property.
Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or
InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and
"in" respectively.
......@@ -44,8 +44,8 @@ def infer_abbr(class_type):
if not inspect.isclass(class_type):
raise TypeError(
f'class_type must be a type, but got {type(class_type)}')
if hasattr(class_type, 'abbr'):
return class_type.abbr
if hasattr(class_type, '_abbr_'):
return class_type._abbr_
if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN
return 'in'
elif issubclass(class_type, _BatchNorm):
......
import inspect
import re
from .registry import PLUGIN_LAYERS
def infer_abbr(class_type):
"""Infer abbreviation from the class name.
This method will infer the abbreviation to map class types to
abbreviations.
Rule 1: If the class has the property "abbr", return the property.
Rule 2: Otherwise, the abbreviation falls back to snake case of class
name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``.
Args:
class_type (type): The norm layer type.
Returns:
str: The inferred abbreviation.
"""
def camel2snack(word):
"""Convert camel case word into snack case.
Modified from `inflection lib
<https://inflection.readthedocs.io/en/latest/#inflection.underscore>`_.
Example::
>>> camel2snack("FancyBlock")
'fancy_block'
"""
word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word)
word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word)
word = word.replace('-', '_')
return word.lower()
if not inspect.isclass(class_type):
raise TypeError(
f'class_type must be a type, but got {type(class_type)}')
if hasattr(class_type, '_abbr_'):
return class_type._abbr_
else:
return camel2snack(class_type.__name__)
def build_plugin_layer(cfg, postfix='', **kwargs):
"""Build plugin layer.
Args:
cfg (None or dict): cfg should contain:
type (str): identify plugin layer type.
layer args: args needed to instantiate a plugin layer.
postfix (int, str): appended into norm abbreviation to
create named layer. Default: ''.
Returns:
tuple[str, nn.Module]:
name (str): abbreviation + postfix
layer (nn.Module): created plugin layer
"""
if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict')
if 'type' not in cfg:
raise KeyError('the cfg dict must contain the key "type"')
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
if layer_type not in PLUGIN_LAYERS:
raise KeyError(f'Unrecognized plugin type {layer_type}')
plugin_layer = PLUGIN_LAYERS.get(layer_type)
abbr = infer_abbr(plugin_layer)
assert isinstance(postfix, (int, str))
name = abbr + str(postfix)
layer = plugin_layer(**kwargs, **cfg_)
return name, layer
......@@ -5,3 +5,4 @@ NORM_LAYERS = Registry('norm layer')
ACTIVATION_LAYERS = Registry('activation layer')
PADDING_LAYERS = Registry('padding layer')
UPSAMPLE_LAYERS = Registry('upsample layer')
PLUGIN_LAYERS = Registry('plugin layer')
from .bbox import bbox_overlaps
from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
from .cc_attention import CrissCrossAttention
from .context_block import ContextBlock
from .conv_ws import ConvWS2d, conv_ws_2d
from .corner_pool import CornerPool
from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d
......@@ -9,14 +8,12 @@ from .deform_roi_pool import (DeformRoIPool, DeformRoIPoolPack,
ModulatedDeformRoIPoolPack, deform_roi_pool)
from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
sigmoid_focal_loss, softmax_focal_loss)
from .generalized_attention import GeneralizedAttention
from .info import get_compiler_version, get_compiling_cuda_version
from .masked_conv import MaskedConv2d, masked_conv2d
from .modulated_deform_conv import (ModulatedDeformConv2d,
ModulatedDeformConv2dPack,
modulated_deform_conv2d)
from .nms import batched_nms, nms, nms_match, soft_nms
from .plugin import build_plugin_layer
from .point_sample import (SimpleRoIAlign, point_sample,
rel_roi_point_to_rel_img_point)
from .psa_mask import PSAMask
......@@ -27,16 +24,14 @@ from .wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d
__all__ = [
'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe',
'carafe_naive', 'ContextBlock', 'ConvWS2d', 'conv_ws_2d', 'CornerPool',
'DeformConv2d', 'DeformConv2dPack', 'deform_conv2d', 'DeformRoIPool',
'DeformRoIPoolPack', 'ModulatedDeformRoIPoolPack', 'deform_roi_pool',
'SigmoidFocalLoss', 'SoftmaxFocalLoss', 'sigmoid_focal_loss',
'softmax_focal_loss', 'GeneralizedAttention', 'get_compiler_version',
'get_compiling_cuda_version', 'MaskedConv2d', 'masked_conv2d',
'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack',
'carafe_naive', 'ConvWS2d', 'conv_ws_2d', 'CornerPool', 'DeformConv2d',
'DeformConv2dPack', 'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack',
'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss',
'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss',
'get_compiler_version', 'get_compiling_cuda_version', 'MaskedConv2d',
'masked_conv2d', 'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack',
'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match',
'build_plugin_layer', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool',
'SyncBatchNorm', 'Conv2d', 'ConvTranspose2d', 'Linear', 'MaxPool2d',
'CrissCrossAttention', 'PSAMask', 'point_sample',
'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign'
'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign'
]
from ..cnn import ConvModule, NonLocal2d
from .context_block import ContextBlock
from .generalized_attention import GeneralizedAttention
plugin_cfg = {
# format: layer_type: (abbreviation, module)
'ContextBlock': ('context_block', ContextBlock),
'GeneralizedAttention': ('gen_attention_block', GeneralizedAttention),
'NonLocal2d': ('nonlocal_block', NonLocal2d),
'ConvModule': ('conv_block', ConvModule),
}
def build_plugin_layer(cfg, postfix='', **kwargs):
"""Build plugin layer.
Args:
cfg (None or dict): cfg should contain:
type (str): identify plugin layer type.
layer args: args needed to instantiate a plugin layer.
postfix (int, str): appended into norm abbreviation to
create named layer.
Returns:
name (str): abbreviation + postfix
layer (nn.Module): created plugin layer
"""
assert isinstance(cfg, dict) and 'type' in cfg
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
if layer_type not in plugin_cfg:
raise KeyError(f'Unrecognized plugin type {layer_type}')
else:
abbr, plugin_layer = plugin_cfg[layer_type]
assert isinstance(postfix, (int, str))
name = abbr + str(postfix)
layer = plugin_layer(**kwargs, **cfg_)
return name, layer
......@@ -3,11 +3,12 @@ import torch
import torch.nn as nn
from mmcv.cnn.bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, build_activation_layer,
build_conv_layer, build_norm_layer,
build_padding_layer, build_upsample_layer,
is_norm)
from mmcv.cnn.bricks.norm import infer_abbr
PADDING_LAYERS, PLUGIN_LAYERS,
build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer,
build_plugin_layer, build_upsample_layer, is_norm)
from mmcv.cnn.bricks.norm import infer_abbr as infer_norm_abbr
from mmcv.cnn.bricks.plugin import infer_abbr as infer_plugin_abbr
from mmcv.cnn.bricks.upsample import PixelShufflePack
from mmcv.utils.parrots_wrapper import _BatchNorm
......@@ -56,41 +57,41 @@ def test_build_conv_layer():
assert layer.out_channels == kwargs['out_channels']
def test_infer_abbr():
def test_infer_norm_abbr():
with pytest.raises(TypeError):
# class_type must be a class
infer_abbr(0)
infer_norm_abbr(0)
class MyNorm:
abbr = 'mn'
_abbr_ = 'mn'
assert infer_abbr(MyNorm) == 'mn'
assert infer_norm_abbr(MyNorm) == 'mn'
class FancyBatchNorm:
pass
assert infer_abbr(FancyBatchNorm) == 'bn'
assert infer_norm_abbr(FancyBatchNorm) == 'bn'
class FancyInstanceNorm:
pass
assert infer_abbr(FancyInstanceNorm) == 'in'
assert infer_norm_abbr(FancyInstanceNorm) == 'in'
class FancyLayerNorm:
pass
assert infer_abbr(FancyLayerNorm) == 'ln'
assert infer_norm_abbr(FancyLayerNorm) == 'ln'
class FancyGroupNorm:
pass
assert infer_abbr(FancyGroupNorm) == 'gn'
assert infer_norm_abbr(FancyGroupNorm) == 'gn'
class FancyNorm:
pass
assert infer_abbr(FancyNorm) == 'norm'
assert infer_norm_abbr(FancyNorm) == 'norm'
def test_build_norm_layer():
......@@ -296,3 +297,77 @@ def test_is_norm():
with pytest.raises(TypeError):
layer = nn.BatchNorm1d(3)
is_norm(layer, exclude=('BN', ))
def test_infer_plugin_abbr():
with pytest.raises(TypeError):
# class_type must be a class
infer_plugin_abbr(0)
class MyPlugin:
_abbr_ = 'mp'
assert infer_plugin_abbr(MyPlugin) == 'mp'
class FancyPlugin:
pass
assert infer_plugin_abbr(FancyPlugin) == 'fancy_plugin'
def test_build_plugin_layer():
with pytest.raises(TypeError):
# cfg must be a dict
cfg = 'Plugin'
build_plugin_layer(cfg)
with pytest.raises(KeyError):
# `type` must be in cfg
cfg = dict()
build_plugin_layer(cfg)
with pytest.raises(KeyError):
# unsupported plugin type
cfg = dict(type='FancyPlugin')
build_plugin_layer(cfg)
with pytest.raises(AssertionError):
# postfix must be int or str
cfg = dict(type='ConvModule')
build_plugin_layer(cfg, postfix=[1, 2])
# test ContextBlock
for postfix in ['', '_test', 1]:
cfg = dict(type='ContextBlock')
name, layer = build_plugin_layer(
cfg, postfix=postfix, in_channels=16, ratio=1. / 4)
assert name == 'context_block' + str(postfix)
assert isinstance(layer, PLUGIN_LAYERS.module_dict['ContextBlock'])
# test GeneralizedAttention
for postfix in ['', '_test', 1]:
cfg = dict(type='GeneralizedAttention')
name, layer = build_plugin_layer(cfg, postfix=postfix, in_channels=16)
assert name == 'gen_attention_block' + str(postfix)
assert isinstance(layer,
PLUGIN_LAYERS.module_dict['GeneralizedAttention'])
# test NonLocal2d
for postfix in ['', '_test', 1]:
cfg = dict(type='NonLocal2d')
name, layer = build_plugin_layer(cfg, postfix=postfix, in_channels=16)
assert name == 'nonlocal_block' + str(postfix)
assert isinstance(layer, PLUGIN_LAYERS.module_dict['NonLocal2d'])
# test ConvModule
for postfix in ['', '_test', 1]:
cfg = dict(type='ConvModule')
name, layer = build_plugin_layer(
cfg,
postfix=postfix,
in_channels=16,
out_channels=4,
kernel_size=3)
assert name == 'conv_block' + str(postfix)
assert isinstance(layer, PLUGIN_LAYERS.module_dict['ConvModule'])
import pytest
import torch
from mmcv.cnn.bricks import ContextBlock
def test_context_block():
with pytest.raises(AssertionError):
# pooling_type should be in ['att', 'avg']
ContextBlock(16, 1. / 4, pooling_type='unsupport_type')
with pytest.raises(AssertionError):
# fusion_types should be of type list or tuple
ContextBlock(16, 1. / 4, fusion_types='unsupport_type')
with pytest.raises(AssertionError):
# fusion_types should be in ['channel_add', 'channel_mul']
ContextBlock(16, 1. / 4, fusion_types=('unsupport_type', ))
# test pooling_type='att'
imgs = torch.randn(2, 16, 20, 20)
context_block = ContextBlock(16, 1. / 4, pooling_type='att')
out = context_block(imgs)
assert context_block.conv_mask.in_channels == 16
assert context_block.conv_mask.out_channels == 1
assert out.shape == imgs.shape
# test pooling_type='avg'
imgs = torch.randn(2, 16, 20, 20)
context_block = ContextBlock(16, 1. / 4, pooling_type='avg')
out = context_block(imgs)
assert hasattr(context_block, 'avg_pool')
assert out.shape == imgs.shape
# test fusion_types=('channel_add',)
imgs = torch.randn(2, 16, 20, 20)
context_block = ContextBlock(16, 1. / 4, fusion_types=('channel_add', ))
out = context_block(imgs)
assert context_block.channel_add_conv is not None
assert context_block.channel_mul_conv is None
assert out.shape == imgs.shape
# test fusion_types=('channel_mul',)
imgs = torch.randn(2, 16, 20, 20)
context_block = ContextBlock(16, 1. / 4, fusion_types=('channel_mul', ))
out = context_block(imgs)
assert context_block.channel_add_conv is None
assert context_block.channel_mul_conv is not None
assert out.shape == imgs.shape
# test fusion_types=('channel_add', 'channel_mul')
imgs = torch.randn(2, 16, 20, 20)
context_block = ContextBlock(
16, 1. / 4, fusion_types=('channel_add', 'channel_mul'))
out = context_block(imgs)
assert context_block.channel_add_conv is not None
assert context_block.channel_mul_conv is not None
assert out.shape == imgs.shape
import torch
from mmcv.cnn.bricks import GeneralizedAttention
def test_context_block():
# test attention_type='1000'
imgs = torch.randn(2, 16, 20, 20)
gen_attention_block = GeneralizedAttention(16, attention_type='1000')
assert gen_attention_block.query_conv.in_channels == 16
assert gen_attention_block.key_conv.in_channels == 16
assert gen_attention_block.key_conv.in_channels == 16
out = gen_attention_block(imgs)
assert out.shape == imgs.shape
# test attention_type='0100'
imgs = torch.randn(2, 16, 20, 20)
gen_attention_block = GeneralizedAttention(16, attention_type='0100')
assert gen_attention_block.query_conv.in_channels == 16
assert gen_attention_block.appr_geom_fc_x.in_features == 8
assert gen_attention_block.appr_geom_fc_y.in_features == 8
out = gen_attention_block(imgs)
assert out.shape == imgs.shape
# test attention_type='0010'
imgs = torch.randn(2, 16, 20, 20)
gen_attention_block = GeneralizedAttention(16, attention_type='0010')
assert gen_attention_block.key_conv.in_channels == 16
assert hasattr(gen_attention_block, 'appr_bias')
out = gen_attention_block(imgs)
assert out.shape == imgs.shape
# test attention_type='0001'
imgs = torch.randn(2, 16, 20, 20)
gen_attention_block = GeneralizedAttention(16, attention_type='0001')
assert gen_attention_block.appr_geom_fc_x.in_features == 8
assert gen_attention_block.appr_geom_fc_y.in_features == 8
assert hasattr(gen_attention_block, 'geom_bias')
out = gen_attention_block(imgs)
assert out.shape == imgs.shape
# test spatial_range >= 0
imgs = torch.randn(2, 256, 20, 20)
gen_attention_block = GeneralizedAttention(256, spatial_range=10)
assert hasattr(gen_attention_block, 'local_constraint_map')
out = gen_attention_block(imgs)
assert out.shape == imgs.shape
# test q_stride > 1
imgs = torch.randn(2, 16, 20, 20)
gen_attention_block = GeneralizedAttention(16, q_stride=2)
assert gen_attention_block.q_downsample is not None
out = gen_attention_block(imgs)
assert out.shape == imgs.shape
# test kv_stride > 1
imgs = torch.randn(2, 16, 20, 20)
gen_attention_block = GeneralizedAttention(16, kv_stride=2)
assert gen_attention_block.kv_downsample is not None
out = gen_attention_block(imgs)
assert out.shape == imgs.shape
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment