Commit 92f1932e authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Initial commit

parent 3d9c2de3
MIT License
Apache License
Copyright (c) 2021 AQ Laboratory Version 2.0, January 2004
http://www.apache.org/licenses/
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 1. Definitions.
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions: "License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software. "Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, "Legal Entity" shall mean the union of the acting entity and all
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE other entities that control, are controlled by, or are under common
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER control with that entity. For the purposes of this definition,
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, "control" means (i) the power, direct or indirect, to cause the
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE direction or management of such entity, whether by contract or
SOFTWARE. otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from functools import partialmethod
from typing import Union, List
class Dropout(nn.Module):
"""
Implementation of dropout with the ability to share the dropout mask
along a particular dimension.
If not in training mode, this module computes the identity function.
"""
def __init__(self, r: float, batch_dim: Union[int, List[int]]):
"""
Args:
r:
Dropout rate
batch_dim:
Dimension(s) along which the dropout mask is shared
"""
super(Dropout, self).__init__()
self.r = r
if(type(batch_dim) == int):
batch_dim = [batch_dim]
self.batch_dim = batch_dim
self.dropout = nn.Dropout(self.r)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
Tensor to which dropout is applied. Can have any shape
compatible with self.batch_dim
"""
shape = list(x.shape)
if(self.batch_dim is not None):
for bd in self.batch_dim:
shape[bd] = 1
mask = x.new_ones(shape, requires_grad=False)
mask = self.dropout(mask)
x = x * mask
return x
class DropoutRowwise(Dropout):
"""
Convenience class for rowwise dropout as described in subsection
1.11.6.
"""
__init__ = partialmethod(Dropout.__init__, batch_dim=-3)
class DropoutColumnwise(Dropout):
"""
Convenience class for columnwise dropout as described in subsection
1.11.6.
"""
__init__ = partialmethod(Dropout.__init__, batch_dim=-2)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from typing import Tuple
from alphafold.model.primitives import Linear
from alphafold.utils.tensor_utils import one_hot
class InputEmbedder(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
relpos_k: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedder, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.relpos_k = relpos_k
self.no_bins = 2 * relpos_k + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self,
ri: torch.Tensor
):
"""
Computes relative positional encodings
Implements Algorithm 4.
Args:
ri:
"residue_index" features of shape [*, N]
"""
d = ri[..., None] - ri[..., None, :]
boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
)
oh = one_hot(d, boundaries)
return self.linear_relpos(oh)
def forward(self,
tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
tf:
"target_feat" features of shape [*, N_res, tf_dim]
ri:
"residue_index" features of shape [*, N_res]
msa:
"msa_feat" features of shape [*, N_clust, N_res, msa_dim]
Returns:
msa_emb:
[*, N_clust, N_res, C_m] MSA embedding
pair_emb:
[*, N_res, N_res, C_z] pair embedding
"""
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb += self.relpos(ri)
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (self.linear_tf_m(tf)
.unsqueeze(-3)
.expand((*(-1,) * len(tf.shape[:-2]), n_clust, -1, -1)))
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
class RecyclingEmbedder(nn.Module):
"""
Embeds the output of an iteration of the model for recycling.
Implements Algorithm 32.
"""
def __init__(self,
c_m: int,
c_z: int,
min_bin: float,
max_bin: float,
no_bins: int,
inf: float = 1e8,
**kwargs
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair embedding channel dimension
min_bin:
Smallest distogram bin (Angstroms)
max_bin:
Largest distogram bin (Angstroms)
no_bins:
Number of distogram bins
"""
super(RecyclingEmbedder, self).__init__()
self.c_m = c_m
self.c_z = c_z
self.min_bin = min_bin
self.max_bin = max_bin
self.no_bins = no_bins
self.inf = inf
self.bins = None
self.linear = Linear(self.no_bins, self.c_z)
self.layer_norm_m = nn.LayerNorm(self.c_m)
self.layer_norm_z = nn.LayerNorm(self.c_z)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
m:
First row of the MSA embedding. [*, N_res, C_m]
z:
[*, N_res, N_res, C_z] pair embedding
x:
[*, N_res, 3] predicted C_beta coordinates
Returns:
m:
[*, N_res, C_m] MSA embedding update
z:
[*, N_res, N_res, C_z] pair embedding update
"""
if(self.bins is None):
self.bins = torch.linspace(
self.min_bin,
self.max_bin,
self.no_bins,
requires_grad=False,
device=x.device
)
# [*, N, C_m]
m_update = self.layer_norm_m(m)
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
squared_bins = self.bins ** 2
upper = torch.cat(
[
squared_bins[1:],
squared_bins.new_tensor([self.inf])
], dim=-1
)
d = torch.sum(
(x[..., None, :] - x[..., None, :, :]) ** 2,
dim=-1,
keepdims=True
)
# [*, N, N, no_bins]
d = ((d > squared_bins) * (d < upper)).type(x.dtype)
# [*, N, N, C_z]
d = self.linear(d)
z_update = d + self.layer_norm_z(z)
return m_update, z_update
class TemplateAngleEmbedder(nn.Module):
"""
Embeds the "template_angle_feat" feature.
Implements Algorithm 2, line 7.
"""
def __init__(self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
Final dimension of "template_angle_feat"
c_out:
Output channel dimension
"""
super(TemplateAngleEmbedder, self).__init__()
self.c_out = c_out
self.c_in = c_in
self.linear_1 = Linear(self.c_in, self.c_out, init="relu")
self.relu = nn.ReLU()
self.linear_2 = Linear(self.c_out, self.c_out, init="relu")
def forward(self,
x: torch.Tensor
) -> torch.Tensor:
"""
Args:
x: [*, N_templ, N_res, c_in] "template_angle_feat" features
Returns:
x: [*, N_templ, N_res, C_out] embedding
"""
x = self.linear_1(x)
x = self.relu(x)
x = self.linear_2(x)
return x
class TemplatePairEmbedder(nn.Module):
"""
Embeds "template_pair_feat" features.
Implements Algorithm 2, line 9.
"""
def __init__(self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
c_out:
Output channel dimension
"""
super(TemplatePairEmbedder, self).__init__()
self.c_in = c_in
self.c_out = c_out
# Despite there being no relu nearby, the source uses that initializer
self.linear = Linear(self.c_in, self.c_out, init="relu")
def forward(self,
x: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
[*, C_in] input tensor
Returns:
[*, C_out] output tensor
"""
x = self.linear(x)
return x
class ExtraMSAEmbedder(nn.Module):
"""
Embeds unclustered MSA sequences.
Implements Algorithm 2, line 15
"""
def __init__(self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
Input channel dimension
c_out:
Output channel dimension
"""
super(ExtraMSAEmbedder, self).__init__()
self.c_in = c_in
self.c_out = c_out
self.linear = Linear(self.c_in, self.c_out)
def forward(self,
x: torch.Tensor
) -> torch.Tensor:
"""
Args:
x:
[*, N_extra_seq, N_res, C_in] "extra_msa_feat" features
Returns:
[*, N_extra_seq, N_res, C_out] embedding
"""
x = self.linear(x)
return x
if __name__ == "__main__":
tf_dim = 21
msa_dim = 49
c_z = 128
c_m = 256
relpos_k = 32
b = 16
n_res = 200
n_clust = 10
tf = torch.rand((b, n_res, tf_dim))
ri = torch.rand((b, n_res))
msa = torch.rand((b, n_clust, n_res, msa_dim))
batch = {}
batch["target_feat"] = tf
batch["residue_index"] = ri
batch["msa_feat"] = msa
ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k)
msa_emb, pair_emb = ie(batch)
assert(msa_emb.shape == (b, n_clust, n_res, c_m))
assert(pair_emb.shape == (b, n_res, n_res, c_z))
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from typing import Tuple, Optional
from functools import partial
from alphafold.model.primitives import Linear
from alphafold.utils.deepspeed import checkpoint_blocks
from alphafold.model.dropout import DropoutRowwise, DropoutColumnwise
from alphafold.model.msa import (
MSARowAttentionWithPairBias,
MSAColumnAttention,
MSAColumnGlobalAttention,
)
from alphafold.model.outer_product_mean import OuterProductMean
from alphafold.model.pair_transition import PairTransition
from alphafold.model.triangular_attention import (
TriangleAttentionStartingNode,
TriangleAttentionEndingNode,
)
from alphafold.model.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming,
)
from alphafold.utils.tensor_utils import chunk_layer
class MSATransition(nn.Module):
"""
Feed-forward network applied to MSA activations after attention.
Implements Algorithm 9
"""
def __init__(self, c_m, n, chunk_size):
"""
Args:
c_m:
MSA channel dimension
n:
Factor multiplied to c_m to obtain the hidden channel
dimension
"""
super(MSATransition, self).__init__()
self.c_m = c_m
self.n = n
self.chunk_size = chunk_size
self.layer_norm = nn.LayerNorm(self.c_m)
self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
self.relu = nn.ReLU()
self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
def _transition(self, m, mask):
m = self.linear_1(m)
m = self.relu(m)
m = self.linear_2(m) * mask
return m
def forward(self,
m: torch.Tensor,
mask: torch.Tensor = None,
) -> torch.Tensor:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA activation
mask:
[*, N_seq, N_res, C_m] MSA mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA activation update
"""
# DISCREPANCY: DeepMind forgets to apply the MSA mask here.
if(mask is None):
mask = m.new_ones(m.shape[:-1])
mask = mask.unsqueeze(-1)
m = self.layer_norm(m)
inp = {"m": m, "mask": mask}
if(not self.training and self.chunk_size is not None):
m = chunk_layer(
self._transition,
inp,
chunk_size=self.chunk_size,
no_batch_dims=len(m.shape[:-2]),
)
else:
m = self._transition(**inp)
return m
class EvoformerBlock(nn.Module):
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
chunk_size: int,
inf: float,
eps: float,
_is_extra_msa_stack: bool = False,
):
super(EvoformerBlock, self).__init__()
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
)
if(_is_extra_msa_stack):
self.msa_att_col = MSAColumnGlobalAttention(
c_in=c_m,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
chunk_size=chunk_size,
inf=inf,
eps=eps,
)
else:
self.msa_att_col = MSAColumnAttention(
c_m,
c_hidden_msa_att,
no_heads_msa,
chunk_size=chunk_size,
inf=inf,
)
self.msa_transition = MSATransition(
c_m=c_m,
n=transition_n,
chunk_size=chunk_size,
)
self.outer_product_mean = OuterProductMean(
c_m,
c_z,
c_hidden_opm,
chunk_size=chunk_size,
)
self.tri_mul_out = TriangleMultiplicationOutgoing(
c_z,
c_hidden_mul,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
c_z,
c_hidden_mul,
)
self.tri_att_start = TriangleAttentionStartingNode(
c_z,
c_hidden_pair_att,
no_heads_pair,
chunk_size=chunk_size,
inf=inf,
)
self.tri_att_end = TriangleAttentionEndingNode(
c_z,
c_hidden_pair_att,
no_heads_pair,
chunk_size=chunk_size,
inf=inf,
)
self.pair_transition = PairTransition(
c_z,
transition_n,
chunk_size=chunk_size,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
msa_trans_mask = msa_mask if _mask_trans else None
pair_trans_mask = pair_mask if _mask_trans else None
m = m + self.msa_dropout_layer(self.msa_att_row(m, z, mask=msa_mask))
m = m + self.msa_att_col(m, mask=msa_mask)
m = m + self.msa_transition(m, mask=msa_trans_mask)
z = z + self.outer_product_mean(m, mask=msa_mask)
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(self.tri_att_start(z, mask=pair_mask))
z = z + self.ps_dropout_col_layer(self.tri_att_end(z, mask=pair_mask))
z = z + self.pair_transition(z, mask=pair_trans_mask)
return m, z
class EvoformerStack(nn.Module):
"""
Main Evoformer trunk.
Implements Algorithm 6.
"""
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
c_s: int,
no_heads_msa: int,
no_heads_pair: int,
no_blocks: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
blocks_per_ckpt: int,
chunk_size: int,
inf: float,
eps: float,
_is_extra_msa_stack: bool = False,
**kwargs,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair channel dimension
c_hidden_msa_att:
Hidden dimension in MSA attention
c_hidden_opm:
Hidden dimension in outer product mean module
c_hidden_mul:
Hidden dimension in multiplicative updates
c_hidden_pair_att:
Hidden dimension in triangular attention
c_s:
Channel dimension of the output "single" embedding
no_heads_msa:
Number of heads used for MSA attention
no_heads_pair:
Number of heads used for pair attention
no_blocks:
Number of Evoformer blocks in the stack
transition_n:
Factor by which to multiply c_m to obtain the MSATransition
hidden dimension
msa_dropout:
Dropout rate for MSA activations
pair_dropout:
Dropout used for pair activations
blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint
"""
super(EvoformerStack, self).__init__()
self.blocks_per_ckpt = blocks_per_ckpt
self._is_extra_msa_stack = _is_extra_msa_stack
self.blocks = nn.ModuleList()
for _ in range(no_blocks):
block = EvoformerBlock(
c_m=c_m,
c_z=c_z,
c_hidden_msa_att=c_hidden_msa_att,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
chunk_size=chunk_size,
inf=inf,
eps=eps,
_is_extra_msa_stack=_is_extra_msa_stack,
)
self.blocks.append(block)
if(not self._is_extra_msa_stack):
self.linear = Linear(c_m, c_s)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding
"""
m, z = checkpoint_blocks(
blocks=[
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
_mask_trans=_mask_trans,
) for b in self.blocks
],
args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt,
)
s = None
if(not self._is_extra_msa_stack):
seq_dim = -3
index = torch.tensor([0], device=m.device)
s = self.linear(torch.index_select(m, dim=seq_dim, index=index))
s = s.squeeze(seq_dim)
return m, z, s
class ExtraMSAStack(nn.Module):
"""
Implements Algorithm 18.
"""
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
no_blocks: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
blocks_per_ckpt: int,
chunk_size: int,
inf: float,
eps: float,
**kwargs,
):
super(ExtraMSAStack, self).__init__()
c_s = None
self.stack = EvoformerStack(
c_m=c_m,
c_z=c_z,
c_hidden_msa_att=c_hidden_msa_att,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
c_s=c_s,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
no_blocks=no_blocks,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
blocks_per_ckpt=blocks_per_ckpt,
chunk_size=chunk_size,
inf=inf,
eps=eps,
_is_extra_msa_stack=True,
)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True
) -> torch.Tensor:
"""
Args:
m:
[*, N_extra, N_res, C_m] extra MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
Optional [*, N_extra, N_res] MSA mask
pair_mask:
Optional [*, N_res, N_res] pair mask
Returns:
[*, N_res, N_res, C_z] pair update
"""
_, z, _ = self.stack(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
_mask_trans=_mask_trans
)
return z
if __name__ == "__main__":
batch_size = 2
s_t = 3
n_res = 100
c_m = 128
c_z = 64
c_hidden_att = 32
c_hidden_opm = 31
c_hidden_mul = 30
c_s = 29
no_heads_msa = 4
no_heads_pair = 8
no_blocks = 2
transition_n = 5
msa_dropout = 0.15
pair_dropout = 0.25
es = EvoformerStack(
c_m,
c_z,
c_hidden_att,
c_hidden_opm,
c_hidden_mul,
c_s,
no_heads_msa,
no_heads_pair,
no_blocks,
transition_n,
msa_dropout,
pair_dropout,
)
m = torch.rand((batch_size, s_t, n_res, c_m))
z = torch.rand((batch_size, n_res, n_res, c_z))
shape_m_before = m.shape
shape_z_before = z.shape
m, z, s = es(m, z)
assert(m.shape == shape_m_before)
assert(z.shape == shape_z_before)
assert(s.shape == (batch_size, n_res, c_s))
batch_size = 2
s = 5
n_res = 100
c_m = 256
c = 32
c_z = 128
opm = OuterProductMean(c_m, c_z, c)
m = torch.rand((batch_size, s, n_res, c_m))
m = opm(m)
assert(m.shape == (batch_size, n_res, n_res, c_z))
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from alphafold.model.primitives import Linear
from alphafold.utils.loss import compute_plddt
class AuxiliaryHeads(nn.Module):
def __init__(self, config):
super(AuxiliaryHeads, self).__init__()
self.plddt = PerResidueLDDTCaPredictor(
**config["lddt"],
)
self.distogram = DistogramHead(
**config["distogram"],
)
self.masked_msa = MaskedMSAHead(
**config["masked_msa"],
)
self.experimentally_resolved = ExperimentallyResolvedHead(
**config["experimentally_resolved"],
)
if(config.tm_score.enabled):
self.tm_score = TMScoreHead(
**config["tm_score"],
)
self.config = config
def forward(self, outputs):
aux_out = {}
lddt_logits = self.plddt(outputs["single"])
aux_out["lddt_logits"] = lddt_logits
# Required for relaxation later on
aux_out["plddt"] = compute_plddt(lddt_logits)
distogram_logits = self.distogram(outputs["pair"])
aux_out["distogram_logits"] = distogram_logits
masked_msa_logits = self.masked_msa(outputs["msa"])
aux_out["masked_msa_logits"] = masked_msa_logits
experimentally_resolved_logits = self.experimentally_resolved(
outputs["single"]
)
aux_out["experimentally_resolved_logits"] = (
experimentally_resolved_logits
)
if(self.config.tm_score.enabled):
tm_score_logits = self.tm_score(outputs["pair"])
aux_out["tm_score_logits"] = tm_score_logits
return aux_out
class PerResidueLDDTCaPredictor(nn.Module):
def __init__(self, no_bins, c_in, c_hidden):
super(PerResidueLDDTCaPredictor, self).__init__()
self.no_bins = no_bins
self.c_in = c_in
self.c_hidden = c_hidden
self.layer_norm = nn.LayerNorm(self.c_in)
self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu")
self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu")
self.linear_3 = Linear(self.c_hidden, self.no_bins, init="final")
self.relu = nn.ReLU()
def forward(self, s):
s = self.layer_norm(s)
s = self.linear_1(s)
s = self.relu(s)
s = self.linear_2(s)
s = self.relu(s)
s = self.linear_3(s)
return s
class DistogramHead(nn.Module):
"""
Computes a distogram probability distribution.
For use in computation of distogram loss, subsection 1.9.8
"""
def __init__(self, c_z, no_bins, **kwargs):
"""
Args:
c_z:
Input channel dimension
no_bins:
Number of distogram bins
"""
super(DistogramHead, self).__init__()
self.c_z = c_z
self.no_bins = no_bins
self.linear = Linear(self.c_z, self.no_bins, init="final")
def forward(self,
z # [*, N, N, C_z]
):
"""
Args:
z:
[*, N_res, N_res, C_z] pair embedding
Returns:
[*, N, N, no_bins] distogram probability distribution
"""
# [*, N, N, no_bins]
logits = self.linear(z)
logits = logits + logits.transpose(-2, -3)
return logits
class TMScoreHead(nn.Module):
"""
For use in computation of TM-score, subsection 1.9.7
"""
def __init__(self, c_z, no_bins, **kwargs):
"""
Args:
c_z:
Input channel dimension
no_bins:
Number of bins
"""
super(TMScoreHead, self).__init__()
self.c_z = c_z
self.no_bins = no_bins
self.linear = Linear(self.c_z, self.no_bins, init="final")
def forward(self, z):
"""
Args:
z:
[*, N_res, N_res, C_z] pairwise embedding
Returns:
[*, N_res, N_res, no_bins] prediction
"""
# [*, N, N, no_bins]
logits = self.linear(z)
return logits
class MaskedMSAHead(nn.Module):
"""
For use in computation of masked MSA loss, subsection 1.9.9
"""
def __init__(self, c_m, c_out, **kwargs):
"""
Args:
c_m:
MSA channel dimension
c_out:
Output channel dimension
"""
super(MaskedMSAHead, self).__init__()
self.c_m = c_m
self.c_out = c_out
self.linear = Linear(self.c_m, self.c_out, init="final")
def forward(self, m):
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
Returns:
[*, N_seq, N_res, C_out] reconstruction
"""
# [*, N_seq, N_res, C_out]
logits = self.linear(m)
return logits
class ExperimentallyResolvedHead(nn.Module):
"""
For use in computation of "experimentally resolved" loss, subsection
1.9.10
"""
def __init__(self, c_s, c_out, **kwargs):
"""
Args:
c_s:
Input channel dimension
c_out:
Number of distogram bins
"""
super(ExperimentallyResolvedHead, self).__init__()
self.c_s = c_s
self.c_out = c_out
self.linear = Linear(self.c_s, self.c_out, init="final")
def forward(self, s):
"""
Args:
s:
[*, N_res, C_s] single embedding
Returns:
[*, N, C_out] logits
"""
# [*, N, C_out]
logits = self.linear(s)
return logits
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from alphafold.utils.feats import (
pseudo_beta_fn,
atom37_to_torsion_angles,
build_extra_msa_feat,
build_template_angle_feat,
build_template_pair_feat,
atom14_to_atom37,
)
from alphafold.model.embedders import (
InputEmbedder,
RecyclingEmbedder,
TemplateAngleEmbedder,
TemplatePairEmbedder,
ExtraMSAEmbedder,
)
from alphafold.model.evoformer import EvoformerStack, ExtraMSAStack
from alphafold.model.heads import AuxiliaryHeads
import alphafold.np.residue_constants as residue_constants
from alphafold.model.structure_module import StructureModule
from alphafold.model.template import (
TemplatePairStack,
TemplatePointwiseAttention,
)
from alphafold.utils.loss import (
compute_plddt,
)
from alphafold.utils.tensor_utils import (
tensor_tree_map,
)
class AlphaFold(nn.Module):
"""
Alphafold 2.
Implements Algorithm 2 (but with training).
"""
def __init__(self, config):
"""
Args:
config:
A dict-like config object (like the one in config.py)
"""
super(AlphaFold, self).__init__()
template_config = config.template
extra_msa_config = config.extra_msa
# Main trunk + structure module
self.input_embedder = InputEmbedder(
**config["input_embedder"],
)
self.recycling_embedder = RecyclingEmbedder(
**config["recycling_embedder"],
)
self.template_angle_embedder = TemplateAngleEmbedder(
**template_config["template_angle_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**template_config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**template_config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**template_config["template_pointwise_attention"],
)
self.extra_msa_embedder = ExtraMSAEmbedder(
**extra_msa_config["extra_msa_embedder"],
)
self.extra_msa_stack = ExtraMSAStack(
**extra_msa_config["extra_msa_stack"],
)
self.evoformer = EvoformerStack(
**config["evoformer_stack"],
)
self.structure_module = StructureModule(
**config["structure_module"],
)
self.aux_heads = AuxiliaryHeads(
config["heads"],
)
self.config = config
def embed_templates(self, batch, z, pair_mask):
# Build template angle feats
angle_feats = atom37_to_torsion_angles(
batch["template_aatype"],
batch["template_all_atom_positions"],
batch["template_all_atom_masks"],
eps=1e-8
)
# Stow this away for later
batch["torsion_angles_mask"] = angle_feats["torsion_angles_mask"]
template_angle_feat = build_template_angle_feat(
angle_feats,
batch["template_aatype"],
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
# [*, S_t, N, N, C_t]
t = build_template_pair_feat(
batch,
eps=self.config.template.eps,
**self.config.template.distogram
)
t = self.template_pair_embedder(t)
t = self.template_pair_stack(
t,
pair_mask.unsqueeze(-3),
_mask_trans=self.config._mask_trans
)
# [*, N, N, C_z]
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"]
)
t *= torch.sum(batch["template_mask"]) > 0
return a, t
def forward(self, batch):
"""
Args:
batch:
Dictionary of arguments outlined in Algorithm 2. Keys must
include the official names of the features in the
supplement subsection 1.2.9.
The final dimension of each input must have length equal to
the number of recycling iterations.
Features (without the recycling dimension):
"aatype" ([*, N_res]):
Contrary to the supplement, this tensor of residue
indices is not one-hot.
"target_feat" ([*, N_res, C_tf])
One-hot encoding of the target sequence. C_tf is
config.model.input_embedder.tf_dim.
"residue_index" ([*, N_res])
Tensor whose final dimension consists of
consecutive indices from 0 to N_res.
"msa_feat" ([*, N_seq, N_res, C_msa])
MSA features, constructed as in the supplement.
C_msa is config.model.input_embedder.msa_dim.
"seq_mask" ([*, N_res])
1-D sequence mask
"msa_mask" ([*, N_seq, N_res])
MSA mask
"pair_mask" ([*, N_res, N_res])
2-D pair mask
"extra_msa_mask" ([*, N_extra, N_res])
Extra MSA mask
"template_mask" ([*, N_templ])
Template mask (on the level of templates, not
residues)
"template_aatype" ([*, N_templ, N_res])
Tensor of template residue indices (indices greater
than 19 are clamped to 20 (Unknown))
"template_all_atom_pos" ([*, N_templ, N_res, 37, 3])
Template atom coordinates in atom37 format
"template_all_atom_mask" ([*, N_templ, N_res, 37])
Template atom coordinate mask
"template_pseudo_beta" ([*, N_templ, N_res, 3])
Positions of template carbon "pseudo-beta" atoms
(i.e. C_beta for all residues but glycine, for
for which C_alpha is used instead)
"template_pseudo_beta_mask" ([*, N_templ, N_res])
Pseudo-beta mask
"""
# Recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None
# Primary output dictionary
outputs = {}
# Main recycling loop
for cycle_no in range(self.config.no_cycles):
# Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch)
# Grab some data about the input
batch_dims = feats["target_feat"].shape[:-2]
n = feats["target_feat"].shape[-2]
n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device
# Prep some features
seq_mask = feats["seq_mask"]
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"]
# Initialize the MSA and pair representations
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(
feats["target_feat"],
feats["residue_index"],
feats["msa_feat"],
)
# Inject information from previous recycling iterations
if(self.config.no_cycles > 1):
# Initialize the recycling embeddings, if needs be
if(None in [m_1_prev, z_prev, x_prev]):
# [*, N, C_m]
m_1_prev = torch.zeros(
(*batch_dims, n, self.config.c_m),
requires_grad=False,
device=device,
)
# [*, N, N, C_z]
z_prev = torch.zeros(
(*batch_dims, n, n, self.config.c_z),
requires_grad=False,
device=device,
)
# [*, N, 3]
x_prev = torch.zeros(
(*batch_dims, n, residue_constants.atom_type_num, 3),
requires_grad=False,
device=device,
)
x_prev = pseudo_beta_fn(
feats["aatype"],
x_prev,
None # TODO: figure this part out
)
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev,
z_prev,
x_prev,
)
# [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb
# [*, N, N, C_z]
z += z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
if(self.config.template.enabled):
a, t = self.embed_templates(feats, z, pair_mask)
# [*, N, N, C_z]
z += t
if(self.config.template.embed_angles):
# [*, S = S_c + S_t, N, C_m]
m = torch.cat([m, a], dim=-3)
# [*, S, N]
torsion_angles_mask = feats["torsion_angles_mask"]
msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]], axis=-2
)
# Embed extra MSA features + merge with pairwise embeddings
if(self.config.extra_msa.enabled):
# [*, S_e, N, C_e]
a = self.extra_msa_embedder(build_extra_msa_feat(feats))
# [*, N, N, C_z]
z = self.extra_msa_stack(
a,
z,
msa_mask=feats["extra_msa_mask"],
pair_mask=pair_mask,
_mask_trans=self.config._mask_trans,
)
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
_mask_trans=self.config._mask_trans
)
outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z
outputs["single"] = s
# Predict 3D structure
outputs["sm"] = self.structure_module(
s, z, feats["aatype"], mask=feats["seq_mask"],
)
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats
)
outputs["final_atom_mask"] = feats["atom37_atom_exists"]
# Save embeddings for use during the next recycling iteration
# [*, N, C_m]
m_1_prev = m[..., 0, :, :]
# [* N, N, C_z]
z_prev = z
# [*, N, 3]
x_prev = outputs["final_atom_positions"]
outputs.update(self.aux_heads(outputs))
return outputs
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
from alphafold.model.primitives import Linear, Attention
from alphafold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
flatten_final_dims,
)
class MSAAttention(nn.Module):
def __init__(self,
c_in,
c_hidden,
no_heads,
pair_bias=False,
c_z=None,
chunk_size=4,
inf=1e9,
):
"""
Args:
c_in:
Input channel dimension
c_hidden:
Per-head hidden channel dimension
no_heads:
Number of attention heads
pair_bias:
Whether to use pair embedding bias
c_z:
Pair embedding channel dimension. Ignored unless pair_bias
is true
inf:
A large number to be used in computing the attention mask
"""
super(MSAAttention, self).__init__()
self.c_in = c_in
self.c_hidden = c_hidden
self.no_heads = no_heads
self.pair_bias = pair_bias
self.c_z = c_z
self.chunk_size = chunk_size
self.inf = inf
self.layer_norm_m = nn.LayerNorm(self.c_in)
if(self.pair_bias):
self.layer_norm_z = nn.LayerNorm(self.c_z)
self.linear_z = Linear(
self.c_z, self.no_heads, bias=False, init="normal"
)
self.mha = Attention(
self.c_in, self.c_in, self.c_in,
self.c_hidden,
self.no_heads
)
def forward(self, m, z=None, mask=None):
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding. Required only if
pair_bias is True
mask:
[*, N_seq, N_res] MSA mask
"""
# [*, N_seq, N_res, C_m]
m = self.layer_norm_m(m)
n_seq, n_res = m.shape[-3:-1]
if(mask is None):
# [*, N_seq, N_res]
mask = torch.ones(
(*m.shape[:-3], n_seq, n_res),
device=m.device,
requires_grad=False
)
# [*, N_seq, 1, 1, N_res]
bias = (self.inf * (mask - 1))[..., :, None, None, :]
# [*, N_seq, no_heads, N_res, N_res]
bias = bias.expand(
(*((-1,) * len(bias.shape[:-4])), -1, self.no_heads, n_res, -1)
)
if(self.pair_bias):
# [*, N_res, N_res, C_z]
z = self.layer_norm_z(z)
# [*, N_res, N_res, no_heads]
z = self.linear_z(z)
# [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, 2, 0, 1).unsqueeze(-4)
# [*, N_seq, no_heads, N_res, N_res]
bias = bias + z
mha_inputs = {
"q_x": m,
"k_x": m,
"v_x": m,
"biases": [bias]
}
if(not self.training and self.chunk_size is not None):
m = chunk_layer(
self.mha,
mha_inputs,
chunk_size=self.chunk_size,
no_batch_dims=len(m.shape[:-2])
)
else:
m = self.mha(**mha_inputs)
return m
class MSARowAttentionWithPairBias(MSAAttention):
"""
Implements Algorithm 7.
"""
def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9):
"""
Args:
c_m:
Input channel dimension
c_z:
Pair embedding channel dimension
c_hidden:
Per-head hidden channel dimension
no_heads:
Number of attention heads
inf:
Large number used to construct attention masks
"""
super(MSARowAttentionWithPairBias, self).__init__(
c_m,
c_hidden,
no_heads,
pair_bias=True,
c_z=c_z,
inf=inf,
)
class MSAColumnAttention(MSAAttention):
"""
Implements Algorithm 8.
"""
def __init__(self, c_m, c_hidden, no_heads, chunk_size=4, inf=1e9):
"""
Args:
c_m:
MSA channel dimension
c_hidden:
Per-head hidden channel dimension
no_heads:
Number of attention heads
inf:
Large number used to construct attention masks
"""
super(MSAColumnAttention, self).__init__(
c_in=c_m,
c_hidden=c_hidden,
no_heads=no_heads,
pair_bias=False,
c_z=None,
chunk_size=chunk_size,
inf=inf,
)
def forward(self, m, mask=None):
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
mask:
[*, N_seq, N_res] MSA mask
"""
# [*, N_res, N_seq, C_in]
m = m.transpose(-2, -3)
if(mask is not None):
mask = mask.transpose(-1, -2)
m = super().forward(m, mask=mask)
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)
if(mask is not None):
mask = mask.transpose(-1, -2)
return m
class MSAColumnGlobalAttention(nn.Module):
def __init__(self,
c_in,
c_hidden,
no_heads,
chunk_size=4,
inf=1e9,
eps=1e-10
):
super(MSAColumnGlobalAttention, self).__init__()
self.c_in = c_in
self.c_hidden = c_hidden
self.no_heads = no_heads
self.chunk_size = chunk_size
self.inf = inf
self.eps = eps
self.layer_norm_m = nn.LayerNorm(self.c_in)
self.linear_q = Linear(
self.c_in, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
C_hidden = self.c_hidden
self.linear_k = Linear(
self.c_in, C_hidden, bias=False, init="glorot",
)
self.linear_v = Linear(
self.c_in, C_hidden, bias=False, init="glorot",
)
self.linear_g = Linear(self.c_in, self.c_hidden * self.no_heads, init="gating")
self.linear_o = Linear(self.c_hidden * self.no_heads, self.c_in, init="final")
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def global_attention(self, m, mask):
# [*, N_res, C_in]
q = (torch.sum(m * mask.unsqueeze(-1), dim=-2) /
(torch.sum(mask, dim=-1)[..., None] + self.eps))
# [*, N_res, H * C_hidden]
q = self.linear_q(q)
q *= self.c_hidden ** (-0.5)
# [*, N_res, H, C_hidden]
q = q.view(*q.shape[:-1], self.no_heads, -1)
# [*, N_res, N_seq, C_hidden]
k = self.linear_k(m)
v = self.linear_v(m)
# [*, N_res, H, N_seq]
a = torch.matmul(
q,
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias
a = self.softmax(a)
# [*, N_res, H, C_hidden]
o = torch.matmul(
a,
v,
)
# [*, N_res, N_seq, C_hidden]
g = self.sigmoid(self.linear_g(m))
# [*, N_res, N_seq, H, C_hidden]
g = g.view(*g.shape[:-1], self.no_heads, -1)
# [*, N_res, N_seq, H, C_hidden]
o = o.unsqueeze(-3) * g
# [*, N_res, N_seq, H * C_hidden]
o = o.reshape(*o.shape[:-2], -1)
# [*, N_res, N_seq, C_in]
m = self.linear_o(o)
return m
def forward(self, m, mask=None):
n_seq, n_res, c_in = m.shape[-3:]
if(mask is None):
# [*, N_seq, N_res]
mask = m.new_ones(m.shape[:-1], requires_grad=False)
# [*, N_res, N_seq, C_in]
m = m.transpose(-2, -3)
mask = mask.transpose(-1, -2)
# [*, N_res, N_seq, C_in]
m = self.layer_norm_m(m)
mha_input = {
"m": m,
"mask": mask,
}
if(not self.training and self.chunk_size is not None):
m = chunk_layer(
self.global_attention,
mha_input,
chunk_size=self.chunk_size,
no_batch_dims=len(m.shape[:-2])
)
else:
m = self.global_attention(**mha_input)
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)
return m
if __name__ == "__main__":
batch_size = 2
s_t = 3
n = 100
c_in = 128
c = 32
no_heads = 4
msaca = MSAColumnAttention(c_in, c, no_heads)
x = torch.rand((batch_size, s_t, n, c_in))
shape_before = x.shape
x = msaca(x)
shape_after = x.shape
assert(shape_before == shape_after)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import torch
import torch.nn as nn
from alphafold.model.primitives import Linear
from alphafold.utils.tensor_utils import chunk_layer
class OuterProductMean(nn.Module):
"""
Implements Algorithm 10.
"""
def __init__(self, c_m, c_z, c_hidden, chunk_size=4, eps=1e-3):
"""
Args:
c_m:
MSA embedding channel dimension
c_z:
Pair embedding channel dimension
c_hidden:
Hidden channel dimension
"""
super(OuterProductMean, self).__init__()
self.c_z = c_z
self.c_hidden = c_hidden
self.chunk_size = chunk_size
self.eps = eps
self.layer_norm = nn.LayerNorm(c_m)
self.linear_1 = Linear(c_m, c_hidden)
self.linear_2 = Linear(c_m, c_hidden)
self.linear_out = Linear(c_hidden**2, c_z, init="final")
def _opm(self, a, b):
# [*, N_res, N_res, C, C]
outer = torch.einsum("...bac,...dae->...bdce", a, b)
# [*, N_res, N_res, C * C]
outer = outer.reshape(*outer.shape[:-2], -1)
# [*, N_res, N_res, C_z]
outer = self.linear_out(outer)
return outer
def forward(self, m, mask=None):
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
mask:
[*, N_seq, N_res] MSA mask
Returns:
[*, N_res, N_res, C_z] pair embedding update
"""
if(mask is None):
mask = m.new_ones(m.shape[:-1], requires_grad=False)
# [*, N_seq, N_res, C_m]
m = self.layer_norm(m)
# [*, N_seq, N_res, C]
mask = mask.unsqueeze(-1)
a = self.linear_1(m) * mask
b = self.linear_2(m) * mask
a = a.transpose(-2, -3)
b = b.transpose(-2, -3)
if(not self.training and self.chunk_size is not None):
# Since the "batch dim" in this case is not a true batch dimension
# (in that the shape of the output depends on it), we need to
# iterate over it ourselves
a_reshape = a.reshape(-1, *a.shape[-3:])
b_reshape = b.reshape(-1, *b.shape[-3:])
out = []
for a_prime, b_prime in zip(a_reshape, b_reshape):
outer = chunk_layer(
partial(self._opm, b=b_prime),
{"a": a_prime},
chunk_size=self.chunk_size,
no_batch_dims=1,
)
out.append(outer)
outer = torch.stack(out, dim=0)
outer = outer.reshape(*a.shape[:-3], *outer.shape[1:])
else:
outer = self._opm(a, b)
# [*, N_res, N_res, 1]
norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
# [*, N_res, N_res, C_z]
outer /= self.eps + norm
return outer
if __name__ == "__main__":
batch_size = 2
s = 5
n_res = 100
c_m = 256
c = 32
c_z = 128
opm = OuterProductMean(c_m, c_z, c)
m = torch.rand((batch_size, s, n_res, c_m))
m = opm(m)
assert(m.shape == (batch_size, n_res, n_res, c_z))
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from alphafold.model.primitives import Linear
from alphafold.utils.tensor_utils import chunk_layer
class PairTransition(nn.Module):
"""
Implements Algorithm 15.
"""
def __init__(self, c_z, n, chunk_size=4):
"""
Args:
c_z:
Pair transition channel dimension
n:
Factor by which c_z is multiplied to obtain hidden channel
dimension
"""
super(PairTransition, self).__init__()
self.c_z = c_z
self.n = n
self.chunk_size = chunk_size
self.layer_norm = nn.LayerNorm(self.c_z)
self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
self.relu = nn.ReLU()
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
def _transition(self, z, mask):
# [*, N_res, N_res, C_hidden]
z = self.linear_1(z)
z = self.relu(z)
# [*, N_res, N_res, C_z]
z = self.linear_2(z) * mask
return z
def forward(self, z, mask=None):
"""
Args:
z:
[*, N_res, N_res, C_z] pair embedding
Returns:
[*, N_res, N_res, C_z] pair embedding update
"""
# DISCREPANCY: DeepMind forgets to apply the mask in this module.
if(mask is None):
mask = z.new_ones(z.shape[:-1], requires_grad=False)
# [*, N_res, N_res, 1]
mask = mask.unsqueeze(-1)
# [*, N_res, N_res, C_z]
z = self.layer_norm(z)
inp = {"z": z, "mask": mask}
if(not self.training and self.chunk_size is not None):
z = chunk_layer(
self._transition,
inp,
chunk_size=self.chunk_size,
no_batch_dims=len(z.shape[:-2]),
)
else:
z = self._transition(**inp)
return z
if __name__ == "__main__":
n = 4
c_in = 128
pt = PairTransition(n, c_in)
batch_size = 4
n_res = 256
z = torch.rand((batch_size, n_res, n_res, c_in))
shape_before = z.shape
z = pt(z)
shape_after = z.shape
assert(shape_before == shape_after)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Callable
import numpy as np
import torch
import torch.nn as nn
from scipy.stats import truncnorm
from alphafold.utils.tensor_utils import (
permute_final_dims,
flatten_final_dims,
)
def _calculate_fan(shape, fan="fan_in"):
i = shape[0]
o = shape[1]
prod = math.prod(shape[:2])
fan_in = prod * i
fan_out = prod * o
if(fan == "fan_in"):
f = fan_in
elif(fan == "fan_out"):
f = fan_out
elif(fan == "fan_avg"):
f = (fan_in + fan_out) / 2
else:
raise ValueError("Invalid fan option")
return f
def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
shape = weights.shape
f = _calculate_fan(shape, fan)
scale = scale / max(1, f)
a = -2
b = 2
std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
size = math.prod(shape)
samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
samples = np.reshape(samples, shape)
with torch.no_grad():
weights.copy_(torch.tensor(samples, device=weights.device))
def lecun_normal_init_(weights):
trunc_normal_init_(weights, scale=1.0)
def he_normal_init_(weights):
trunc_normal_init_(weights, scale=2.0)
def glorot_uniform_init_(weights):
nn.init.xavier_uniform_(weights, gain=1)
def final_init_(weights):
with torch.no_grad():
weights.fill_(0.)
def gating_init_(weights):
with torch.no_grad():
weights.fill_(0.)
def normal_init_(weights):
torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
def ipa_point_weights_init_(weights):
with torch.no_grad():
softplus_inverse_1 = 0.541324854612918
weights.fill_(softplus_inverse_1)
class Linear(nn.Linear):
"""
A Linear layer with built-in nonstandard initializations. Called just
like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found
in the code.
"""
def __init__(self,
in_dim: int,
out_dim: int,
bias: bool = True,
init: str = "default",
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
):
"""
Args:
in_dim:
The final dimension of inputs to the layer
out_dim:
The final dimension of layer outputs
bias:
Whether to learn an additive bias. True by default
init:
The initializer to use. Choose from:
"default": LeCun fan-in truncated normal initialization
"relu": He initialization w/ truncated normal distribution
"glorot": Fan-average Glorot uniform initialization
"gating": Weights=0, Bias=1
"normal": Normal initialization with std=1/sqrt(fan_in)
"final": Weights=0, Bias=0
Overridden by init_fn if the latter is not None.
init_fn:
A custom initializer taking weight and bias as inputs.
Overrides init if not None.
"""
super(Linear, self).__init__(in_dim, out_dim, bias=bias)
if(bias):
with torch.no_grad():
self.bias.fill_(0)
if(init_fn is not None):
init_fn(self.weight, self.bias)
else:
if(init == "default"):
lecun_normal_init_(self.weight)
elif(init == "relu"):
he_normal_init_(self.weight)
elif(init == "glorot"):
glorot_uniform_init_(self.weight)
elif(init == "gating"):
gating_init_(self.weight)
if(bias):
with torch.no_grad():
self.bias.fill_(1.)
elif(init == "normal"):
normal_init_(self.weight)
elif(init == "final"):
final_init_(self.weight)
else:
raise ValueError("Invalid init string.")
class Attention(nn.Module):
"""
Standard multi-head attention using AlphaFold's default layer
initialization.
"""
def __init__(self,
c_q: int,
c_k: int,
c_v: int,
c_hidden: int,
no_heads: int,
gating: bool = True,
):
"""
Args:
c_q:
Input dimension of query data
c_k:
Input dimension of key data
c_v:
Input dimension of value data
c_hidden:
Per-head hidden dimension
no_heads:
Number of attention heads
gating:
Whether the output should be gated using query data
"""
super(Attention, self).__init__()
self.c_q = c_q
self.c_k = c_k
self.c_v = c_v
self.c_hidden = c_hidden
self.no_heads = no_heads
self.gating = gating
# DISCREPANCY: c_hidden is not the per-head channel dimension, as
# stated in the supplement, but the overall channel dimension
self.linear_q = Linear(
self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_k = Linear(
self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_v = Linear(
self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_o = Linear(
self.c_hidden * self.no_heads, self.c_q, init="final"
)
if(self.gating):
self.linear_g = Linear(self.c_q, self.c_hidden * self.no_heads, init="gating")
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def forward(self,
q_x: torch.Tensor,
k_x: torch.Tensor,
v_x: torch.Tensor,
biases: bool = None,
) -> torch.Tensor:
"""
Args:
q_x:
[*, Q, C_q] query data
k_x:
[*, K, C_k] key data
v_x:
[*, V, C_v] value data
Returns
[*, Q, C_q] attention update
"""
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(k_x)
v = self.linear_v(v_x)
# [*, Q/K, H, C_hidden]
q = q.view(*q.shape[:-1], self.no_heads, -1)
k = k.view(*k.shape[:-1], self.no_heads, -1)
v = v.view(*v.shape[:-1], self.no_heads, -1)
# [*, H, Q, K]
a = torch.matmul(
permute_final_dims(q, 1, 0, 2), # [*, H, Q, C_hidden]
permute_final_dims(k, 1, 2, 0), # [*, H, C_hidden, K]
)
norm = 1 / math.sqrt(self.c_hidden) # [1]
a *= norm
if(biases is not None):
for b in biases:
a += b
a = self.softmax(a)
# [*, H, Q, C_hidden]
o = torch.matmul(
a,
permute_final_dims(v, 1, 0, 2), # [*, H, V, C_hidden]
)
# [*, Q, H, C_hidden]
o = o.transpose(-2, -3)
if(self.gating):
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(*g.shape[:-1], self.no_heads, -1)
o = o * g
# [*, Q, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, Q, C_q]
o = self.linear_o(o)
return o
This diff is collapsed.
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import math
import torch
import torch.nn as nn
from alphafold.model.primitives import Linear, Attention
from alphafold.utils.deepspeed import checkpoint_blocks
from alphafold.model.dropout import (
DropoutRowwise,
DropoutColumnwise,
)
from alphafold.model.pair_transition import PairTransition
from alphafold.model.triangular_attention import (
TriangleAttentionStartingNode,
TriangleAttentionEndingNode,
)
from alphafold.model.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming,
)
from alphafold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
flatten_final_dims,
)
class TemplatePointwiseAttention(nn.Module):
"""
Implements Algorithm 17.
"""
def __init__(self,
c_t,
c_z,
c_hidden,
no_heads,
chunk_size,
**kwargs
):
"""
Args:
c_t:
Template embedding channel dimension
c_z:
Pair embedding channel dimension
c_hidden:
Hidden channel dimension
"""
super(TemplatePointwiseAttention, self).__init__()
self.c_t = c_t
self.c_z = c_z
self.c_hidden = c_hidden
self.no_heads = no_heads
self.chunk_size = chunk_size
self.mha = Attention(
self.c_z, self.c_t, self.c_t,
self.c_hidden, self.no_heads,
gating=False,
)
def forward(self, t, z, template_mask=None):
"""
Args:
t:
[*, N_templ, N_res, N_res, C_t] template embedding
z:
[*, N_res, N_res, C_t] pair embedding
template_mask:
[*, N_templ] template mask
Returns:
[*, N_res, N_res, C_z] pair embedding update
"""
if(template_mask is None):
# NOTE: This is not the "template_mask" from the supplement, but a
# [*, N_templ] mask from the code. I'm pretty sure it's always just 1,
# but not sure enough to remove it. It's nice to have, I guess.
template_mask = torch.ones(t.shape[:-3], device=t.device)
bias = (1e9 * (template_mask[..., None, None, None, None, :] - 1))
# [*, N_res, N_res, 1, C_z]
z = z.unsqueeze(-2)
# [*, N_res, N_res, N_temp, C_t]
t = permute_final_dims(t, 1, 2, 0, 3)
# [*, N_res, N_res, 1, C_z]
mha_inputs = {
"q_x": z,
"k_x": t,
"v_x": t,
"biases": [bias],
}
if(not self.training and self.chunk_size is not None):
z = chunk_layer(
self.mha,
mha_inputs,
chunk_size=self.chunk_size,
no_batch_dims=len(z.shape[:-2])
)
else:
z = self.mha(**mha_inputs)
# [*, N_res, N_res, C_z]
z = z.squeeze(-2)
return z
class TemplatePairStackBlock(nn.Module):
def __init__(self,
c_t,
c_hidden_tri_att,
c_hidden_tri_mul,
no_heads,
pair_transition_n,
dropout_rate,
chunk_size,
):
super(TemplatePairStackBlock, self).__init__()
self.c_t = c_t
self.c_hidden_tri_att = c_hidden_tri_att
self.c_hidden_tri_mul = c_hidden_tri_mul
self.no_heads = no_heads
self.pair_transition_n = pair_transition_n
self.dropout_rate = dropout_rate
self.chunk_size = chunk_size
self.dropout_row = DropoutRowwise(self.dropout_rate)
self.dropout_col = DropoutColumnwise(self.dropout_rate)
self.tri_att_start = TriangleAttentionStartingNode(
self.c_t,
self.c_hidden_tri_att,
self.no_heads,
chunk_size=chunk_size,
)
self.tri_att_end = TriangleAttentionEndingNode(
self.c_t,
self.c_hidden_tri_att,
self.no_heads,
chunk_size=chunk_size,
)
self.tri_mul_out = TriangleMultiplicationOutgoing(
self.c_t,
self.c_hidden_tri_mul,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
self.c_t,
self.c_hidden_tri_mul,
)
self.pair_transition = PairTransition(
self.c_t,
self.pair_transition_n,
chunk_size=chunk_size,
)
def forward(self, z, mask, _mask_trans=True):
z = z + self.dropout_row(self.tri_att_start(z, mask=mask))
z = z + self.dropout_col(self.tri_att_end(z, mask=mask))
z = z + self.dropout_row(self.tri_mul_out(z, mask=mask))
z = z + self.dropout_row(self.tri_mul_in(z, mask=mask))
z = z + self.pair_transition(z, mask=mask if _mask_trans else None)
return z
class TemplatePairStack(nn.Module):
"""
Implements Algorithm 16.
"""
def __init__(self,
c_t,
c_hidden_tri_att,
c_hidden_tri_mul,
no_blocks,
no_heads,
pair_transition_n,
dropout_rate,
blocks_per_ckpt,
chunk_size,
**kwargs,
):
"""
Args:
c_t:
Template embedding channel dimension
c_hidden_tri_att:
Per-head hidden dimension for triangular attention
c_hidden_tri_att:
Hidden dimension for triangular multiplication
no_blocks:
Number of blocks in the stack
pair_transition_n:
Scale of pair transition (Alg. 15) hidden dimension
dropout_rate:
Dropout rate used throughout the stack
blocks_per_ckpt:
Number of blocks per activation checkpoint. None disables
activation checkpointing
chunk_size:
Size of subbatches. A higher value increases throughput at
the cost of memory
"""
super(TemplatePairStack, self).__init__()
self.blocks_per_ckpt = blocks_per_ckpt
self.blocks = nn.ModuleList()
for i in range(no_blocks):
block = TemplatePairStackBlock(
c_t=c_t,
c_hidden_tri_att=c_hidden_tri_att,
c_hidden_tri_mul=c_hidden_tri_mul,
no_heads=no_heads,
pair_transition_n=pair_transition_n,
dropout_rate=dropout_rate,
chunk_size=chunk_size,
)
self.blocks.append(block)
self.layer_norm = nn.LayerNorm(c_t)
def forward(self,
t: torch.tensor,
mask: torch.tensor,
_mask_trans: bool = True,
):
"""
Args:
t:
[*, N_res, N_res, C_t] template embedding
mask:
[*, N_res, N_res] mask
Returns:
[*, N_res, N_res, C_t] template embedding update
"""
t, = checkpoint_blocks(
blocks=[
partial(
b,
mask=mask,
_mask_trans=_mask_trans,
) for b in self.blocks
],
args=(t),
blocks_per_ckpt=self.blocks_per_ckpt,
)
t = self.layer_norm(t)
return t
if __name__ == "__main__":
template_angle_dim = 51
c_m = 256
batch_size = 4
n_templ = 4
n_res = 256
tae = TemplateAngleEmbedder(
template_angle_dim,
c_m,
)
x = torch.rand((batch_size, n_templ, n_res, template_angle_dim))
x.shape_before = x.shape
x = tae(x)
x.shape_after = x.shape
assert(shape_before == shape_after)
batch_size = 2
s_t = 4
c_t = 64
c_z = 128
c = 32
no_heads = 3
n = 100
tpa = TemplatePointwiseAttention(c_t, c_z, c)
t = torch.rand((batch_size, s_t, n, n, c_t))
z = torch.rand((batch_size, n, n, c_z))
z_update = tpa(t, z)
assert(z_update.shape == z.shape)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partialmethod
import math
import torch
import torch.nn as nn
from alphafold.model.primitives import Linear, Attention
from alphafold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
flatten_final_dims,
)
class TriangleAttention(nn.Module):
def __init__(self,
c_in,
c_hidden,
no_heads,
starting,
chunk_size=4,
inf=1e9
):
"""
Args:
c_in:
Input channel dimension
c_hidden:
Overall hidden channel dimension (not per-head)
no_heads:
Number of attention heads
"""
super(TriangleAttention, self).__init__()
self.c_in = c_in
self.c_hidden = c_hidden
self.no_heads = no_heads
self.starting = starting
self.chunk_size = chunk_size
self.inf = inf
self.layer_norm = nn.LayerNorm(self.c_in)
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
self.mha = Attention(
self.c_in, self.c_in, self.c_in,
self.c_hidden,
self.no_heads
)
def forward(self, x, mask=None):
"""
Args:
x:
[*, I, J, C_in] input tensor (e.g. the pair representation)
Returns:
[*, I, J, C_in] output tensor
"""
if(mask is None):
# [*, I, J]
mask = torch.ones(
x.shape[:-1],
device=x.device,
requires_grad=False,
)
# Shape annotations assume self.starting. Else, I and J are flipped
if(not self.starting):
x = x.transpose(-2, -3)
mask = mask.transpose(-1, -2)
# [*, I, J, C_in]
x = self.layer_norm(x)
# [*, I, 1, 1, J]
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
# [*, H, I, J]
triangle_bias = permute_final_dims(self.linear(x), 2, 0, 1)
# [*, 1, H, I, J]
triangle_bias = triangle_bias.unsqueeze(-4)
# Broadcasting and chunking doesn't really work yet (TODO)
# [*, I, H, I, J]
i = x.shape[-3]
triangle_bias = triangle_bias.expand(
(*((-1,) * len(triangle_bias.shape[:-4])), i, -1, -1, -1)
)
#print(x.shape)
#print(mask_bias.shape)
#print(triangle_bias.shape)
mha_inputs = {
"q_x": x,
"k_x": x,
"v_x": x,
"biases": [mask_bias, triangle_bias],
}
if(not self.training and self.chunk_size is not None):
x = chunk_layer(
self.mha,
mha_inputs,
chunk_size=self.chunk_size,
no_batch_dims=len(x.shape[:-2])
)
else:
x = self.mha(**mha_inputs)
if(not self.starting):
x = x.transpose(-2, -3)
return x
class TriangleAttentionStartingNode(TriangleAttention):
"""
Implements Algorithm 13.
"""
__init__ = partialmethod(TriangleAttention.__init__, starting=True)
class TriangleAttentionEndingNode(TriangleAttention):
"""
Implements Algorithm 14.
"""
__init__ = partialmethod(TriangleAttention.__init__, starting=False)
if __name__ == "__main__":
c_in = 256
c = 32
no_heads = 4
starting = True
tan = TriangleAttention(
c_in,
c,
no_heads,
starting
)
batch_size = 16
n_res = 256
x = torch.rand((batch_size, n_res, n_res, c_in))
shape_before = x.shape
x = tan(x)
shape_after = x.shape
assert(shape_before == shape_after)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partialmethod
import torch
import torch.nn as nn
from alphafold.model.primitives import Linear
from alphafold.utils.tensor_utils import permute_final_dims
class TriangleMultiplicativeUpdate(nn.Module):
"""
Implements Algorithms 11 and 12.
"""
def __init__(self, c_z, c_hidden, _outgoing=True):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super(TriangleMultiplicativeUpdate, self).__init__()
self.c_z = c_z
self.c_hidden = c_hidden
self._outgoing = _outgoing
self.linear_a_p = Linear(self.c_z, self.c_hidden)
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_b_p = Linear(self.c_z, self.c_hidden)
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_g = Linear(self.c_z, self.c_z, init="gating")
self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
self.layer_norm_in = nn.LayerNorm(self.c_z)
self.layer_norm_out = nn.LayerNorm(self.c_hidden)
self.sigmoid = nn.Sigmoid()
cp = self._outgoing_matmul if self._outgoing else self._incoming_matmul
self.combine_projections = cp
def _outgoing_matmul(self,
a: torch.Tensor, # [*, N_i, N_k, C]
b: torch.Tensor, # [*, N_j, N_k, C]
):
# [*, C, N_i, N_j]
p = torch.matmul(
permute_final_dims(a, 2, 0, 1),
permute_final_dims(b, 2, 1, 0),
)
# [*, N_i, N_j, C]
return permute_final_dims(p, 1, 2, 0)
def _incoming_matmul(self,
a: torch.Tensor, # [*, N_k, N_i, C]
b: torch.Tensor, # [*, N_k, N_j, C]
):
# [*, C, N_i, N_j]
p = torch.matmul(
permute_final_dims(a, 2, 1, 0),
permute_final_dims(b, 2, 0, 1),
)
# [*, N_i, N_j, C]
return permute_final_dims(p, 1, 2, 0)
def forward(self, z, mask=None):
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if(mask is None):
mask = z.new_ones(z.shape[:-1], requires_grad=False)
mask = mask.unsqueeze(-1)
z = self.layer_norm_in(z)
a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z))
a = a * mask
b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z))
b = b * mask
x = self.combine_projections(a, b)
x = self.layer_norm_out(x)
x = self.linear_z(x)
g = self.sigmoid(self.linear_g(z))
z = x * g
return z
class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
"""
Implements Algorithm 11.
"""
__init__ = partialmethod(
TriangleMultiplicativeUpdate.__init__, _outgoing=True,
)
class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
"""
Implements Algorithm 12.
"""
__init__ = partialmethod(
TriangleMultiplicativeUpdate.__init__, _outgoing=False,
)
if __name__ == "__main__":
c_in = 256 # doubled to make shape changes more apparent
c = 128
outgoing = True
tm = TriangleMultiplication(
c_in,
c,
outgoing,
)
n_res = 300
batch_size = 16
x = torch.rand((batch_size, n_res, n_res, c_in))
shape_before = x.shape
x = tm(x)
shape_after = x.shape
assert(shape_before == shape_after)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Protein data type."""
import dataclasses
import io
from typing import Any, Mapping, Optional
from alphafold.np import residue_constants
from Bio.PDB import PDBParser
import numpy as np
FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any] # Is a nested dict.
@dataclasses.dataclass(frozen=True)
class Protein:
"""Protein structure representation."""
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
atom_positions: np.ndarray # [num_res, num_atom_type, 3]
# Amino-acid type for each residue represented as an integer between 0 and
# 20, where 20 is 'X'.
aatype: np.ndarray # [num_res]
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
# is present and 0.0 if not. This should be used for loss masking.
atom_mask: np.ndarray # [num_res, num_atom_type]
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors: np.ndarray # [num_res, num_atom_type]
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a Protein object.
WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored.
Args:
pdb_str: The contents of the pdb file
chain_id: If None, then the pdb file must contain a single chain (which
will be parsed). If chain_id is specified (e.g. A), then only that chain
is parsed.
Returns:
A new `Protein` parsed from the pdb contents.
"""
pdb_fh = io.StringIO(pdb_str)
parser = PDBParser(QUIET=True)
structure = parser.get_structure('none', pdb_fh)
models = list(structure.get_models())
if len(models) != 1:
raise ValueError(
f'Only single model PDBs are supported. Found {len(models)} models.')
model = models[0]
if chain_id is not None:
chain = model[chain_id]
else:
chains = list(model.get_chains())
if len(chains) != 1:
raise ValueError(
'Only single chain PDBs are supported when chain_id not specified. '
f'Found {len(chains)} chains.')
else:
chain = chains[0]
atom_positions = []
aatype = []
atom_mask = []
residue_index = []
b_factors = []
for res in chain:
if res.id[2] != ' ':
raise ValueError(
f'PDB contains an insertion code at chain {chain.id} and residue '
f'index {res.id[1]}. These are not supported.')
res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
b_factors.append(res_b_factors)
return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
b_factors=np.array(b_factors))
def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string.
Args:
prot: The protein to convert to PDB.
Returns:
PDB string.
"""
restypes = residue_constants.restypes + ['X']
res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK')
atom_types = residue_constants.atom_types
pdb_lines = []
atom_mask = prot.atom_mask
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
b_factors = prot.b_factors
if np.any(aatype > residue_constants.restype_num):
raise ValueError('Invalid aatypes.')
pdb_lines.append('MODEL 1')
atom_index = 1
chain_id = 'A'
# Add all atom sites.
for i in range(aatype.shape[0]):
res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(
atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
if mask < 0.5:
continue
record_type = 'ATOM'
name = atom_name if len(atom_name) == 4 else f' {atom_name}'
alt_loc = ''
insertion_code = ''
occupancy = 1.00
element = atom_name[0] # Protein supports only C, N, O, S, this works.
charge = ''
# PDB is a columnar format, every space matters here!
atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
f'{res_name_3:>3} {chain_id:>1}'
f'{residue_index[i]:>4}{insertion_code:>1} '
f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
f'{occupancy:>6.2f}{b_factor:>6.2f} '
f'{element:>2}{charge:>2}')
pdb_lines.append(atom_line)
atom_index += 1
# Close the chain.
chain_end = 'TER'
chain_termination_line = (
f'{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} '
f'{chain_id:>1}{residue_index[-1]:>4}')
pdb_lines.append(chain_termination_line)
pdb_lines.append('ENDMDL')
pdb_lines.append('END')
pdb_lines.append('')
return '\n'.join(pdb_lines)
def ideal_atom_mask(prot: Protein) -> np.ndarray:
"""Computes an ideal atom mask.
`Protein.atom_mask` typically is defined according to the atoms that are
reported in the PDB. This function computes a mask according to heavy atoms
that should be present in the given sequence of amino acids.
Args:
prot: `Protein` whose fields are `numpy.ndarray` objects.
Returns:
An ideal atom mask.
"""
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
def from_prediction(features: FeatureDict, result: ModelOutput,
b_factors: Optional[np.ndarray] = None) -> Protein:
"""Assembles a protein from a prediction.
Args:
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
Returns:
A protein instance.
"""
if b_factors is None:
b_factors = np.zeros_like(result['final_atom_mask'])
return Protein(
aatype=features['aatype'],
atom_positions=result['final_atom_positions'],
atom_mask=result['final_atom_mask'],
residue_index=features['residue_index'] + 1,
b_factors=b_factors
)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment