import random import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor class GraphAttentionLayer(nn.Module): def __init__(self, in_dim, out_dim, **kwargs): super().__init__() # attention map self.att_proj = nn.Linear(in_dim, out_dim) self.att_weight = self._init_new_params(out_dim, 1) # project self.proj_with_att = nn.Linear(in_dim, out_dim) self.proj_without_att = nn.Linear(in_dim, out_dim) # batch norm self.bn = nn.BatchNorm1d(out_dim) # dropout for inputs self.input_drop = nn.Dropout(p=0.2) # activate self.act = nn.SELU(inplace=True) def forward(self, x): ''' x :(#bs, #node, #dim) ''' # apply input dropout x = self.input_drop(x) # derive attention map att_map = self._derive_att_map(x) # projection x = self._project(x, att_map) # apply batch norm x = self._apply_BN(x) x = self.act(x) return x def _pairwise_mul_nodes(self, x): ''' Calculates pairwise multiplication of nodes. - for attention map x :(#bs, #node, #dim) out_shape :(#bs, #node, #node, #dim) ''' nb_nodes = x.size(1) x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1) x_mirror = x.transpose(1, 2) return x * x_mirror def _derive_att_map(self, x): ''' x :(#bs, #node, #dim) out_shape :(#bs, #node, #node, 1) ''' att_map = self._pairwise_mul_nodes(x) # size: (#bs, #node, #node, #dim_out) att_map = torch.tanh(self.att_proj(att_map)) # size: (#bs, #node, #node, 1) att_map = torch.matmul(att_map, self.att_weight) att_map = F.softmax(att_map, dim=-2) return att_map def _project(self, x, att_map): x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x)) x2 = self.proj_without_att(x) return x1 + x2 def _apply_BN(self, x): org_size = x.size() x = x.view(-1, org_size[-1]) x = self.bn(x) x = x.view(org_size) return x def _init_new_params(self, *size): out = nn.Parameter(torch.FloatTensor(*size)) nn.init.xavier_normal_(out) return out class GraphPool(nn.Module): def __init__(self, k, in_dim, p): super().__init__() self.k = k self.sigmoid = nn.Sigmoid() self.proj = nn.Linear(in_dim, 1) self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity() self.in_dim = in_dim def forward(self, h): Z = self.drop(h) weights = self.proj(Z) scores = self.sigmoid(weights) new_h = self.top_k_graph(scores, h, self.k) return new_h def top_k_graph(self, scores, h, k): """ args ===== scores: attention-based weights (#bs, #node, 1) h: graph data (#bs, #node, #dim) k: ratio of remaining nodes, (float) returns ===== h: graph pool applied data (#bs, #node', #dim) """ n_nodes = max(int(h.size(1) * k), 2) n_feat = h.size(2) _, idx = torch.topk(scores, n_nodes, dim=1) idx = idx.expand(-1, -1, n_feat) h = h * scores h = torch.gather(h, 1, idx) return h class CONV(nn.Module): @staticmethod def to_mel(hz): return 2595 * np.log10(1 + hz / 700) @staticmethod def to_hz(mel): return 700 * (10**(mel / 2595) - 1) def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1, stride=1, padding=0, dilation=1, bias=False, groups=1, mask=False): super().__init__() if in_channels != 1: msg = "SincConv only support one input channel (here, in_channels = {%i})" % ( in_channels) raise ValueError(msg) self.out_channels = out_channels self.kernel_size = kernel_size self.sample_rate = sample_rate # Forcing the filters to be odd (i.e, perfectly symmetrics) if kernel_size % 2 == 0: self.kernel_size = self.kernel_size + 1 self.stride = stride self.padding = padding self.dilation = dilation self.mask = mask if bias: raise ValueError('SincConv does not support bias.') if groups > 1: raise ValueError('SincConv does not support groups.') NFFT = 512 f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1) fmel = self.to_mel(f) fmelmax = np.max(fmel) fmelmin = np.min(fmel) filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1) filbandwidthsf = self.to_hz(filbandwidthsmel) self.mel = filbandwidthsf self.hsupp = torch.arange(-(self.kernel_size - 1) / 2, (self.kernel_size - 1) / 2 + 1) self.band_pass = torch.zeros(self.out_channels, self.kernel_size) for i in range(len(self.mel) - 1): fmin = self.mel[i] fmax = self.mel[i + 1] hHigh = (2*fmax/self.sample_rate) * \ np.sinc(2*fmax*self.hsupp/self.sample_rate) hLow = (2*fmin/self.sample_rate) * \ np.sinc(2*fmin*self.hsupp/self.sample_rate) hideal = hHigh - hLow self.band_pass[i, :] = Tensor(np.hamming( self.kernel_size)) * Tensor(hideal) def forward(self, x, mask=False): band_pass_filter = self.band_pass.clone().to(x.device) if mask: A = np.random.uniform(0, 20) A = int(A) A0 = random.randint(0, band_pass_filter.shape[0] - A) band_pass_filter[A0:A0 + A, :] = 0 else: band_pass_filter = band_pass_filter self.filters = (band_pass_filter).view(self.out_channels, 1, self.kernel_size) return F.conv1d(x, self.filters, stride=self.stride, padding=self.padding, dilation=self.dilation, bias=None, groups=1) class Residual_block(nn.Module): def __init__(self, nb_filts, first=False): super().__init__() self.first = first if not self.first: self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0]) self.conv1 = nn.Conv2d(in_channels=nb_filts[0], out_channels=nb_filts[1], kernel_size=(2, 3), padding=(1, 1), stride=1) self.selu = nn.SELU(inplace=True) self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1]) self.conv2 = nn.Conv2d(in_channels=nb_filts[1], out_channels=nb_filts[1], kernel_size=(2, 3), padding=(0, 1), stride=1) if nb_filts[0] != nb_filts[1]: self.downsample = True self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0], out_channels=nb_filts[1], padding=(0, 1), kernel_size=(1, 3), stride=1) else: self.downsample = False self.mp = nn.MaxPool2d((1, 3)) # self.mp = nn.MaxPool2d((1,4)) def forward(self, x): identity = x if not self.first: out = self.bn1(x) out = self.selu(out) else: out = x out = self.conv1(x) # print('out',out.shape) out = self.bn2(out) out = self.selu(out) # print('out',out.shape) out = self.conv2(out) #print('conv2 out',out.shape) if self.downsample: identity = self.conv_downsample(identity) out += identity out = self.mp(out) return out class Model(nn.Module): def __init__(self, d_args): super().__init__() self.d_args = d_args filts = d_args["filts"] self.conv_time = CONV(out_channels=filts[0], kernel_size=d_args["first_conv"], in_channels=1) self.first_bn = nn.BatchNorm2d(num_features=1) self.selu = nn.SELU(inplace=True) self.encoder_T = nn.Sequential( nn.Sequential(Residual_block(nb_filts=filts[1], first=True)), nn.Sequential(Residual_block(nb_filts=filts[2])), nn.Sequential(Residual_block(nb_filts=filts[3])), nn.Sequential(Residual_block(nb_filts=filts[4])), nn.Sequential(Residual_block(nb_filts=filts[4])), nn.Sequential(Residual_block(nb_filts=filts[4]))) self.encoder_S = nn.Sequential( nn.Sequential(Residual_block(nb_filts=filts[1], first=True)), nn.Sequential(Residual_block(nb_filts=filts[2])), nn.Sequential(Residual_block(nb_filts=filts[3])), nn.Sequential(Residual_block(nb_filts=filts[4])), nn.Sequential(Residual_block(nb_filts=filts[4])), nn.Sequential(Residual_block(nb_filts=filts[4]))) self.GAT_layer_T = GraphAttentionLayer(64, 32) self.GAT_layer_S = GraphAttentionLayer(64, 32) self.GAT_layer_ST = GraphAttentionLayer(32, 16) self.pool_T = GraphPool(0.64, 32, 0.3) self.pool_S = GraphPool(0.81, 32, 0.3) self.pool_ST = GraphPool(0.64, 16, 0.3) self.proj_T = nn.Linear(14, 12) self.proj_S = nn.Linear(23, 12) self.proj_ST = nn.Linear(16, 1) self.out_layer = nn.Linear(7, 2) def forward(self, x, Freq_aug=False): nb_samp1 = x.shape[0] len_seq = x.shape[1] x = x.view(nb_samp1, 1, len_seq) x = self.conv_time(x, mask=Freq_aug) x = x.unsqueeze(dim=1) x = F.max_pool2d(torch.abs(x), (3, 3)) x = self.first_bn(x) x = self.selu(x) e_T = self.encoder_T(x) # (#bs, #filt, #spec, #seq) e_T, _ = torch.max(torch.abs(e_T), dim=3) # max along time gat_T = self.GAT_layer_T(e_T.transpose(1, 2)) pool_T = self.pool_T(gat_T) # (#bs, #node, #dim) out_T = self.proj_T(pool_T.transpose(1, 2)) e_S = self.encoder_S(x) e_S, _ = torch.max(torch.abs(e_S), dim=2) # max along freq gat_S = self.GAT_layer_S(e_S.transpose(1, 2)) pool_S = self.pool_S(gat_S) out_S = self.proj_S(pool_S.transpose(1, 2)) gat_ST = torch.mul(out_T, out_S) gat_ST = self.GAT_layer_ST(gat_ST.transpose(1, 2)) pool_ST = self.pool_ST(gat_ST) proj_ST = self.proj_ST(pool_ST).flatten(1) output = self.out_layer(proj_ST) return proj_ST, output