from functools import partialmethod from typing import Optional, List import torch import torch.nn as nn from .common import Linear, chunk_layer from unicore.utils import ( permute_final_dims, ) from unicore.modules import ( softmax_dropout, LayerNorm, ) def gen_attn_mask(mask, neg_inf): assert neg_inf < -1e4 attn_mask = torch.zeros_like(mask) attn_mask[mask == 0] = neg_inf return attn_mask class Attention(nn.Module): def __init__( self, q_dim: int, k_dim: int, v_dim: int, head_dim: int, num_heads: int, gating: bool = True, ): super(Attention, self).__init__() self.num_heads = num_heads total_dim = head_dim * self.num_heads self.gating = gating self.linear_q = Linear(q_dim, total_dim, bias=False, init="glorot") self.linear_k = Linear(k_dim, total_dim, bias=False, init="glorot") self.linear_v = Linear(v_dim, total_dim, bias=False, init="glorot") self.linear_o = Linear(total_dim, q_dim, init="final") self.linear_g = None if self.gating: self.linear_g = Linear(q_dim, total_dim, init="gating") # precompute the 1/sqrt(head_dim) self.norm = head_dim**-0.5 def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: g = None if self.linear_g is not None: # gating, use raw query input g = self.linear_g(q) q = self.linear_q(q) q *= self.norm k = self.linear_k(k) v = self.linear_v(v) q = q.view(q.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3).contiguous() k = k.view(k.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3).contiguous() v = v.view(v.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3) attn = torch.matmul(q, k.transpose(-1, -2)) del q, k attn = softmax_dropout(attn, 0, self.training, mask=mask, bias=bias) o = torch.matmul(attn, v) del attn, v o = o.transpose(-2, -3).contiguous() o = o.view(*o.shape[:-2], -1) if g is not None: o = torch.sigmoid(g) * o # merge heads o = nn.functional.linear(o, self.linear_o.weight) return o def get_output_bias(self): return self.linear_o.bias class GlobalAttention(nn.Module): def __init__(self, input_dim, head_dim, num_heads, inf, eps): super(GlobalAttention, self).__init__() self.num_heads = num_heads self.inf = inf self.eps = eps self.linear_q = Linear( input_dim, head_dim * num_heads, bias=False, init="glorot" ) self.linear_k = Linear(input_dim, head_dim, bias=False, init="glorot") self.linear_v = Linear(input_dim, head_dim, bias=False, init="glorot") self.linear_g = Linear(input_dim, head_dim * num_heads, init="gating") self.linear_o = Linear(head_dim * num_heads, input_dim, init="final") self.sigmoid = nn.Sigmoid() # precompute the 1/sqrt(head_dim) self.norm = head_dim**-0.5 def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: # gating g = self.sigmoid(self.linear_g(x)) k = self.linear_k(x) v = self.linear_v(x) q = torch.sum(x * mask.unsqueeze(-1), dim=-2) / ( torch.sum(mask, dim=-1, keepdims=True) + self.eps ) q = self.linear_q(q) q *= self.norm q = q.view(q.shape[:-1] + (self.num_heads, -1)) attn = torch.matmul(q, k.transpose(-1, -2)) del q, k attn_mask = gen_attn_mask(mask, -self.inf)[..., :, None, :] attn = softmax_dropout(attn, 0, self.training, mask=attn_mask) o = torch.matmul( attn, v, ) del attn, v g = g.view(g.shape[:-1] + (self.num_heads, -1)) o = o.unsqueeze(-3) * g del g # merge heads o = o.reshape(o.shape[:-2] + (-1,)) return self.linear_o(o) def gen_msa_attn_mask(mask, inf, gen_col_mask=True): row_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :] if gen_col_mask: col_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None, None, :] return row_mask, col_mask else: return row_mask class MSAAttention(nn.Module): def __init__( self, d_in, d_hid, num_heads, pair_bias=False, d_pair=None, ): super(MSAAttention, self).__init__() self.pair_bias = pair_bias self.layer_norm_m = LayerNorm(d_in) self.layer_norm_z = None self.linear_z = None if self.pair_bias: self.layer_norm_z = LayerNorm(d_pair) self.linear_z = Linear(d_pair, num_heads, bias=False, init="normal") self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads) @torch.jit.ignore def _chunk( self, m: torch.Tensor, mask: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, chunk_size: int = None, ) -> torch.Tensor: return chunk_layer( self._attn_forward, {"m": m, "mask": mask, "bias": bias}, chunk_size=chunk_size, num_batch_dims=len(m.shape[:-2]), ) @torch.jit.ignore def _attn_chunk_forward( self, m: torch.Tensor, mask: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, chunk_size: Optional[int] = 2560, ) -> torch.Tensor: m = self.layer_norm_m(m) num_chunk = (m.shape[-3] + chunk_size - 1) // chunk_size outputs = [] for i in range(num_chunk): chunk_start = i * chunk_size chunk_end = min(m.shape[-3], chunk_start + chunk_size) cur_m = m[..., chunk_start:chunk_end, :, :] cur_mask = ( mask[..., chunk_start:chunk_end, :, :, :] if mask is not None else None ) outputs.append( self.mha(q=cur_m, k=cur_m, v=cur_m, mask=cur_mask, bias=bias) ) return torch.concat(outputs, dim=-3) def _attn_forward(self, m, mask, bias: Optional[torch.Tensor] = None): m = self.layer_norm_m(m) return self.mha(q=m, k=m, v=m, mask=mask, bias=bias) def forward( self, m: torch.Tensor, z: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, chunk_size: Optional[int] = None, ) -> torch.Tensor: bias = None if self.pair_bias: z = self.layer_norm_z(z) bias = ( permute_final_dims(self.linear_z(z), (2, 0, 1)) .unsqueeze(-4) .contiguous() ) if chunk_size is not None: m = self._chunk(m, attn_mask, bias, chunk_size) else: attn_chunk_size = 2560 if m.shape[-3] <= attn_chunk_size: m = self._attn_forward(m, attn_mask, bias) else: # reduce the peak memory cost in extra_msa_stack return self._attn_chunk_forward( m, attn_mask, bias, chunk_size=attn_chunk_size ) return m def get_output_bias(self): return self.mha.get_output_bias() class MSARowAttentionWithPairBias(MSAAttention): def __init__(self, d_msa, d_pair, d_hid, num_heads): super(MSARowAttentionWithPairBias, self).__init__( d_msa, d_hid, num_heads, pair_bias=True, d_pair=d_pair, ) class MSAColumnAttention(MSAAttention): def __init__(self, d_msa, d_hid, num_heads): super(MSAColumnAttention, self).__init__( d_in=d_msa, d_hid=d_hid, num_heads=num_heads, pair_bias=False, d_pair=None, ) def forward( self, m: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, chunk_size: Optional[int] = None, ) -> torch.Tensor: m = m.transpose(-2, -3) m = super().forward(m, attn_mask=attn_mask, chunk_size=chunk_size) m = m.transpose(-2, -3) return m class MSAColumnGlobalAttention(nn.Module): def __init__( self, d_in, d_hid, num_heads, inf=1e9, eps=1e-10, ): super(MSAColumnGlobalAttention, self).__init__() self.layer_norm_m = LayerNorm(d_in) self.global_attention = GlobalAttention( d_in, d_hid, num_heads, inf=inf, eps=eps, ) @torch.jit.ignore def _chunk( self, m: torch.Tensor, mask: torch.Tensor, chunk_size: int, ) -> torch.Tensor: return chunk_layer( self._attn_forward, {"m": m, "mask": mask}, chunk_size=chunk_size, num_batch_dims=len(m.shape[:-2]), ) def _attn_forward(self, m, mask): m = self.layer_norm_m(m) return self.global_attention(m, mask=mask) def forward( self, m: torch.Tensor, mask: Optional[torch.Tensor] = None, chunk_size: Optional[int] = None, ) -> torch.Tensor: m = m.transpose(-2, -3) mask = mask.transpose(-1, -2) if chunk_size is not None: m = self._chunk(m, mask, chunk_size) else: m = self._attn_forward(m, mask=mask) m = m.transpose(-2, -3) return m def gen_tri_attn_mask(mask, inf): start_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :] end_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None, None, :] return start_mask, end_mask class TriangleAttention(nn.Module): def __init__( self, d_in, d_hid, num_heads, starting, ): super(TriangleAttention, self).__init__() self.starting = starting self.layer_norm = LayerNorm(d_in) self.linear = Linear(d_in, num_heads, bias=False, init="normal") self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads) @torch.jit.ignore def _chunk( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, chunk_size: int = None, ) -> torch.Tensor: return chunk_layer( self.mha, {"q": x, "k": x, "v": x, "mask": mask, "bias": bias}, chunk_size=chunk_size, num_batch_dims=len(x.shape[:-2]), ) def forward( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, chunk_size: Optional[int] = None, ) -> torch.Tensor: if not self.starting: x = x.transpose(-2, -3) x = self.layer_norm(x) triangle_bias = ( permute_final_dims(self.linear(x), (2, 0, 1)).unsqueeze(-4).contiguous() ) if chunk_size is not None: x = self._chunk(x, attn_mask, triangle_bias, chunk_size) else: x = self.mha(q=x, k=x, v=x, mask=attn_mask, bias=triangle_bias) if not self.starting: x = x.transpose(-2, -3) return x def get_output_bias(self): return self.mha.get_output_bias() class TriangleAttentionStarting(TriangleAttention): __init__ = partialmethod(TriangleAttention.__init__, starting=True) class TriangleAttentionEnding(TriangleAttention): __init__ = partialmethod(TriangleAttention.__init__, starting=False)