auxillary_heads.py 5.19 KB
Newer Older
zhangqha's avatar
zhangqha committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import torch.nn as nn
from typing import Dict
from unicore.modules import LayerNorm
from .common import Linear
from .confidence import predicted_lddt, predicted_tm_score, predicted_aligned_error


class AuxiliaryHeads(nn.Module):
    def __init__(self, config):
        super(AuxiliaryHeads, self).__init__()

        self.plddt = PredictedLDDTHead(
            **config["plddt"],
        )

        self.distogram = DistogramHead(
            **config["distogram"],
        )

        self.masked_msa = MaskedMSAHead(
            **config["masked_msa"],
        )

        if config.experimentally_resolved.enabled:
            self.experimentally_resolved = ExperimentallyResolvedHead(
                **config["experimentally_resolved"],
            )

        if config.pae.enabled:
            self.pae = PredictedAlignedErrorHead(
                **config.pae,
            )

        self.config = config

    def forward(self, outputs):
        aux_out = {}
        plddt_logits = self.plddt(outputs["sm"]["single"])
        aux_out["plddt_logits"] = plddt_logits

        aux_out["plddt"] = predicted_lddt(plddt_logits.detach())

        distogram_logits = self.distogram(outputs["pair"])
        aux_out["distogram_logits"] = distogram_logits

        masked_msa_logits = self.masked_msa(outputs["msa"])
        aux_out["masked_msa_logits"] = masked_msa_logits

        if self.config.experimentally_resolved.enabled:
            exp_res_logits = self.experimentally_resolved(outputs["single"])
            aux_out["experimentally_resolved_logits"] = exp_res_logits

        if self.config.pae.enabled:
            pae_logits = self.pae(outputs["pair"])
            aux_out["pae_logits"] = pae_logits
            pae_logits = pae_logits.detach()
            aux_out.update(
                predicted_aligned_error(
                    pae_logits,
                    **self.config.pae,
                )
            )
            aux_out["ptm"] = predicted_tm_score(
                pae_logits, interface=False, **self.config.pae
            )

            iptm_weight = self.config.pae.get("iptm_weight", 0.0)
            if iptm_weight > 0.0:
                aux_out["iptm"] = predicted_tm_score(
                    pae_logits,
                    interface=True,
                    asym_id=outputs["asym_id"],
                    **self.config.pae,
                )
                aux_out["iptm+ptm"] = (
                    iptm_weight * aux_out["iptm"] + (1.0 - iptm_weight) * aux_out["ptm"]
                )

        return aux_out


class PredictedLDDTHead(nn.Module):
    def __init__(self, num_bins, d_in, d_hid):
        super(PredictedLDDTHead, self).__init__()

        self.num_bins = num_bins
        self.d_in = d_in
        self.d_hid = d_hid

        self.layer_norm = LayerNorm(self.d_in)

        self.linear_1 = Linear(self.d_in, self.d_hid, init="relu")
        self.linear_2 = Linear(self.d_hid, self.d_hid, init="relu")
        self.act = nn.GELU()
        self.linear_3 = Linear(self.d_hid, self.num_bins, init="final")

    def forward(self, s):
        s = self.layer_norm(s)
        s = self.linear_1(s)
        s = self.act(s)
        s = self.linear_2(s)
        s = self.act(s)
        s = self.linear_3(s)
        return s


class EnhancedHeadBase(nn.Module):
    def __init__(self, d_in, d_out, disable_enhance_head):
        super(EnhancedHeadBase, self).__init__()
        if disable_enhance_head:
            self.layer_norm = None
            self.linear_in = None
        else:
            self.layer_norm = LayerNorm(d_in)
            self.linear_in = Linear(d_in, d_in, init="relu")
        self.act = nn.GELU()
        self.linear = Linear(d_in, d_out, init="final")

    def apply_alphafold_original_mode(self):
        self.layer_norm = None
        self.linear_in = None

    def forward(self, x):
        if self.layer_norm is not None:
            x = self.layer_norm(x)
            x = self.act(self.linear_in(x))
        logits = self.linear(x)
        return logits


class DistogramHead(EnhancedHeadBase):
    def __init__(self, d_pair, num_bins, disable_enhance_head, **kwargs):
        super(DistogramHead, self).__init__(
            d_in=d_pair,
            d_out=num_bins,
            disable_enhance_head=disable_enhance_head,
        )

    def forward(self, x):
        logits = super().forward(x)
        logits = logits + logits.transpose(-2, -3)
        return logits


class PredictedAlignedErrorHead(EnhancedHeadBase):
    def __init__(self, d_pair, num_bins, disable_enhance_head, **kwargs):
        super(PredictedAlignedErrorHead, self).__init__(
            d_in=d_pair,
            d_out=num_bins,
            disable_enhance_head=disable_enhance_head,
        )


class MaskedMSAHead(EnhancedHeadBase):
    def __init__(self, d_msa, d_out, disable_enhance_head, **kwargs):
        super(MaskedMSAHead, self).__init__(
            d_in=d_msa,
            d_out=d_out,
            disable_enhance_head=disable_enhance_head,
        )


class ExperimentallyResolvedHead(EnhancedHeadBase):
    def __init__(self, d_single, d_out, disable_enhance_head, **kwargs):
        super(ExperimentallyResolvedHead, self).__init__(
            d_in=d_single,
            d_out=d_out,
            disable_enhance_head=disable_enhance_head,
        )