import torch from torch import nn import torch.nn.functional as F import numpy as np import sys sys.path.append('../') from util import sigma2alpha class MFNBase(nn.Module): """ Multiplicative filter network base class. Expects the child class to define the 'filters' attribute, which should be a nn.ModuleList of n_layers+1 filters with output equal to hidden_size. """ def __init__( self, hidden_size, out_size, n_layers, weight_scale, bias=True, output_act=False ): super().__init__() self.linear = nn.ModuleList( [nn.Linear(hidden_size, hidden_size, bias) for _ in range(n_layers)] ) self.output_linear = nn.Linear(hidden_size, out_size) self.output_act = output_act for lin in self.linear: lin.weight.data.uniform_( -np.sqrt(weight_scale / hidden_size), np.sqrt(weight_scale / hidden_size), ) return def forward(self, x): out = self.filters[0](x) for i in range(1, len(self.filters)): out = self.filters[i](x) * self.linear[i - 1](out) out = self.output_linear(out) if self.output_act: out = torch.sin(out) return out class FourierLayer(nn.Module): """ Sine filter as used in FourierNet. """ def __init__(self, in_features, out_features, weight_scale): super().__init__() self.linear = nn.Linear(in_features, out_features) self.linear.weight.data *= weight_scale # gamma self.linear.bias.data.uniform_(-np.pi, np.pi) return def forward(self, x): return torch.sin(self.linear(x)) class FourierNet(MFNBase): def __init__( self, in_size, hidden_size, out_size, n_layers=3, input_scale=256.0, weight_scale=1.0, bias=True, output_act=False, ): super().__init__( hidden_size, out_size, n_layers, weight_scale, bias, output_act ) self.filters = nn.ModuleList( [ FourierLayer(in_size, hidden_size, input_scale / np.sqrt(n_layers + 1)) for _ in range(n_layers + 1) ] ) class GaborLayer(nn.Module): """ Gabor-like filter as used in GaborNet. """ def __init__(self, in_features, out_features, weight_scale, alpha=1.0, beta=1.0): super().__init__() self.linear = nn.Linear(in_features, out_features) self.mu = nn.Parameter(2 * torch.rand(out_features, in_features) - 1) self.gamma = nn.Parameter( torch.distributions.gamma.Gamma(alpha, beta).sample((out_features,)) ) self.linear.weight.data *= weight_scale * torch.sqrt(self.gamma[:, None]) self.linear.bias.data.uniform_(-np.pi, np.pi) return def forward(self, x): D = ( (x ** 2).sum(-1)[..., None] + (self.mu ** 2).sum(-1)[None, :] - 2 * x @ self.mu.T ) return torch.sin(self.linear(x)) * torch.exp(-0.5 * D * self.gamma[None, :]) class GaborNet(MFNBase): def __init__( self, in_size, hidden_size, out_size, n_layers=3, input_scale=256.0, weight_scale=1.0, alpha=6.0, beta=1.0, bias=True, output_act=False, ): super().__init__( hidden_size, out_size, n_layers, weight_scale, bias, output_act ) self.filters = nn.ModuleList( [ GaborLayer( in_size, hidden_size, input_scale / np.sqrt(n_layers + 1), alpha / (n_layers + 1), beta, ) for _ in range(n_layers + 1) ] ) def gradient(self, x): # only for the color mlp x.requires_grad_(True) y = self.forward(x)[..., -1:] y = F.softplus(y - 1.) y = sigma2alpha(y) d_output = torch.ones_like(y, requires_grad=False, device=y.device) gradients = torch.autograd.grad( outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True)[0] return gradients.unsqueeze(1)