from collections import OrderedDict
import math

import torch.nn as nn

__all__ = [
    'SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152',
    'se_resnext50_32x4d', 'se_resnext101_32x4d', 'se_resnext101_64x4d'
]


class SEModule(nn.Module):

    def __init__(self, channels, reduction):
        super(SEModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(channels,
                             channels // reduction,
                             kernel_size=1,
                             padding=0)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(channels // reduction,
                             channels,
                             kernel_size=1,
                             padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        module_input = x
        x = self.avg_pool(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return module_input * x


class Bottleneck(nn.Module):

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out = self.se_module(out) + residual
        out = self.relu(out)

        return out


class SEBottleneck(Bottleneck):

    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 groups,
                 reduction,
                 stride=1,
                 downsample=None):
        super(SEBottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes * 2)
        self.conv2 = nn.Conv2d(planes * 2,
                               planes * 4,
                               kernel_size=3,
                               stride=stride,
                               padding=1,
                               groups=groups,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(planes * 4)
        self.conv3 = nn.Conv2d(planes * 4,
                               planes * 4,
                               kernel_size=1,
                               bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.se_module = SEModule(planes * 4, reduction=reduction)
        self.downsample = downsample
        self.stride = stride


class SEResNetBottleneck(Bottleneck):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 groups,
                 reduction,
                 stride=1,
                 downsample=None):
        super(SEResNetBottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes,
                               planes,
                               kernel_size=1,
                               bias=False,
                               stride=stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes,
                               planes,
                               kernel_size=3,
                               padding=1,
                               groups=groups,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.se_module = SEModule(planes * 4, reduction=reduction)
        self.downsample = downsample
        self.stride = stride


class SEResNeXtBottleneck(Bottleneck):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 groups,
                 reduction,
                 stride=1,
                 downsample=None,
                 base_width=4):
        super(SEResNeXtBottleneck, self).__init__()
        width = math.floor(planes * (base_width / 64)) * groups
        self.conv1 = nn.Conv2d(inplanes,
                               width,
                               kernel_size=1,
                               bias=False,
                               stride=1)
        self.bn1 = nn.BatchNorm2d(width)
        self.conv2 = nn.Conv2d(width,
                               width,
                               kernel_size=3,
                               stride=stride,
                               padding=1,
                               groups=groups,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(width)
        self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.se_module = SEModule(planes * 4, reduction=reduction)
        self.downsample = downsample
        self.stride = stride


class SENet(nn.Module):

    def __init__(self,
                 block,
                 layers,
                 groups,
                 reduction,
                 dropout_p=0.2,
                 inplanes=128,
                 input_3x3=True,
                 downsample_kernel_size=3,
                 downsample_padding=1,
                 num_classes=1000):
        super(SENet, self).__init__()
        self.inplanes = inplanes
        if input_3x3:
            layer0_modules = [
                ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1,
                                    bias=False)),
                ('bn1', nn.BatchNorm2d(64)),
                ('relu1', nn.ReLU(inplace=True)),
                ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1,
                                    bias=False)),
                ('bn2', nn.BatchNorm2d(64)),
                ('relu2', nn.ReLU(inplace=True)),
                ('conv3',
                 nn.Conv2d(64, inplanes, 3, stride=1, padding=1, bias=False)),
                ('bn3', nn.BatchNorm2d(inplanes)),
                ('relu3', nn.ReLU(inplace=True)),
            ]
        else:
            layer0_modules = [
                ('conv1',
                 nn.Conv2d(3,
                           inplanes,
                           kernel_size=7,
                           stride=2,
                           padding=3,
                           bias=False)),
                ('bn1', nn.BatchNorm2d(inplanes)),
                ('relu1', nn.ReLU(inplace=True)),
            ]
        layer0_modules.append(('pool', nn.MaxPool2d(3,
                                                    stride=2,
                                                    ceil_mode=True)))
        self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
        self.layer1 = self._make_layer(block,
                                       planes=64,
                                       blocks=layers[0],
                                       groups=groups,
                                       reduction=reduction,
                                       downsample_kernel_size=1,
                                       downsample_padding=0)
        self.layer2 = self._make_layer(
            block,
            planes=128,
            blocks=layers[1],
            stride=2,
            groups=groups,
            reduction=reduction,
            downsample_kernel_size=downsample_kernel_size,
            downsample_padding=downsample_padding)
        self.layer3 = self._make_layer(
            block,
            planes=256,
            blocks=layers[2],
            stride=2,
            groups=groups,
            reduction=reduction,
            downsample_kernel_size=downsample_kernel_size,
            downsample_padding=downsample_padding)
        self.layer4 = self._make_layer(
            block,
            planes=512,
            blocks=layers[3],
            stride=2,
            groups=groups,
            reduction=reduction,
            downsample_kernel_size=downsample_kernel_size,
            downsample_padding=downsample_padding)
        self.avg_pool = nn.AvgPool2d(7, stride=1)
        self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
        self.last_linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self,
                    block,
                    planes,
                    blocks,
                    groups,
                    reduction,
                    stride=1,
                    downsample_kernel_size=1,
                    downsample_padding=0):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes,
                          planes * block.expansion,
                          kernel_size=downsample_kernel_size,
                          stride=stride,
                          padding=downsample_padding,
                          bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(
            block(self.inplanes, planes, groups, reduction, stride,
                  downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups, reduction))

        return nn.Sequential(*layers)

    def features(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x

    def logits(self, x):
        x = self.avg_pool(x)
        if self.dropout is not None:
            x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = self.last_linear(x)
        return x

    def forward(self, x):
        x = self.features(x)
        x = self.logits(x)
        return x


def senet154(**kwargs):
    model = SENet(SEBottleneck, [3, 8, 36, 3],
                  groups=64,
                  reduction=16,
                  dropout_p=0.2,
                  **kwargs)
    return model


def se_resnet50(**kwargs):
    model = SENet(SEResNetBottleneck, [3, 4, 6, 3],
                  groups=1,
                  reduction=16,
                  dropout_p=None,
                  inplanes=64,
                  input_3x3=False,
                  downsample_kernel_size=1,
                  downsample_padding=0,
                  **kwargs)
    return model


def se_resnet101(**kwargs):
    model = SENet(SEResNetBottleneck, [3, 4, 23, 3],
                  groups=1,
                  reduction=16,
                  dropout_p=None,
                  inplanes=64,
                  input_3x3=False,
                  downsample_kernel_size=1,
                  downsample_padding=0,
                  **kwargs)
    return model


def se_resnet152(**kwargs):
    model = SENet(SEResNetBottleneck, [3, 8, 36, 3],
                  groups=1,
                  reduction=16,
                  dropout_p=None,
                  inplanes=64,
                  input_3x3=False,
                  downsample_kernel_size=1,
                  downsample_padding=0,
                  **kwargs)
    return model


def se_resnext50_32x4d(**kwargs):
    model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3],
                  groups=32,
                  reduction=16,
                  dropout_p=None,
                  inplanes=64,
                  input_3x3=False,
                  downsample_kernel_size=1,
                  downsample_padding=0,
                  **kwargs)
    return model


def se_resnext101_32x4d(**kwargs):
    model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3],
                  groups=32,
                  reduction=16,
                  dropout_p=None,
                  inplanes=64,
                  input_3x3=False,
                  downsample_kernel_size=1,
                  downsample_padding=0,
                  **kwargs)
    return model


def se_resnext101_64x4d(**kwargs):
    model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3],
                  groups=64,
                  reduction=16,
                  dropout_p=None,
                  inplanes=64,
                  input_3x3=False,
                  downsample_kernel_size=1,
                  downsample_padding=0,
                  **kwargs)
    return model
