Commit c04f261a authored by dongchy920's avatar dongchy920
Browse files

InstruceBLIP

parents
Pipeline #1594 canceled with stages
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 1, 1),
strides=(1, 2, 2, 2),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='UPerHead',
in_channels=[256, 512, 1024, 2048],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=512,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
# model settings
norm_cfg = dict(type='BN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='UniFormer',
embed_dim=[64, 128, 320, 512],
layers=[3, 4, 8, 3],
head_dim=64,
mlp_ratio=4.,
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.1),
decode_head=dict(
type='UPerHead',
in_channels=[64, 128, 320, 512],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=512,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=320,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
\ No newline at end of file
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=160000)
checkpoint_config = dict(by_epoch=False, interval=16000)
evaluation = dict(interval=16000, metric='mIoU')
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=20000)
checkpoint_config = dict(by_epoch=False, interval=2000)
evaluation = dict(interval=2000, metric='mIoU')
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=40000)
checkpoint_config = dict(by_epoch=False, interval=4000)
evaluation = dict(interval=4000, metric='mIoU')
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=80000)
checkpoint_config = dict(by_epoch=False, interval=8000)
evaluation = dict(interval=8000, metric='mIoU')
_base_ = [
'../../configs/_base_/models/upernet_uniformer.py',
'../../configs/_base_/datasets/ade20k.py',
'../../configs/_base_/default_runtime.py',
'../../configs/_base_/schedules/schedule_160k.py'
]
model = dict(
backbone=dict(
type='UniFormer',
embed_dim=[64, 128, 320, 512],
layers=[3, 4, 8, 3],
head_dim=64,
drop_path_rate=0.25,
windows=False,
hybrid=False
),
decode_head=dict(
in_channels=[64, 128, 320, 512],
num_classes=150
),
auxiliary_head=dict(
in_channels=320,
num_classes=150
))
# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
lr_config = dict(_delete_=True, policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0, min_lr=0.0, by_epoch=False)
data=dict(samples_per_gpu=2)
\ No newline at end of file
#!/usr/bin/env bash
work_path=$(dirname $0)
PYTHONPATH="$(dirname $0)/../../":$PYTHONPATH \
python -m torch.distributed.launch --nproc_per_node=8 \
tools/train.py ${work_path}/config.py \
--launcher pytorch \
--options model.backbone.pretrained_path='your_model_path/uniformer_small_in1k.pth' \
--work-dir ${work_path}/ckpt \
2>&1 | tee -a ${work_path}/log.txt
#!/usr/bin/env bash
work_path=$(dirname $0)
PYTHONPATH="$(dirname $0)/../../":$PYTHONPATH \
python -m torch.distributed.launch --nproc_per_node=8 \
tools/test.py ${work_path}/test_config_h32.py \
${work_path}/ckpt/latest.pth \
--launcher pytorch \
--eval mIoU \
2>&1 | tee -a ${work_path}/log.txt
_base_ = [
'../../configs/_base_/models/upernet_uniformer.py',
'../../configs/_base_/datasets/ade20k.py',
'../../configs/_base_/default_runtime.py',
'../../configs/_base_/schedules/schedule_160k.py'
]
model = dict(
backbone=dict(
type='UniFormer',
embed_dim=[64, 128, 320, 512],
layers=[3, 4, 8, 3],
head_dim=64,
drop_path_rate=0.25,
windows=False,
hybrid=False,
),
decode_head=dict(
in_channels=[64, 128, 320, 512],
num_classes=150
),
auxiliary_head=dict(
in_channels=320,
num_classes=150
))
# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
lr_config = dict(_delete_=True, policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0, min_lr=0.0, by_epoch=False)
data=dict(samples_per_gpu=2)
\ No newline at end of file
_base_ = [
'../../configs/_base_/models/upernet_uniformer.py',
'../../configs/_base_/datasets/ade20k.py',
'../../configs/_base_/default_runtime.py',
'../../configs/_base_/schedules/schedule_160k.py'
]
model = dict(
backbone=dict(
type='UniFormer',
embed_dim=[64, 128, 320, 512],
layers=[3, 4, 8, 3],
head_dim=64,
drop_path_rate=0.25,
windows=False,
hybrid=True,
window_size=32
),
decode_head=dict(
in_channels=[64, 128, 320, 512],
num_classes=150
),
auxiliary_head=dict(
in_channels=320,
num_classes=150
))
# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
lr_config = dict(_delete_=True, policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0, min_lr=0.0, by_epoch=False)
data=dict(samples_per_gpu=2)
\ No newline at end of file
_base_ = [
'../../configs/_base_/models/upernet_uniformer.py',
'../../configs/_base_/datasets/ade20k.py',
'../../configs/_base_/default_runtime.py',
'../../configs/_base_/schedules/schedule_160k.py'
]
model = dict(
backbone=dict(
type='UniFormer',
embed_dim=[64, 128, 320, 512],
layers=[3, 4, 8, 3],
head_dim=64,
drop_path_rate=0.25,
windows=True,
hybrid=False,
window_size=32
),
decode_head=dict(
in_channels=[64, 128, 320, 512],
num_classes=150
),
auxiliary_head=dict(
in_channels=320,
num_classes=150
))
# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
lr_config = dict(_delete_=True, policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0, min_lr=0.0, by_epoch=False)
data=dict(samples_per_gpu=2)
\ No newline at end of file
# Copyright (c) OpenMMLab. All rights reserved.
# flake8: noqa
from .arraymisc import *
from .fileio import *
from .image import *
from .utils import *
from .version import *
from .video import *
from .visualization import *
# The following modules are not imported to this level, so mmcv may be used
# without PyTorch.
# - runner
# - parallel
# - op
# Copyright (c) OpenMMLab. All rights reserved.
from .quantization import dequantize, quantize
__all__ = ['quantize', 'dequantize']
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
def quantize(arr, min_val, max_val, levels, dtype=np.int64):
"""Quantize an array of (-inf, inf) to [0, levels-1].
Args:
arr (ndarray): Input array.
min_val (scalar): Minimum value to be clipped.
max_val (scalar): Maximum value to be clipped.
levels (int): Quantization levels.
dtype (np.type): The type of the quantized array.
Returns:
tuple: Quantized array.
"""
if not (isinstance(levels, int) and levels > 1):
raise ValueError(
f'levels must be a positive integer, but got {levels}')
if min_val >= max_val:
raise ValueError(
f'min_val ({min_val}) must be smaller than max_val ({max_val})')
arr = np.clip(arr, min_val, max_val) - min_val
quantized_arr = np.minimum(
np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
return quantized_arr
def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
"""Dequantize an array.
Args:
arr (ndarray): Input array.
min_val (scalar): Minimum value to be clipped.
max_val (scalar): Maximum value to be clipped.
levels (int): Quantization levels.
dtype (np.type): The type of the dequantized array.
Returns:
tuple: Dequantized array.
"""
if not (isinstance(levels, int) and levels > 1):
raise ValueError(
f'levels must be a positive integer, but got {levels}')
if min_val >= max_val:
raise ValueError(
f'min_val ({min_val}) must be smaller than max_val ({max_val})')
dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
min_val) / levels + min_val
return dequantized_arr
# Copyright (c) OpenMMLab. All rights reserved.
from .alexnet import AlexNet
# yapf: disable
from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule,
ConvTranspose2d, ConvTranspose3d, ConvWS2d,
DepthwiseSeparableConvModule, GeneralizedAttention,
HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d,
NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer, build_plugin_layer,
build_upsample_layer, conv_ws_2d, is_norm)
from .builder import MODELS, build_model_from_cfg
# yapf: enable
from .resnet import ResNet, make_res_layer
from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
NormalInit, PretrainedInit, TruncNormalInit, UniformInit,
XavierInit, bias_init_with_prob, caffe2_xavier_init,
constant_init, fuse_conv_bn, get_model_complexity_info,
initialize, kaiming_init, normal_init, trunc_normal_init,
uniform_init, xavier_init)
from .vgg import VGG, make_vgg_layer
__all__ = [
'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
'bias_init_with_prob', 'ConvModule', 'build_activation_layer',
'build_conv_layer', 'build_norm_layer', 'build_padding_layer',
'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d',
'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish',
'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale',
'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d',
'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d',
'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d',
'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
]
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import torch.nn as nn
class AlexNet(nn.Module):
"""AlexNet backbone.
Args:
num_classes (int): number of classes for classification.
"""
def __init__(self, num_classes=-1):
super(AlexNet, self).__init__()
self.num_classes = num_classes
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
if self.num_classes > 0:
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
from ..runner import load_checkpoint
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
# use default initializer
pass
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
x = self.features(x)
if self.num_classes > 0:
x = x.view(x.size(0), 256 * 6 * 6)
x = self.classifier(x)
return x
# Copyright (c) OpenMMLab. All rights reserved.
from .activation import build_activation_layer
from .context_block import ContextBlock
from .conv import build_conv_layer
from .conv2d_adaptive_padding import Conv2dAdaptivePadding
from .conv_module import ConvModule
from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
from .drop import Dropout, DropPath
from .generalized_attention import GeneralizedAttention
from .hsigmoid import HSigmoid
from .hswish import HSwish
from .non_local import NonLocal1d, NonLocal2d, NonLocal3d
from .norm import build_norm_layer, is_norm
from .padding import build_padding_layer
from .plugin import build_plugin_layer
from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS)
from .scale import Scale
from .swish import Swish
from .upsample import build_upsample_layer
from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
Linear, MaxPool2d, MaxPool3d)
__all__ = [
'ConvModule', 'build_activation_layer', 'build_conv_layer',
'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
'build_plugin_layer', 'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d',
'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'Dropout', 'DropPath'
]
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from annotator.uniformer.mmcv.utils import TORCH_VERSION, build_from_cfg, digit_version
from .registry import ACTIVATION_LAYERS
for module in [
nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.RReLU, nn.ReLU6, nn.ELU,
nn.Sigmoid, nn.Tanh
]:
ACTIVATION_LAYERS.register_module(module=module)
@ACTIVATION_LAYERS.register_module(name='Clip')
@ACTIVATION_LAYERS.register_module()
class Clamp(nn.Module):
"""Clamp activation layer.
This activation function is to clamp the feature map value within
:math:`[min, max]`. More details can be found in ``torch.clamp()``.
Args:
min (Number | optional): Lower-bound of the range to be clamped to.
Default to -1.
max (Number | optional): Upper-bound of the range to be clamped to.
Default to 1.
"""
def __init__(self, min=-1., max=1.):
super(Clamp, self).__init__()
self.min = min
self.max = max
def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: Clamped tensor.
"""
return torch.clamp(x, min=self.min, max=self.max)
class GELU(nn.Module):
r"""Applies the Gaussian Error Linear Units function:
.. math::
\text{GELU}(x) = x * \Phi(x)
where :math:`\Phi(x)` is the Cumulative Distribution Function for
Gaussian Distribution.
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
.. image:: scripts/activation_images/GELU.png
Examples::
>>> m = nn.GELU()
>>> input = torch.randn(2)
>>> output = m(input)
"""
def forward(self, input):
return F.gelu(input)
if (TORCH_VERSION == 'parrots'
or digit_version(TORCH_VERSION) < digit_version('1.4')):
ACTIVATION_LAYERS.register_module(module=GELU)
else:
ACTIVATION_LAYERS.register_module(module=nn.GELU)
def build_activation_layer(cfg):
"""Build activation layer.
Args:
cfg (dict): The activation layer config, which should contain:
- type (str): Layer type.
- layer args: Args needed to instantiate an activation layer.
Returns:
nn.Module: Created activation layer.
"""
return build_from_cfg(cfg, ACTIVATION_LAYERS)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn
from ..utils import constant_init, kaiming_init
from .registry import PLUGIN_LAYERS
def last_zero_init(m):
if isinstance(m, nn.Sequential):
constant_init(m[-1], val=0)
else:
constant_init(m, val=0)
@PLUGIN_LAYERS.register_module()
class ContextBlock(nn.Module):
"""ContextBlock module in GCNet.
See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond'
(https://arxiv.org/abs/1904.11492) for details.
Args:
in_channels (int): Channels of the input feature map.
ratio (float): Ratio of channels of transform bottleneck
pooling_type (str): Pooling method for context modeling.
Options are 'att' and 'avg', stand for attention pooling and
average pooling respectively. Default: 'att'.
fusion_types (Sequence[str]): Fusion method for feature fusion,
Options are 'channels_add', 'channel_mul', stand for channelwise
addition and multiplication respectively. Default: ('channel_add',)
"""
_abbr_ = 'context_block'
def __init__(self,
in_channels,
ratio,
pooling_type='att',
fusion_types=('channel_add', )):
super(ContextBlock, self).__init__()
assert pooling_type in ['avg', 'att']
assert isinstance(fusion_types, (list, tuple))
valid_fusion_types = ['channel_add', 'channel_mul']
assert all([f in valid_fusion_types for f in fusion_types])
assert len(fusion_types) > 0, 'at least one fusion should be used'
self.in_channels = in_channels
self.ratio = ratio
self.planes = int(in_channels * ratio)
self.pooling_type = pooling_type
self.fusion_types = fusion_types
if pooling_type == 'att':
self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1)
self.softmax = nn.Softmax(dim=2)
else:
self.avg_pool = nn.AdaptiveAvgPool2d(1)
if 'channel_add' in fusion_types:
self.channel_add_conv = nn.Sequential(
nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(inplace=True), # yapf: disable
nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
else:
self.channel_add_conv = None
if 'channel_mul' in fusion_types:
self.channel_mul_conv = nn.Sequential(
nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(inplace=True), # yapf: disable
nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
else:
self.channel_mul_conv = None
self.reset_parameters()
def reset_parameters(self):
if self.pooling_type == 'att':
kaiming_init(self.conv_mask, mode='fan_in')
self.conv_mask.inited = True
if self.channel_add_conv is not None:
last_zero_init(self.channel_add_conv)
if self.channel_mul_conv is not None:
last_zero_init(self.channel_mul_conv)
def spatial_pool(self, x):
batch, channel, height, width = x.size()
if self.pooling_type == 'att':
input_x = x
# [N, C, H * W]
input_x = input_x.view(batch, channel, height * width)
# [N, 1, C, H * W]
input_x = input_x.unsqueeze(1)
# [N, 1, H, W]
context_mask = self.conv_mask(x)
# [N, 1, H * W]
context_mask = context_mask.view(batch, 1, height * width)
# [N, 1, H * W]
context_mask = self.softmax(context_mask)
# [N, 1, H * W, 1]
context_mask = context_mask.unsqueeze(-1)
# [N, 1, C, 1]
context = torch.matmul(input_x, context_mask)
# [N, C, 1, 1]
context = context.view(batch, channel, 1, 1)
else:
# [N, C, 1, 1]
context = self.avg_pool(x)
return context
def forward(self, x):
# [N, C, 1, 1]
context = self.spatial_pool(x)
out = x
if self.channel_mul_conv is not None:
# [N, C, 1, 1]
channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
out = out * channel_mul_term
if self.channel_add_conv is not None:
# [N, C, 1, 1]
channel_add_term = self.channel_add_conv(context)
out = out + channel_add_term
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment