model.py 14.5 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
3
#
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
4
5
6
7
8
9
10
11
12
13
14
15
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
from functools import partial
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
17
18
19
import torch
import torch.nn as nn

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
20
from openfold.utils.feats import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
21
22
23
24
25
26
    pseudo_beta_fn,
    build_extra_msa_feat,
    build_template_angle_feat,
    build_template_pair_feat,
    atom14_to_atom37,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
27
from openfold.model.embedders import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
28
    InputEmbedder,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
29
30
31
32
33
    RecyclingEmbedder,
    TemplateAngleEmbedder,
    TemplatePairEmbedder,
    ExtraMSAEmbedder,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
34
35
36
37
38
from openfold.model.evoformer import EvoformerStack, ExtraMSAStack
from openfold.model.heads import AuxiliaryHeads
import openfold.np.residue_constants as residue_constants
from openfold.model.structure_module import StructureModule
from openfold.model.template import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
39
    TemplatePairStack,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
40
41
    TemplatePointwiseAttention,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
42
from openfold.utils.loss import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
43
44
    compute_plddt,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
45
from openfold.utils.tensor_utils import (
46
    dict_multimap,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
47
48
    tensor_tree_map,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
49

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
50
51

class AlphaFold(nn.Module):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
52
53
    """
    Alphafold 2.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
54

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
55
    Implements Algorithm 2 (but with training).
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
56
    """
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
57

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
58
59
    def __init__(self, config):
        """
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
60
61
62
        Args:
            config:
                A dict-like config object (like the one in config.py)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
63
64
65
        """
        super(AlphaFold, self).__init__()

66
67
        self.globals = config.globals
        config = config.model
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        template_config = config.template
        extra_msa_config = config.extra_msa

        # Main trunk + structure module
        self.input_embedder = InputEmbedder(
            **config["input_embedder"],
        )
        self.recycling_embedder = RecyclingEmbedder(
            **config["recycling_embedder"],
        )
        self.template_angle_embedder = TemplateAngleEmbedder(
            **template_config["template_angle_embedder"],
        )
        self.template_pair_embedder = TemplatePairEmbedder(
            **template_config["template_pair_embedder"],
        )
        self.template_pair_stack = TemplatePairStack(
            **template_config["template_pair_stack"],
        )
        self.template_pointwise_att = TemplatePointwiseAttention(
            **template_config["template_pointwise_attention"],
        )
        self.extra_msa_embedder = ExtraMSAEmbedder(
            **extra_msa_config["extra_msa_embedder"],
        )
        self.extra_msa_stack = ExtraMSAStack(
            **extra_msa_config["extra_msa_stack"],
        )
        self.evoformer = EvoformerStack(
            **config["evoformer_stack"],
        )
        self.structure_module = StructureModule(
            **config["structure_module"],
        )

        self.aux_heads = AuxiliaryHeads(
            config["heads"],
        )

        self.config = config

109
    def embed_templates(self, batch, z, pair_mask, templ_dim, chunk_size):
110
111
        # Embed the templates one at a time (with a poor man's vmap)
        template_embeds = []
112
        n_templ = batch["template_aatype"].shape[templ_dim]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
113
        for i in range(n_templ):
114
115
116
117
118
            idx = batch["template_aatype"].new_tensor(i)
            single_template_feats = tensor_tree_map(
                lambda t: torch.index_select(t, templ_dim, idx),
                batch,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
119

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
120
            single_template_embeds = {}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
121
            if self.config.template.embed_angles:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
122
123
124
                template_angle_feat = build_template_angle_feat(
                    single_template_feats,
                )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
125

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
126
127
128
129
                # [*, S_t, N, C_m]
                a = self.template_angle_embedder(template_angle_feat)

                single_template_embeds["angle"] = a
130
131
132
133

            # [*, S_t, N, N, C_t]
            t = build_template_pair_feat(
                single_template_feats,
134
                inf=self.config.template.inf,
135
                eps=self.config.template.eps,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
136
                **self.config.template.distogram,
137
138
139
            )
            t = self.template_pair_embedder(t)
            t = self.template_pair_stack(
140
141
142
143
                t, 
                pair_mask.unsqueeze(-3), 
                chunk_size=chunk_size,
                _mask_trans=self.config._mask_trans,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
144
145
            )

146
            single_template_embeds.update({"pair": t})
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
147

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
148
            template_embeds.append(single_template_embeds)
149
150
151
152
153

        template_embeds = dict_multimap(
            partial(torch.cat, dim=templ_dim),
            template_embeds,
        )
154

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
155
156
        # [*, N, N, C_z]
        t = self.template_pointwise_att(
157
158
159
160
            template_embeds["pair"], 
            z, 
            template_mask=batch["template_mask"],
            chunk_size=chunk_size,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
161
        )
162
        t = t * (torch.sum(batch["template_mask"]) > 0)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
163

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
164
        ret = {}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
165
        if self.config.template.embed_angles:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
166
            ret["template_angle_embedding"] = template_embeds["angle"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
167

168
        ret.update({"template_pair_embedding": t})
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
169
170

        return ret
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
171

172
    def iteration(self, feats, m_1_prev, z_prev, x_prev):
173
174
175
176
177
178
        # Establish constants
        chunk_size = (
            self.globals.train_chunk_size 
            if self.training else self.globals.eval_chunk_size
        )

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        # Primary output dictionary
        outputs = {}

        # Grab some data about the input
        batch_dims = feats["target_feat"].shape[:-2]
        no_batch_dims = len(batch_dims)
        n = feats["target_feat"].shape[-2]
        n_seq = feats["msa_feat"].shape[-3]
        device = feats["target_feat"].device

        # Prep some features
        seq_mask = feats["seq_mask"]
        pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
        msa_mask = feats["msa_mask"]

        # Initialize the MSA and pair representations

        # m: [*, S_c, N, C_m]
        # z: [*, N, N, C_z]
        m, z = self.input_embedder(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
199
200
            feats["target_feat"],
            feats["residue_index"],
201
202
203
204
            feats["msa_feat"],
        )

        # Inject information from previous recycling iterations
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
205
        if self.config.num_recycle > 0:
206
            # Initialize the recycling embeddings, if needs be
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
207
            if None in [m_1_prev, z_prev, x_prev]:
208
209
                # [*, N, C_m]
                m_1_prev = m.new_zeros(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
210
                    (*batch_dims, n, self.config.input_embedder.c_m),
211
212
213
214
                )

                # [*, N, N, C_z]
                z_prev = z.new_zeros(
215
                    (*batch_dims, n, n, self.config.input_embedder.c_z),
216
217
218
219
220
221
222
                )

                # [*, N, 3]
                x_prev = z.new_zeros(
                    (*batch_dims, n, residue_constants.atom_type_num, 3),
                )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
223
            x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None)
224
225
226
227

            # m_1_prev_emb: [*, N, C_m]
            # z_prev_emb: [*, N, N, C_z]
            m_1_prev_emb, z_prev_emb = self.recycling_embedder(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
228
229
                m_1_prev,
                z_prev,
230
231
232
233
                x_prev,
            )

            # [*, S_c, N, C_m]
234
            m[..., 0, :, :] = m[..., 0, :, :] + m_1_prev_emb
235
236
237
238

            # [*, N, N, C_z]
            z = z + z_prev_emb

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
239
240
241
            # This can matter during inference when N_res is very large
            del m_1_prev_emb, z_prev_emb

242
        # Embed the templates + merge with MSA/pair embeddings
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
243
        if self.config.template.enabled:
244
            template_feats = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
245
                k: v for k, v in feats.items() if k.startswith("template_")
246
247
248
249
250
251
            }
            template_embeds = self.embed_templates(
                template_feats,
                z,
                pair_mask,
                no_batch_dims,
252
                chunk_size,
253
254
255
256
            )

            # [*, N, N, C_z]
            z = z + template_embeds["template_pair_embedding"]
257

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
258
            if self.config.template.embed_angles:
259
260
                # [*, S = S_c + S_t, N, C_m]
                m = torch.cat(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
261
                    [m, template_embeds["template_angle_embedding"]], dim=-3
262
263
264
                )

                # [*, S, N]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
265
                torsion_angles_mask = feats["template_torsion_angles_mask"]
266
267
268
269
                msa_mask = torch.cat(
                    [feats["msa_mask"], torsion_angles_mask[..., 2]], axis=-2
                )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
270
271
        # Embed extra MSA features + merge with pairwise embeddings
        if self.config.extra_msa.enabled:
272
273
            # [*, S_e, N, C_e]
            a = self.extra_msa_embedder(build_extra_msa_feat(feats))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
274

275
276
            # [*, N, N, C_z]
            z = self.extra_msa_stack(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
277
278
                a,
                z,
279
                msa_mask=feats["extra_msa_mask"],
280
                chunk_size=chunk_size,
281
282
283
284
285
286
287
288
289
                pair_mask=pair_mask,
                _mask_trans=self.config._mask_trans,
            )

        # Run MSA + pair embeddings through the trunk of the network
        # m: [*, S, N, C_m]
        # z: [*, N, N, C_z]
        # s: [*, N, C_s]
        m, z, s = self.evoformer(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
290
291
292
            m,
            z,
            msa_mask=msa_mask,
293
            pair_mask=pair_mask,
294
            chunk_size=chunk_size,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
295
            _mask_trans=self.config._mask_trans,
296
297
298
299
300
301
302
303
        )

        outputs["msa"] = m[..., :n_seq, :, :]
        outputs["pair"] = z
        outputs["single"] = s

        # Predict 3D structure
        outputs["sm"] = self.structure_module(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
304
305
306
307
308
            s,
            z,
            feats["aatype"],
            mask=feats["seq_mask"],
        )
309
310
311
312
        outputs["final_atom_positions"] = atom14_to_atom37(
            outputs["sm"]["positions"][-1], feats
        )
        outputs["final_atom_mask"] = feats["atom37_atom_exists"]
313
        outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1]
314

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
315
        # Save embeddings for use during the next recycling iteration
316
317
318
319
320
321
322
323
324
325
326
327

        # [*, N, C_m]
        m_1_prev = m[..., 0, :, :]

        # [* N, N, C_z]
        z_prev = z

        # [*, N, 3]
        x_prev = outputs["final_atom_positions"]

        return outputs, m_1_prev, z_prev, x_prev

328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
    def _disable_activation_checkpointing(self):
        self.template_pair_stack.blocks_per_ckpt = None
        self.evoformer.blocks_per_ckpt = None
        self.extra_msa_stack.stack.blocks_per_ckpt = None

    def _enable_activation_checkpointing(self):
        self.template_pair_stack.blocks_per_ckpt = (
            self.config.template.template_pair_stack.blocks_per_ckpt
        )
        self.evoformer.blocks_per_ckpt = (
            self.config.evoformer_stack.blocks_per_ckpt
        )
        self.extra_msa_stack.stack.blocks_per_ckpt = (
            self.config.extra_msa.extra_msa_stack.blocks_per_ckpt
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
344
345
    def forward(self, batch):
        """
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
        Args:
            batch:
                Dictionary of arguments outlined in Algorithm 2. Keys must
                include the official names of the features in the
                supplement subsection 1.2.9.

                The final dimension of each input must have length equal to
                the number of recycling iterations.

                Features (without the recycling dimension):

                    "aatype" ([*, N_res]):
                        Contrary to the supplement, this tensor of residue
                        indices is not one-hot.
                    "target_feat" ([*, N_res, C_tf])
                        One-hot encoding of the target sequence. C_tf is
                        config.model.input_embedder.tf_dim.
                    "residue_index" ([*, N_res])
                        Tensor whose final dimension consists of
                        consecutive indices from 0 to N_res.
                    "msa_feat" ([*, N_seq, N_res, C_msa])
                        MSA features, constructed as in the supplement.
                        C_msa is config.model.input_embedder.msa_dim.
                    "seq_mask" ([*, N_res])
                        1-D sequence mask
                    "msa_mask" ([*, N_seq, N_res])
                        MSA mask
                    "pair_mask" ([*, N_res, N_res])
                        2-D pair mask
                    "extra_msa_mask" ([*, N_extra, N_res])
                        Extra MSA mask
                    "template_mask" ([*, N_templ])
                        Template mask (on the level of templates, not
                        residues)
                    "template_aatype" ([*, N_templ, N_res])
                        Tensor of template residue indices (indices greater
                        than 19 are clamped to 20 (Unknown))
                    "template_all_atom_positions"
                        ([*, N_templ, N_res, 37, 3])
                        Template atom coordinates in atom37 format
                    "template_all_atom_mask" ([*, N_templ, N_res, 37])
                        Template atom coordinate mask
                    "template_pseudo_beta" ([*, N_templ, N_res, 3])
                        Positions of template carbon "pseudo-beta" atoms
                        (i.e. C_beta for all residues but glycine, for
                        for which C_alpha is used instead)
                    "template_pseudo_beta_mask" ([*, N_templ, N_res])
                        Pseudo-beta mask
394
        """
395
        # Initialize recycling embeddings
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
396
        m_1_prev, z_prev, x_prev = None, None, None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
397

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
398
        is_grad_enabled = torch.is_grad_enabled()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
399
        self._disable_activation_checkpointing()
400

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
401
        # Main recycling loop
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
402
        for cycle_no in range(self.config.num_recycle + 1):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
403
            # Select the features for the current recycling cycle
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
404
            fetch_cur_batch = lambda t: t[..., cycle_no]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
405
            feats = tensor_tree_map(fetch_cur_batch, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
406

407
            # Enable grad iff we're training and it's the final recycling layer
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
408
            is_final_iter = cycle_no == self.config.num_recycle
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
409
410
            with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
                # Sidestep AMP bug discussed in pytorch issue #65766
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
411
                if is_final_iter:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
412
                    self._enable_activation_checkpointing()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
413
                    if torch.is_autocast_enabled():
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
414
                        torch.clear_autocast_cache()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
415
416
                # Run the next iteration of the model
                outputs, m_1_prev, z_prev, x_prev = self.iteration(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
417
418
419
420
                    feats,
                    m_1_prev,
                    z_prev,
                    x_prev,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
421
                )
422

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
423
        # Run auxiliary heads
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
424
425
426
        outputs.update(self.aux_heads(outputs))

        return outputs