import torch.nn as nn from typing import Dict from unicore.modules import LayerNorm from .common import Linear from .confidence import predicted_lddt, predicted_tm_score, predicted_aligned_error class AuxiliaryHeads(nn.Module): def __init__(self, config): super(AuxiliaryHeads, self).__init__() self.plddt = PredictedLDDTHead( **config["plddt"], ) self.distogram = DistogramHead( **config["distogram"], ) self.masked_msa = MaskedMSAHead( **config["masked_msa"], ) if config.experimentally_resolved.enabled: self.experimentally_resolved = ExperimentallyResolvedHead( **config["experimentally_resolved"], ) if config.pae.enabled: self.pae = PredictedAlignedErrorHead( **config.pae, ) self.config = config def forward(self, outputs): aux_out = {} plddt_logits = self.plddt(outputs["sm"]["single"]) aux_out["plddt_logits"] = plddt_logits aux_out["plddt"] = predicted_lddt(plddt_logits.detach()) distogram_logits = self.distogram(outputs["pair"]) aux_out["distogram_logits"] = distogram_logits masked_msa_logits = self.masked_msa(outputs["msa"]) aux_out["masked_msa_logits"] = masked_msa_logits if self.config.experimentally_resolved.enabled: exp_res_logits = self.experimentally_resolved(outputs["single"]) aux_out["experimentally_resolved_logits"] = exp_res_logits if self.config.pae.enabled: pae_logits = self.pae(outputs["pair"]) aux_out["pae_logits"] = pae_logits pae_logits = pae_logits.detach() aux_out.update( predicted_aligned_error( pae_logits, **self.config.pae, ) ) aux_out["ptm"] = predicted_tm_score( pae_logits, interface=False, **self.config.pae ) iptm_weight = self.config.pae.get("iptm_weight", 0.0) if iptm_weight > 0.0: aux_out["iptm"] = predicted_tm_score( pae_logits, interface=True, asym_id=outputs["asym_id"], **self.config.pae, ) aux_out["iptm+ptm"] = ( iptm_weight * aux_out["iptm"] + (1.0 - iptm_weight) * aux_out["ptm"] ) return aux_out class PredictedLDDTHead(nn.Module): def __init__(self, num_bins, d_in, d_hid): super(PredictedLDDTHead, self).__init__() self.num_bins = num_bins self.d_in = d_in self.d_hid = d_hid self.layer_norm = LayerNorm(self.d_in) self.linear_1 = Linear(self.d_in, self.d_hid, init="relu") self.linear_2 = Linear(self.d_hid, self.d_hid, init="relu") self.act = nn.GELU() self.linear_3 = Linear(self.d_hid, self.num_bins, init="final") def forward(self, s): s = self.layer_norm(s) s = self.linear_1(s) s = self.act(s) s = self.linear_2(s) s = self.act(s) s = self.linear_3(s) return s class EnhancedHeadBase(nn.Module): def __init__(self, d_in, d_out, disable_enhance_head): super(EnhancedHeadBase, self).__init__() if disable_enhance_head: self.layer_norm = None self.linear_in = None else: self.layer_norm = LayerNorm(d_in) self.linear_in = Linear(d_in, d_in, init="relu") self.act = nn.GELU() self.linear = Linear(d_in, d_out, init="final") def apply_alphafold_original_mode(self): self.layer_norm = None self.linear_in = None def forward(self, x): if self.layer_norm is not None: x = self.layer_norm(x) x = self.act(self.linear_in(x)) logits = self.linear(x) return logits class DistogramHead(EnhancedHeadBase): def __init__(self, d_pair, num_bins, disable_enhance_head, **kwargs): super(DistogramHead, self).__init__( d_in=d_pair, d_out=num_bins, disable_enhance_head=disable_enhance_head, ) def forward(self, x): logits = super().forward(x) logits = logits + logits.transpose(-2, -3) return logits class PredictedAlignedErrorHead(EnhancedHeadBase): def __init__(self, d_pair, num_bins, disable_enhance_head, **kwargs): super(PredictedAlignedErrorHead, self).__init__( d_in=d_pair, d_out=num_bins, disable_enhance_head=disable_enhance_head, ) class MaskedMSAHead(EnhancedHeadBase): def __init__(self, d_msa, d_out, disable_enhance_head, **kwargs): super(MaskedMSAHead, self).__init__( d_in=d_msa, d_out=d_out, disable_enhance_head=disable_enhance_head, ) class ExperimentallyResolvedHead(EnhancedHeadBase): def __init__(self, d_single, d_out, disable_enhance_head, **kwargs): super(ExperimentallyResolvedHead, self).__init__( d_in=d_single, d_out=d_out, disable_enhance_head=disable_enhance_head, )