model.py 15.2 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
3
#
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
4
5
6
7
8
9
10
11
12
13
14
15
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
from functools import partial
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
17
18
19
import torch
import torch.nn as nn

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
20
from openfold.utils.feats import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
21
22
23
24
25
26
    pseudo_beta_fn,
    build_extra_msa_feat,
    build_template_angle_feat,
    build_template_pair_feat,
    atom14_to_atom37,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
27
from openfold.model.embedders import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
28
    InputEmbedder,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
29
30
31
32
33
    RecyclingEmbedder,
    TemplateAngleEmbedder,
    TemplatePairEmbedder,
    ExtraMSAEmbedder,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
34
35
36
37
38
from openfold.model.evoformer import EvoformerStack, ExtraMSAStack
from openfold.model.heads import AuxiliaryHeads
import openfold.np.residue_constants as residue_constants
from openfold.model.structure_module import StructureModule
from openfold.model.template import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
39
    TemplatePairStack,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
40
41
    TemplatePointwiseAttention,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
42
from openfold.utils.loss import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
43
44
    compute_plddt,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
45
from openfold.utils.tensor_utils import (
46
    dict_multimap,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
47
48
    tensor_tree_map,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
49

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
50
51

class AlphaFold(nn.Module):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
52
53
    """
    Alphafold 2.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
54

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
55
    Implements Algorithm 2 (but with training).
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
56
    """
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
57

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
58
59
    def __init__(self, config):
        """
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
60
61
62
        Args:
            config:
                A dict-like config object (like the one in config.py)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
63
64
65
        """
        super(AlphaFold, self).__init__()

66
67
        self.globals = config.globals
        config = config.model
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        template_config = config.template
        extra_msa_config = config.extra_msa

        # Main trunk + structure module
        self.input_embedder = InputEmbedder(
            **config["input_embedder"],
        )
        self.recycling_embedder = RecyclingEmbedder(
            **config["recycling_embedder"],
        )
        self.template_angle_embedder = TemplateAngleEmbedder(
            **template_config["template_angle_embedder"],
        )
        self.template_pair_embedder = TemplatePairEmbedder(
            **template_config["template_pair_embedder"],
        )
        self.template_pair_stack = TemplatePairStack(
            **template_config["template_pair_stack"],
        )
        self.template_pointwise_att = TemplatePointwiseAttention(
            **template_config["template_pointwise_attention"],
        )
        self.extra_msa_embedder = ExtraMSAEmbedder(
            **extra_msa_config["extra_msa_embedder"],
        )
        self.extra_msa_stack = ExtraMSAStack(
            **extra_msa_config["extra_msa_stack"],
        )
        self.evoformer = EvoformerStack(
            **config["evoformer_stack"],
        )
        self.structure_module = StructureModule(
            **config["structure_module"],
        )

        self.aux_heads = AuxiliaryHeads(
            config["heads"],
        )

        self.config = config

109
    def embed_templates(self, batch, z, pair_mask, templ_dim): 
110
111
        # Embed the templates one at a time (with a poor man's vmap)
        template_embeds = []
112
        n_templ = batch["template_aatype"].shape[templ_dim]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
113
        for i in range(n_templ):
114
115
116
117
118
            idx = batch["template_aatype"].new_tensor(i)
            single_template_feats = tensor_tree_map(
                lambda t: torch.index_select(t, templ_dim, idx),
                batch,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
119

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
120
            single_template_embeds = {}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
121
            if self.config.template.embed_angles:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
122
123
124
                template_angle_feat = build_template_angle_feat(
                    single_template_feats,
                )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
125

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
126
127
128
129
                # [*, S_t, N, C_m]
                a = self.template_angle_embedder(template_angle_feat)

                single_template_embeds["angle"] = a
130
131
132
133

            # [*, S_t, N, N, C_t]
            t = build_template_pair_feat(
                single_template_feats,
134
                inf=self.config.template.inf,
135
                eps=self.config.template.eps,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
136
                **self.config.template.distogram,
137
            ).to(z.dtype)
138
            t = self.template_pair_embedder(t)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
139

140
            single_template_embeds.update({"pair": t})
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
141

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
142
            template_embeds.append(single_template_embeds)
143
144
145
146
147

        template_embeds = dict_multimap(
            partial(torch.cat, dim=templ_dim),
            template_embeds,
        )
148

149
        # [*, S_t, N, N, C_z]
150
        t = self.template_pair_stack(
151
            template_embeds["pair"], 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
152
            pair_mask.unsqueeze(-3).to(dtype=z.dtype), 
153
            chunk_size=self.globals.chunk_size,
154
155
            _mask_trans=self.config._mask_trans,
        )
156
157

        # [*, N, N, C_z]
158
159
        t = self.template_pointwise_att(
            t, 
160
            z, 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
161
            template_mask=batch["template_mask"].to(dtype=z.dtype),
162
            chunk_size=self.globals.chunk_size,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
163
        )
164
        t = t * (torch.sum(batch["template_mask"]) > 0)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
165

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
166
        ret = {}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
167
        if self.config.template.embed_angles:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
168
            ret["template_angle_embedding"] = template_embeds["angle"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
169

170
        ret.update({"template_pair_embedding": t})
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
171
172

        return ret
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
173

174
    def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True):
175
176
177
        # Primary output dictionary
        outputs = {}

178
179
180
181
182
183
        # This needs to be done manually for DeepSpeed's sake
        dtype = next(self.parameters()).dtype
        for k in feats:
            if(feats[k].dtype == torch.float32):
                feats[k] = feats[k].to(dtype=dtype)

184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        # Grab some data about the input
        batch_dims = feats["target_feat"].shape[:-2]
        no_batch_dims = len(batch_dims)
        n = feats["target_feat"].shape[-2]
        n_seq = feats["msa_feat"].shape[-3]
        device = feats["target_feat"].device

        # Prep some features
        seq_mask = feats["seq_mask"]
        pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
        msa_mask = feats["msa_mask"]

        # Initialize the MSA and pair representations

        # m: [*, S_c, N, C_m]
        # z: [*, N, N, C_z]
        m, z = self.input_embedder(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
201
202
            feats["target_feat"],
            feats["residue_index"],
203
204
205
            feats["msa_feat"],
        )

206
207
208
209
210
211
212
        # Initialize the recycling embeddings, if needs be
        if None in [m_1_prev, z_prev, x_prev]:
            # [*, N, C_m]
            m_1_prev = m.new_zeros(
                (*batch_dims, n, self.config.input_embedder.c_m),
                requires_grad=False,
            )
213

214
215
216
217
218
            # [*, N, N, C_z]
            z_prev = z.new_zeros(
                (*batch_dims, n, n, self.config.input_embedder.c_z),
                requires_grad=False,
            )
219

220
221
222
223
224
            # [*, N, 3]
            x_prev = z.new_zeros(
                (*batch_dims, n, residue_constants.atom_type_num, 3),
                requires_grad=False,
            )
225

226
227
228
        x_prev = pseudo_beta_fn(
            feats["aatype"], x_prev, None
        ).to(dtype=z.dtype)
229

230
231
232
233
234
235
236
        # m_1_prev_emb: [*, N, C_m]
        # z_prev_emb: [*, N, N, C_z]
        m_1_prev_emb, z_prev_emb = self.recycling_embedder(
            m_1_prev,
            z_prev,
            x_prev,
        )
237

238
239
240
241
242
243
244
        # If the number of recycling iterations is 0, skip recycling
        # altogether. We zero them this way instead of computing them
        # conditionally to avoid leaving parameters unused, which has annoying
        # implications for DDP training.
        if(not _recycle):
            m_1_prev_emb *= 0
            z_prev_emb *= 0
245

246
247
        # [*, S_c, N, C_m]
        m[..., 0, :, :] += m_1_prev_emb
248

249
250
251
252
253
        # [*, N, N, C_z]
        z += z_prev_emb

        # Possibly prevents memory fragmentation
        del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
254

255
        # Embed the templates + merge with MSA/pair embeddings
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
256
        if self.config.template.enabled:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
257
258
259
260
261
262
263
264
265
            template_feats = {
                k: v for k, v in feats.items() if k.startswith("template_")
            }
            template_embeds = self.embed_templates(
                template_feats,
                z,
                pair_mask.to(dtype=z.dtype),
                no_batch_dims,
            )
266

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
267
268
            # [*, N, N, C_z]
            z = z + template_embeds["template_pair_embedding"]
269

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
270
271
272
273
274
275
            if self.config.template.embed_angles:
                # [*, S = S_c + S_t, N, C_m]
                m = torch.cat(
                    [m, template_embeds["template_angle_embedding"]], 
                    dim=-3
                )
276

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
277
278
279
280
281
282
                # [*, S, N]
                torsion_angles_mask = feats["template_torsion_angles_mask"]
                msa_mask = torch.cat(
                    [feats["msa_mask"], torsion_angles_mask[..., 2]], 
                    dim=-2
                )
283

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
284
285
        # Embed extra MSA features + merge with pairwise embeddings
        if self.config.extra_msa.enabled:
286
287
            # [*, S_e, N, C_e]
            a = self.extra_msa_embedder(build_extra_msa_feat(feats))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
288

289
290
            # [*, N, N, C_z]
            z = self.extra_msa_stack(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
291
292
                a,
                z,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
293
                msa_mask=feats["extra_msa_mask"].to(dtype=a.dtype),
294
                chunk_size=self.globals.chunk_size,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
295
                pair_mask=pair_mask.to(dtype=z.dtype),
296
297
298
299
300
301
302
303
                _mask_trans=self.config._mask_trans,
            )

        # Run MSA + pair embeddings through the trunk of the network
        # m: [*, S, N, C_m]
        # z: [*, N, N, C_z]
        # s: [*, N, C_s]
        m, z, s = self.evoformer(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
304
305
            m,
            z,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
306
307
            msa_mask=msa_mask.to(dtype=m.dtype),
            pair_mask=pair_mask.to(dtype=z.dtype),
308
            chunk_size=self.globals.chunk_size,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
309
            _mask_trans=self.config._mask_trans,
310
311
312
313
314
315
316
317
        )

        outputs["msa"] = m[..., :n_seq, :, :]
        outputs["pair"] = z
        outputs["single"] = s

        # Predict 3D structure
        outputs["sm"] = self.structure_module(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
318
319
320
            s,
            z,
            feats["aatype"],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
321
            mask=feats["seq_mask"].to(dtype=s.dtype),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
322
        )
323
324
325
326
        outputs["final_atom_positions"] = atom14_to_atom37(
            outputs["sm"]["positions"][-1], feats
        )
        outputs["final_atom_mask"] = feats["atom37_atom_exists"]
327
        outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1]
328

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
329
        # Save embeddings for use during the next recycling iteration
330
331
332
333

        # [*, N, C_m]
        m_1_prev = m[..., 0, :, :]

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
334
        # [*, N, N, C_z]
335
336
337
338
339
340
341
        z_prev = z

        # [*, N, 3]
        x_prev = outputs["final_atom_positions"]

        return outputs, m_1_prev, z_prev, x_prev

342
343
344
    def _disable_activation_checkpointing(self):
        self.template_pair_stack.blocks_per_ckpt = None
        self.evoformer.blocks_per_ckpt = None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
345
346
347

        for b in self.extra_msa_stack.blocks:
            b.ckpt = False
348
349
350
351
352
353
354
355

    def _enable_activation_checkpointing(self):
        self.template_pair_stack.blocks_per_ckpt = (
            self.config.template.template_pair_stack.blocks_per_ckpt
        )
        self.evoformer.blocks_per_ckpt = (
            self.config.evoformer_stack.blocks_per_ckpt
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
356
357
358

        for b in self.extra_msa_stack.blocks:
            b.ckpt = self.config.extra_msa.extra_msa_stack.ckpt
359

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
360
361
    def forward(self, batch):
        """
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
        Args:
            batch:
                Dictionary of arguments outlined in Algorithm 2. Keys must
                include the official names of the features in the
                supplement subsection 1.2.9.

                The final dimension of each input must have length equal to
                the number of recycling iterations.

                Features (without the recycling dimension):

                    "aatype" ([*, N_res]):
                        Contrary to the supplement, this tensor of residue
                        indices is not one-hot.
                    "target_feat" ([*, N_res, C_tf])
                        One-hot encoding of the target sequence. C_tf is
                        config.model.input_embedder.tf_dim.
                    "residue_index" ([*, N_res])
                        Tensor whose final dimension consists of
                        consecutive indices from 0 to N_res.
                    "msa_feat" ([*, N_seq, N_res, C_msa])
                        MSA features, constructed as in the supplement.
                        C_msa is config.model.input_embedder.msa_dim.
                    "seq_mask" ([*, N_res])
                        1-D sequence mask
                    "msa_mask" ([*, N_seq, N_res])
                        MSA mask
                    "pair_mask" ([*, N_res, N_res])
                        2-D pair mask
                    "extra_msa_mask" ([*, N_extra, N_res])
                        Extra MSA mask
                    "template_mask" ([*, N_templ])
                        Template mask (on the level of templates, not
                        residues)
                    "template_aatype" ([*, N_templ, N_res])
                        Tensor of template residue indices (indices greater
                        than 19 are clamped to 20 (Unknown))
                    "template_all_atom_positions"
                        ([*, N_templ, N_res, 37, 3])
                        Template atom coordinates in atom37 format
                    "template_all_atom_mask" ([*, N_templ, N_res, 37])
                        Template atom coordinate mask
                    "template_pseudo_beta" ([*, N_templ, N_res, 3])
                        Positions of template carbon "pseudo-beta" atoms
                        (i.e. C_beta for all residues but glycine, for
                        for which C_alpha is used instead)
                    "template_pseudo_beta_mask" ([*, N_templ, N_res])
                        Pseudo-beta mask
410
        """
411
        # Initialize recycling embeddings
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
412
        m_1_prev, z_prev, x_prev = None, None, None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
413

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
414
        # Disable activation checkpointing for the first few recycling iters
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
415
        is_grad_enabled = torch.is_grad_enabled()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
416
        self._disable_activation_checkpointing()
417

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
418
        # Main recycling loop
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
419
420
        num_iters = batch["aatype"].shape[-1]
        for cycle_no in range(num_iters):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
421
            # Select the features for the current recycling cycle
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
422
            fetch_cur_batch = lambda t: t[..., cycle_no]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
423
            feats = tensor_tree_map(fetch_cur_batch, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
424

425
            # Enable grad iff we're training and it's the final recycling layer
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
426
            is_final_iter = cycle_no == (num_iters - 1)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
427
            with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
428
                if is_final_iter:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
429
                    self._enable_activation_checkpointing()
430
                    # Sidestep AMP bug (PyTorch issue #65766)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
431
                    if torch.is_autocast_enabled():
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
432
                        torch.clear_autocast_cache()
433

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
434
435
                # Run the next iteration of the model
                outputs, m_1_prev, z_prev, x_prev = self.iteration(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
436
437
438
439
                    feats,
                    m_1_prev,
                    z_prev,
                    x_prev,
440
                    _recycle=(num_iters > 1)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
441
                )
442

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
443
        # Run auxiliary heads
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
444
445
446
        outputs.update(self.aux_heads(outputs))

        return outputs