separable_conv.py 674 Bytes
Newer Older
xinghao's avatar
xinghao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from torch import nn


class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, onnx_compatible=False):
        super().__init__()
        ReLU = nn.ReLU if onnx_compatible else nn.ReLU6
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size,
                      groups=in_channels, stride=stride, padding=padding),
            nn.BatchNorm2d(in_channels),
            ReLU(),
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1),
        )

    def forward(self, x):
        return self.conv(x)