import torch.nn as nn from ..functions.masked_conv import masked_conv2d class MaskedConv2d(nn.Conv2d): """A MaskedConv2d which inherits the official Conv2d. The masked forward doesn't implement the backward function and only supports the stride parameter to be 1 currently. """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super(MaskedConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) def forward(self, input, mask=None): if mask is None: # fallback to the normal Conv2d return super(MaskedConv2d, self).forward(input) else: return masked_conv2d(input, mask, self.weight, self.bias, self.padding)