Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
alphafold2_jax
Commits
2f0d89e7
Commit
2f0d89e7
authored
Aug 24, 2023
by
zhuwenwen
Browse files
remove duplicated code
parent
a1597f3f
Changes
103
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
3076 deletions
+0
-3076
alphafold/model/r3.py
alphafold/model/r3.py
+0
-320
alphafold/model/tf/__init__.py
alphafold/model/tf/__init__.py
+0
-14
alphafold/model/tf/data_transforms.py
alphafold/model/tf/data_transforms.py
+0
-625
alphafold/model/tf/input_pipeline.py
alphafold/model/tf/input_pipeline.py
+0
-166
alphafold/model/tf/protein_features.py
alphafold/model/tf/protein_features.py
+0
-129
alphafold/model/tf/protein_features_test.py
alphafold/model/tf/protein_features_test.py
+0
-51
alphafold/model/tf/proteins_dataset.py
alphafold/model/tf/proteins_dataset.py
+0
-166
alphafold/model/tf/shape_helpers.py
alphafold/model/tf/shape_helpers.py
+0
-47
alphafold/model/tf/shape_helpers_test.py
alphafold/model/tf/shape_helpers_test.py
+0
-39
alphafold/model/tf/shape_placeholders.py
alphafold/model/tf/shape_placeholders.py
+0
-20
alphafold/model/tf/utils.py
alphafold/model/tf/utils.py
+0
-47
alphafold/model/utils.py
alphafold/model/utils.py
+0
-131
alphafold/notebooks/__init__.py
alphafold/notebooks/__init__.py
+0
-14
alphafold/notebooks/notebook_utils.py
alphafold/notebooks/notebook_utils.py
+0
-182
alphafold/notebooks/notebook_utils_test.py
alphafold/notebooks/notebook_utils_test.py
+0
-203
alphafold/relax/__init__.py
alphafold/relax/__init__.py
+0
-14
alphafold/relax/amber_minimize.py
alphafold/relax/amber_minimize.py
+0
-511
alphafold/relax/amber_minimize_test.py
alphafold/relax/amber_minimize_test.py
+0
-133
alphafold/relax/cleanup.py
alphafold/relax/cleanup.py
+0
-127
alphafold/relax/cleanup_test.py
alphafold/relax/cleanup_test.py
+0
-137
No files found.
alphafold/model/r3.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformations for 3D coordinates.
This Module contains objects for representing Vectors (Vecs), Rotation Matrices
(Rots) and proper Rigid transformation (Rigids). These are represented as
named tuples with arrays for each entry, for example a set of
[N, M] points would be represented as a Vecs object with arrays of shape [N, M]
for x, y and z.
This is being done to improve readability by making it very clear what objects
are geometric objects rather than relying on comments and array shapes.
Another reason for this is to avoid using matrix
multiplication primitives like matmul or einsum, on modern accelerator hardware
these can end up on specialized cores such as tensor cores on GPU or the MXU on
cloud TPUs, this often involves lower computational precision which can be
problematic for coordinate geometry. Also these cores are typically optimized
for larger matrices than 3 dimensional, this code is written to avoid any
unintended use of these cores on both GPUs and TPUs.
"""
import
collections
from
typing
import
List
from
alphafold.model
import
quat_affine
import
jax.numpy
as
jnp
import
tree
# Array of 3-component vectors, stored as individual array for
# each component.
Vecs
=
collections
.
namedtuple
(
'Vecs'
,
[
'x'
,
'y'
,
'z'
])
# Array of 3x3 rotation matrices, stored as individual array for
# each component.
Rots
=
collections
.
namedtuple
(
'Rots'
,
[
'xx'
,
'xy'
,
'xz'
,
'yx'
,
'yy'
,
'yz'
,
'zx'
,
'zy'
,
'zz'
])
# Array of rigid 3D transformations, stored as array of rotations and
# array of translations.
Rigids
=
collections
.
namedtuple
(
'Rigids'
,
[
'rot'
,
'trans'
])
def
squared_difference
(
x
,
y
):
return
jnp
.
square
(
x
-
y
)
def
invert_rigids
(
r
:
Rigids
)
->
Rigids
:
"""Computes group inverse of rigid transformations 'r'."""
inv_rots
=
invert_rots
(
r
.
rot
)
t
=
rots_mul_vecs
(
inv_rots
,
r
.
trans
)
inv_trans
=
Vecs
(
-
t
.
x
,
-
t
.
y
,
-
t
.
z
)
return
Rigids
(
inv_rots
,
inv_trans
)
def
invert_rots
(
m
:
Rots
)
->
Rots
:
"""Computes inverse of rotations 'm'."""
return
Rots
(
m
.
xx
,
m
.
yx
,
m
.
zx
,
m
.
xy
,
m
.
yy
,
m
.
zy
,
m
.
xz
,
m
.
yz
,
m
.
zz
)
def
rigids_from_3_points
(
point_on_neg_x_axis
:
Vecs
,
# shape (...)
origin
:
Vecs
,
# shape (...)
point_on_xy_plane
:
Vecs
,
# shape (...)
)
->
Rigids
:
# shape (...)
"""Create Rigids from 3 points.
Jumper et al. (2021) Suppl. Alg. 21 "rigidFrom3Points"
This creates a set of rigid transformations from 3 points by Gram Schmidt
orthogonalization.
Args:
point_on_neg_x_axis: Vecs corresponding to points on the negative x axis
origin: Origin of resulting rigid transformations
point_on_xy_plane: Vecs corresponding to points in the xy plane
Returns:
Rigid transformations from global frame to local frames derived from
the input points.
"""
m
=
rots_from_two_vecs
(
e0_unnormalized
=
vecs_sub
(
origin
,
point_on_neg_x_axis
),
e1_unnormalized
=
vecs_sub
(
point_on_xy_plane
,
origin
))
return
Rigids
(
rot
=
m
,
trans
=
origin
)
def
rigids_from_list
(
l
:
List
[
jnp
.
ndarray
])
->
Rigids
:
"""Converts flat list of arrays to rigid transformations."""
assert
len
(
l
)
==
12
return
Rigids
(
Rots
(
*
(
l
[:
9
])),
Vecs
(
*
(
l
[
9
:])))
def
rigids_from_quataffine
(
a
:
quat_affine
.
QuatAffine
)
->
Rigids
:
"""Converts QuatAffine object to the corresponding Rigids object."""
return
Rigids
(
Rots
(
*
tree
.
flatten
(
a
.
rotation
)),
Vecs
(
*
a
.
translation
))
def
rigids_from_tensor4x4
(
m
:
jnp
.
ndarray
# shape (..., 4, 4)
)
->
Rigids
:
# shape (...)
"""Construct Rigids object from an 4x4 array.
Here the 4x4 is representing the transformation in homogeneous coordinates.
Args:
m: Array representing transformations in homogeneous coordinates.
Returns:
Rigids object corresponding to transformations m
"""
assert
m
.
shape
[
-
1
]
==
4
assert
m
.
shape
[
-
2
]
==
4
return
Rigids
(
Rots
(
m
[...,
0
,
0
],
m
[...,
0
,
1
],
m
[...,
0
,
2
],
m
[...,
1
,
0
],
m
[...,
1
,
1
],
m
[...,
1
,
2
],
m
[...,
2
,
0
],
m
[...,
2
,
1
],
m
[...,
2
,
2
]),
Vecs
(
m
[...,
0
,
3
],
m
[...,
1
,
3
],
m
[...,
2
,
3
]))
def
rigids_from_tensor_flat9
(
m
:
jnp
.
ndarray
# shape (..., 9)
)
->
Rigids
:
# shape (...)
"""Flat9 encoding: first two columns of rotation matrix + translation."""
assert
m
.
shape
[
-
1
]
==
9
e0
=
Vecs
(
m
[...,
0
],
m
[...,
1
],
m
[...,
2
])
e1
=
Vecs
(
m
[...,
3
],
m
[...,
4
],
m
[...,
5
])
trans
=
Vecs
(
m
[...,
6
],
m
[...,
7
],
m
[...,
8
])
return
Rigids
(
rot
=
rots_from_two_vecs
(
e0
,
e1
),
trans
=
trans
)
def
rigids_from_tensor_flat12
(
m
:
jnp
.
ndarray
# shape (..., 12)
)
->
Rigids
:
# shape (...)
"""Flat12 encoding: rotation matrix (9 floats) + translation (3 floats)."""
assert
m
.
shape
[
-
1
]
==
12
x
=
jnp
.
moveaxis
(
m
,
-
1
,
0
)
# Unstack
return
Rigids
(
Rots
(
*
x
[:
9
]),
Vecs
(
*
x
[
9
:]))
def
rigids_mul_rigids
(
a
:
Rigids
,
b
:
Rigids
)
->
Rigids
:
"""Group composition of Rigids 'a' and 'b'."""
return
Rigids
(
rots_mul_rots
(
a
.
rot
,
b
.
rot
),
vecs_add
(
a
.
trans
,
rots_mul_vecs
(
a
.
rot
,
b
.
trans
)))
def
rigids_mul_rots
(
r
:
Rigids
,
m
:
Rots
)
->
Rigids
:
"""Compose rigid transformations 'r' with rotations 'm'."""
return
Rigids
(
rots_mul_rots
(
r
.
rot
,
m
),
r
.
trans
)
def
rigids_mul_vecs
(
r
:
Rigids
,
v
:
Vecs
)
->
Vecs
:
"""Apply rigid transforms 'r' to points 'v'."""
return
vecs_add
(
rots_mul_vecs
(
r
.
rot
,
v
),
r
.
trans
)
def
rigids_to_list
(
r
:
Rigids
)
->
List
[
jnp
.
ndarray
]:
"""Turn Rigids into flat list, inverse of 'rigids_from_list'."""
return
list
(
r
.
rot
)
+
list
(
r
.
trans
)
def
rigids_to_quataffine
(
r
:
Rigids
)
->
quat_affine
.
QuatAffine
:
"""Convert Rigids r into QuatAffine, inverse of 'rigids_from_quataffine'."""
return
quat_affine
.
QuatAffine
(
quaternion
=
None
,
rotation
=
[[
r
.
rot
.
xx
,
r
.
rot
.
xy
,
r
.
rot
.
xz
],
[
r
.
rot
.
yx
,
r
.
rot
.
yy
,
r
.
rot
.
yz
],
[
r
.
rot
.
zx
,
r
.
rot
.
zy
,
r
.
rot
.
zz
]],
translation
=
[
r
.
trans
.
x
,
r
.
trans
.
y
,
r
.
trans
.
z
])
def
rigids_to_tensor_flat9
(
r
:
Rigids
# shape (...)
)
->
jnp
.
ndarray
:
# shape (..., 9)
"""Flat9 encoding: first two columns of rotation matrix + translation."""
return
jnp
.
stack
(
[
r
.
rot
.
xx
,
r
.
rot
.
yx
,
r
.
rot
.
zx
,
r
.
rot
.
xy
,
r
.
rot
.
yy
,
r
.
rot
.
zy
]
+
list
(
r
.
trans
),
axis
=-
1
)
def
rigids_to_tensor_flat12
(
r
:
Rigids
# shape (...)
)
->
jnp
.
ndarray
:
# shape (..., 12)
"""Flat12 encoding: rotation matrix (9 floats) + translation (3 floats)."""
return
jnp
.
stack
(
list
(
r
.
rot
)
+
list
(
r
.
trans
),
axis
=-
1
)
def
rots_from_tensor3x3
(
m
:
jnp
.
ndarray
,
# shape (..., 3, 3)
)
->
Rots
:
# shape (...)
"""Convert rotations represented as (3, 3) array to Rots."""
assert
m
.
shape
[
-
1
]
==
3
assert
m
.
shape
[
-
2
]
==
3
return
Rots
(
m
[...,
0
,
0
],
m
[...,
0
,
1
],
m
[...,
0
,
2
],
m
[...,
1
,
0
],
m
[...,
1
,
1
],
m
[...,
1
,
2
],
m
[...,
2
,
0
],
m
[...,
2
,
1
],
m
[...,
2
,
2
])
def
rots_from_two_vecs
(
e0_unnormalized
:
Vecs
,
e1_unnormalized
:
Vecs
)
->
Rots
:
"""Create rotation matrices from unnormalized vectors for the x and y-axes.
This creates a rotation matrix from two vectors using Gram-Schmidt
orthogonalization.
Args:
e0_unnormalized: vectors lying along x-axis of resulting rotation
e1_unnormalized: vectors lying in xy-plane of resulting rotation
Returns:
Rotations resulting from Gram-Schmidt procedure.
"""
# Normalize the unit vector for the x-axis, e0.
e0
=
vecs_robust_normalize
(
e0_unnormalized
)
# make e1 perpendicular to e0.
c
=
vecs_dot_vecs
(
e1_unnormalized
,
e0
)
e1
=
Vecs
(
e1_unnormalized
.
x
-
c
*
e0
.
x
,
e1_unnormalized
.
y
-
c
*
e0
.
y
,
e1_unnormalized
.
z
-
c
*
e0
.
z
)
e1
=
vecs_robust_normalize
(
e1
)
# Compute e2 as cross product of e0 and e1.
e2
=
vecs_cross_vecs
(
e0
,
e1
)
return
Rots
(
e0
.
x
,
e1
.
x
,
e2
.
x
,
e0
.
y
,
e1
.
y
,
e2
.
y
,
e0
.
z
,
e1
.
z
,
e2
.
z
)
def
rots_mul_rots
(
a
:
Rots
,
b
:
Rots
)
->
Rots
:
"""Composition of rotations 'a' and 'b'."""
c0
=
rots_mul_vecs
(
a
,
Vecs
(
b
.
xx
,
b
.
yx
,
b
.
zx
))
c1
=
rots_mul_vecs
(
a
,
Vecs
(
b
.
xy
,
b
.
yy
,
b
.
zy
))
c2
=
rots_mul_vecs
(
a
,
Vecs
(
b
.
xz
,
b
.
yz
,
b
.
zz
))
return
Rots
(
c0
.
x
,
c1
.
x
,
c2
.
x
,
c0
.
y
,
c1
.
y
,
c2
.
y
,
c0
.
z
,
c1
.
z
,
c2
.
z
)
def
rots_mul_vecs
(
m
:
Rots
,
v
:
Vecs
)
->
Vecs
:
"""Apply rotations 'm' to vectors 'v'."""
return
Vecs
(
m
.
xx
*
v
.
x
+
m
.
xy
*
v
.
y
+
m
.
xz
*
v
.
z
,
m
.
yx
*
v
.
x
+
m
.
yy
*
v
.
y
+
m
.
yz
*
v
.
z
,
m
.
zx
*
v
.
x
+
m
.
zy
*
v
.
y
+
m
.
zz
*
v
.
z
)
def
vecs_add
(
v1
:
Vecs
,
v2
:
Vecs
)
->
Vecs
:
"""Add two vectors 'v1' and 'v2'."""
return
Vecs
(
v1
.
x
+
v2
.
x
,
v1
.
y
+
v2
.
y
,
v1
.
z
+
v2
.
z
)
def
vecs_dot_vecs
(
v1
:
Vecs
,
v2
:
Vecs
)
->
jnp
.
ndarray
:
"""Dot product of vectors 'v1' and 'v2'."""
return
v1
.
x
*
v2
.
x
+
v1
.
y
*
v2
.
y
+
v1
.
z
*
v2
.
z
def
vecs_cross_vecs
(
v1
:
Vecs
,
v2
:
Vecs
)
->
Vecs
:
"""Cross product of vectors 'v1' and 'v2'."""
return
Vecs
(
v1
.
y
*
v2
.
z
-
v1
.
z
*
v2
.
y
,
v1
.
z
*
v2
.
x
-
v1
.
x
*
v2
.
z
,
v1
.
x
*
v2
.
y
-
v1
.
y
*
v2
.
x
)
def
vecs_from_tensor
(
x
:
jnp
.
ndarray
# shape (..., 3)
)
->
Vecs
:
# shape (...)
"""Converts from tensor of shape (3,) to Vecs."""
num_components
=
x
.
shape
[
-
1
]
assert
num_components
==
3
return
Vecs
(
x
[...,
0
],
x
[...,
1
],
x
[...,
2
])
def
vecs_robust_normalize
(
v
:
Vecs
,
epsilon
:
float
=
1e-8
)
->
Vecs
:
"""Normalizes vectors 'v'.
Args:
v: vectors to be normalized.
epsilon: small regularizer added to squared norm before taking square root.
Returns:
normalized vectors
"""
norms
=
vecs_robust_norm
(
v
,
epsilon
)
return
Vecs
(
v
.
x
/
norms
,
v
.
y
/
norms
,
v
.
z
/
norms
)
def
vecs_robust_norm
(
v
:
Vecs
,
epsilon
:
float
=
1e-8
)
->
jnp
.
ndarray
:
"""Computes norm of vectors 'v'.
Args:
v: vectors to be normalized.
epsilon: small regularizer added to squared norm before taking square root.
Returns:
norm of 'v'
"""
return
jnp
.
sqrt
(
jnp
.
square
(
v
.
x
)
+
jnp
.
square
(
v
.
y
)
+
jnp
.
square
(
v
.
z
)
+
epsilon
)
def
vecs_sub
(
v1
:
Vecs
,
v2
:
Vecs
)
->
Vecs
:
"""Computes v1 - v2."""
return
Vecs
(
v1
.
x
-
v2
.
x
,
v1
.
y
-
v2
.
y
,
v1
.
z
-
v2
.
z
)
def
vecs_squared_distance
(
v1
:
Vecs
,
v2
:
Vecs
)
->
jnp
.
ndarray
:
"""Computes squared euclidean difference between 'v1' and 'v2'."""
return
(
squared_difference
(
v1
.
x
,
v2
.
x
)
+
squared_difference
(
v1
.
y
,
v2
.
y
)
+
squared_difference
(
v1
.
z
,
v2
.
z
))
def
vecs_to_tensor
(
v
:
Vecs
# shape (...)
)
->
jnp
.
ndarray
:
# shape(..., 3)
"""Converts 'v' to tensor with shape 3, inverse of 'vecs_from_tensor'."""
return
jnp
.
stack
([
v
.
x
,
v
.
y
,
v
.
z
],
axis
=-
1
)
alphafold/model/tf/__init__.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Alphafold model TensorFlow code."""
alphafold/model/tf/data_transforms.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data for AlphaFold."""
from
alphafold.common
import
residue_constants
from
alphafold.model.tf
import
shape_helpers
from
alphafold.model.tf
import
shape_placeholders
from
alphafold.model.tf
import
utils
import
numpy
as
np
import
tensorflow.compat.v1
as
tf
# Pylint gets confused by the curry1 decorator because it changes the number
# of arguments to the function.
# pylint:disable=no-value-for-parameter
NUM_RES
=
shape_placeholders
.
NUM_RES
NUM_MSA_SEQ
=
shape_placeholders
.
NUM_MSA_SEQ
NUM_EXTRA_SEQ
=
shape_placeholders
.
NUM_EXTRA_SEQ
NUM_TEMPLATES
=
shape_placeholders
.
NUM_TEMPLATES
def
cast_64bit_ints
(
protein
):
for
k
,
v
in
protein
.
items
():
if
v
.
dtype
==
tf
.
int64
:
protein
[
k
]
=
tf
.
cast
(
v
,
tf
.
int32
)
return
protein
_MSA_FEATURE_NAMES
=
[
'msa'
,
'deletion_matrix'
,
'msa_mask'
,
'msa_row_mask'
,
'bert_mask'
,
'true_msa'
]
def
make_seq_mask
(
protein
):
protein
[
'seq_mask'
]
=
tf
.
ones
(
shape_helpers
.
shape_list
(
protein
[
'aatype'
]),
dtype
=
tf
.
float32
)
return
protein
def
make_template_mask
(
protein
):
protein
[
'template_mask'
]
=
tf
.
ones
(
shape_helpers
.
shape_list
(
protein
[
'template_domain_names'
]),
dtype
=
tf
.
float32
)
return
protein
def
curry1
(
f
):
"""Supply all arguments but the first."""
def
fc
(
*
args
,
**
kwargs
):
return
lambda
x
:
f
(
x
,
*
args
,
**
kwargs
)
return
fc
@
curry1
def
add_distillation_flag
(
protein
,
distillation
):
protein
[
'is_distillation'
]
=
tf
.
constant
(
float
(
distillation
),
shape
=
[],
dtype
=
tf
.
float32
)
return
protein
def
make_all_atom_aatype
(
protein
):
protein
[
'all_atom_aatype'
]
=
protein
[
'aatype'
]
return
protein
def
fix_templates_aatype
(
protein
):
"""Fixes aatype encoding of templates."""
# Map one-hot to indices.
protein
[
'template_aatype'
]
=
tf
.
argmax
(
protein
[
'template_aatype'
],
output_type
=
tf
.
int32
,
axis
=-
1
)
# Map hhsearch-aatype to our aatype.
new_order_list
=
residue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
tf
.
constant
(
new_order_list
,
dtype
=
tf
.
int32
)
protein
[
'template_aatype'
]
=
tf
.
gather
(
params
=
new_order
,
indices
=
protein
[
'template_aatype'
])
return
protein
def
correct_msa_restypes
(
protein
):
"""Correct MSA restype to have the same order as residue_constants."""
new_order_list
=
residue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
tf
.
constant
(
new_order_list
,
dtype
=
protein
[
'msa'
].
dtype
)
protein
[
'msa'
]
=
tf
.
gather
(
new_order
,
protein
[
'msa'
],
axis
=
0
)
perm_matrix
=
np
.
zeros
((
22
,
22
),
dtype
=
np
.
float32
)
perm_matrix
[
range
(
len
(
new_order_list
)),
new_order_list
]
=
1.
for
k
in
protein
:
if
'profile'
in
k
:
# Include both hhblits and psiblast profiles
num_dim
=
protein
[
k
].
shape
.
as_list
()[
-
1
]
assert
num_dim
in
[
20
,
21
,
22
],
(
'num_dim for %s out of expected range: %s'
%
(
k
,
num_dim
))
protein
[
k
]
=
tf
.
tensordot
(
protein
[
k
],
perm_matrix
[:
num_dim
,
:
num_dim
],
1
)
return
protein
def
squeeze_features
(
protein
):
"""Remove singleton and repeated dimensions in protein features."""
protein
[
'aatype'
]
=
tf
.
argmax
(
protein
[
'aatype'
],
axis
=-
1
,
output_type
=
tf
.
int32
)
for
k
in
[
'domain_name'
,
'msa'
,
'num_alignments'
,
'seq_length'
,
'sequence'
,
'superfamily'
,
'deletion_matrix'
,
'resolution'
,
'between_segment_residues'
,
'residue_index'
,
'template_all_atom_masks'
]:
if
k
in
protein
:
final_dim
=
shape_helpers
.
shape_list
(
protein
[
k
])[
-
1
]
if
isinstance
(
final_dim
,
int
)
and
final_dim
==
1
:
protein
[
k
]
=
tf
.
squeeze
(
protein
[
k
],
axis
=-
1
)
for
k
in
[
'seq_length'
,
'num_alignments'
]:
if
k
in
protein
:
protein
[
k
]
=
protein
[
k
][
0
]
# Remove fake sequence dimension
return
protein
def
make_random_crop_to_size_seed
(
protein
):
"""Random seed for cropping residues and templates."""
protein
[
'random_crop_to_size_seed'
]
=
utils
.
make_random_seed
()
return
protein
@
curry1
def
randomly_replace_msa_with_unknown
(
protein
,
replace_proportion
):
"""Replace a proportion of the MSA with 'X'."""
msa_mask
=
(
tf
.
random
.
uniform
(
shape_helpers
.
shape_list
(
protein
[
'msa'
]))
<
replace_proportion
)
x_idx
=
20
gap_idx
=
21
msa_mask
=
tf
.
logical_and
(
msa_mask
,
protein
[
'msa'
]
!=
gap_idx
)
protein
[
'msa'
]
=
tf
.
where
(
msa_mask
,
tf
.
ones_like
(
protein
[
'msa'
])
*
x_idx
,
protein
[
'msa'
])
aatype_mask
=
(
tf
.
random
.
uniform
(
shape_helpers
.
shape_list
(
protein
[
'aatype'
]))
<
replace_proportion
)
protein
[
'aatype'
]
=
tf
.
where
(
aatype_mask
,
tf
.
ones_like
(
protein
[
'aatype'
])
*
x_idx
,
protein
[
'aatype'
])
return
protein
@
curry1
def
sample_msa
(
protein
,
max_seq
,
keep_extra
):
"""Sample MSA randomly, remaining sequences are stored as `extra_*`.
Args:
protein: batch to sample msa from.
max_seq: number of sequences to sample.
keep_extra: When True sequences not sampled are put into fields starting
with 'extra_*'.
Returns:
Protein with sampled msa.
"""
num_seq
=
tf
.
shape
(
protein
[
'msa'
])[
0
]
shuffled
=
tf
.
random_shuffle
(
tf
.
range
(
1
,
num_seq
))
index_order
=
tf
.
concat
([[
0
],
shuffled
],
axis
=
0
)
num_sel
=
tf
.
minimum
(
max_seq
,
num_seq
)
sel_seq
,
not_sel_seq
=
tf
.
split
(
index_order
,
[
num_sel
,
num_seq
-
num_sel
])
for
k
in
_MSA_FEATURE_NAMES
:
if
k
in
protein
:
if
keep_extra
:
protein
[
'extra_'
+
k
]
=
tf
.
gather
(
protein
[
k
],
not_sel_seq
)
protein
[
k
]
=
tf
.
gather
(
protein
[
k
],
sel_seq
)
return
protein
@
curry1
def
crop_extra_msa
(
protein
,
max_extra_msa
):
"""MSA features are cropped so only `max_extra_msa` sequences are kept."""
num_seq
=
tf
.
shape
(
protein
[
'extra_msa'
])[
0
]
num_sel
=
tf
.
minimum
(
max_extra_msa
,
num_seq
)
select_indices
=
tf
.
random_shuffle
(
tf
.
range
(
0
,
num_seq
))[:
num_sel
]
for
k
in
_MSA_FEATURE_NAMES
:
if
'extra_'
+
k
in
protein
:
protein
[
'extra_'
+
k
]
=
tf
.
gather
(
protein
[
'extra_'
+
k
],
select_indices
)
return
protein
def
delete_extra_msa
(
protein
):
for
k
in
_MSA_FEATURE_NAMES
:
if
'extra_'
+
k
in
protein
:
del
protein
[
'extra_'
+
k
]
return
protein
@
curry1
def
block_delete_msa
(
protein
,
config
):
"""Sample MSA by deleting contiguous blocks.
Jumper et al. (2021) Suppl. Alg. 1 "MSABlockDeletion"
Arguments:
protein: batch dict containing the msa
config: ConfigDict with parameters
Returns:
updated protein
"""
num_seq
=
shape_helpers
.
shape_list
(
protein
[
'msa'
])[
0
]
block_num_seq
=
tf
.
cast
(
tf
.
floor
(
tf
.
cast
(
num_seq
,
tf
.
float32
)
*
config
.
msa_fraction_per_block
),
tf
.
int32
)
if
config
.
randomize_num_blocks
:
nb
=
tf
.
random
.
uniform
([],
0
,
config
.
num_blocks
+
1
,
dtype
=
tf
.
int32
)
else
:
nb
=
config
.
num_blocks
del_block_starts
=
tf
.
random
.
uniform
([
nb
],
0
,
num_seq
,
dtype
=
tf
.
int32
)
del_blocks
=
del_block_starts
[:,
None
]
+
tf
.
range
(
block_num_seq
)
del_blocks
=
tf
.
clip_by_value
(
del_blocks
,
0
,
num_seq
-
1
)
del_indices
=
tf
.
unique
(
tf
.
sort
(
tf
.
reshape
(
del_blocks
,
[
-
1
])))[
0
]
# Make sure we keep the original sequence
sparse_diff
=
tf
.
sets
.
difference
(
tf
.
range
(
1
,
num_seq
)[
None
],
del_indices
[
None
])
keep_indices
=
tf
.
squeeze
(
tf
.
sparse
.
to_dense
(
sparse_diff
),
0
)
keep_indices
=
tf
.
concat
([[
0
],
keep_indices
],
axis
=
0
)
for
k
in
_MSA_FEATURE_NAMES
:
if
k
in
protein
:
protein
[
k
]
=
tf
.
gather
(
protein
[
k
],
keep_indices
)
return
protein
@
curry1
def
nearest_neighbor_clusters
(
protein
,
gap_agreement_weight
=
0.
):
"""Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
# Determine how much weight we assign to each agreement. In theory, we could
# use a full blosum matrix here, but right now let's just down-weight gap
# agreement because it could be spurious.
# Never put weight on agreeing on BERT mask
weights
=
tf
.
concat
([
tf
.
ones
(
21
),
gap_agreement_weight
*
tf
.
ones
(
1
),
np
.
zeros
(
1
)],
0
)
# Make agreement score as weighted Hamming distance
sample_one_hot
=
(
protein
[
'msa_mask'
][:,
:,
None
]
*
tf
.
one_hot
(
protein
[
'msa'
],
23
))
extra_one_hot
=
(
protein
[
'extra_msa_mask'
][:,
:,
None
]
*
tf
.
one_hot
(
protein
[
'extra_msa'
],
23
))
num_seq
,
num_res
,
_
=
shape_helpers
.
shape_list
(
sample_one_hot
)
extra_num_seq
,
_
,
_
=
shape_helpers
.
shape_list
(
extra_one_hot
)
# Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
# in an optimized fashion to avoid possible memory or computation blowup.
agreement
=
tf
.
matmul
(
tf
.
reshape
(
extra_one_hot
,
[
extra_num_seq
,
num_res
*
23
]),
tf
.
reshape
(
sample_one_hot
*
weights
,
[
num_seq
,
num_res
*
23
]),
transpose_b
=
True
)
# Assign each sequence in the extra sequences to the closest MSA sample
protein
[
'extra_cluster_assignment'
]
=
tf
.
argmax
(
agreement
,
axis
=
1
,
output_type
=
tf
.
int32
)
return
protein
@
curry1
def
summarize_clusters
(
protein
):
"""Produce profile and deletion_matrix_mean within each cluster."""
num_seq
=
shape_helpers
.
shape_list
(
protein
[
'msa'
])[
0
]
def
csum
(
x
):
return
tf
.
math
.
unsorted_segment_sum
(
x
,
protein
[
'extra_cluster_assignment'
],
num_seq
)
mask
=
protein
[
'extra_msa_mask'
]
mask_counts
=
1e-6
+
protein
[
'msa_mask'
]
+
csum
(
mask
)
# Include center
msa_sum
=
csum
(
mask
[:,
:,
None
]
*
tf
.
one_hot
(
protein
[
'extra_msa'
],
23
))
msa_sum
+=
tf
.
one_hot
(
protein
[
'msa'
],
23
)
# Original sequence
protein
[
'cluster_profile'
]
=
msa_sum
/
mask_counts
[:,
:,
None
]
del
msa_sum
del_sum
=
csum
(
mask
*
protein
[
'extra_deletion_matrix'
])
del_sum
+=
protein
[
'deletion_matrix'
]
# Original sequence
protein
[
'cluster_deletion_mean'
]
=
del_sum
/
mask_counts
del
del_sum
return
protein
def
make_msa_mask
(
protein
):
"""Mask features are all ones, but will later be zero-padded."""
protein
[
'msa_mask'
]
=
tf
.
ones
(
shape_helpers
.
shape_list
(
protein
[
'msa'
]),
dtype
=
tf
.
float32
)
protein
[
'msa_row_mask'
]
=
tf
.
ones
(
shape_helpers
.
shape_list
(
protein
[
'msa'
])[
0
],
dtype
=
tf
.
float32
)
return
protein
def
pseudo_beta_fn
(
aatype
,
all_atom_positions
,
all_atom_masks
):
"""Create pseudo beta features."""
is_gly
=
tf
.
equal
(
aatype
,
residue_constants
.
restype_order
[
'G'
])
ca_idx
=
residue_constants
.
atom_order
[
'CA'
]
cb_idx
=
residue_constants
.
atom_order
[
'CB'
]
pseudo_beta
=
tf
.
where
(
tf
.
tile
(
is_gly
[...,
None
],
[
1
]
*
len
(
is_gly
.
shape
)
+
[
3
]),
all_atom_positions
[...,
ca_idx
,
:],
all_atom_positions
[...,
cb_idx
,
:])
if
all_atom_masks
is
not
None
:
pseudo_beta_mask
=
tf
.
where
(
is_gly
,
all_atom_masks
[...,
ca_idx
],
all_atom_masks
[...,
cb_idx
])
pseudo_beta_mask
=
tf
.
cast
(
pseudo_beta_mask
,
tf
.
float32
)
return
pseudo_beta
,
pseudo_beta_mask
else
:
return
pseudo_beta
@
curry1
def
make_pseudo_beta
(
protein
,
prefix
=
''
):
"""Create pseudo-beta (alpha for glycine) position and mask."""
assert
prefix
in
[
''
,
'template_'
]
protein
[
prefix
+
'pseudo_beta'
],
protein
[
prefix
+
'pseudo_beta_mask'
]
=
(
pseudo_beta_fn
(
protein
[
'template_aatype'
if
prefix
else
'all_atom_aatype'
],
protein
[
prefix
+
'all_atom_positions'
],
protein
[
'template_all_atom_masks'
if
prefix
else
'all_atom_mask'
]))
return
protein
@
curry1
def
add_constant_field
(
protein
,
key
,
value
):
protein
[
key
]
=
tf
.
convert_to_tensor
(
value
)
return
protein
def
shaped_categorical
(
probs
,
epsilon
=
1e-10
):
ds
=
shape_helpers
.
shape_list
(
probs
)
num_classes
=
ds
[
-
1
]
counts
=
tf
.
random
.
categorical
(
tf
.
reshape
(
tf
.
log
(
probs
+
epsilon
),
[
-
1
,
num_classes
]),
1
,
dtype
=
tf
.
int32
)
return
tf
.
reshape
(
counts
,
ds
[:
-
1
])
def
make_hhblits_profile
(
protein
):
"""Compute the HHblits MSA profile if not already present."""
if
'hhblits_profile'
in
protein
:
return
protein
# Compute the profile for every residue (over all MSA sequences).
protein
[
'hhblits_profile'
]
=
tf
.
reduce_mean
(
tf
.
one_hot
(
protein
[
'msa'
],
22
),
axis
=
0
)
return
protein
@
curry1
def
make_masked_msa
(
protein
,
config
,
replace_fraction
):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly
random_aa
=
tf
.
constant
([
0.05
]
*
20
+
[
0.
,
0.
],
dtype
=
tf
.
float32
)
categorical_probs
=
(
config
.
uniform_prob
*
random_aa
+
config
.
profile_prob
*
protein
[
'hhblits_profile'
]
+
config
.
same_prob
*
tf
.
one_hot
(
protein
[
'msa'
],
22
))
# Put all remaining probability on [MASK] which is a new column
pad_shapes
=
[[
0
,
0
]
for
_
in
range
(
len
(
categorical_probs
.
shape
))]
pad_shapes
[
-
1
][
1
]
=
1
mask_prob
=
1.
-
config
.
profile_prob
-
config
.
same_prob
-
config
.
uniform_prob
assert
mask_prob
>=
0.
categorical_probs
=
tf
.
pad
(
categorical_probs
,
pad_shapes
,
constant_values
=
mask_prob
)
sh
=
shape_helpers
.
shape_list
(
protein
[
'msa'
])
mask_position
=
tf
.
random
.
uniform
(
sh
)
<
replace_fraction
bert_msa
=
shaped_categorical
(
categorical_probs
)
bert_msa
=
tf
.
where
(
mask_position
,
bert_msa
,
protein
[
'msa'
])
# Mix real and masked MSA
protein
[
'bert_mask'
]
=
tf
.
cast
(
mask_position
,
tf
.
float32
)
protein
[
'true_msa'
]
=
protein
[
'msa'
]
protein
[
'msa'
]
=
bert_msa
return
protein
@
curry1
def
make_fixed_size
(
protein
,
shape_schema
,
msa_cluster_size
,
extra_msa_size
,
num_res
,
num_templates
=
0
):
"""Guess at the MSA and sequence dimensions to make fixed size."""
pad_size_map
=
{
NUM_RES
:
num_res
,
NUM_MSA_SEQ
:
msa_cluster_size
,
NUM_EXTRA_SEQ
:
extra_msa_size
,
NUM_TEMPLATES
:
num_templates
,
}
for
k
,
v
in
protein
.
items
():
# Don't transfer this to the accelerator.
if
k
==
'extra_cluster_assignment'
:
continue
shape
=
v
.
shape
.
as_list
()
schema
=
shape_schema
[
k
]
assert
len
(
shape
)
==
len
(
schema
),
(
f
'Rank mismatch between shape and shape schema for
{
k
}
: '
f
'
{
shape
}
vs
{
schema
}
'
)
pad_size
=
[
pad_size_map
.
get
(
s2
,
None
)
or
s1
for
(
s1
,
s2
)
in
zip
(
shape
,
schema
)
]
padding
=
[(
0
,
p
-
tf
.
shape
(
v
)[
i
])
for
i
,
p
in
enumerate
(
pad_size
)]
if
padding
:
protein
[
k
]
=
tf
.
pad
(
v
,
padding
,
name
=
f
'pad_to_fixed_
{
k
}
'
)
protein
[
k
].
set_shape
(
pad_size
)
return
protein
@
curry1
def
make_msa_feat
(
protein
):
"""Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping
# for compatibility with domain datasets.
has_break
=
tf
.
clip_by_value
(
tf
.
cast
(
protein
[
'between_segment_residues'
],
tf
.
float32
),
0
,
1
)
aatype_1hot
=
tf
.
one_hot
(
protein
[
'aatype'
],
21
,
axis
=-
1
)
target_feat
=
[
tf
.
expand_dims
(
has_break
,
axis
=-
1
),
aatype_1hot
,
# Everyone gets the original sequence.
]
msa_1hot
=
tf
.
one_hot
(
protein
[
'msa'
],
23
,
axis
=-
1
)
has_deletion
=
tf
.
clip_by_value
(
protein
[
'deletion_matrix'
],
0.
,
1.
)
deletion_value
=
tf
.
atan
(
protein
[
'deletion_matrix'
]
/
3.
)
*
(
2.
/
np
.
pi
)
msa_feat
=
[
msa_1hot
,
tf
.
expand_dims
(
has_deletion
,
axis
=-
1
),
tf
.
expand_dims
(
deletion_value
,
axis
=-
1
),
]
if
'cluster_profile'
in
protein
:
deletion_mean_value
=
(
tf
.
atan
(
protein
[
'cluster_deletion_mean'
]
/
3.
)
*
(
2.
/
np
.
pi
))
msa_feat
.
extend
([
protein
[
'cluster_profile'
],
tf
.
expand_dims
(
deletion_mean_value
,
axis
=-
1
),
])
if
'extra_deletion_matrix'
in
protein
:
protein
[
'extra_has_deletion'
]
=
tf
.
clip_by_value
(
protein
[
'extra_deletion_matrix'
],
0.
,
1.
)
protein
[
'extra_deletion_value'
]
=
tf
.
atan
(
protein
[
'extra_deletion_matrix'
]
/
3.
)
*
(
2.
/
np
.
pi
)
protein
[
'msa_feat'
]
=
tf
.
concat
(
msa_feat
,
axis
=-
1
)
protein
[
'target_feat'
]
=
tf
.
concat
(
target_feat
,
axis
=-
1
)
return
protein
@
curry1
def
select_feat
(
protein
,
feature_list
):
return
{
k
:
v
for
k
,
v
in
protein
.
items
()
if
k
in
feature_list
}
@
curry1
def
crop_templates
(
protein
,
max_templates
):
for
k
,
v
in
protein
.
items
():
if
k
.
startswith
(
'template_'
):
protein
[
k
]
=
v
[:
max_templates
]
return
protein
@
curry1
def
random_crop_to_size
(
protein
,
crop_size
,
max_templates
,
shape_schema
,
subsample_templates
=
False
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
seq_length
=
protein
[
'seq_length'
]
if
'template_mask'
in
protein
:
num_templates
=
tf
.
cast
(
shape_helpers
.
shape_list
(
protein
[
'template_mask'
])[
0
],
tf
.
int32
)
else
:
num_templates
=
tf
.
constant
(
0
,
dtype
=
tf
.
int32
)
num_res_crop_size
=
tf
.
math
.
minimum
(
seq_length
,
crop_size
)
# Ensures that the cropping of residues and templates happens in the same way
# across ensembling iterations.
# Do not use for randomness that should vary in ensembling.
seed_maker
=
utils
.
SeedMaker
(
initial_seed
=
protein
[
'random_crop_to_size_seed'
])
if
subsample_templates
:
templates_crop_start
=
tf
.
random
.
stateless_uniform
(
shape
=
(),
minval
=
0
,
maxval
=
num_templates
+
1
,
dtype
=
tf
.
int32
,
seed
=
seed_maker
())
else
:
templates_crop_start
=
0
num_templates_crop_size
=
tf
.
math
.
minimum
(
num_templates
-
templates_crop_start
,
max_templates
)
num_res_crop_start
=
tf
.
random
.
stateless_uniform
(
shape
=
(),
minval
=
0
,
maxval
=
seq_length
-
num_res_crop_size
+
1
,
dtype
=
tf
.
int32
,
seed
=
seed_maker
())
templates_select_indices
=
tf
.
argsort
(
tf
.
random
.
stateless_uniform
(
[
num_templates
],
seed
=
seed_maker
()))
for
k
,
v
in
protein
.
items
():
if
k
not
in
shape_schema
or
(
'template'
not
in
k
and
NUM_RES
not
in
shape_schema
[
k
]):
continue
# randomly permute the templates before cropping them.
if
k
.
startswith
(
'template'
)
and
subsample_templates
:
v
=
tf
.
gather
(
v
,
templates_select_indices
)
crop_sizes
=
[]
crop_starts
=
[]
for
i
,
(
dim_size
,
dim
)
in
enumerate
(
zip
(
shape_schema
[
k
],
shape_helpers
.
shape_list
(
v
))):
is_num_res
=
(
dim_size
==
NUM_RES
)
if
i
==
0
and
k
.
startswith
(
'template'
):
crop_size
=
num_templates_crop_size
crop_start
=
templates_crop_start
else
:
crop_start
=
num_res_crop_start
if
is_num_res
else
0
crop_size
=
(
num_res_crop_size
if
is_num_res
else
(
-
1
if
dim
is
None
else
dim
))
crop_sizes
.
append
(
crop_size
)
crop_starts
.
append
(
crop_start
)
protein
[
k
]
=
tf
.
slice
(
v
,
crop_starts
,
crop_sizes
)
protein
[
'seq_length'
]
=
num_res_crop_size
return
protein
def
make_atom14_masks
(
protein
):
"""Construct denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37
=
[]
# mapping (restype, atom14) --> atom37
restype_atom37_to_atom14
=
[]
# mapping (restype, atom37) --> atom14
restype_atom14_mask
=
[]
for
rt
in
residue_constants
.
restypes
:
atom_names
=
residue_constants
.
restype_name_to_atom14_names
[
residue_constants
.
restype_1to3
[
rt
]]
restype_atom14_to_atom37
.
append
([
(
residue_constants
.
atom_order
[
name
]
if
name
else
0
)
for
name
in
atom_names
])
atom_name_to_idx14
=
{
name
:
i
for
i
,
name
in
enumerate
(
atom_names
)}
restype_atom37_to_atom14
.
append
([
(
atom_name_to_idx14
[
name
]
if
name
in
atom_name_to_idx14
else
0
)
for
name
in
residue_constants
.
atom_types
])
restype_atom14_mask
.
append
([(
1.
if
name
else
0.
)
for
name
in
atom_names
])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37
.
append
([
0
]
*
14
)
restype_atom37_to_atom14
.
append
([
0
]
*
37
)
restype_atom14_mask
.
append
([
0.
]
*
14
)
restype_atom14_to_atom37
=
np
.
array
(
restype_atom14_to_atom37
,
dtype
=
np
.
int32
)
restype_atom37_to_atom14
=
np
.
array
(
restype_atom37_to_atom14
,
dtype
=
np
.
int32
)
restype_atom14_mask
=
np
.
array
(
restype_atom14_mask
,
dtype
=
np
.
float32
)
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37
=
tf
.
gather
(
restype_atom14_to_atom37
,
protein
[
'aatype'
])
residx_atom14_mask
=
tf
.
gather
(
restype_atom14_mask
,
protein
[
'aatype'
])
protein
[
'atom14_atom_exists'
]
=
residx_atom14_mask
protein
[
'residx_atom14_to_atom37'
]
=
residx_atom14_to_atom37
# create the gather indices for mapping back
residx_atom37_to_atom14
=
tf
.
gather
(
restype_atom37_to_atom14
,
protein
[
'aatype'
])
protein
[
'residx_atom37_to_atom14'
]
=
residx_atom37_to_atom14
# create the corresponding mask
restype_atom37_mask
=
np
.
zeros
([
21
,
37
],
dtype
=
np
.
float32
)
for
restype
,
restype_letter
in
enumerate
(
residue_constants
.
restypes
):
restype_name
=
residue_constants
.
restype_1to3
[
restype_letter
]
atom_names
=
residue_constants
.
residue_atoms
[
restype_name
]
for
atom_name
in
atom_names
:
atom_type
=
residue_constants
.
atom_order
[
atom_name
]
restype_atom37_mask
[
restype
,
atom_type
]
=
1
residx_atom37_mask
=
tf
.
gather
(
restype_atom37_mask
,
protein
[
'aatype'
])
protein
[
'atom37_atom_exists'
]
=
residx_atom37_mask
return
protein
alphafold/model/tf/input_pipeline.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature pre-processing input pipeline for AlphaFold."""
from
alphafold.model.tf
import
data_transforms
from
alphafold.model.tf
import
shape_placeholders
import
tensorflow.compat.v1
as
tf
import
tree
# Pylint gets confused by the curry1 decorator because it changes the number
# of arguments to the function.
# pylint:disable=no-value-for-parameter
NUM_RES
=
shape_placeholders
.
NUM_RES
NUM_MSA_SEQ
=
shape_placeholders
.
NUM_MSA_SEQ
NUM_EXTRA_SEQ
=
shape_placeholders
.
NUM_EXTRA_SEQ
NUM_TEMPLATES
=
shape_placeholders
.
NUM_TEMPLATES
def
nonensembled_map_fns
(
data_config
):
"""Input pipeline functions which are not ensembled."""
common_cfg
=
data_config
.
common
map_fns
=
[
data_transforms
.
correct_msa_restypes
,
data_transforms
.
add_distillation_flag
(
False
),
data_transforms
.
cast_64bit_ints
,
data_transforms
.
squeeze_features
,
# Keep to not disrupt RNG.
data_transforms
.
randomly_replace_msa_with_unknown
(
0.0
),
data_transforms
.
make_seq_mask
,
data_transforms
.
make_msa_mask
,
# Compute the HHblits profile if it's not set. This has to be run before
# sampling the MSA.
data_transforms
.
make_hhblits_profile
,
data_transforms
.
make_random_crop_to_size_seed
,
]
if
common_cfg
.
use_templates
:
map_fns
.
extend
([
data_transforms
.
fix_templates_aatype
,
data_transforms
.
make_template_mask
,
data_transforms
.
make_pseudo_beta
(
'template_'
)
])
map_fns
.
extend
([
data_transforms
.
make_atom14_masks
,
])
return
map_fns
def
ensembled_map_fns
(
data_config
):
"""Input pipeline functions that can be ensembled and averaged."""
common_cfg
=
data_config
.
common
eval_cfg
=
data_config
.
eval
map_fns
=
[]
if
common_cfg
.
reduce_msa_clusters_by_max_templates
:
pad_msa_clusters
=
eval_cfg
.
max_msa_clusters
-
eval_cfg
.
max_templates
else
:
pad_msa_clusters
=
eval_cfg
.
max_msa_clusters
max_msa_clusters
=
pad_msa_clusters
max_extra_msa
=
common_cfg
.
max_extra_msa
map_fns
.
append
(
data_transforms
.
sample_msa
(
max_msa_clusters
,
keep_extra
=
True
))
if
'masked_msa'
in
common_cfg
:
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
map_fns
.
append
(
data_transforms
.
make_masked_msa
(
common_cfg
.
masked_msa
,
eval_cfg
.
masked_msa_replace_fraction
))
if
common_cfg
.
msa_cluster_features
:
map_fns
.
append
(
data_transforms
.
nearest_neighbor_clusters
())
map_fns
.
append
(
data_transforms
.
summarize_clusters
())
# Crop after creating the cluster profiles.
if
max_extra_msa
:
map_fns
.
append
(
data_transforms
.
crop_extra_msa
(
max_extra_msa
))
else
:
map_fns
.
append
(
data_transforms
.
delete_extra_msa
)
map_fns
.
append
(
data_transforms
.
make_msa_feat
())
crop_feats
=
dict
(
eval_cfg
.
feat
)
if
eval_cfg
.
fixed_size
:
map_fns
.
append
(
data_transforms
.
select_feat
(
list
(
crop_feats
)))
map_fns
.
append
(
data_transforms
.
random_crop_to_size
(
eval_cfg
.
crop_size
,
eval_cfg
.
max_templates
,
crop_feats
,
eval_cfg
.
subsample_templates
))
map_fns
.
append
(
data_transforms
.
make_fixed_size
(
crop_feats
,
pad_msa_clusters
,
common_cfg
.
max_extra_msa
,
eval_cfg
.
crop_size
,
eval_cfg
.
max_templates
))
else
:
map_fns
.
append
(
data_transforms
.
crop_templates
(
eval_cfg
.
max_templates
))
return
map_fns
def
process_tensors_from_config
(
tensors
,
data_config
):
"""Apply filters and maps to an existing dataset, based on the config."""
def
wrap_ensemble_fn
(
data
,
i
):
"""Function to be mapped over the ensemble dimension."""
d
=
data
.
copy
()
fns
=
ensembled_map_fns
(
data_config
)
fn
=
compose
(
fns
)
d
[
'ensemble_index'
]
=
i
return
fn
(
d
)
eval_cfg
=
data_config
.
eval
tensors
=
compose
(
nonensembled_map_fns
(
data_config
))(
tensors
)
tensors_0
=
wrap_ensemble_fn
(
tensors
,
tf
.
constant
(
0
))
num_ensemble
=
eval_cfg
.
num_ensemble
if
data_config
.
common
.
resample_msa_in_recycling
:
# Separate batch per ensembling & recycling step.
num_ensemble
*=
data_config
.
common
.
num_recycle
+
1
if
isinstance
(
num_ensemble
,
tf
.
Tensor
)
or
num_ensemble
>
1
:
fn_output_signature
=
tree
.
map_structure
(
tf
.
TensorSpec
.
from_tensor
,
tensors_0
)
tensors
=
tf
.
map_fn
(
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
tf
.
range
(
num_ensemble
),
parallel_iterations
=
1
,
fn_output_signature
=
fn_output_signature
)
else
:
tensors
=
tree
.
map_structure
(
lambda
x
:
x
[
None
],
tensors_0
)
return
tensors
@
data_transforms
.
curry1
def
compose
(
x
,
fs
):
for
f
in
fs
:
x
=
f
(
x
)
return
x
alphafold/model/tf/protein_features.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains descriptions of various protein features."""
import
enum
from
typing
import
Dict
,
Optional
,
Sequence
,
Tuple
,
Union
from
alphafold.common
import
residue_constants
import
tensorflow.compat.v1
as
tf
# Type aliases.
FeaturesMetadata
=
Dict
[
str
,
Tuple
[
tf
.
dtypes
.
DType
,
Sequence
[
Union
[
str
,
int
]]]]
class
FeatureType
(
enum
.
Enum
):
ZERO_DIM
=
0
# Shape [x]
ONE_DIM
=
1
# Shape [num_res, x]
TWO_DIM
=
2
# Shape [num_res, num_res, x]
MSA
=
3
# Shape [msa_length, num_res, x]
# Placeholder values that will be replaced with their true value at runtime.
NUM_RES
=
"num residues placeholder"
NUM_SEQ
=
"length msa placeholder"
NUM_TEMPLATES
=
"num templates placeholder"
# Sizes of the protein features, NUM_RES and NUM_SEQ are allowed as placeholders
# to be replaced with the number of residues and the number of sequences in the
# multiple sequence alignment, respectively.
FEATURES
=
{
#### Static features of a protein sequence ####
"aatype"
:
(
tf
.
float32
,
[
NUM_RES
,
21
]),
"between_segment_residues"
:
(
tf
.
int64
,
[
NUM_RES
,
1
]),
"deletion_matrix"
:
(
tf
.
float32
,
[
NUM_SEQ
,
NUM_RES
,
1
]),
"domain_name"
:
(
tf
.
string
,
[
1
]),
"msa"
:
(
tf
.
int64
,
[
NUM_SEQ
,
NUM_RES
,
1
]),
"num_alignments"
:
(
tf
.
int64
,
[
NUM_RES
,
1
]),
"residue_index"
:
(
tf
.
int64
,
[
NUM_RES
,
1
]),
"seq_length"
:
(
tf
.
int64
,
[
NUM_RES
,
1
]),
"sequence"
:
(
tf
.
string
,
[
1
]),
"all_atom_positions"
:
(
tf
.
float32
,
[
NUM_RES
,
residue_constants
.
atom_type_num
,
3
]),
"all_atom_mask"
:
(
tf
.
int64
,
[
NUM_RES
,
residue_constants
.
atom_type_num
]),
"resolution"
:
(
tf
.
float32
,
[
1
]),
"template_domain_names"
:
(
tf
.
string
,
[
NUM_TEMPLATES
]),
"template_sum_probs"
:
(
tf
.
float32
,
[
NUM_TEMPLATES
,
1
]),
"template_aatype"
:
(
tf
.
float32
,
[
NUM_TEMPLATES
,
NUM_RES
,
22
]),
"template_all_atom_positions"
:
(
tf
.
float32
,
[
NUM_TEMPLATES
,
NUM_RES
,
residue_constants
.
atom_type_num
,
3
]),
"template_all_atom_masks"
:
(
tf
.
float32
,
[
NUM_TEMPLATES
,
NUM_RES
,
residue_constants
.
atom_type_num
,
1
]),
}
FEATURE_TYPES
=
{
k
:
v
[
0
]
for
k
,
v
in
FEATURES
.
items
()}
FEATURE_SIZES
=
{
k
:
v
[
1
]
for
k
,
v
in
FEATURES
.
items
()}
def
register_feature
(
name
:
str
,
type_
:
tf
.
dtypes
.
DType
,
shape_
:
Tuple
[
Union
[
str
,
int
]]):
"""Register extra features used in custom datasets."""
FEATURES
[
name
]
=
(
type_
,
shape_
)
FEATURE_TYPES
[
name
]
=
type_
FEATURE_SIZES
[
name
]
=
shape_
def
shape
(
feature_name
:
str
,
num_residues
:
int
,
msa_length
:
int
,
num_templates
:
Optional
[
int
]
=
None
,
features
:
Optional
[
FeaturesMetadata
]
=
None
):
"""Get the shape for the given feature name.
This is near identical to _get_tf_shape_no_placeholders() but with 2
differences:
* This method does not calculate a single placeholder from the total number of
elements (eg given <NUM_RES, 3> and size := 12, this won't deduce NUM_RES
must be 4)
* This method will work with tensors
Args:
feature_name: String identifier for the feature. If the feature name ends
with "_unnormalized", this suffix is stripped off.
num_residues: The number of residues in the current domain - some elements
of the shape can be dynamic and will be replaced by this value.
msa_length: The number of sequences in the multiple sequence alignment, some
elements of the shape can be dynamic and will be replaced by this value.
If the number of alignments is unknown / not read, please pass None for
msa_length.
num_templates (optional): The number of templates in this tfexample.
features: A feature_name to (tf_dtype, shape) lookup; defaults to FEATURES.
Returns:
List of ints representation the tensor size.
Raises:
ValueError: If a feature is requested but no concrete placeholder value is
given.
"""
features
=
features
or
FEATURES
if
feature_name
.
endswith
(
"_unnormalized"
):
feature_name
=
feature_name
[:
-
13
]
unused_dtype
,
raw_sizes
=
features
[
feature_name
]
replacements
=
{
NUM_RES
:
num_residues
,
NUM_SEQ
:
msa_length
}
if
num_templates
is
not
None
:
replacements
[
NUM_TEMPLATES
]
=
num_templates
sizes
=
[
replacements
.
get
(
dimension
,
dimension
)
for
dimension
in
raw_sizes
]
for
dimension
in
sizes
:
if
isinstance
(
dimension
,
str
):
raise
ValueError
(
"Could not parse %s (shape: %s) with values: %s"
%
(
feature_name
,
raw_sizes
,
replacements
))
return
sizes
alphafold/model/tf/protein_features_test.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for protein_features."""
import
uuid
from
absl.testing
import
absltest
from
absl.testing
import
parameterized
from
alphafold.model.tf
import
protein_features
import
tensorflow.compat.v1
as
tf
def
_random_bytes
():
return
str
(
uuid
.
uuid4
()).
encode
(
'utf-8'
)
class
FeaturesTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
testFeatureNames
(
self
):
self
.
assertEqual
(
len
(
protein_features
.
FEATURE_SIZES
),
len
(
protein_features
.
FEATURE_TYPES
))
sorted_size_names
=
sorted
(
protein_features
.
FEATURE_SIZES
.
keys
())
sorted_type_names
=
sorted
(
protein_features
.
FEATURE_TYPES
.
keys
())
for
i
,
size_name
in
enumerate
(
sorted_size_names
):
self
.
assertEqual
(
size_name
,
sorted_type_names
[
i
])
def
testReplacement
(
self
):
for
name
in
protein_features
.
FEATURE_SIZES
.
keys
():
sizes
=
protein_features
.
shape
(
name
,
num_residues
=
12
,
msa_length
=
24
,
num_templates
=
3
)
for
x
in
sizes
:
self
.
assertEqual
(
type
(
x
),
int
)
self
.
assertGreater
(
x
,
0
)
if
__name__
==
'__main__'
:
tf
.
disable_v2_behavior
()
absltest
.
main
()
alphafold/model/tf/proteins_dataset.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Datasets consisting of proteins."""
from
typing
import
Dict
,
Mapping
,
Optional
,
Sequence
from
alphafold.model.tf
import
protein_features
import
numpy
as
np
import
tensorflow.compat.v1
as
tf
TensorDict
=
Dict
[
str
,
tf
.
Tensor
]
def
parse_tfexample
(
raw_data
:
bytes
,
features
:
protein_features
.
FeaturesMetadata
,
key
:
Optional
[
str
]
=
None
)
->
Dict
[
str
,
tf
.
train
.
Feature
]:
"""Read a single TF Example proto and return a subset of its features.
Args:
raw_data: A serialized tf.Example proto.
features: A dictionary of features, mapping string feature names to a tuple
(dtype, shape). This dictionary should be a subset of
protein_features.FEATURES (or the dictionary itself for all features).
key: Optional string with the SSTable key of that tf.Example. This will be
added into features as a 'key' but only if requested in features.
Returns:
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
feature_map
=
{
k
:
tf
.
io
.
FixedLenSequenceFeature
(
shape
=
(),
dtype
=
v
[
0
],
allow_missing
=
True
)
for
k
,
v
in
features
.
items
()
}
parsed_features
=
tf
.
io
.
parse_single_example
(
raw_data
,
feature_map
)
reshaped_features
=
parse_reshape_logic
(
parsed_features
,
features
,
key
=
key
)
return
reshaped_features
def
_first
(
tensor
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Returns the 1st element - the input can be a tensor or a scalar."""
return
tf
.
reshape
(
tensor
,
shape
=
(
-
1
,))[
0
]
def
parse_reshape_logic
(
parsed_features
:
TensorDict
,
features
:
protein_features
.
FeaturesMetadata
,
key
:
Optional
[
str
]
=
None
)
->
TensorDict
:
"""Transforms parsed serial features to the correct shape."""
# Find out what is the number of sequences and the number of alignments.
num_residues
=
tf
.
cast
(
_first
(
parsed_features
[
"seq_length"
]),
dtype
=
tf
.
int32
)
if
"num_alignments"
in
parsed_features
:
num_msa
=
tf
.
cast
(
_first
(
parsed_features
[
"num_alignments"
]),
dtype
=
tf
.
int32
)
else
:
num_msa
=
0
if
"template_domain_names"
in
parsed_features
:
num_templates
=
tf
.
cast
(
tf
.
shape
(
parsed_features
[
"template_domain_names"
])[
0
],
dtype
=
tf
.
int32
)
else
:
num_templates
=
0
if
key
is
not
None
and
"key"
in
features
:
parsed_features
[
"key"
]
=
[
key
]
# Expand dims from () to (1,).
# Reshape the tensors according to the sequence length and num alignments.
for
k
,
v
in
parsed_features
.
items
():
new_shape
=
protein_features
.
shape
(
feature_name
=
k
,
num_residues
=
num_residues
,
msa_length
=
num_msa
,
num_templates
=
num_templates
,
features
=
features
)
new_shape_size
=
tf
.
constant
(
1
,
dtype
=
tf
.
int32
)
for
dim
in
new_shape
:
new_shape_size
*=
tf
.
cast
(
dim
,
tf
.
int32
)
assert_equal
=
tf
.
assert_equal
(
tf
.
size
(
v
),
new_shape_size
,
name
=
"assert_%s_shape_correct"
%
k
,
message
=
"The size of feature %s (%s) could not be reshaped "
"into %s"
%
(
k
,
tf
.
size
(
v
),
new_shape
))
if
"template"
not
in
k
:
# Make sure the feature we are reshaping is not empty.
assert_non_empty
=
tf
.
assert_greater
(
tf
.
size
(
v
),
0
,
name
=
"assert_%s_non_empty"
%
k
,
message
=
"The feature %s is not set in the tf.Example. Either do not "
"request the feature or use a tf.Example that has the "
"feature set."
%
k
)
with
tf
.
control_dependencies
([
assert_non_empty
,
assert_equal
]):
parsed_features
[
k
]
=
tf
.
reshape
(
v
,
new_shape
,
name
=
"reshape_%s"
%
k
)
else
:
with
tf
.
control_dependencies
([
assert_equal
]):
parsed_features
[
k
]
=
tf
.
reshape
(
v
,
new_shape
,
name
=
"reshape_%s"
%
k
)
return
parsed_features
def
_make_features_metadata
(
feature_names
:
Sequence
[
str
])
->
protein_features
.
FeaturesMetadata
:
"""Makes a feature name to type and shape mapping from a list of names."""
# Make sure these features are always read.
required_features
=
[
"aatype"
,
"sequence"
,
"seq_length"
]
feature_names
=
list
(
set
(
feature_names
)
|
set
(
required_features
))
features_metadata
=
{
name
:
protein_features
.
FEATURES
[
name
]
for
name
in
feature_names
}
return
features_metadata
def
create_tensor_dict
(
raw_data
:
bytes
,
features
:
Sequence
[
str
],
key
:
Optional
[
str
]
=
None
,
)
->
TensorDict
:
"""Creates a dictionary of tensor features.
Args:
raw_data: A serialized tf.Example proto.
features: A list of strings of feature names to be returned in the dataset.
key: Optional string with the SSTable key of that tf.Example. This will be
added into features as a 'key' but only if requested in features.
Returns:
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
features_metadata
=
_make_features_metadata
(
features
)
return
parse_tfexample
(
raw_data
,
features_metadata
,
key
)
def
np_to_tensor_dict
(
np_example
:
Mapping
[
str
,
np
.
ndarray
],
features
:
Sequence
[
str
],
)
->
TensorDict
:
"""Creates dict of tensors from a dict of NumPy arrays.
Args:
np_example: A dict of NumPy feature arrays.
features: A list of strings of feature names to be returned in the dataset.
Returns:
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
features_metadata
=
_make_features_metadata
(
features
)
tensor_dict
=
{
k
:
tf
.
constant
(
v
)
for
k
,
v
in
np_example
.
items
()
if
k
in
features_metadata
}
# Ensures shapes are as expected. Needed for setting size of empty features
# e.g. when no template hits were found.
tensor_dict
=
parse_reshape_logic
(
tensor_dict
,
features_metadata
)
return
tensor_dict
alphafold/model/tf/shape_helpers.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with shapes of TensorFlow tensors."""
import
tensorflow.compat.v1
as
tf
def
shape_list
(
x
):
"""Return list of dimensions of a tensor, statically where possible.
Like `x.shape.as_list()` but with tensors instead of `None`s.
Args:
x: A tensor.
Returns:
A list with length equal to the rank of the tensor. The n-th element of the
list is an integer when that dimension is statically known otherwise it is
the n-th element of `tf.shape(x)`.
"""
x
=
tf
.
convert_to_tensor
(
x
)
# If unknown rank, return dynamic shape
if
x
.
get_shape
().
dims
is
None
:
return
tf
.
shape
(
x
)
static
=
x
.
get_shape
().
as_list
()
shape
=
tf
.
shape
(
x
)
ret
=
[]
for
i
in
range
(
len
(
static
)):
dim
=
static
[
i
]
if
dim
is
None
:
dim
=
shape
[
i
]
ret
.
append
(
dim
)
return
ret
alphafold/model/tf/shape_helpers_test.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for shape_helpers."""
from
alphafold.model.tf
import
shape_helpers
import
numpy
as
np
import
tensorflow.compat.v1
as
tf
class
ShapeTest
(
tf
.
test
.
TestCase
):
def
test_shape_list
(
self
):
"""Test that shape_list can allow for reshaping to dynamic shapes."""
a
=
tf
.
zeros
([
10
,
4
,
4
,
2
])
p
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
,
None
,
1
,
4
,
4
])
shape_dyn
=
shape_helpers
.
shape_list
(
p
)[:
2
]
+
[
4
,
4
]
b
=
tf
.
reshape
(
a
,
shape_dyn
)
with
self
.
session
()
as
sess
:
out
=
sess
.
run
(
b
,
feed_dict
=
{
p
:
np
.
ones
((
20
,
1
,
1
,
4
,
4
))})
self
.
assertAllEqual
(
out
.
shape
,
(
20
,
1
,
4
,
4
))
if
__name__
==
'__main__'
:
tf
.
disable_v2_behavior
()
tf
.
test
.
main
()
alphafold/model/tf/shape_placeholders.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Placeholder values for run-time varying dimension sizes."""
NUM_RES
=
'num residues placeholder'
NUM_MSA_SEQ
=
'msa placeholder'
NUM_EXTRA_SEQ
=
'extra msa placeholder'
NUM_TEMPLATES
=
'num templates placeholder'
alphafold/model/tf/utils.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utilities for various components."""
import
tensorflow.compat.v1
as
tf
def
tf_combine_mask
(
*
masks
):
"""Take the intersection of float-valued masks."""
ret
=
1
for
m
in
masks
:
ret
*=
m
return
ret
class
SeedMaker
(
object
):
"""Return unique seeds."""
def
__init__
(
self
,
initial_seed
=
0
):
self
.
next_seed
=
initial_seed
def
__call__
(
self
):
i
=
self
.
next_seed
self
.
next_seed
+=
1
return
i
seed_maker
=
SeedMaker
()
def
make_random_seed
():
return
tf
.
random
.
uniform
([
2
],
tf
.
int32
.
min
,
tf
.
int32
.
max
,
tf
.
int32
,
seed
=
seed_maker
())
alphafold/model/utils.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A collection of JAX utility functions for use in protein folding."""
import
collections
import
functools
import
numbers
from
typing
import
Mapping
import
haiku
as
hk
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
def
final_init
(
config
):
if
config
.
zero_init
:
return
'zeros'
else
:
return
'linear'
def
batched_gather
(
params
,
indices
,
axis
=
0
,
batch_dims
=
0
):
"""Implements a JAX equivalent of `tf.gather` with `axis` and `batch_dims`."""
take_fn
=
lambda
p
,
i
:
jnp
.
take
(
p
,
i
,
axis
=
axis
,
mode
=
"clip"
)
for
_
in
range
(
batch_dims
):
take_fn
=
jax
.
vmap
(
take_fn
)
return
take_fn
(
params
,
indices
)
def
mask_mean
(
mask
,
value
,
axis
=
None
,
drop_mask_channel
=
False
,
eps
=
1e-10
):
"""Masked mean."""
if
drop_mask_channel
:
mask
=
mask
[...,
0
]
mask_shape
=
mask
.
shape
value_shape
=
value
.
shape
assert
len
(
mask_shape
)
==
len
(
value_shape
)
if
isinstance
(
axis
,
numbers
.
Integral
):
axis
=
[
axis
]
elif
axis
is
None
:
axis
=
list
(
range
(
len
(
mask_shape
)))
assert
isinstance
(
axis
,
collections
.
Iterable
),
(
'axis needs to be either an iterable, integer or "None"'
)
broadcast_factor
=
1.
for
axis_
in
axis
:
value_size
=
value_shape
[
axis_
]
mask_size
=
mask_shape
[
axis_
]
if
mask_size
==
1
:
broadcast_factor
*=
value_size
else
:
assert
mask_size
==
value_size
return
(
jnp
.
sum
(
mask
*
value
,
axis
=
axis
)
/
(
jnp
.
sum
(
mask
,
axis
=
axis
)
*
broadcast_factor
+
eps
))
def
flat_params_to_haiku
(
params
:
Mapping
[
str
,
np
.
ndarray
])
->
hk
.
Params
:
"""Convert a dictionary of NumPy arrays to Haiku parameters."""
hk_params
=
{}
for
path
,
array
in
params
.
items
():
scope
,
name
=
path
.
split
(
'//'
)
if
scope
not
in
hk_params
:
hk_params
[
scope
]
=
{}
hk_params
[
scope
][
name
]
=
jnp
.
array
(
array
)
return
hk_params
def
padding_consistent_rng
(
f
):
"""Modify any element-wise random function to be consistent with padding.
Normally if you take a function like jax.random.normal and generate an array,
say of size (10,10), you will get a different set of random numbers to if you
add padding and take the first (10,10) sub-array.
This function makes a random function that is consistent regardless of the
amount of padding added.
Note: The padding-consistent function is likely to be slower to compile and
run than the function it is wrapping, but these slowdowns are likely to be
negligible in a large network.
Args:
f: Any element-wise function that takes (PRNG key, shape) as the first 2
arguments.
Returns:
An equivalent function to f, that is now consistent for different amounts of
padding.
"""
def
grid_keys
(
key
,
shape
):
"""Generate a grid of rng keys that is consistent with different padding.
Generate random keys such that the keys will be identical, regardless of
how much padding is added to any dimension.
Args:
key: A PRNG key.
shape: The shape of the output array of keys that will be generated.
Returns:
An array of shape `shape` consisting of random keys.
"""
if
not
shape
:
return
key
new_keys
=
jax
.
vmap
(
functools
.
partial
(
jax
.
random
.
fold_in
,
key
))(
jnp
.
arange
(
shape
[
0
]))
return
jax
.
vmap
(
functools
.
partial
(
grid_keys
,
shape
=
shape
[
1
:]))(
new_keys
)
def
inner
(
key
,
shape
,
**
kwargs
):
return
jnp
.
vectorize
(
lambda
key
:
f
(
key
,
shape
=
(),
**
kwargs
),
signature
=
'(2)->()'
)(
grid_keys
(
key
,
shape
))
return
inner
alphafold/notebooks/__init__.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AlphaFold Colab notebook."""
alphafold/notebooks/notebook_utils.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper methods for the AlphaFold Colab notebook."""
import
enum
import
json
from
typing
import
Any
,
Mapping
,
Optional
,
Sequence
,
Tuple
from
alphafold.common
import
residue_constants
from
alphafold.data
import
parsers
from
matplotlib
import
pyplot
as
plt
import
numpy
as
np
@
enum
.
unique
class
ModelType
(
enum
.
Enum
):
MONOMER
=
0
MULTIMER
=
1
def
clean_and_validate_sequence
(
input_sequence
:
str
,
min_length
:
int
,
max_length
:
int
)
->
str
:
"""Checks that the input sequence is ok and returns a clean version of it."""
# Remove all whitespaces, tabs and end lines; upper-case.
clean_sequence
=
input_sequence
.
translate
(
str
.
maketrans
(
''
,
''
,
'
\n\t
'
)).
upper
()
aatypes
=
set
(
residue_constants
.
restypes
)
# 20 standard aatypes.
if
not
set
(
clean_sequence
).
issubset
(
aatypes
):
raise
ValueError
(
f
'Input sequence contains non-amino acid letters: '
f
'
{
set
(
clean_sequence
)
-
aatypes
}
. AlphaFold only supports 20 standard '
'amino acids as inputs.'
)
if
len
(
clean_sequence
)
<
min_length
:
raise
ValueError
(
f
'Input sequence is too short:
{
len
(
clean_sequence
)
}
amino acids, '
f
'while the minimum is
{
min_length
}
'
)
if
len
(
clean_sequence
)
>
max_length
:
raise
ValueError
(
f
'Input sequence is too long:
{
len
(
clean_sequence
)
}
amino acids, while '
f
'the maximum is
{
max_length
}
. You may be able to run it with the full '
f
'AlphaFold system depending on your resources (system memory, '
f
'GPU memory).'
)
return
clean_sequence
def
validate_input
(
input_sequences
:
Sequence
[
str
],
min_length
:
int
,
max_length
:
int
,
max_multimer_length
:
int
)
->
Tuple
[
Sequence
[
str
],
ModelType
]:
"""Validates and cleans input sequences and determines which model to use."""
sequences
=
[]
for
input_sequence
in
input_sequences
:
if
input_sequence
.
strip
():
input_sequence
=
clean_and_validate_sequence
(
input_sequence
=
input_sequence
,
min_length
=
min_length
,
max_length
=
max_length
)
sequences
.
append
(
input_sequence
)
if
len
(
sequences
)
==
1
:
print
(
'Using the single-chain model.'
)
return
sequences
,
ModelType
.
MONOMER
elif
len
(
sequences
)
>
1
:
total_multimer_length
=
sum
([
len
(
seq
)
for
seq
in
sequences
])
if
total_multimer_length
>
max_multimer_length
:
raise
ValueError
(
f
'The total length of multimer sequences is too long: '
f
'
{
total_multimer_length
}
, while the maximum is '
f
'
{
max_multimer_length
}
. Please use the full AlphaFold '
f
'system for long multimers.'
)
elif
total_multimer_length
>
1536
:
print
(
'WARNING: The accuracy of the system has not been fully validated '
'above 1536 residues, and you may experience long running times or '
f
'run out of memory for your complex with
{
total_multimer_length
}
'
'residues.'
)
print
(
f
'Using the multimer model with
{
len
(
sequences
)
}
sequences.'
)
return
sequences
,
ModelType
.
MULTIMER
else
:
raise
ValueError
(
'No input amino acid sequence provided, please provide at '
'least one sequence.'
)
def
merge_chunked_msa
(
results
:
Sequence
[
Mapping
[
str
,
Any
]],
max_hits
:
Optional
[
int
]
=
None
)
->
parsers
.
Msa
:
"""Merges chunked database hits together into hits for the full database."""
unsorted_results
=
[]
for
chunk_index
,
chunk
in
enumerate
(
results
):
msa
=
parsers
.
parse_stockholm
(
chunk
[
'sto'
])
e_values_dict
=
parsers
.
parse_e_values_from_tblout
(
chunk
[
'tbl'
])
# Jackhmmer lists sequences as <sequence name>/<residue from>-<residue to>.
e_values
=
[
e_values_dict
[
t
.
partition
(
'/'
)[
0
]]
for
t
in
msa
.
descriptions
]
chunk_results
=
zip
(
msa
.
sequences
,
msa
.
deletion_matrix
,
msa
.
descriptions
,
e_values
)
if
chunk_index
!=
0
:
next
(
chunk_results
)
# Only take query (first hit) from the first chunk.
unsorted_results
.
extend
(
chunk_results
)
sorted_by_evalue
=
sorted
(
unsorted_results
,
key
=
lambda
x
:
x
[
-
1
])
merged_sequences
,
merged_deletion_matrix
,
merged_descriptions
,
_
=
zip
(
*
sorted_by_evalue
)
merged_msa
=
parsers
.
Msa
(
sequences
=
merged_sequences
,
deletion_matrix
=
merged_deletion_matrix
,
descriptions
=
merged_descriptions
)
if
max_hits
is
not
None
:
merged_msa
=
merged_msa
.
truncate
(
max_seqs
=
max_hits
)
return
merged_msa
def
show_msa_info
(
single_chain_msas
:
Sequence
[
parsers
.
Msa
],
sequence_index
:
int
):
"""Prints info and shows a plot of the deduplicated single chain MSA."""
full_single_chain_msa
=
[]
for
single_chain_msa
in
single_chain_msas
:
full_single_chain_msa
.
extend
(
single_chain_msa
.
sequences
)
# Deduplicate but preserve order (hence can't use set).
deduped_full_single_chain_msa
=
list
(
dict
.
fromkeys
(
full_single_chain_msa
))
total_msa_size
=
len
(
deduped_full_single_chain_msa
)
print
(
f
'
\n
{
total_msa_size
}
unique sequences found in total for sequence '
f
'
{
sequence_index
}
\n
'
)
aa_map
=
{
res
:
i
for
i
,
res
in
enumerate
(
'ABCDEFGHIJKLMNOPQRSTUVWXYZ-'
)}
msa_arr
=
np
.
array
(
[[
aa_map
[
aa
]
for
aa
in
seq
]
for
seq
in
deduped_full_single_chain_msa
])
plt
.
figure
(
figsize
=
(
12
,
3
))
plt
.
title
(
f
'Per-Residue Count of Non-Gap Amino Acids in the MSA for Sequence '
f
'
{
sequence_index
}
'
)
plt
.
plot
(
np
.
sum
(
msa_arr
!=
aa_map
[
'-'
],
axis
=
0
),
color
=
'black'
)
plt
.
ylabel
(
'Non-Gap Count'
)
plt
.
yticks
(
range
(
0
,
total_msa_size
+
1
,
max
(
1
,
int
(
total_msa_size
/
3
))))
plt
.
show
()
def
empty_placeholder_template_features
(
num_templates
:
int
,
num_res
:
int
)
->
Mapping
[
str
,
np
.
ndarray
]:
return
{
'template_aatype'
:
np
.
zeros
(
(
num_templates
,
num_res
,
len
(
residue_constants
.
restypes_with_x_and_gap
)),
dtype
=
np
.
float32
),
'template_all_atom_masks'
:
np
.
zeros
(
(
num_templates
,
num_res
,
residue_constants
.
atom_type_num
),
dtype
=
np
.
float32
),
'template_all_atom_positions'
:
np
.
zeros
(
(
num_templates
,
num_res
,
residue_constants
.
atom_type_num
,
3
),
dtype
=
np
.
float32
),
'template_domain_names'
:
np
.
zeros
([
num_templates
],
dtype
=
np
.
object
),
'template_sequence'
:
np
.
zeros
([
num_templates
],
dtype
=
np
.
object
),
'template_sum_probs'
:
np
.
zeros
([
num_templates
],
dtype
=
np
.
float32
),
}
def
get_pae_json
(
pae
:
np
.
ndarray
,
max_pae
:
float
)
->
str
:
"""Returns the PAE in the same format as is used in the AFDB."""
rounded_errors
=
np
.
round
(
pae
.
astype
(
np
.
float64
),
decimals
=
1
)
indices
=
np
.
indices
((
len
(
rounded_errors
),
len
(
rounded_errors
)))
+
1
indices_1
=
indices
[
0
].
flatten
().
tolist
()
indices_2
=
indices
[
1
].
flatten
().
tolist
()
return
json
.
dumps
(
[{
'residue1'
:
indices_1
,
'residue2'
:
indices_2
,
'distance'
:
rounded_errors
.
flatten
().
tolist
(),
'max_predicted_aligned_error'
:
max_pae
}],
indent
=
None
,
separators
=
(
','
,
':'
))
alphafold/notebooks/notebook_utils_test.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for notebook_utils."""
import
io
from
absl.testing
import
absltest
from
absl.testing
import
parameterized
from
alphafold.data
import
parsers
from
alphafold.data
import
templates
from
alphafold.notebooks
import
notebook_utils
import
mock
import
numpy
as
np
ONLY_QUERY_HIT
=
{
'sto'
:
(
'# STOCKHOLM 1.0
\n
'
'#=GF ID query-i1
\n
'
'query MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEH
\n
'
'//
\n
'
),
'tbl'
:
''
,
'stderr'
:
b
''
,
'n_iter'
:
1
,
'e_value'
:
0.0001
}
# pylint: disable=line-too-long
MULTI_SEQUENCE_HIT_1
=
{
'sto'
:
(
'# STOCKHOLM 1.0
\n
'
'#=GF ID query-i1
\n
'
'#=GS ERR1700680_4602609/41-109 DE [subseq from] ERR1700680_4602609
\n
'
'#=GS ERR1019366_5760491/40-105 DE [subseq from] ERR1019366_5760491
\n
'
'#=GS SRR5580704_12853319/61-125 DE [subseq from] SRR5580704_12853319
\n
'
'query MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH
\n
'
'ERR1700680_4602609/41-109 --INKGAEYHKKAAEHHELAAKHHREAAKHHEAGSHEKAAHHSEIAAGHGLTAVHHTEEATK-HHPEEHTEK--
\n
'
'ERR1019366_5760491/40-105 ---RSGAQHHDAAAQHYEEAARHHRMAAKQYQASHHEKAAHYAQLAYAHHMYAEQHAAEAAK-AHAKNHG----
\n
'
'SRR5580704_12853319/61-125 ----PAADHHMKAAEHHEEAAKHHRAAAEHHTAGDHQKAGHHAHVANGHHVNAVHHAEEASK-HHATDHS----
\n
'
'//
\n
'
),
'tbl'
:
(
'ERR1700680_4602609 - query - 7.7e-09 47.7 33.8 1.1e-08 47.2 33.8 1.2 1 0 0 1 1 1 1 -
\n
'
'ERR1019366_5760491 - query - 1.7e-08 46.6 33.1 2.5e-08 46.1 33.1 1.3 1 0 0 1 1 1 1 -
\n
'
'SRR5580704_12853319 - query - 1.1e-07 44.0 41.6 2e-07 43.1 41.6 1.4 1 0 0 1 1 1 1 -
\n
'
),
'stderr'
:
b
''
,
'n_iter'
:
1
,
'e_value'
:
0.0001
}
MULTI_SEQUENCE_HIT_2
=
{
'sto'
:
(
'# STOCKHOLM 1.0
\n
'
'#=GF ID query-i1
\n
'
'#=GS ERR1700719_3476944/70-137 DE [subseq from] ERR1700719_3476944
\n
'
'#=GS ERR1700761_4254522/72-138 DE [subseq from] ERR1700761_4254522
\n
'
'#=GS SRR5438477_9761204/64-132 DE [subseq from] SRR5438477_9761204
\n
'
'query MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH
\n
'
'ERR1700719_3476944/70-137 ---KQAAEHHHQAAEHHEHAARHHREAAKHHEAGDHESAAHHAHTAQGHLHQATHHASEAAKLHVEHHGQK--
\n
'
'ERR1700761_4254522/72-138 ----QASEHHNLAAEHHEHAARHHRDAAKHHKAGDHEKAAHHAHVAHGHHLHATHHATEAAKHHVEAHGEK--
\n
'
'SRR5438477_9761204/64-132 MPKHEGAEHHKKAAEHNEHAARHHKEAARHHEEGSHEKVGHHAHIAHGHHLHATHHAEEAAKTHSNQHE----
\n
'
'//
\n
'
),
'tbl'
:
(
'ERR1700719_3476944 - query - 2e-07 43.2 47.5 3.5e-07 42.4 47.5 1.4 1 0 0 1 1 1 1 -
\n
'
'ERR1700761_4254522 - query - 6.1e-07 41.6 48.1 8.1e-07 41.3 48.1 1.2 1 0 0 1 1 1 1 -
\n
'
'SRR5438477_9761204 - query - 1.8e-06 40.2 46.9 2.3e-06 39.8 46.9 1.2 1 0 0 1 1 1 1 -
\n
'
),
'stderr'
:
b
''
,
'n_iter'
:
1
,
'e_value'
:
0.0001
}
# pylint: enable=line-too-long
class
NotebookUtilsTest
(
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
(
'DeepMind'
,
'DEEPMIND'
),
(
'A '
,
'A'
),
(
'
\t
A'
,
'A'
),
(
' A
\t\n
'
,
'A'
),
(
'ACDEFGHIKLMNPQRSTVWY'
,
'ACDEFGHIKLMNPQRSTVWY'
))
def
test_clean_and_validate_sequence_ok
(
self
,
sequence
,
exp_clean
):
clean
=
notebook_utils
.
clean_and_validate_sequence
(
sequence
,
min_length
=
1
,
max_length
=
100
)
self
.
assertEqual
(
clean
,
exp_clean
)
@
parameterized
.
named_parameters
(
(
'too_short'
,
'AA'
,
'too short'
),
(
'too_long'
,
'AAAAAAAAAA'
,
'too long'
),
(
'bad_amino_acids_B'
,
'BBBB'
,
'non-amino acid'
),
(
'bad_amino_acids_J'
,
'JJJJ'
,
'non-amino acid'
),
(
'bad_amino_acids_O'
,
'OOOO'
,
'non-amino acid'
),
(
'bad_amino_acids_U'
,
'UUUU'
,
'non-amino acid'
),
(
'bad_amino_acids_X'
,
'XXXX'
,
'non-amino acid'
),
(
'bad_amino_acids_Z'
,
'ZZZZ'
,
'non-amino acid'
))
def
test_clean_and_validate_sequence_bad
(
self
,
sequence
,
exp_error
):
with
self
.
assertRaisesRegex
(
ValueError
,
f
'.*
{
exp_error
}
.*'
):
notebook_utils
.
clean_and_validate_sequence
(
sequence
,
min_length
=
4
,
max_length
=
8
)
@
parameterized
.
parameters
(
([
'A'
,
''
,
''
,
' '
,
'
\t
'
,
'
\t\n
'
,
''
,
''
],
[
'A'
],
notebook_utils
.
ModelType
.
MONOMER
),
([
''
,
'A'
],
[
'A'
],
notebook_utils
.
ModelType
.
MONOMER
),
([
'A'
,
'C '
,
''
],
[
'A'
,
'C'
],
notebook_utils
.
ModelType
.
MULTIMER
),
([
''
,
'A'
,
''
,
'C '
],
[
'A'
,
'C'
],
notebook_utils
.
ModelType
.
MULTIMER
))
def
test_validate_input_ok
(
self
,
input_sequences
,
exp_sequences
,
exp_model_type
):
sequences
,
model_type
=
notebook_utils
.
validate_input
(
input_sequences
=
input_sequences
,
min_length
=
1
,
max_length
=
100
,
max_multimer_length
=
100
)
self
.
assertSequenceEqual
(
sequences
,
exp_sequences
)
self
.
assertEqual
(
model_type
,
exp_model_type
)
@
parameterized
.
named_parameters
(
(
'no_input_sequence'
,
[
''
,
'
\t
'
,
'
\n
'
],
'No input amino acid sequence'
),
(
'too_long_single'
,
[
'AAAAAAAAA'
,
'AAAA'
],
'Input sequence is too long'
),
(
'too_long_multimer'
,
[
'AAAA'
,
'AAAAA'
],
'The total length of multimer'
))
def
test_validate_input_bad
(
self
,
input_sequences
,
exp_error
):
with
self
.
assertRaisesRegex
(
ValueError
,
f
'.*
{
exp_error
}
.*'
):
notebook_utils
.
validate_input
(
input_sequences
=
input_sequences
,
min_length
=
4
,
max_length
=
8
,
max_multimer_length
=
6
)
def
test_merge_chunked_msa_no_hits
(
self
):
results
=
[
ONLY_QUERY_HIT
,
ONLY_QUERY_HIT
]
merged_msa
=
notebook_utils
.
merge_chunked_msa
(
results
=
results
)
self
.
assertSequenceEqual
(
merged_msa
.
sequences
,
(
'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEH'
,))
self
.
assertSequenceEqual
(
merged_msa
.
deletion_matrix
,
([
0
]
*
56
,))
def
test_merge_chunked_msa
(
self
):
results
=
[
MULTI_SEQUENCE_HIT_1
,
MULTI_SEQUENCE_HIT_2
]
merged_msa
=
notebook_utils
.
merge_chunked_msa
(
results
=
results
)
self
.
assertLen
(
merged_msa
.
sequences
,
7
)
# The 1st one is the query.
self
.
assertEqual
(
merged_msa
.
sequences
[
0
],
'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAP'
'KPH'
)
# The 2nd one is the one with the lowest e-value: ERR1700680_4602609.
self
.
assertEqual
(
merged_msa
.
sequences
[
1
],
'--INKGAEYHKKAAEHHELAAKHHREAAKHHEAGSHEKAAHHSEIAAGHGLTAVHHTEEATK-HHPEEHT'
'EK-'
)
# The last one is the one with the largest e-value: SRR5438477_9761204.
self
.
assertEqual
(
merged_msa
.
sequences
[
-
1
],
'MPKHEGAEHHKKAAEHNEHAARHHKEAARHHEEGSHEKVGHHAHIAHGHHLHATHHAEEAAKTHSNQHE-'
'---'
)
self
.
assertLen
(
merged_msa
.
deletion_matrix
,
7
)
@
mock
.
patch
(
'sys.stdout'
,
new_callable
=
io
.
StringIO
)
def
test_show_msa_info
(
self
,
mocked_stdout
):
single_chain_msas
=
[
parsers
.
Msa
(
sequences
=
[
'A'
,
'B'
,
'C'
,
'C'
],
deletion_matrix
=
[
None
]
*
4
,
descriptions
=
[
''
]
*
4
),
parsers
.
Msa
(
sequences
=
[
'A'
,
'A'
,
'A'
,
'D'
],
deletion_matrix
=
[
None
]
*
4
,
descriptions
=
[
''
]
*
4
)
]
notebook_utils
.
show_msa_info
(
single_chain_msas
=
single_chain_msas
,
sequence_index
=
1
)
self
.
assertEqual
(
mocked_stdout
.
getvalue
(),
'
\n
4 unique sequences found in total for sequence 1
\n\n
'
)
@
parameterized
.
named_parameters
(
(
'some_templates'
,
4
),
(
'no_templates'
,
0
))
def
test_empty_placeholder_template_features
(
self
,
num_templates
):
template_features
=
notebook_utils
.
empty_placeholder_template_features
(
num_templates
=
num_templates
,
num_res
=
16
)
self
.
assertCountEqual
(
template_features
.
keys
(),
templates
.
TEMPLATE_FEATURES
.
keys
())
self
.
assertSameElements
(
[
v
.
shape
[
0
]
for
v
in
template_features
.
values
()],
[
num_templates
])
self
.
assertSequenceEqual
(
[
t
.
dtype
for
t
in
template_features
.
values
()],
[
np
.
array
([],
dtype
=
templates
.
TEMPLATE_FEATURES
[
feat_name
]).
dtype
for
feat_name
in
template_features
])
def
test_get_pae_json
(
self
):
pae
=
np
.
array
([[
0.01
,
13.12345
],
[
20.0987
,
0.0
]])
pae_json
=
notebook_utils
.
get_pae_json
(
pae
=
pae
,
max_pae
=
31.75
)
self
.
assertEqual
(
pae_json
,
'[{"residue1":[1,1,2,2],"residue2":[1,2,1,2],"distance":'
'[0.0,13.1,20.1,0.0],"max_predicted_aligned_error":31.75}]'
)
if
__name__
==
'__main__'
:
absltest
.
main
()
alphafold/relax/__init__.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Amber relaxation."""
alphafold/relax/amber_minimize.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Restrained Amber Minimization of a structure."""
import
io
import
time
from
typing
import
Collection
,
Optional
,
Sequence
from
absl
import
logging
from
alphafold.common
import
protein
from
alphafold.common
import
residue_constants
from
alphafold.model
import
folding
from
alphafold.relax
import
cleanup
from
alphafold.relax
import
utils
import
ml_collections
import
numpy
as
np
try
:
# openmm >= 7.6
import
openmm
from
openmm
import
unit
from
openmm
import
app
as
openmm_app
from
openmm.app.internal.pdbstructure
import
PdbStructure
except
ImportError
:
# openmm < 7.6 (requires DeepMind patch)
from
simtk
import
openmm
from
simtk
import
unit
from
simtk.openmm
import
app
as
openmm_app
from
simtk.openmm.app.internal.pdbstructure
import
PdbStructure
ENERGY
=
unit
.
kilocalories_per_mole
LENGTH
=
unit
.
angstroms
def
will_restrain
(
atom
:
openmm_app
.
Atom
,
rset
:
str
)
->
bool
:
"""Returns True if the atom will be restrained by the given restraint set."""
if
rset
==
"non_hydrogen"
:
return
atom
.
element
.
name
!=
"hydrogen"
elif
rset
==
"c_alpha"
:
return
atom
.
name
==
"CA"
def
_add_restraints
(
system
:
openmm
.
System
,
reference_pdb
:
openmm_app
.
PDBFile
,
stiffness
:
unit
.
Unit
,
rset
:
str
,
exclude_residues
:
Sequence
[
int
]):
"""Adds a harmonic potential that restrains the system to a structure."""
assert
rset
in
[
"non_hydrogen"
,
"c_alpha"
]
force
=
openmm
.
CustomExternalForce
(
"0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)"
)
force
.
addGlobalParameter
(
"k"
,
stiffness
)
for
p
in
[
"x0"
,
"y0"
,
"z0"
]:
force
.
addPerParticleParameter
(
p
)
for
i
,
atom
in
enumerate
(
reference_pdb
.
topology
.
atoms
()):
if
atom
.
residue
.
index
in
exclude_residues
:
continue
if
will_restrain
(
atom
,
rset
):
force
.
addParticle
(
i
,
reference_pdb
.
positions
[
i
])
logging
.
info
(
"Restraining %d / %d particles."
,
force
.
getNumParticles
(),
system
.
getNumParticles
())
system
.
addForce
(
force
)
def
_openmm_minimize
(
pdb_str
:
str
,
max_iterations
:
int
,
tolerance
:
unit
.
Unit
,
stiffness
:
unit
.
Unit
,
restraint_set
:
str
,
exclude_residues
:
Sequence
[
int
],
use_gpu
:
bool
):
"""Minimize energy via openmm."""
pdb_file
=
io
.
StringIO
(
pdb_str
)
pdb
=
openmm_app
.
PDBFile
(
pdb_file
)
force_field
=
openmm_app
.
ForceField
(
"amber99sb.xml"
)
constraints
=
openmm_app
.
HBonds
system
=
force_field
.
createSystem
(
pdb
.
topology
,
constraints
=
constraints
)
if
stiffness
>
0
*
ENERGY
/
(
LENGTH
**
2
):
_add_restraints
(
system
,
pdb
,
stiffness
,
restraint_set
,
exclude_residues
)
integrator
=
openmm
.
LangevinIntegrator
(
0
,
0.01
,
0.0
)
platform
=
openmm
.
Platform
.
getPlatformByName
(
"HIP"
if
use_gpu
else
"CPU"
)
simulation
=
openmm_app
.
Simulation
(
pdb
.
topology
,
system
,
integrator
,
platform
)
simulation
.
context
.
setPositions
(
pdb
.
positions
)
ret
=
{}
state
=
simulation
.
context
.
getState
(
getEnergy
=
True
,
getPositions
=
True
)
ret
[
"einit"
]
=
state
.
getPotentialEnergy
().
value_in_unit
(
ENERGY
)
ret
[
"posinit"
]
=
state
.
getPositions
(
asNumpy
=
True
).
value_in_unit
(
LENGTH
)
simulation
.
minimizeEnergy
(
maxIterations
=
max_iterations
,
tolerance
=
tolerance
)
state
=
simulation
.
context
.
getState
(
getEnergy
=
True
,
getPositions
=
True
)
ret
[
"efinal"
]
=
state
.
getPotentialEnergy
().
value_in_unit
(
ENERGY
)
ret
[
"pos"
]
=
state
.
getPositions
(
asNumpy
=
True
).
value_in_unit
(
LENGTH
)
ret
[
"min_pdb"
]
=
_get_pdb_string
(
simulation
.
topology
,
state
.
getPositions
())
return
ret
def
_get_pdb_string
(
topology
:
openmm_app
.
Topology
,
positions
:
unit
.
Quantity
):
"""Returns a pdb string provided OpenMM topology and positions."""
with
io
.
StringIO
()
as
f
:
openmm_app
.
PDBFile
.
writeFile
(
topology
,
positions
,
f
)
return
f
.
getvalue
()
def
_check_cleaned_atoms
(
pdb_cleaned_string
:
str
,
pdb_ref_string
:
str
):
"""Checks that no atom positions have been altered by cleaning."""
cleaned
=
openmm_app
.
PDBFile
(
io
.
StringIO
(
pdb_cleaned_string
))
reference
=
openmm_app
.
PDBFile
(
io
.
StringIO
(
pdb_ref_string
))
cl_xyz
=
np
.
array
(
cleaned
.
getPositions
().
value_in_unit
(
LENGTH
))
ref_xyz
=
np
.
array
(
reference
.
getPositions
().
value_in_unit
(
LENGTH
))
for
ref_res
,
cl_res
in
zip
(
reference
.
topology
.
residues
(),
cleaned
.
topology
.
residues
()):
assert
ref_res
.
name
==
cl_res
.
name
for
rat
in
ref_res
.
atoms
():
for
cat
in
cl_res
.
atoms
():
if
cat
.
name
==
rat
.
name
:
if
not
np
.
array_equal
(
cl_xyz
[
cat
.
index
],
ref_xyz
[
rat
.
index
]):
raise
ValueError
(
f
"Coordinates of cleaned atom
{
cat
}
do not match "
f
"coordinates of reference atom
{
rat
}
."
)
def
_check_residues_are_well_defined
(
prot
:
protein
.
Protein
):
"""Checks that all residues contain non-empty atom sets."""
if
(
prot
.
atom_mask
.
sum
(
axis
=-
1
)
==
0
).
any
():
raise
ValueError
(
"Amber minimization can only be performed on proteins with"
" well-defined residues. This protein contains at least"
" one residue with no atoms."
)
def
_check_atom_mask_is_ideal
(
prot
):
"""Sanity-check the atom mask is ideal, up to a possible OXT."""
atom_mask
=
prot
.
atom_mask
ideal_atom_mask
=
protein
.
ideal_atom_mask
(
prot
)
utils
.
assert_equal_nonterminal_atom_types
(
atom_mask
,
ideal_atom_mask
)
def
clean_protein
(
prot
:
protein
.
Protein
,
checks
:
bool
=
True
):
"""Adds missing atoms to Protein instance.
Args:
prot: A `protein.Protein` instance.
checks: A `bool` specifying whether to add additional checks to the cleaning
process.
Returns:
pdb_string: A string of the cleaned protein.
"""
_check_atom_mask_is_ideal
(
prot
)
# Clean pdb.
prot_pdb_string
=
protein
.
to_pdb
(
prot
)
pdb_file
=
io
.
StringIO
(
prot_pdb_string
)
alterations_info
=
{}
fixed_pdb
=
cleanup
.
fix_pdb
(
pdb_file
,
alterations_info
)
fixed_pdb_file
=
io
.
StringIO
(
fixed_pdb
)
pdb_structure
=
PdbStructure
(
fixed_pdb_file
)
cleanup
.
clean_structure
(
pdb_structure
,
alterations_info
)
logging
.
info
(
"alterations info: %s"
,
alterations_info
)
# Write pdb file of cleaned structure.
as_file
=
openmm_app
.
PDBFile
(
pdb_structure
)
pdb_string
=
_get_pdb_string
(
as_file
.
getTopology
(),
as_file
.
getPositions
())
if
checks
:
_check_cleaned_atoms
(
pdb_string
,
prot_pdb_string
)
return
pdb_string
def
make_atom14_positions
(
prot
):
"""Constructs denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37
=
[]
# mapping (restype, atom14) --> atom37
restype_atom37_to_atom14
=
[]
# mapping (restype, atom37) --> atom14
restype_atom14_mask
=
[]
for
rt
in
residue_constants
.
restypes
:
atom_names
=
residue_constants
.
restype_name_to_atom14_names
[
residue_constants
.
restype_1to3
[
rt
]]
restype_atom14_to_atom37
.
append
([
(
residue_constants
.
atom_order
[
name
]
if
name
else
0
)
for
name
in
atom_names
])
atom_name_to_idx14
=
{
name
:
i
for
i
,
name
in
enumerate
(
atom_names
)}
restype_atom37_to_atom14
.
append
([
(
atom_name_to_idx14
[
name
]
if
name
in
atom_name_to_idx14
else
0
)
for
name
in
residue_constants
.
atom_types
])
restype_atom14_mask
.
append
([(
1.
if
name
else
0.
)
for
name
in
atom_names
])
# Add dummy mapping for restype 'UNK'.
restype_atom14_to_atom37
.
append
([
0
]
*
14
)
restype_atom37_to_atom14
.
append
([
0
]
*
37
)
restype_atom14_mask
.
append
([
0.
]
*
14
)
restype_atom14_to_atom37
=
np
.
array
(
restype_atom14_to_atom37
,
dtype
=
np
.
int32
)
restype_atom37_to_atom14
=
np
.
array
(
restype_atom37_to_atom14
,
dtype
=
np
.
int32
)
restype_atom14_mask
=
np
.
array
(
restype_atom14_mask
,
dtype
=
np
.
float32
)
# Create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein.
residx_atom14_to_atom37
=
restype_atom14_to_atom37
[
prot
[
"aatype"
]]
residx_atom14_mask
=
restype_atom14_mask
[
prot
[
"aatype"
]]
# Create a mask for known ground truth positions.
residx_atom14_gt_mask
=
residx_atom14_mask
*
np
.
take_along_axis
(
prot
[
"all_atom_mask"
],
residx_atom14_to_atom37
,
axis
=
1
).
astype
(
np
.
float32
)
# Gather the ground truth positions.
residx_atom14_gt_positions
=
residx_atom14_gt_mask
[:,
:,
None
]
*
(
np
.
take_along_axis
(
prot
[
"all_atom_positions"
],
residx_atom14_to_atom37
[...,
None
],
axis
=
1
))
prot
[
"atom14_atom_exists"
]
=
residx_atom14_mask
prot
[
"atom14_gt_exists"
]
=
residx_atom14_gt_mask
prot
[
"atom14_gt_positions"
]
=
residx_atom14_gt_positions
prot
[
"residx_atom14_to_atom37"
]
=
residx_atom14_to_atom37
# Create the gather indices for mapping back.
residx_atom37_to_atom14
=
restype_atom37_to_atom14
[
prot
[
"aatype"
]]
prot
[
"residx_atom37_to_atom14"
]
=
residx_atom37_to_atom14
# Create the corresponding mask.
restype_atom37_mask
=
np
.
zeros
([
21
,
37
],
dtype
=
np
.
float32
)
for
restype
,
restype_letter
in
enumerate
(
residue_constants
.
restypes
):
restype_name
=
residue_constants
.
restype_1to3
[
restype_letter
]
atom_names
=
residue_constants
.
residue_atoms
[
restype_name
]
for
atom_name
in
atom_names
:
atom_type
=
residue_constants
.
atom_order
[
atom_name
]
restype_atom37_mask
[
restype
,
atom_type
]
=
1
residx_atom37_mask
=
restype_atom37_mask
[
prot
[
"aatype"
]]
prot
[
"atom37_atom_exists"
]
=
residx_atom37_mask
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative ground truth coordinates where the naming is swapped
restype_3
=
[
residue_constants
.
restype_1to3
[
res
]
for
res
in
residue_constants
.
restypes
]
restype_3
+=
[
"UNK"
]
# Matrices for renaming ambiguous atoms.
all_matrices
=
{
res
:
np
.
eye
(
14
,
dtype
=
np
.
float32
)
for
res
in
restype_3
}
for
resname
,
swap
in
residue_constants
.
residue_atom_renaming_swaps
.
items
():
correspondences
=
np
.
arange
(
14
)
for
source_atom_swap
,
target_atom_swap
in
swap
.
items
():
source_index
=
residue_constants
.
restype_name_to_atom14_names
[
resname
].
index
(
source_atom_swap
)
target_index
=
residue_constants
.
restype_name_to_atom14_names
[
resname
].
index
(
target_atom_swap
)
correspondences
[
source_index
]
=
target_index
correspondences
[
target_index
]
=
source_index
renaming_matrix
=
np
.
zeros
((
14
,
14
),
dtype
=
np
.
float32
)
for
index
,
correspondence
in
enumerate
(
correspondences
):
renaming_matrix
[
index
,
correspondence
]
=
1.
all_matrices
[
resname
]
=
renaming_matrix
.
astype
(
np
.
float32
)
renaming_matrices
=
np
.
stack
([
all_matrices
[
restype
]
for
restype
in
restype_3
])
# Pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14).
renaming_transform
=
renaming_matrices
[
prot
[
"aatype"
]]
# Apply it to the ground truth positions. shape (num_res, 14, 3).
alternative_gt_positions
=
np
.
einsum
(
"rac,rab->rbc"
,
residx_atom14_gt_positions
,
renaming_transform
)
prot
[
"atom14_alt_gt_positions"
]
=
alternative_gt_positions
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position).
alternative_gt_mask
=
np
.
einsum
(
"ra,rab->rb"
,
residx_atom14_gt_mask
,
renaming_transform
)
prot
[
"atom14_alt_gt_exists"
]
=
alternative_gt_mask
# Create an ambiguous atoms mask. shape: (21, 14).
restype_atom14_is_ambiguous
=
np
.
zeros
((
21
,
14
),
dtype
=
np
.
float32
)
for
resname
,
swap
in
residue_constants
.
residue_atom_renaming_swaps
.
items
():
for
atom_name1
,
atom_name2
in
swap
.
items
():
restype
=
residue_constants
.
restype_order
[
residue_constants
.
restype_3to1
[
resname
]]
atom_idx1
=
residue_constants
.
restype_name_to_atom14_names
[
resname
].
index
(
atom_name1
)
atom_idx2
=
residue_constants
.
restype_name_to_atom14_names
[
resname
].
index
(
atom_name2
)
restype_atom14_is_ambiguous
[
restype
,
atom_idx1
]
=
1
restype_atom14_is_ambiguous
[
restype
,
atom_idx2
]
=
1
# From this create an ambiguous_mask for the given sequence.
prot
[
"atom14_atom_is_ambiguous"
]
=
(
restype_atom14_is_ambiguous
[
prot
[
"aatype"
]])
return
prot
def
find_violations
(
prot_np
:
protein
.
Protein
):
"""Analyzes a protein and returns structural violation information.
Args:
prot_np: A protein.
Returns:
violations: A `dict` of structure components with structural violations.
violation_metrics: A `dict` of violation metrics.
"""
batch
=
{
"aatype"
:
prot_np
.
aatype
,
"all_atom_positions"
:
prot_np
.
atom_positions
.
astype
(
np
.
float32
),
"all_atom_mask"
:
prot_np
.
atom_mask
.
astype
(
np
.
float32
),
"residue_index"
:
prot_np
.
residue_index
,
}
batch
[
"seq_mask"
]
=
np
.
ones_like
(
batch
[
"aatype"
],
np
.
float32
)
batch
=
make_atom14_positions
(
batch
)
violations
=
folding
.
find_structural_violations
(
batch
=
batch
,
atom14_pred_positions
=
batch
[
"atom14_gt_positions"
],
config
=
ml_collections
.
ConfigDict
(
{
"violation_tolerance_factor"
:
12
,
# Taken from model config.
"clash_overlap_tolerance"
:
1.5
,
# Taken from model config.
}))
violation_metrics
=
folding
.
compute_violation_metrics
(
batch
=
batch
,
atom14_pred_positions
=
batch
[
"atom14_gt_positions"
],
violations
=
violations
,
)
return
violations
,
violation_metrics
def
get_violation_metrics
(
prot
:
protein
.
Protein
):
"""Computes violation and alignment metrics."""
structural_violations
,
struct_metrics
=
find_violations
(
prot
)
violation_idx
=
np
.
flatnonzero
(
structural_violations
[
"total_per_residue_violations_mask"
])
struct_metrics
[
"residue_violations"
]
=
violation_idx
struct_metrics
[
"num_residue_violations"
]
=
len
(
violation_idx
)
struct_metrics
[
"structural_violations"
]
=
structural_violations
return
struct_metrics
def
_run_one_iteration
(
*
,
pdb_string
:
str
,
max_iterations
:
int
,
tolerance
:
float
,
stiffness
:
float
,
restraint_set
:
str
,
max_attempts
:
int
,
use_gpu
:
bool
,
exclude_residues
:
Optional
[
Collection
[
int
]]
=
None
):
"""Runs the minimization pipeline.
Args:
pdb_string: A pdb string.
max_iterations: An `int` specifying the maximum number of L-BFGS iterations.
A value of 0 specifies no limit.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
restraint_set: The set of atoms to restrain.
max_attempts: The maximum number of minimization attempts.
use_gpu: Whether to run on GPU.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Returns:
A `dict` of minimization info.
"""
exclude_residues
=
exclude_residues
or
[]
# Assign physical dimensions.
tolerance
=
tolerance
*
ENERGY
stiffness
=
stiffness
*
ENERGY
/
(
LENGTH
**
2
)
start
=
time
.
time
()
minimized
=
False
attempts
=
0
while
not
minimized
and
attempts
<
max_attempts
:
attempts
+=
1
try
:
logging
.
info
(
"Minimizing protein, attempt %d of %d."
,
attempts
,
max_attempts
)
ret
=
_openmm_minimize
(
pdb_string
,
max_iterations
=
max_iterations
,
tolerance
=
tolerance
,
stiffness
=
stiffness
,
restraint_set
=
restraint_set
,
exclude_residues
=
exclude_residues
,
use_gpu
=
use_gpu
)
minimized
=
True
except
Exception
as
e
:
# pylint: disable=broad-except
logging
.
info
(
e
)
if
not
minimized
:
raise
ValueError
(
f
"Minimization failed after
{
max_attempts
}
attempts."
)
ret
[
"opt_time"
]
=
time
.
time
()
-
start
ret
[
"min_attempts"
]
=
attempts
return
ret
def
run_pipeline
(
prot
:
protein
.
Protein
,
stiffness
:
float
,
use_gpu
:
bool
,
max_outer_iterations
:
int
=
1
,
place_hydrogens_every_iteration
:
bool
=
True
,
max_iterations
:
int
=
0
,
tolerance
:
float
=
2.39
,
restraint_set
:
str
=
"non_hydrogen"
,
max_attempts
:
int
=
100
,
checks
:
bool
=
True
,
exclude_residues
:
Optional
[
Sequence
[
int
]]
=
None
):
"""Run iterative amber relax.
Successive relax iterations are performed until all violations have been
resolved. Each iteration involves a restrained Amber minimization, with
restraint exclusions determined by violation-participating residues.
Args:
prot: A protein to be relaxed.
stiffness: kcal/mol A**2, the restraint stiffness.
use_gpu: Whether to run on GPU.
max_outer_iterations: The maximum number of iterative minimization.
place_hydrogens_every_iteration: Whether hydrogens are re-initialized
prior to every minimization.
max_iterations: An `int` specifying the maximum number of L-BFGS steps
per relax iteration. A value of 0 specifies no limit.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
The default value is the OpenMM default.
restraint_set: The set of atoms to restrain.
max_attempts: The maximum number of minimization attempts per iteration.
checks: Whether to perform cleaning checks.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Returns:
out: A dictionary of output values.
"""
# `protein.to_pdb` will strip any poorly-defined residues so we need to
# perform this check before `clean_protein`.
_check_residues_are_well_defined
(
prot
)
pdb_string
=
clean_protein
(
prot
,
checks
=
checks
)
exclude_residues
=
exclude_residues
or
[]
exclude_residues
=
set
(
exclude_residues
)
violations
=
np
.
inf
iteration
=
0
while
violations
>
0
and
iteration
<
max_outer_iterations
:
ret
=
_run_one_iteration
(
pdb_string
=
pdb_string
,
exclude_residues
=
exclude_residues
,
max_iterations
=
max_iterations
,
tolerance
=
tolerance
,
stiffness
=
stiffness
,
restraint_set
=
restraint_set
,
max_attempts
=
max_attempts
,
use_gpu
=
use_gpu
)
prot
=
protein
.
from_pdb_string
(
ret
[
"min_pdb"
])
if
place_hydrogens_every_iteration
:
pdb_string
=
clean_protein
(
prot
,
checks
=
True
)
else
:
pdb_string
=
ret
[
"min_pdb"
]
ret
.
update
(
get_violation_metrics
(
prot
))
ret
.
update
({
"num_exclusions"
:
len
(
exclude_residues
),
"iteration"
:
iteration
,
})
violations
=
ret
[
"violations_per_residue"
]
exclude_residues
=
exclude_residues
.
union
(
ret
[
"residue_violations"
])
logging
.
info
(
"Iteration completed: Einit %.2f Efinal %.2f Time %.2f s "
"num residue violations %d num residue exclusions %d "
,
ret
[
"einit"
],
ret
[
"efinal"
],
ret
[
"opt_time"
],
ret
[
"num_residue_violations"
],
ret
[
"num_exclusions"
])
iteration
+=
1
return
ret
alphafold/relax/amber_minimize_test.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for amber_minimize."""
import
os
from
absl.testing
import
absltest
from
alphafold.common
import
protein
from
alphafold.relax
import
amber_minimize
import
numpy
as
np
# Internal import (7716).
_USE_GPU
=
False
def
_load_test_protein
(
data_path
):
pdb_path
=
os
.
path
.
join
(
absltest
.
get_default_test_srcdir
(),
data_path
)
with
open
(
pdb_path
,
'r'
)
as
f
:
return
protein
.
from_pdb_string
(
f
.
read
())
class
AmberMinimizeTest
(
absltest
.
TestCase
):
def
test_multiple_disulfides_target
(
self
):
prot
=
_load_test_protein
(
'alphafold/relax/testdata/multiple_disulfides_target.pdb'
)
ret
=
amber_minimize
.
run_pipeline
(
prot
,
max_iterations
=
10
,
max_attempts
=
1
,
stiffness
=
10.
,
use_gpu
=
_USE_GPU
)
self
.
assertIn
(
'opt_time'
,
ret
)
self
.
assertIn
(
'min_attempts'
,
ret
)
def
test_raises_invalid_protein_assertion
(
self
):
prot
=
_load_test_protein
(
'alphafold/relax/testdata/multiple_disulfides_target.pdb'
)
prot
.
atom_mask
[
4
,
:]
=
0
with
self
.
assertRaisesRegex
(
ValueError
,
'Amber minimization can only be performed on proteins with well-defined'
' residues. This protein contains at least one residue with no atoms.'
):
amber_minimize
.
run_pipeline
(
prot
,
max_iterations
=
10
,
stiffness
=
1.
,
max_attempts
=
1
,
use_gpu
=
_USE_GPU
)
def
test_iterative_relax
(
self
):
prot
=
_load_test_protein
(
'alphafold/relax/testdata/with_violations.pdb'
)
violations
=
amber_minimize
.
get_violation_metrics
(
prot
)
self
.
assertGreater
(
violations
[
'num_residue_violations'
],
0
)
out
=
amber_minimize
.
run_pipeline
(
prot
=
prot
,
max_outer_iterations
=
10
,
stiffness
=
10.
,
use_gpu
=
_USE_GPU
)
self
.
assertLess
(
out
[
'efinal'
],
out
[
'einit'
])
self
.
assertEqual
(
0
,
out
[
'num_residue_violations'
])
def
test_find_violations
(
self
):
prot
=
_load_test_protein
(
'alphafold/relax/testdata/multiple_disulfides_target.pdb'
)
viols
,
_
=
amber_minimize
.
find_violations
(
prot
)
expected_between_residues_connection_mask
=
np
.
zeros
((
191
,),
np
.
float32
)
for
residue
in
(
42
,
43
,
59
,
60
,
135
,
136
):
expected_between_residues_connection_mask
[
residue
]
=
1.0
expected_clash_indices
=
np
.
array
([
[
8
,
4
],
[
8
,
5
],
[
13
,
3
],
[
14
,
1
],
[
14
,
4
],
[
26
,
4
],
[
26
,
5
],
[
31
,
8
],
[
31
,
10
],
[
39
,
0
],
[
39
,
1
],
[
39
,
2
],
[
39
,
3
],
[
39
,
4
],
[
42
,
5
],
[
42
,
6
],
[
42
,
7
],
[
42
,
8
],
[
47
,
7
],
[
47
,
8
],
[
47
,
9
],
[
47
,
10
],
[
64
,
4
],
[
85
,
5
],
[
102
,
4
],
[
102
,
5
],
[
109
,
13
],
[
111
,
5
],
[
118
,
6
],
[
118
,
7
],
[
118
,
8
],
[
124
,
4
],
[
124
,
5
],
[
131
,
5
],
[
139
,
7
],
[
147
,
4
],
[
152
,
7
]],
dtype
=
np
.
int32
)
expected_between_residues_clash_mask
=
np
.
zeros
([
191
,
14
])
expected_between_residues_clash_mask
[
expected_clash_indices
[:,
0
],
expected_clash_indices
[:,
1
]]
+=
1
expected_per_atom_violations
=
np
.
zeros
([
191
,
14
])
np
.
testing
.
assert_array_equal
(
viols
[
'between_residues'
][
'connections_per_residue_violation_mask'
],
expected_between_residues_connection_mask
)
np
.
testing
.
assert_array_equal
(
viols
[
'between_residues'
][
'clashes_per_atom_clash_mask'
],
expected_between_residues_clash_mask
)
np
.
testing
.
assert_array_equal
(
viols
[
'within_residues'
][
'per_atom_violations'
],
expected_per_atom_violations
)
if
__name__
==
'__main__'
:
absltest
.
main
()
alphafold/relax/cleanup.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations.
fix_pdb uses a third-party tool. We also support fixing some additional edge
cases like removing chains of length one (see clean_structure).
"""
import
io
import
pdbfixer
from
simtk.openmm
import
app
from
simtk.openmm.app
import
element
def
fix_pdb
(
pdbfile
,
alterations_info
):
"""Apply pdbfixer to the contents of a PDB file; return a PDB string result.
1) Replaces nonstandard residues.
2) Removes heterogens (non protein residues) including water.
3) Adds missing residues and missing atoms within existing residues.
4) Adds hydrogens assuming pH=7.0.
5) KeepIds is currently true, so the fixer must keep the existing chain and
residue identifiers. This will fail for some files in wider PDB that have
invalid IDs.
Args:
pdbfile: Input PDB file handle.
alterations_info: A dict that will store details of changes made.
Returns:
A PDB string representing the fixed structure.
"""
fixer
=
pdbfixer
.
PDBFixer
(
pdbfile
=
pdbfile
)
fixer
.
findNonstandardResidues
()
alterations_info
[
'nonstandard_residues'
]
=
fixer
.
nonstandardResidues
fixer
.
replaceNonstandardResidues
()
_remove_heterogens
(
fixer
,
alterations_info
,
keep_water
=
False
)
fixer
.
findMissingResidues
()
alterations_info
[
'missing_residues'
]
=
fixer
.
missingResidues
fixer
.
findMissingAtoms
()
alterations_info
[
'missing_heavy_atoms'
]
=
fixer
.
missingAtoms
alterations_info
[
'missing_terminals'
]
=
fixer
.
missingTerminals
fixer
.
addMissingAtoms
(
seed
=
0
)
fixer
.
addMissingHydrogens
()
out_handle
=
io
.
StringIO
()
app
.
PDBFile
.
writeFile
(
fixer
.
topology
,
fixer
.
positions
,
out_handle
,
keepIds
=
True
)
return
out_handle
.
getvalue
()
def
clean_structure
(
pdb_structure
,
alterations_info
):
"""Applies additional fixes to an OpenMM structure, to handle edge cases.
Args:
pdb_structure: An OpenMM structure to modify and fix.
alterations_info: A dict that will store details of changes made.
"""
_replace_met_se
(
pdb_structure
,
alterations_info
)
_remove_chains_of_length_one
(
pdb_structure
,
alterations_info
)
def
_remove_heterogens
(
fixer
,
alterations_info
,
keep_water
):
"""Removes the residues that Pdbfixer considers to be heterogens.
Args:
fixer: A Pdbfixer instance.
alterations_info: A dict that will store details of changes made.
keep_water: If True, water (HOH) is not considered to be a heterogen.
"""
initial_resnames
=
set
()
for
chain
in
fixer
.
topology
.
chains
():
for
residue
in
chain
.
residues
():
initial_resnames
.
add
(
residue
.
name
)
fixer
.
removeHeterogens
(
keepWater
=
keep_water
)
final_resnames
=
set
()
for
chain
in
fixer
.
topology
.
chains
():
for
residue
in
chain
.
residues
():
final_resnames
.
add
(
residue
.
name
)
alterations_info
[
'removed_heterogens'
]
=
(
initial_resnames
.
difference
(
final_resnames
))
def
_replace_met_se
(
pdb_structure
,
alterations_info
):
"""Replace the Se in any MET residues that were not marked as modified."""
modified_met_residues
=
[]
for
res
in
pdb_structure
.
iter_residues
():
name
=
res
.
get_name_with_spaces
().
strip
()
if
name
==
'MET'
:
s_atom
=
res
.
get_atom
(
'SD'
)
if
s_atom
.
element_symbol
==
'Se'
:
s_atom
.
element_symbol
=
'S'
s_atom
.
element
=
element
.
get_by_symbol
(
'S'
)
modified_met_residues
.
append
(
s_atom
.
residue_number
)
alterations_info
[
'Se_in_MET'
]
=
modified_met_residues
def
_remove_chains_of_length_one
(
pdb_structure
,
alterations_info
):
"""Removes chains that correspond to a single amino acid.
A single amino acid in a chain is both N and C terminus. There is no force
template for this case.
Args:
pdb_structure: An OpenMM pdb_structure to modify and fix.
alterations_info: A dict that will store details of changes made.
"""
removed_chains
=
{}
for
model
in
pdb_structure
.
iter_models
():
valid_chains
=
[
c
for
c
in
model
.
iter_chains
()
if
len
(
c
)
>
1
]
invalid_chain_ids
=
[
c
.
chain_id
for
c
in
model
.
iter_chains
()
if
len
(
c
)
<=
1
]
model
.
chains
=
valid_chains
for
chain_id
in
invalid_chain_ids
:
model
.
chains_by_id
.
pop
(
chain_id
)
removed_chains
[
model
.
number
]
=
invalid_chain_ids
alterations_info
[
'removed_chains'
]
=
removed_chains
alphafold/relax/cleanup_test.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for relax.cleanup."""
import
io
from
absl.testing
import
absltest
from
alphafold.relax
import
cleanup
from
simtk.openmm.app.internal
import
pdbstructure
def
_pdb_to_structure
(
pdb_str
):
handle
=
io
.
StringIO
(
pdb_str
)
return
pdbstructure
.
PdbStructure
(
handle
)
def
_lines_to_structure
(
pdb_lines
):
return
_pdb_to_structure
(
'
\n
'
.
join
(
pdb_lines
))
class
CleanupTest
(
absltest
.
TestCase
):
def
test_missing_residues
(
self
):
pdb_lines
=
[
'SEQRES 1 C 3 CYS GLY LEU'
,
'ATOM 1 N CYS C 1 -12.262 20.115 60.959 1.00 '
'19.08 N'
,
'ATOM 2 CA CYS C 1 -11.065 20.934 60.773 1.00 '
'17.23 C'
,
'ATOM 3 C CYS C 1 -10.002 20.742 61.844 1.00 '
'15.38 C'
,
'ATOM 4 O CYS C 1 -10.284 20.225 62.929 1.00 '
'16.04 O'
,
'ATOM 5 N LEU C 3 -7.688 18.700 62.045 1.00 '
'14.75 N'
,
'ATOM 6 CA LEU C 3 -7.256 17.320 62.234 1.00 '
'16.81 C'
,
'ATOM 7 C LEU C 3 -6.380 16.864 61.070 1.00 '
'16.95 C'
,
'ATOM 8 O LEU C 3 -6.551 17.332 59.947 1.00 '
'16.97 O'
]
input_handle
=
io
.
StringIO
(
'
\n
'
.
join
(
pdb_lines
))
alterations
=
{}
result
=
cleanup
.
fix_pdb
(
input_handle
,
alterations
)
structure
=
_pdb_to_structure
(
result
)
residue_names
=
[
r
.
get_name
()
for
r
in
structure
.
iter_residues
()]
self
.
assertCountEqual
(
residue_names
,
[
'CYS'
,
'GLY'
,
'LEU'
])
self
.
assertCountEqual
(
alterations
[
'missing_residues'
].
values
(),
[[
'GLY'
]])
def
test_missing_atoms
(
self
):
pdb_lines
=
[
'SEQRES 1 A 1 PRO'
,
'ATOM 1 CA PRO A 1 1.000 1.000 1.000 1.00 '
' 0.00 C'
]
input_handle
=
io
.
StringIO
(
'
\n
'
.
join
(
pdb_lines
))
alterations
=
{}
result
=
cleanup
.
fix_pdb
(
input_handle
,
alterations
)
structure
=
_pdb_to_structure
(
result
)
atom_names
=
[
a
.
get_name
()
for
a
in
structure
.
iter_atoms
()]
self
.
assertCountEqual
(
atom_names
,
[
'N'
,
'CD'
,
'HD2'
,
'HD3'
,
'CG'
,
'HG2'
,
'HG3'
,
'CB'
,
'HB2'
,
'HB3'
,
'CA'
,
'HA'
,
'C'
,
'O'
,
'H2'
,
'H3'
,
'OXT'
])
missing_atoms_by_residue
=
list
(
alterations
[
'missing_heavy_atoms'
].
values
())
self
.
assertLen
(
missing_atoms_by_residue
,
1
)
atoms_added
=
[
a
.
name
for
a
in
missing_atoms_by_residue
[
0
]]
self
.
assertCountEqual
(
atoms_added
,
[
'N'
,
'CD'
,
'CG'
,
'CB'
,
'C'
,
'O'
])
missing_terminals_by_residue
=
alterations
[
'missing_terminals'
]
self
.
assertLen
(
missing_terminals_by_residue
,
1
)
has_missing_terminal
=
[
r
.
name
for
r
in
missing_terminals_by_residue
.
keys
()]
self
.
assertCountEqual
(
has_missing_terminal
,
[
'PRO'
])
self
.
assertCountEqual
([
t
for
t
in
missing_terminals_by_residue
.
values
()],
[[
'OXT'
]])
def
test_remove_heterogens
(
self
):
pdb_lines
=
[
'SEQRES 1 A 1 GLY'
,
'ATOM 1 CA GLY A 1 0.000 0.000 0.000 1.00 '
' 0.00 C'
,
'ATOM 2 O HOH A 2 0.000 0.000 0.000 1.00 '
' 0.00 O'
]
input_handle
=
io
.
StringIO
(
'
\n
'
.
join
(
pdb_lines
))
alterations
=
{}
result
=
cleanup
.
fix_pdb
(
input_handle
,
alterations
)
structure
=
_pdb_to_structure
(
result
)
self
.
assertCountEqual
([
res
.
get_name
()
for
res
in
structure
.
iter_residues
()],
[
'GLY'
])
self
.
assertEqual
(
alterations
[
'removed_heterogens'
],
set
([
'HOH'
]))
def
test_fix_nonstandard_residues
(
self
):
pdb_lines
=
[
'SEQRES 1 A 1 DAL'
,
'ATOM 1 CA DAL A 1 0.000 0.000 0.000 1.00 '
' 0.00 C'
]
input_handle
=
io
.
StringIO
(
'
\n
'
.
join
(
pdb_lines
))
alterations
=
{}
result
=
cleanup
.
fix_pdb
(
input_handle
,
alterations
)
structure
=
_pdb_to_structure
(
result
)
residue_names
=
[
res
.
get_name
()
for
res
in
structure
.
iter_residues
()]
self
.
assertCountEqual
(
residue_names
,
[
'ALA'
])
self
.
assertLen
(
alterations
[
'nonstandard_residues'
],
1
)
original_res
,
new_name
=
alterations
[
'nonstandard_residues'
][
0
]
self
.
assertEqual
(
original_res
.
id
,
'1'
)
self
.
assertEqual
(
new_name
,
'ALA'
)
def
test_replace_met_se
(
self
):
pdb_lines
=
[
'SEQRES 1 A 1 MET'
,
'ATOM 1 SD MET A 1 0.000 0.000 0.000 1.00 '
' 0.00 Se'
]
structure
=
_lines_to_structure
(
pdb_lines
)
alterations
=
{}
cleanup
.
_replace_met_se
(
structure
,
alterations
)
sd
=
[
a
for
a
in
structure
.
iter_atoms
()
if
a
.
get_name
()
==
'SD'
]
self
.
assertLen
(
sd
,
1
)
self
.
assertEqual
(
sd
[
0
].
element_symbol
,
'S'
)
self
.
assertCountEqual
(
alterations
[
'Se_in_MET'
],
[
sd
[
0
].
residue_number
])
def
test_remove_chains_of_length_one
(
self
):
pdb_lines
=
[
'SEQRES 1 A 1 GLY'
,
'ATOM 1 CA GLY A 1 0.000 0.000 0.000 1.00 '
' 0.00 C'
]
structure
=
_lines_to_structure
(
pdb_lines
)
alterations
=
{}
cleanup
.
_remove_chains_of_length_one
(
structure
,
alterations
)
chains
=
list
(
structure
.
iter_chains
())
self
.
assertEmpty
(
chains
)
self
.
assertCountEqual
(
alterations
[
'removed_chains'
].
values
(),
[[
'A'
]])
if
__name__
==
'__main__'
:
absltest
.
main
()
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment