import torch.nn as nn __all__ = ["Conv3DSimple", "Conv2Plus1D", "Conv3DNoTemporal"] class Conv3DSimple(nn.Conv3d): def __init__(self, in_planes, out_planes, midplanes=None, stride=1, padding=1): super(Conv3DSimple, self).__init__( in_channels=in_planes, out_channels=out_planes, kernel_size=(3, 3, 3), stride=stride, padding=padding, bias=False) @staticmethod def get_downsample_stride(stride): return (stride, stride, stride) class Conv2Plus1D(nn.Sequential): def __init__(self, in_planes, out_planes, midplanes, stride=1, padding=1): conv1 = [ nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3), stride=(1, stride, stride), padding=(0, padding, padding), bias=False), nn.BatchNorm3d(midplanes), nn.ReLU(inplace=True), nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1), stride=(stride, 1, 1), padding=(padding, 0, 0), bias=False) ] super(Conv2Plus1D, self).__init__(*conv1) @staticmethod def get_downsample_stride(stride): return (stride, stride, stride) class Conv3DNoTemporal(nn.Conv3d): def __init__(self, in_planes, out_planes, midplanes=None, stride=1, padding=1): super(Conv3DNoTemporal, self).__init__( in_channels=in_planes, out_channels=out_planes, kernel_size=(1, 3, 3), stride=(1, stride, stride), padding=(0, padding, padding), bias=False) @staticmethod def get_downsample_stride(stride): return (1, stride, stride)