Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenzk
AlphaFold2_jax
Commits
2f0d89e7
"vscode:/vscode.git/clone" did not exist on "96a3a8be05264407870d25aab830f98b161e192d"
Commit
2f0d89e7
authored
Aug 24, 2023
by
zhuwenwen
Browse files
remove duplicated code
parent
a1597f3f
Changes
103
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
7144 deletions
+0
-7144
alphafold/model/folding_multimer.py
alphafold/model/folding_multimer.py
+0
-1162
alphafold/model/geometry/__init__.py
alphafold/model/geometry/__init__.py
+0
-31
alphafold/model/geometry/rigid_matrix_vector.py
alphafold/model/geometry/rigid_matrix_vector.py
+0
-106
alphafold/model/geometry/rotation_matrix.py
alphafold/model/geometry/rotation_matrix.py
+0
-157
alphafold/model/geometry/struct_of_array.py
alphafold/model/geometry/struct_of_array.py
+0
-220
alphafold/model/geometry/test_utils.py
alphafold/model/geometry/test_utils.py
+0
-98
alphafold/model/geometry/utils.py
alphafold/model/geometry/utils.py
+0
-23
alphafold/model/geometry/vector.py
alphafold/model/geometry/vector.py
+0
-217
alphafold/model/layer_stack.py
alphafold/model/layer_stack.py
+0
-274
alphafold/model/layer_stack_test.py
alphafold/model/layer_stack_test.py
+0
-335
alphafold/model/lddt.py
alphafold/model/lddt.py
+0
-88
alphafold/model/lddt_test.py
alphafold/model/lddt_test.py
+0
-79
alphafold/model/mapping.py
alphafold/model/mapping.py
+0
-223
alphafold/model/model.py
alphafold/model/model.py
+0
-177
alphafold/model/modules.py
alphafold/model/modules.py
+0
-2104
alphafold/model/modules_multimer.py
alphafold/model/modules_multimer.py
+0
-1126
alphafold/model/prng.py
alphafold/model/prng.py
+0
-69
alphafold/model/prng_test.py
alphafold/model/prng_test.py
+0
-46
alphafold/model/quat_affine.py
alphafold/model/quat_affine.py
+0
-459
alphafold/model/quat_affine_test.py
alphafold/model/quat_affine_test.py
+0
-150
No files found.
alphafold/model/folding_multimer.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modules and utilities for the structure module in the multimer system."""
import
functools
import
numbers
from
typing
import
Any
,
Dict
,
Iterable
,
Mapping
,
Optional
,
Tuple
,
Union
from
alphafold.common
import
residue_constants
from
alphafold.model
import
all_atom_multimer
from
alphafold.model
import
common_modules
from
alphafold.model
import
geometry
from
alphafold.model
import
modules
from
alphafold.model
import
prng
from
alphafold.model
import
utils
from
alphafold.model.geometry
import
utils
as
geometry_utils
import
haiku
as
hk
import
jax
import
jax.numpy
as
jnp
import
ml_collections
import
numpy
as
np
EPSILON
=
1e-8
Float
=
Union
[
float
,
jnp
.
ndarray
]
def
squared_difference
(
x
:
jnp
.
ndarray
,
y
:
jnp
.
ndarray
)
->
jnp
.
ndarray
:
"""Computes Squared difference between two arrays."""
return
jnp
.
square
(
x
-
y
)
def
make_backbone_affine
(
positions
:
geometry
.
Vec3Array
,
mask
:
jnp
.
ndarray
,
aatype
:
jnp
.
ndarray
,
)
->
Tuple
[
geometry
.
Rigid3Array
,
jnp
.
ndarray
]:
"""Make backbone Rigid3Array and mask."""
del
aatype
a
=
residue_constants
.
atom_order
[
'N'
]
b
=
residue_constants
.
atom_order
[
'CA'
]
c
=
residue_constants
.
atom_order
[
'C'
]
rigid_mask
=
(
mask
[:,
a
]
*
mask
[:,
b
]
*
mask
[:,
c
]).
astype
(
jnp
.
float32
)
rigid
=
all_atom_multimer
.
make_transform_from_reference
(
a_xyz
=
positions
[:,
a
],
b_xyz
=
positions
[:,
b
],
c_xyz
=
positions
[:,
c
])
return
rigid
,
rigid_mask
class
QuatRigid
(
hk
.
Module
):
"""Module for projecting Rigids via a quaternion."""
def
__init__
(
self
,
global_config
:
ml_collections
.
ConfigDict
,
rigid_shape
:
Union
[
int
,
Iterable
[
int
]]
=
tuple
(),
full_quat
:
bool
=
False
,
init
:
str
=
'zeros'
,
name
:
str
=
'quat_rigid'
):
"""Module projecting a Rigid Object.
For this Module the Rotation is parametrized as a quaternion,
If 'full_quat' is True a 4 vector is produced for the rotation which is
normalized and treated as a quaternion.
When 'full_quat' is False a 3 vector is produced and the 1st component of
the quaternion is set to 1.
Args:
global_config: Global Config, used to set certain properties of underlying
Linear module, see common_modules.Linear for details.
rigid_shape: Shape of Rigids relative to shape of activations, e.g. when
activations have shape (n,) and this is (m,) output will be (n, m)
full_quat: Whether to parametrize rotation using full quaternion.
init: initializer to use, see common_modules.Linear for details
name: Name to use for module.
"""
self
.
init
=
init
self
.
global_config
=
global_config
if
isinstance
(
rigid_shape
,
int
):
self
.
rigid_shape
=
(
rigid_shape
,)
else
:
self
.
rigid_shape
=
tuple
(
rigid_shape
)
self
.
full_quat
=
full_quat
super
(
QuatRigid
,
self
).
__init__
(
name
=
name
)
def
__call__
(
self
,
activations
:
jnp
.
ndarray
)
->
geometry
.
Rigid3Array
:
"""Executes Module.
This returns a set of rigid with the same shape as activations, projecting
the channel dimension, rigid_shape controls the trailing dimensions.
For example when activations is shape (12, 5) and rigid_shape is (3, 2)
then the shape of the output rigids will be (12, 3, 2).
This also supports passing in an empty tuple for rigid shape, in that case
the example would produce a rigid of shape (12,).
Args:
activations: Activations to use for projection, shape [..., num_channel]
Returns:
Rigid transformations with shape [...] + rigid_shape
"""
if
self
.
full_quat
:
rigid_dim
=
7
else
:
rigid_dim
=
6
linear_dims
=
self
.
rigid_shape
+
(
rigid_dim
,)
rigid_flat
=
common_modules
.
Linear
(
linear_dims
,
initializer
=
self
.
init
,
precision
=
jax
.
lax
.
Precision
.
HIGHEST
,
name
=
'rigid'
)(
activations
)
rigid_flat
=
geometry_utils
.
unstack
(
rigid_flat
)
if
self
.
full_quat
:
qw
,
qx
,
qy
,
qz
=
rigid_flat
[:
4
]
translation
=
rigid_flat
[
4
:]
else
:
qx
,
qy
,
qz
=
rigid_flat
[:
3
]
qw
=
jnp
.
ones_like
(
qx
)
translation
=
rigid_flat
[
3
:]
rotation
=
geometry
.
Rot3Array
.
from_quaternion
(
qw
,
qx
,
qy
,
qz
,
normalize
=
True
)
translation
=
geometry
.
Vec3Array
(
*
translation
)
return
geometry
.
Rigid3Array
(
rotation
,
translation
)
class
PointProjection
(
hk
.
Module
):
"""Given input reprensentation and frame produces points in global frame."""
def
__init__
(
self
,
num_points
:
Union
[
Iterable
[
int
],
int
],
global_config
:
ml_collections
.
ConfigDict
,
return_local_points
:
bool
=
False
,
name
:
str
=
'point_projection'
):
"""Constructs Linear Module.
Args:
num_points: number of points to project. Can be tuple when outputting
multiple dimensions
global_config: Global Config, passed through to underlying Linear
return_local_points: Whether to return points in local frame as well.
name: name of module, used for name scopes.
"""
if
isinstance
(
num_points
,
numbers
.
Integral
):
self
.
num_points
=
(
num_points
,)
else
:
self
.
num_points
=
tuple
(
num_points
)
self
.
return_local_points
=
return_local_points
self
.
global_config
=
global_config
super
().
__init__
(
name
=
name
)
def
__call__
(
self
,
activations
:
jnp
.
ndarray
,
rigids
:
geometry
.
Rigid3Array
)
->
Union
[
geometry
.
Vec3Array
,
Tuple
[
geometry
.
Vec3Array
,
geometry
.
Vec3Array
]]:
output_shape
=
self
.
num_points
output_shape
=
output_shape
[:
-
1
]
+
(
3
*
output_shape
[
-
1
],)
points_local
=
common_modules
.
Linear
(
output_shape
,
precision
=
jax
.
lax
.
Precision
.
HIGHEST
,
name
=
'point_projection'
)(
activations
)
points_local
=
jnp
.
split
(
points_local
,
3
,
axis
=-
1
)
points_local
=
geometry
.
Vec3Array
(
*
points_local
)
rigids
=
rigids
[(...,)
+
(
None
,)
*
len
(
output_shape
)]
points_global
=
rigids
.
apply_to_point
(
points_local
)
if
self
.
return_local_points
:
return
points_global
,
points_local
else
:
return
points_global
class
InvariantPointAttention
(
hk
.
Module
):
"""Invariant point attention module.
The high-level idea is that this attention module works over a set of points
and associated orientations in 3D space (e.g. protein residues).
Each residue outputs a set of queries and keys as points in their local
reference frame. The attention is then defined as the euclidean distance
between the queries and keys in the global frame.
"""
def
__init__
(
self
,
config
:
ml_collections
.
ConfigDict
,
global_config
:
ml_collections
.
ConfigDict
,
dist_epsilon
:
float
=
1e-8
,
name
:
str
=
'invariant_point_attention'
):
"""Initialize.
Args:
config: iterative Fold Head Config
global_config: Global Config of Model.
dist_epsilon: Small value to avoid NaN in distance calculation.
name: Sonnet name.
"""
super
().
__init__
(
name
=
name
)
self
.
_dist_epsilon
=
dist_epsilon
self
.
_zero_initialize_last
=
global_config
.
zero_init
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
inputs_1d
:
jnp
.
ndarray
,
inputs_2d
:
jnp
.
ndarray
,
mask
:
jnp
.
ndarray
,
rigid
:
geometry
.
Rigid3Array
,
)
->
jnp
.
ndarray
:
"""Compute geometric aware attention.
Given a set of query residues (defined by affines and associated scalar
features), this function computes geometric aware attention between the
query residues and target residues.
The residues produce points in their local reference frame, which
are converted into the global frame to get attention via euclidean distance.
Equivalently the target residues produce points in their local frame to be
used as attention values, which are converted into the query residues local
frames.
Args:
inputs_1d: (N, C) 1D input embedding that is the basis for the
scalar queries.
inputs_2d: (N, M, C') 2D input embedding, used for biases values in the
attention between query_inputs_1d and target_inputs_1d.
mask: (N, 1) mask to indicate query_inputs_1d that participate in
the attention.
rigid: Rigid object describing the position and orientation of
every element in query_inputs_1d.
Returns:
Transformation of the input embedding.
"""
num_head
=
self
.
config
.
num_head
attn_logits
=
0.
num_point_qk
=
self
.
config
.
num_point_qk
# Each point pair (q, k) contributes Var [0.5 ||q||^2 - <q, k>] = 9 / 2
point_variance
=
max
(
num_point_qk
,
1
)
*
9.
/
2
point_weights
=
np
.
sqrt
(
1.0
/
point_variance
)
# This is equivalent to jax.nn.softplus, but avoids a bug in the test...
softplus
=
lambda
x
:
jnp
.
logaddexp
(
x
,
jnp
.
zeros_like
(
x
))
raw_point_weights
=
hk
.
get_parameter
(
'trainable_point_weights'
,
shape
=
[
num_head
],
# softplus^{-1} (1)
init
=
hk
.
initializers
.
Constant
(
np
.
log
(
np
.
exp
(
1.
)
-
1.
)))
# Trainable per-head weights for points.
trainable_point_weights
=
softplus
(
raw_point_weights
)
point_weights
*=
trainable_point_weights
q_point
=
PointProjection
([
num_head
,
num_point_qk
],
self
.
global_config
,
name
=
'q_point_projection'
)(
inputs_1d
,
rigid
)
k_point
=
PointProjection
([
num_head
,
num_point_qk
],
self
.
global_config
,
name
=
'k_point_projection'
)(
inputs_1d
,
rigid
)
dist2
=
geometry
.
square_euclidean_distance
(
q_point
[:,
None
,
:,
:],
k_point
[
None
,
:,
:,
:],
epsilon
=
0.
)
attn_qk_point
=
-
0.5
*
jnp
.
sum
(
point_weights
[:,
None
]
*
dist2
,
axis
=-
1
)
attn_logits
+=
attn_qk_point
num_scalar_qk
=
self
.
config
.
num_scalar_qk
# We assume that all queries and keys come iid from N(0, 1) distribution
# and compute the variances of the attention logits.
# Each scalar pair (q, k) contributes Var q*k = 1
scalar_variance
=
max
(
num_scalar_qk
,
1
)
*
1.
scalar_weights
=
np
.
sqrt
(
1.0
/
scalar_variance
)
q_scalar
=
common_modules
.
Linear
([
num_head
,
num_scalar_qk
],
use_bias
=
False
,
name
=
'q_scalar_projection'
)(
inputs_1d
)
k_scalar
=
common_modules
.
Linear
([
num_head
,
num_scalar_qk
],
use_bias
=
False
,
name
=
'k_scalar_projection'
)(
inputs_1d
)
q_scalar
*=
scalar_weights
attn_logits
+=
jnp
.
einsum
(
'qhc,khc->qkh'
,
q_scalar
,
k_scalar
)
attention_2d
=
common_modules
.
Linear
(
num_head
,
name
=
'attention_2d'
)(
inputs_2d
)
attn_logits
+=
attention_2d
mask_2d
=
mask
*
jnp
.
swapaxes
(
mask
,
-
1
,
-
2
)
attn_logits
-=
1e5
*
(
1.
-
mask_2d
[...,
None
])
attn_logits
*=
np
.
sqrt
(
1.
/
3
)
# Normalize by number of logit terms (3)
attn
=
jax
.
nn
.
softmax
(
attn_logits
,
axis
=-
2
)
num_scalar_v
=
self
.
config
.
num_scalar_v
v_scalar
=
common_modules
.
Linear
([
num_head
,
num_scalar_v
],
use_bias
=
False
,
name
=
'v_scalar_projection'
)(
inputs_1d
)
# [num_query_residues, num_head, num_scalar_v]
result_scalar
=
jnp
.
einsum
(
'qkh, khc->qhc'
,
attn
,
v_scalar
)
num_point_v
=
self
.
config
.
num_point_v
v_point
=
PointProjection
([
num_head
,
num_point_v
],
self
.
global_config
,
name
=
'v_point_projection'
)(
inputs_1d
,
rigid
)
result_point_global
=
jax
.
tree_map
(
lambda
x
:
jnp
.
sum
(
attn
[...,
None
]
*
x
,
axis
=-
3
),
v_point
[
None
])
# Features used in the linear output projection. Should have the size
# [num_query_residues, ?]
output_features
=
[]
num_query_residues
,
_
=
inputs_1d
.
shape
flat_shape
=
[
num_query_residues
,
-
1
]
result_scalar
=
jnp
.
reshape
(
result_scalar
,
flat_shape
)
output_features
.
append
(
result_scalar
)
result_point_global
=
jax
.
tree_map
(
lambda
r
:
jnp
.
reshape
(
r
,
flat_shape
),
result_point_global
)
result_point_local
=
rigid
[...,
None
].
apply_inverse_to_point
(
result_point_global
)
output_features
.
extend
(
[
result_point_local
.
x
,
result_point_local
.
y
,
result_point_local
.
z
])
point_norms
=
result_point_local
.
norm
(
self
.
_dist_epsilon
)
output_features
.
append
(
point_norms
)
# Dimensions: h = heads, i and j = residues,
# c = inputs_2d channels
# Contraction happens over the second residue dimension, similarly to how
# the usual attention is performed.
result_attention_over_2d
=
jnp
.
einsum
(
'ijh, ijc->ihc'
,
attn
,
inputs_2d
)
output_features
.
append
(
jnp
.
reshape
(
result_attention_over_2d
,
flat_shape
))
final_init
=
'zeros'
if
self
.
_zero_initialize_last
else
'linear'
final_act
=
jnp
.
concatenate
(
output_features
,
axis
=-
1
)
return
common_modules
.
Linear
(
self
.
config
.
num_channel
,
initializer
=
final_init
,
name
=
'output_projection'
)(
final_act
)
class
FoldIteration
(
hk
.
Module
):
"""A single iteration of iterative folding.
First, each residue attends to all residues using InvariantPointAttention.
Then, we apply transition layers to update the hidden representations.
Finally, we use the hidden representations to produce an update to the
affine of each residue.
"""
def
__init__
(
self
,
config
:
ml_collections
.
ConfigDict
,
global_config
:
ml_collections
.
ConfigDict
,
name
:
str
=
'fold_iteration'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
activations
:
Mapping
[
str
,
Any
],
aatype
:
jnp
.
ndarray
,
sequence_mask
:
jnp
.
ndarray
,
update_rigid
:
bool
,
is_training
:
bool
,
initial_act
:
jnp
.
ndarray
,
safe_key
:
Optional
[
prng
.
SafeKey
]
=
None
,
static_feat_2d
:
Optional
[
jnp
.
ndarray
]
=
None
,
)
->
Tuple
[
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]]:
c
=
self
.
config
if
safe_key
is
None
:
safe_key
=
prng
.
SafeKey
(
hk
.
next_rng_key
())
def
safe_dropout_fn
(
tensor
,
safe_key
):
return
modules
.
apply_dropout
(
tensor
=
tensor
,
safe_key
=
safe_key
,
rate
=
0.0
if
self
.
global_config
.
deterministic
else
c
.
dropout
,
is_training
=
is_training
)
rigid
=
activations
[
'rigid'
]
act
=
activations
[
'act'
]
attention_module
=
InvariantPointAttention
(
self
.
config
,
self
.
global_config
)
# Attention
act
+=
attention_module
(
inputs_1d
=
act
,
inputs_2d
=
static_feat_2d
,
mask
=
sequence_mask
,
rigid
=
rigid
)
safe_key
,
*
sub_keys
=
safe_key
.
split
(
3
)
sub_keys
=
iter
(
sub_keys
)
act
=
safe_dropout_fn
(
act
,
next
(
sub_keys
))
act
=
hk
.
LayerNorm
(
axis
=-
1
,
create_scale
=
True
,
create_offset
=
True
,
name
=
'attention_layer_norm'
)(
act
)
final_init
=
'zeros'
if
self
.
global_config
.
zero_init
else
'linear'
# Transition
input_act
=
act
for
i
in
range
(
c
.
num_layer_in_transition
):
init
=
'relu'
if
i
<
c
.
num_layer_in_transition
-
1
else
final_init
act
=
common_modules
.
Linear
(
c
.
num_channel
,
initializer
=
init
,
name
=
'transition'
)(
act
)
if
i
<
c
.
num_layer_in_transition
-
1
:
act
=
jax
.
nn
.
relu
(
act
)
act
+=
input_act
act
=
safe_dropout_fn
(
act
,
next
(
sub_keys
))
act
=
hk
.
LayerNorm
(
axis
=-
1
,
create_scale
=
True
,
create_offset
=
True
,
name
=
'transition_layer_norm'
)(
act
)
if
update_rigid
:
# Rigid update
rigid_update
=
QuatRigid
(
self
.
global_config
,
init
=
final_init
)(
act
)
rigid
=
rigid
@
rigid_update
sc
=
MultiRigidSidechain
(
c
.
sidechain
,
self
.
global_config
)(
rigid
.
scale_translation
(
c
.
position_scale
),
[
act
,
initial_act
],
aatype
)
outputs
=
{
'rigid'
:
rigid
,
'sc'
:
sc
}
rotation
=
jax
.
tree_map
(
jax
.
lax
.
stop_gradient
,
rigid
.
rotation
)
rigid
=
geometry
.
Rigid3Array
(
rotation
,
rigid
.
translation
)
new_activations
=
{
'act'
:
act
,
'rigid'
:
rigid
}
return
new_activations
,
outputs
def
generate_monomer_rigids
(
representations
:
Mapping
[
str
,
jnp
.
ndarray
],
batch
:
Mapping
[
str
,
jnp
.
ndarray
],
config
:
ml_collections
.
ConfigDict
,
global_config
:
ml_collections
.
ConfigDict
,
is_training
:
bool
,
safe_key
:
prng
.
SafeKey
)
->
Dict
[
str
,
Any
]:
"""Generate predicted Rigid's for a single chain.
This is the main part of the iterative fold head - it iteratively applies
folding to produce a set of predicted residue positions.
Args:
representations: Embeddings dictionary.
batch: Batch dictionary.
config: config for the iterative fold head.
global_config: global config.
is_training: is training.
safe_key: A prng.SafeKey object that wraps a PRNG key.
Returns:
A dictionary containing residue Rigid's and sidechain positions.
"""
c
=
config
sequence_mask
=
batch
[
'seq_mask'
][:,
None
]
act
=
hk
.
LayerNorm
(
axis
=-
1
,
create_scale
=
True
,
create_offset
=
True
,
name
=
'single_layer_norm'
)(
representations
[
'single'
])
initial_act
=
act
act
=
common_modules
.
Linear
(
c
.
num_channel
,
name
=
'initial_projection'
)(
act
)
# Sequence Mask has extra 1 at the end.
rigid
=
geometry
.
Rigid3Array
.
identity
(
sequence_mask
.
shape
[:
-
1
])
fold_iteration
=
FoldIteration
(
c
,
global_config
,
name
=
'fold_iteration'
)
assert
len
(
batch
[
'seq_mask'
].
shape
)
==
1
activations
=
{
'act'
:
act
,
'rigid'
:
rigid
}
act_2d
=
hk
.
LayerNorm
(
axis
=-
1
,
create_scale
=
True
,
create_offset
=
True
,
name
=
'pair_layer_norm'
)(
representations
[
'pair'
])
safe_keys
=
safe_key
.
split
(
c
.
num_layer
)
outputs
=
[]
for
key
in
safe_keys
:
activations
,
output
=
fold_iteration
(
activations
,
initial_act
=
initial_act
,
static_feat_2d
=
act_2d
,
aatype
=
batch
[
'aatype'
],
safe_key
=
key
,
sequence_mask
=
sequence_mask
,
update_rigid
=
True
,
is_training
=
is_training
,
)
outputs
.
append
(
output
)
output
=
jax
.
tree_map
(
lambda
*
x
:
jnp
.
stack
(
x
),
*
outputs
)
# Pass along for LDDT-Head.
output
[
'act'
]
=
activations
[
'act'
]
return
output
class
StructureModule
(
hk
.
Module
):
"""StructureModule as a network head.
Jumper et al. (2021) Suppl. Alg. 20 "StructureModule"
"""
def
__init__
(
self
,
config
:
ml_collections
.
ConfigDict
,
global_config
:
ml_collections
.
ConfigDict
,
name
:
str
=
'structure_module'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
representations
:
Mapping
[
str
,
jnp
.
ndarray
],
batch
:
Mapping
[
str
,
Any
],
is_training
:
bool
,
safe_key
:
Optional
[
prng
.
SafeKey
]
=
None
,
compute_loss
:
bool
=
False
)
->
Dict
[
str
,
Any
]:
c
=
self
.
config
ret
=
{}
if
safe_key
is
None
:
safe_key
=
prng
.
SafeKey
(
hk
.
next_rng_key
())
output
=
generate_monomer_rigids
(
representations
=
representations
,
batch
=
batch
,
config
=
self
.
config
,
global_config
=
self
.
global_config
,
is_training
=
is_training
,
safe_key
=
safe_key
)
ret
[
'traj'
]
=
output
[
'rigid'
].
scale_translation
(
c
.
position_scale
).
to_array
()
ret
[
'sidechains'
]
=
output
[
'sc'
]
ret
[
'sidechains'
][
'atom_pos'
]
=
ret
[
'sidechains'
][
'atom_pos'
].
to_array
()
ret
[
'sidechains'
][
'frames'
]
=
ret
[
'sidechains'
][
'frames'
].
to_array
()
if
'local_atom_pos'
in
ret
[
'sidechains'
]:
ret
[
'sidechains'
][
'local_atom_pos'
]
=
ret
[
'sidechains'
][
'local_atom_pos'
].
to_array
()
ret
[
'sidechains'
][
'local_frames'
]
=
ret
[
'sidechains'
][
'local_frames'
].
to_array
()
aatype
=
batch
[
'aatype'
]
seq_mask
=
batch
[
'seq_mask'
]
atom14_pred_mask
=
all_atom_multimer
.
get_atom14_mask
(
aatype
)
*
seq_mask
[:,
None
]
atom14_pred_positions
=
output
[
'sc'
][
'atom_pos'
][
-
1
]
ret
[
'final_atom14_positions'
]
=
atom14_pred_positions
# (N, 14, 3)
ret
[
'final_atom14_mask'
]
=
atom14_pred_mask
# (N, 14)
atom37_mask
=
all_atom_multimer
.
get_atom37_mask
(
aatype
)
*
seq_mask
[:,
None
]
atom37_pred_positions
=
all_atom_multimer
.
atom14_to_atom37
(
atom14_pred_positions
,
aatype
)
atom37_pred_positions
*=
atom37_mask
[:,
:,
None
]
ret
[
'final_atom_positions'
]
=
atom37_pred_positions
# (N, 37, 3)
ret
[
'final_atom_mask'
]
=
atom37_mask
# (N, 37)
ret
[
'final_rigids'
]
=
ret
[
'traj'
][
-
1
]
ret
[
'act'
]
=
output
[
'act'
]
if
compute_loss
:
return
ret
else
:
no_loss_features
=
[
'final_atom_positions'
,
'final_atom_mask'
,
'act'
]
no_loss_ret
=
{
k
:
ret
[
k
]
for
k
in
no_loss_features
}
return
no_loss_ret
def
loss
(
self
,
value
:
Mapping
[
str
,
Any
],
batch
:
Mapping
[
str
,
Any
]
)
->
Dict
[
str
,
Any
]:
raise
NotImplementedError
(
'This function should be called on a batch with reordered chains (see '
'Evans et al (2021) Section 7.3. Multi-Chain Permutation Alignment.'
)
ret
=
{
'loss'
:
0.
}
ret
[
'metrics'
]
=
{}
aatype
=
batch
[
'aatype'
]
all_atom_positions
=
batch
[
'all_atom_positions'
]
all_atom_positions
=
geometry
.
Vec3Array
.
from_array
(
all_atom_positions
)
all_atom_mask
=
batch
[
'all_atom_mask'
]
seq_mask
=
batch
[
'seq_mask'
]
residue_index
=
batch
[
'residue_index'
]
gt_rigid
,
gt_affine_mask
=
make_backbone_affine
(
all_atom_positions
,
all_atom_mask
,
aatype
)
chi_angles
,
chi_mask
=
all_atom_multimer
.
compute_chi_angles
(
all_atom_positions
,
all_atom_mask
,
aatype
)
pred_mask
=
all_atom_multimer
.
get_atom14_mask
(
aatype
)
pred_mask
*=
seq_mask
[:,
None
]
pred_positions
=
value
[
'final_atom14_positions'
]
pred_positions
=
geometry
.
Vec3Array
.
from_array
(
pred_positions
)
gt_positions
,
gt_mask
,
alt_naming_is_better
=
compute_atom14_gt
(
aatype
,
all_atom_positions
,
all_atom_mask
,
pred_positions
)
violations
=
find_structural_violations
(
aatype
=
aatype
,
residue_index
=
residue_index
,
mask
=
pred_mask
,
pred_positions
=
pred_positions
,
config
=
self
.
config
,
asym_id
=
batch
[
'asym_id'
])
sidechains
=
value
[
'sidechains'
]
gt_chi_angles
=
get_renamed_chi_angles
(
aatype
,
chi_angles
,
alt_naming_is_better
)
# Several violation metrics:
violation_metrics
=
compute_violation_metrics
(
residue_index
=
residue_index
,
mask
=
pred_mask
,
seq_mask
=
seq_mask
,
pred_positions
=
pred_positions
,
violations
=
violations
)
ret
[
'metrics'
].
update
(
violation_metrics
)
target_rigid
=
geometry
.
Rigid3Array
.
from_array
(
value
[
'traj'
])
gt_frames_mask
=
gt_affine_mask
# Split the loss into within-chain and between-chain components.
intra_chain_mask
=
batch
[
'asym_id'
][:,
None
]
==
batch
[
'asym_id'
][
None
,
:]
intra_chain_bb_loss
,
intra_chain_fape
=
backbone_loss
(
gt_rigid
=
gt_rigid
,
gt_frames_mask
=
gt_frames_mask
,
gt_positions_mask
=
gt_affine_mask
,
target_rigid
=
target_rigid
,
config
=
self
.
config
.
intra_chain_fape
,
pair_mask
=
intra_chain_mask
)
interface_bb_loss
,
interface_fape
=
backbone_loss
(
gt_rigid
=
gt_rigid
,
gt_frames_mask
=
gt_frames_mask
,
gt_positions_mask
=
gt_affine_mask
,
target_rigid
=
target_rigid
,
config
=
self
.
config
.
interface_fape
,
pair_mask
=
1.
-
intra_chain_mask
)
bb_loss
=
intra_chain_bb_loss
+
interface_bb_loss
ret
[
'fape'
]
=
intra_chain_fape
+
interface_fape
ret
[
'bb_loss'
]
=
bb_loss
ret
[
'loss'
]
+=
bb_loss
pred_frames
=
geometry
.
Rigid3Array
.
from_array
(
sidechains
[
'frames'
])
pred_positions
=
geometry
.
Vec3Array
.
from_array
(
sidechains
[
'atom_pos'
])
gt_sc_frames
,
gt_sc_frames_mask
=
compute_frames
(
aatype
=
aatype
,
all_atom_positions
=
all_atom_positions
,
all_atom_mask
=
all_atom_mask
,
use_alt
=
alt_naming_is_better
)
sc_loss
=
sidechain_loss
(
gt_frames
=
gt_sc_frames
,
gt_frames_mask
=
gt_sc_frames_mask
,
gt_positions
=
gt_positions
,
gt_mask
=
gt_mask
,
pred_frames
=
pred_frames
,
pred_positions
=
pred_positions
,
config
=
self
.
config
)
ret
[
'loss'
]
=
((
1
-
self
.
config
.
sidechain
.
weight_frac
)
*
ret
[
'loss'
]
+
self
.
config
.
sidechain
.
weight_frac
*
sc_loss
[
'loss'
])
ret
[
'sidechain_fape'
]
=
sc_loss
[
'fape'
]
unnormed_angles
=
sidechains
[
'unnormalized_angles_sin_cos'
]
pred_angles
=
sidechains
[
'angles_sin_cos'
]
sup_chi_loss
,
ret
[
'chi_loss'
],
ret
[
'angle_norm_loss'
]
=
supervised_chi_loss
(
sequence_mask
=
seq_mask
,
target_chi_mask
=
chi_mask
,
target_chi_angles
=
gt_chi_angles
,
aatype
=
aatype
,
pred_angles
=
pred_angles
,
unnormed_angles
=
unnormed_angles
,
config
=
self
.
config
)
ret
[
'loss'
]
+=
sup_chi_loss
if
self
.
config
.
structural_violation_loss_weight
:
ret
[
'loss'
]
+=
structural_violation_loss
(
mask
=
pred_mask
,
violations
=
violations
,
config
=
self
.
config
)
return
ret
def
compute_atom14_gt
(
aatype
:
jnp
.
ndarray
,
all_atom_positions
:
geometry
.
Vec3Array
,
all_atom_mask
:
jnp
.
ndarray
,
pred_pos
:
geometry
.
Vec3Array
)
->
Tuple
[
geometry
.
Vec3Array
,
jnp
.
ndarray
,
jnp
.
ndarray
]:
"""Find atom14 positions, this includes finding the correct renaming."""
gt_positions
,
gt_mask
=
all_atom_multimer
.
atom37_to_atom14
(
aatype
,
all_atom_positions
,
all_atom_mask
)
alt_gt_positions
,
alt_gt_mask
=
all_atom_multimer
.
get_alt_atom14
(
aatype
,
gt_positions
,
gt_mask
)
atom_is_ambiguous
=
all_atom_multimer
.
get_atom14_is_ambiguous
(
aatype
)
alt_naming_is_better
=
all_atom_multimer
.
find_optimal_renaming
(
gt_positions
=
gt_positions
,
alt_gt_positions
=
alt_gt_positions
,
atom_is_ambiguous
=
atom_is_ambiguous
,
gt_exists
=
gt_mask
,
pred_positions
=
pred_pos
)
use_alt
=
alt_naming_is_better
[:,
None
]
gt_mask
=
(
1.
-
use_alt
)
*
gt_mask
+
use_alt
*
alt_gt_mask
gt_positions
=
(
1.
-
use_alt
)
*
gt_positions
+
use_alt
*
alt_gt_positions
return
gt_positions
,
alt_gt_mask
,
alt_naming_is_better
def
backbone_loss
(
gt_rigid
:
geometry
.
Rigid3Array
,
gt_frames_mask
:
jnp
.
ndarray
,
gt_positions_mask
:
jnp
.
ndarray
,
target_rigid
:
geometry
.
Rigid3Array
,
config
:
ml_collections
.
ConfigDict
,
pair_mask
:
jnp
.
ndarray
)
->
Tuple
[
Float
,
jnp
.
ndarray
]:
"""Backbone FAPE Loss."""
loss_fn
=
functools
.
partial
(
all_atom_multimer
.
frame_aligned_point_error
,
l1_clamp_distance
=
config
.
atom_clamp_distance
,
loss_unit_distance
=
config
.
loss_unit_distance
)
loss_fn
=
jax
.
vmap
(
loss_fn
,
(
0
,
None
,
None
,
0
,
None
,
None
,
None
))
fape
=
loss_fn
(
target_rigid
,
gt_rigid
,
gt_frames_mask
,
target_rigid
.
translation
,
gt_rigid
.
translation
,
gt_positions_mask
,
pair_mask
)
return
jnp
.
mean
(
fape
),
fape
[
-
1
]
def
compute_frames
(
aatype
:
jnp
.
ndarray
,
all_atom_positions
:
geometry
.
Vec3Array
,
all_atom_mask
:
jnp
.
ndarray
,
use_alt
:
jnp
.
ndarray
)
->
Tuple
[
geometry
.
Rigid3Array
,
jnp
.
ndarray
]:
"""Compute Frames from all atom positions.
Args:
aatype: array of aatypes, int of [N]
all_atom_positions: Vector of all atom positions, shape [N, 37]
all_atom_mask: mask, shape [N]
use_alt: whether to use alternative orientation for ambiguous aatypes
shape [N]
Returns:
Rigid corresponding to Frames w shape [N, 8],
mask which Rigids are present w shape [N, 8]
"""
frames_batch
=
all_atom_multimer
.
atom37_to_frames
(
aatype
,
all_atom_positions
,
all_atom_mask
)
gt_frames
=
frames_batch
[
'rigidgroups_gt_frames'
]
alt_gt_frames
=
frames_batch
[
'rigidgroups_alt_gt_frames'
]
use_alt
=
use_alt
[:,
None
]
renamed_gt_frames
=
jax
.
tree_map
(
lambda
x
,
y
:
(
1.
-
use_alt
)
*
x
+
use_alt
*
y
,
gt_frames
,
alt_gt_frames
)
return
renamed_gt_frames
,
frames_batch
[
'rigidgroups_gt_exists'
]
def
sidechain_loss
(
gt_frames
:
geometry
.
Rigid3Array
,
gt_frames_mask
:
jnp
.
ndarray
,
gt_positions
:
geometry
.
Vec3Array
,
gt_mask
:
jnp
.
ndarray
,
pred_frames
:
geometry
.
Rigid3Array
,
pred_positions
:
geometry
.
Vec3Array
,
config
:
ml_collections
.
ConfigDict
)
->
Dict
[
str
,
jnp
.
ndarray
]:
"""Sidechain Loss using cleaned up rigids."""
flat_gt_frames
=
jax
.
tree_map
(
jnp
.
ravel
,
gt_frames
)
flat_frames_mask
=
jnp
.
ravel
(
gt_frames_mask
)
flat_gt_positions
=
jax
.
tree_map
(
jnp
.
ravel
,
gt_positions
)
flat_positions_mask
=
jnp
.
ravel
(
gt_mask
)
# Compute frame_aligned_point_error score for the final layer.
def
_slice_last_layer_and_flatten
(
x
):
return
jnp
.
ravel
(
x
[
-
1
])
flat_pred_frames
=
jax
.
tree_map
(
_slice_last_layer_and_flatten
,
pred_frames
)
flat_pred_positions
=
jax
.
tree_map
(
_slice_last_layer_and_flatten
,
pred_positions
)
fape
=
all_atom_multimer
.
frame_aligned_point_error
(
pred_frames
=
flat_pred_frames
,
target_frames
=
flat_gt_frames
,
frames_mask
=
flat_frames_mask
,
pred_positions
=
flat_pred_positions
,
target_positions
=
flat_gt_positions
,
positions_mask
=
flat_positions_mask
,
pair_mask
=
None
,
length_scale
=
config
.
sidechain
.
loss_unit_distance
,
l1_clamp_distance
=
config
.
sidechain
.
atom_clamp_distance
)
return
{
'fape'
:
fape
,
'loss'
:
fape
}
def
structural_violation_loss
(
mask
:
jnp
.
ndarray
,
violations
:
Mapping
[
str
,
Float
],
config
:
ml_collections
.
ConfigDict
)
->
Float
:
"""Computes Loss for structural Violations."""
# Put all violation losses together to one large loss.
num_atoms
=
jnp
.
sum
(
mask
).
astype
(
jnp
.
float32
)
+
1e-6
between_residues
=
violations
[
'between_residues'
]
within_residues
=
violations
[
'within_residues'
]
return
(
config
.
structural_violation_loss_weight
*
(
between_residues
[
'bonds_c_n_loss_mean'
]
+
between_residues
[
'angles_ca_c_n_loss_mean'
]
+
between_residues
[
'angles_c_n_ca_loss_mean'
]
+
jnp
.
sum
(
between_residues
[
'clashes_per_atom_loss_sum'
]
+
within_residues
[
'per_atom_loss_sum'
])
/
num_atoms
))
def
find_structural_violations
(
aatype
:
jnp
.
ndarray
,
residue_index
:
jnp
.
ndarray
,
mask
:
jnp
.
ndarray
,
pred_positions
:
geometry
.
Vec3Array
,
# (N, 14)
config
:
ml_collections
.
ConfigDict
,
asym_id
:
jnp
.
ndarray
,
)
->
Dict
[
str
,
Any
]:
"""Computes several checks for structural Violations."""
# Compute between residue backbone violations of bonds and angles.
connection_violations
=
all_atom_multimer
.
between_residue_bond_loss
(
pred_atom_positions
=
pred_positions
,
pred_atom_mask
=
mask
.
astype
(
jnp
.
float32
),
residue_index
=
residue_index
.
astype
(
jnp
.
float32
),
aatype
=
aatype
,
tolerance_factor_soft
=
config
.
violation_tolerance_factor
,
tolerance_factor_hard
=
config
.
violation_tolerance_factor
)
# Compute the van der Waals radius for every atom
# (the first letter of the atom name is the element type).
# shape (N, 14)
atomtype_radius
=
jnp
.
array
([
residue_constants
.
van_der_waals_radius
[
name
[
0
]]
for
name
in
residue_constants
.
atom_types
])
residx_atom14_to_atom37
=
all_atom_multimer
.
get_atom14_to_atom37_map
(
aatype
)
atom_radius
=
mask
*
utils
.
batched_gather
(
atomtype_radius
,
residx_atom14_to_atom37
)
# Compute the between residue clash loss.
between_residue_clashes
=
all_atom_multimer
.
between_residue_clash_loss
(
pred_positions
=
pred_positions
,
atom_exists
=
mask
,
atom_radius
=
atom_radius
,
residue_index
=
residue_index
,
overlap_tolerance_soft
=
config
.
clash_overlap_tolerance
,
overlap_tolerance_hard
=
config
.
clash_overlap_tolerance
,
asym_id
=
asym_id
)
# Compute all within-residue violations (clashes,
# bond length and angle violations).
restype_atom14_bounds
=
residue_constants
.
make_atom14_dists_bounds
(
overlap_tolerance
=
config
.
clash_overlap_tolerance
,
bond_length_tolerance_factor
=
config
.
violation_tolerance_factor
)
dists_lower_bound
=
utils
.
batched_gather
(
restype_atom14_bounds
[
'lower_bound'
],
aatype
)
dists_upper_bound
=
utils
.
batched_gather
(
restype_atom14_bounds
[
'upper_bound'
],
aatype
)
within_residue_violations
=
all_atom_multimer
.
within_residue_violations
(
pred_positions
=
pred_positions
,
atom_exists
=
mask
,
dists_lower_bound
=
dists_lower_bound
,
dists_upper_bound
=
dists_upper_bound
,
tighten_bounds_for_loss
=
0.0
)
# Combine them to a single per-residue violation mask (used later for LDDT).
per_residue_violations_mask
=
jnp
.
max
(
jnp
.
stack
([
connection_violations
[
'per_residue_violation_mask'
],
jnp
.
max
(
between_residue_clashes
[
'per_atom_clash_mask'
],
axis
=-
1
),
jnp
.
max
(
within_residue_violations
[
'per_atom_violations'
],
axis
=-
1
)]),
axis
=
0
)
return
{
'between_residues'
:
{
'bonds_c_n_loss_mean'
:
connection_violations
[
'c_n_loss_mean'
],
# ()
'angles_ca_c_n_loss_mean'
:
connection_violations
[
'ca_c_n_loss_mean'
],
# ()
'angles_c_n_ca_loss_mean'
:
connection_violations
[
'c_n_ca_loss_mean'
],
# ()
'connections_per_residue_loss_sum'
:
connection_violations
[
'per_residue_loss_sum'
],
# (N)
'connections_per_residue_violation_mask'
:
connection_violations
[
'per_residue_violation_mask'
],
# (N)
'clashes_mean_loss'
:
between_residue_clashes
[
'mean_loss'
],
# ()
'clashes_per_atom_loss_sum'
:
between_residue_clashes
[
'per_atom_loss_sum'
],
# (N, 14)
'clashes_per_atom_clash_mask'
:
between_residue_clashes
[
'per_atom_clash_mask'
],
# (N, 14)
},
'within_residues'
:
{
'per_atom_loss_sum'
:
within_residue_violations
[
'per_atom_loss_sum'
],
# (N, 14)
'per_atom_violations'
:
within_residue_violations
[
'per_atom_violations'
],
# (N, 14),
},
'total_per_residue_violations_mask'
:
per_residue_violations_mask
,
# (N)
}
def
compute_violation_metrics
(
residue_index
:
jnp
.
ndarray
,
mask
:
jnp
.
ndarray
,
seq_mask
:
jnp
.
ndarray
,
pred_positions
:
geometry
.
Vec3Array
,
# (N, 14)
violations
:
Mapping
[
str
,
jnp
.
ndarray
],
)
->
Dict
[
str
,
jnp
.
ndarray
]:
"""Compute several metrics to assess the structural violations."""
ret
=
{}
between_residues
=
violations
[
'between_residues'
]
within_residues
=
violations
[
'within_residues'
]
extreme_ca_ca_violations
=
all_atom_multimer
.
extreme_ca_ca_distance_violations
(
positions
=
pred_positions
,
mask
=
mask
.
astype
(
jnp
.
float32
),
residue_index
=
residue_index
.
astype
(
jnp
.
float32
))
ret
[
'violations_extreme_ca_ca_distance'
]
=
extreme_ca_ca_violations
ret
[
'violations_between_residue_bond'
]
=
utils
.
mask_mean
(
mask
=
seq_mask
,
value
=
between_residues
[
'connections_per_residue_violation_mask'
])
ret
[
'violations_between_residue_clash'
]
=
utils
.
mask_mean
(
mask
=
seq_mask
,
value
=
jnp
.
max
(
between_residues
[
'clashes_per_atom_clash_mask'
],
axis
=-
1
))
ret
[
'violations_within_residue'
]
=
utils
.
mask_mean
(
mask
=
seq_mask
,
value
=
jnp
.
max
(
within_residues
[
'per_atom_violations'
],
axis
=-
1
))
ret
[
'violations_per_residue'
]
=
utils
.
mask_mean
(
mask
=
seq_mask
,
value
=
violations
[
'total_per_residue_violations_mask'
])
return
ret
def
supervised_chi_loss
(
sequence_mask
:
jnp
.
ndarray
,
target_chi_mask
:
jnp
.
ndarray
,
aatype
:
jnp
.
ndarray
,
target_chi_angles
:
jnp
.
ndarray
,
pred_angles
:
jnp
.
ndarray
,
unnormed_angles
:
jnp
.
ndarray
,
config
:
ml_collections
.
ConfigDict
)
->
Tuple
[
Float
,
Float
,
Float
]:
"""Computes loss for direct chi angle supervision."""
eps
=
1e-6
chi_mask
=
target_chi_mask
.
astype
(
jnp
.
float32
)
pred_angles
=
pred_angles
[:,
:,
3
:]
residue_type_one_hot
=
jax
.
nn
.
one_hot
(
aatype
,
residue_constants
.
restype_num
+
1
,
dtype
=
jnp
.
float32
)[
None
]
chi_pi_periodic
=
jnp
.
einsum
(
'ijk, kl->ijl'
,
residue_type_one_hot
,
jnp
.
asarray
(
residue_constants
.
chi_pi_periodic
))
true_chi
=
target_chi_angles
[
None
]
sin_true_chi
=
jnp
.
sin
(
true_chi
)
cos_true_chi
=
jnp
.
cos
(
true_chi
)
sin_cos_true_chi
=
jnp
.
stack
([
sin_true_chi
,
cos_true_chi
],
axis
=-
1
)
# This is -1 if chi is pi periodic and +1 if it's 2 pi periodic
shifted_mask
=
(
1
-
2
*
chi_pi_periodic
)[...,
None
]
sin_cos_true_chi_shifted
=
shifted_mask
*
sin_cos_true_chi
sq_chi_error
=
jnp
.
sum
(
squared_difference
(
sin_cos_true_chi
,
pred_angles
),
-
1
)
sq_chi_error_shifted
=
jnp
.
sum
(
squared_difference
(
sin_cos_true_chi_shifted
,
pred_angles
),
-
1
)
sq_chi_error
=
jnp
.
minimum
(
sq_chi_error
,
sq_chi_error_shifted
)
sq_chi_loss
=
utils
.
mask_mean
(
mask
=
chi_mask
[
None
],
value
=
sq_chi_error
)
angle_norm
=
jnp
.
sqrt
(
jnp
.
sum
(
jnp
.
square
(
unnormed_angles
),
axis
=-
1
)
+
eps
)
norm_error
=
jnp
.
abs
(
angle_norm
-
1.
)
angle_norm_loss
=
utils
.
mask_mean
(
mask
=
sequence_mask
[
None
,
:,
None
],
value
=
norm_error
)
loss
=
(
config
.
chi_weight
*
sq_chi_loss
+
config
.
angle_norm_weight
*
angle_norm_loss
)
return
loss
,
sq_chi_loss
,
angle_norm_loss
def
l2_normalize
(
x
:
jnp
.
ndarray
,
axis
:
int
=
-
1
,
epsilon
:
float
=
1e-12
)
->
jnp
.
ndarray
:
return
x
/
jnp
.
sqrt
(
jnp
.
maximum
(
jnp
.
sum
(
x
**
2
,
axis
=
axis
,
keepdims
=
True
),
epsilon
))
def
get_renamed_chi_angles
(
aatype
:
jnp
.
ndarray
,
chi_angles
:
jnp
.
ndarray
,
alt_is_better
:
jnp
.
ndarray
)
->
jnp
.
ndarray
:
"""Return renamed chi angles."""
chi_angle_is_ambiguous
=
utils
.
batched_gather
(
jnp
.
array
(
residue_constants
.
chi_pi_periodic
,
dtype
=
jnp
.
float32
),
aatype
)
alt_chi_angles
=
chi_angles
+
np
.
pi
*
chi_angle_is_ambiguous
# Map back to [-pi, pi].
alt_chi_angles
=
alt_chi_angles
-
2
*
np
.
pi
*
(
alt_chi_angles
>
np
.
pi
).
astype
(
jnp
.
float32
)
alt_is_better
=
alt_is_better
[:,
None
]
return
(
1.
-
alt_is_better
)
*
chi_angles
+
alt_is_better
*
alt_chi_angles
class
MultiRigidSidechain
(
hk
.
Module
):
"""Class to make side chain atoms."""
def
__init__
(
self
,
config
:
ml_collections
.
ConfigDict
,
global_config
:
ml_collections
.
ConfigDict
,
name
:
str
=
'rigid_sidechain'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
rigid
:
geometry
.
Rigid3Array
,
representations_list
:
Iterable
[
jnp
.
ndarray
],
aatype
:
jnp
.
ndarray
)
->
Dict
[
str
,
Any
]:
"""Predict sidechains using multi-rigid representations.
Args:
rigid: The Rigid's for each residue (translations in angstoms)
representations_list: A list of activations to predict sidechains from.
aatype: amino acid types.
Returns:
dict containing atom positions and frames (in angstrom)
"""
act
=
[
common_modules
.
Linear
(
# pylint: disable=g-complex-comprehension
self
.
config
.
num_channel
,
name
=
'input_projection'
)(
jax
.
nn
.
relu
(
x
))
for
x
in
representations_list
]
# Sum the activation list (equivalent to concat then Conv1D)
act
=
sum
(
act
)
final_init
=
'zeros'
if
self
.
global_config
.
zero_init
else
'linear'
# Mapping with some residual blocks.
for
_
in
range
(
self
.
config
.
num_residual_block
):
old_act
=
act
act
=
common_modules
.
Linear
(
self
.
config
.
num_channel
,
initializer
=
'relu'
,
name
=
'resblock1'
)(
jax
.
nn
.
relu
(
act
))
act
=
common_modules
.
Linear
(
self
.
config
.
num_channel
,
initializer
=
final_init
,
name
=
'resblock2'
)(
jax
.
nn
.
relu
(
act
))
act
+=
old_act
# Map activations to torsion angles.
# [batch_size, num_res, 14]
num_res
=
act
.
shape
[
0
]
unnormalized_angles
=
common_modules
.
Linear
(
14
,
name
=
'unnormalized_angles'
)(
jax
.
nn
.
relu
(
act
))
unnormalized_angles
=
jnp
.
reshape
(
unnormalized_angles
,
[
num_res
,
7
,
2
])
angles
=
l2_normalize
(
unnormalized_angles
,
axis
=-
1
)
outputs
=
{
'angles_sin_cos'
:
angles
,
# jnp.ndarray (N, 7, 2)
'unnormalized_angles_sin_cos'
:
unnormalized_angles
,
# jnp.ndarray (N, 7, 2)
}
# Map torsion angles to frames.
# geometry.Rigid3Array with shape (N, 8)
all_frames_to_global
=
all_atom_multimer
.
torsion_angles_to_frames
(
aatype
,
rigid
,
angles
)
# Use frames and literature positions to create the final atom coordinates.
# geometry.Vec3Array with shape (N, 14)
pred_positions
=
all_atom_multimer
.
frames_and_literature_positions_to_atom14_pos
(
aatype
,
all_frames_to_global
)
outputs
.
update
({
'atom_pos'
:
pred_positions
,
# geometry.Vec3Array (N, 14)
'frames'
:
all_frames_to_global
,
# geometry.Rigid3Array (N, 8)
})
return
outputs
alphafold/model/geometry/__init__.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Geometry Module."""
from
alphafold.model.geometry
import
rigid_matrix_vector
from
alphafold.model.geometry
import
rotation_matrix
from
alphafold.model.geometry
import
struct_of_array
from
alphafold.model.geometry
import
vector
Rot3Array
=
rotation_matrix
.
Rot3Array
Rigid3Array
=
rigid_matrix_vector
.
Rigid3Array
StructOfArray
=
struct_of_array
.
StructOfArray
Vec3Array
=
vector
.
Vec3Array
square_euclidean_distance
=
vector
.
square_euclidean_distance
euclidean_distance
=
vector
.
euclidean_distance
dihedral_angle
=
vector
.
dihedral_angle
dot
=
vector
.
dot
cross
=
vector
.
cross
alphafold/model/geometry/rigid_matrix_vector.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rigid3Array Transformations represented by a Matrix and a Vector."""
from
__future__
import
annotations
from
typing
import
Union
from
alphafold.model.geometry
import
rotation_matrix
from
alphafold.model.geometry
import
struct_of_array
from
alphafold.model.geometry
import
vector
import
jax
import
jax.numpy
as
jnp
Float
=
Union
[
float
,
jnp
.
ndarray
]
VERSION
=
'0.1'
@
struct_of_array
.
StructOfArray
(
same_dtype
=
True
)
class
Rigid3Array
:
"""Rigid Transformation, i.e. element of special euclidean group."""
rotation
:
rotation_matrix
.
Rot3Array
translation
:
vector
.
Vec3Array
def
__matmul__
(
self
,
other
:
Rigid3Array
)
->
Rigid3Array
:
new_rotation
=
self
.
rotation
@
other
.
rotation
new_translation
=
self
.
apply_to_point
(
other
.
translation
)
return
Rigid3Array
(
new_rotation
,
new_translation
)
def
inverse
(
self
)
->
Rigid3Array
:
"""Return Rigid3Array corresponding to inverse transform."""
inv_rotation
=
self
.
rotation
.
inverse
()
inv_translation
=
inv_rotation
.
apply_to_point
(
-
self
.
translation
)
return
Rigid3Array
(
inv_rotation
,
inv_translation
)
def
apply_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Apply Rigid3Array transform to point."""
return
self
.
rotation
.
apply_to_point
(
point
)
+
self
.
translation
def
apply_inverse_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Apply inverse Rigid3Array transform to point."""
new_point
=
point
-
self
.
translation
return
self
.
rotation
.
apply_inverse_to_point
(
new_point
)
def
compose_rotation
(
self
,
other_rotation
):
rot
=
self
.
rotation
@
other_rotation
trans
=
jax
.
tree_map
(
lambda
x
:
jnp
.
broadcast_to
(
x
,
rot
.
shape
),
self
.
translation
)
return
Rigid3Array
(
rot
,
trans
)
@
classmethod
def
identity
(
cls
,
shape
,
dtype
=
jnp
.
float32
)
->
Rigid3Array
:
"""Return identity Rigid3Array of given shape."""
return
cls
(
rotation_matrix
.
Rot3Array
.
identity
(
shape
,
dtype
=
dtype
),
vector
.
Vec3Array
.
zeros
(
shape
,
dtype
=
dtype
))
# pytype: disable=wrong-arg-count # trace-all-classes
def
scale_translation
(
self
,
factor
:
Float
)
->
Rigid3Array
:
"""Scale translation in Rigid3Array by 'factor'."""
return
Rigid3Array
(
self
.
rotation
,
self
.
translation
*
factor
)
def
to_array
(
self
):
rot_array
=
self
.
rotation
.
to_array
()
vec_array
=
self
.
translation
.
to_array
()
return
jnp
.
concatenate
([
rot_array
,
vec_array
[...,
None
]],
axis
=-
1
)
@
classmethod
def
from_array
(
cls
,
array
):
rot
=
rotation_matrix
.
Rot3Array
.
from_array
(
array
[...,
:
3
])
vec
=
vector
.
Vec3Array
.
from_array
(
array
[...,
-
1
])
return
cls
(
rot
,
vec
)
# pytype: disable=wrong-arg-count # trace-all-classes
@
classmethod
def
from_array4x4
(
cls
,
array
:
jnp
.
ndarray
)
->
Rigid3Array
:
"""Construct Rigid3Array from homogeneous 4x4 array."""
assert
array
.
shape
[
-
1
]
==
4
assert
array
.
shape
[
-
2
]
==
4
rotation
=
rotation_matrix
.
Rot3Array
(
array
[...,
0
,
0
],
array
[...,
0
,
1
],
array
[...,
0
,
2
],
array
[...,
1
,
0
],
array
[...,
1
,
1
],
array
[...,
1
,
2
],
array
[...,
2
,
0
],
array
[...,
2
,
1
],
array
[...,
2
,
2
]
)
translation
=
vector
.
Vec3Array
(
array
[...,
0
,
3
],
array
[...,
1
,
3
],
array
[...,
2
,
3
])
return
cls
(
rotation
,
translation
)
# pytype: disable=wrong-arg-count # trace-all-classes
def
__getstate__
(
self
):
return
(
VERSION
,
(
self
.
rotation
,
self
.
translation
))
def
__setstate__
(
self
,
state
):
version
,
(
rot
,
trans
)
=
state
del
version
object
.
__setattr__
(
self
,
'rotation'
,
rot
)
object
.
__setattr__
(
self
,
'translation'
,
trans
)
alphafold/model/geometry/rotation_matrix.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rot3Array Matrix Class."""
from
__future__
import
annotations
import
dataclasses
from
alphafold.model.geometry
import
struct_of_array
from
alphafold.model.geometry
import
utils
from
alphafold.model.geometry
import
vector
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
COMPONENTS
=
[
'xx'
,
'xy'
,
'xz'
,
'yx'
,
'yy'
,
'yz'
,
'zx'
,
'zy'
,
'zz'
]
VERSION
=
'0.1'
@
struct_of_array
.
StructOfArray
(
same_dtype
=
True
)
class
Rot3Array
:
"""Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays."""
xx
:
jnp
.
ndarray
=
dataclasses
.
field
(
metadata
=
{
'dtype'
:
jnp
.
float32
})
xy
:
jnp
.
ndarray
xz
:
jnp
.
ndarray
yx
:
jnp
.
ndarray
yy
:
jnp
.
ndarray
yz
:
jnp
.
ndarray
zx
:
jnp
.
ndarray
zy
:
jnp
.
ndarray
zz
:
jnp
.
ndarray
__array_ufunc__
=
None
def
inverse
(
self
)
->
Rot3Array
:
"""Returns inverse of Rot3Array."""
return
Rot3Array
(
self
.
xx
,
self
.
yx
,
self
.
zx
,
self
.
xy
,
self
.
yy
,
self
.
zy
,
self
.
xz
,
self
.
yz
,
self
.
zz
)
def
apply_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Applies Rot3Array to point."""
return
vector
.
Vec3Array
(
self
.
xx
*
point
.
x
+
self
.
xy
*
point
.
y
+
self
.
xz
*
point
.
z
,
self
.
yx
*
point
.
x
+
self
.
yy
*
point
.
y
+
self
.
yz
*
point
.
z
,
self
.
zx
*
point
.
x
+
self
.
zy
*
point
.
y
+
self
.
zz
*
point
.
z
)
def
apply_inverse_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Applies inverse Rot3Array to point."""
return
self
.
inverse
().
apply_to_point
(
point
)
def
__matmul__
(
self
,
other
:
Rot3Array
)
->
Rot3Array
:
"""Composes two Rot3Arrays."""
c0
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xx
,
other
.
yx
,
other
.
zx
))
c1
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xy
,
other
.
yy
,
other
.
zy
))
c2
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xz
,
other
.
yz
,
other
.
zz
))
return
Rot3Array
(
c0
.
x
,
c1
.
x
,
c2
.
x
,
c0
.
y
,
c1
.
y
,
c2
.
y
,
c0
.
z
,
c1
.
z
,
c2
.
z
)
@
classmethod
def
identity
(
cls
,
shape
,
dtype
=
jnp
.
float32
)
->
Rot3Array
:
"""Returns identity of given shape."""
ones
=
jnp
.
ones
(
shape
,
dtype
=
dtype
)
zeros
=
jnp
.
zeros
(
shape
,
dtype
=
dtype
)
return
cls
(
ones
,
zeros
,
zeros
,
zeros
,
ones
,
zeros
,
zeros
,
zeros
,
ones
)
# pytype: disable=wrong-arg-count # trace-all-classes
@
classmethod
def
from_two_vectors
(
cls
,
e0
:
vector
.
Vec3Array
,
e1
:
vector
.
Vec3Array
)
->
Rot3Array
:
"""Construct Rot3Array from two Vectors.
Rot3Array is constructed such that in the corresponding frame 'e0' lies on
the positive x-Axis and 'e1' lies in the xy plane with positive sign of y.
Args:
e0: Vector
e1: Vector
Returns:
Rot3Array
"""
# Normalize the unit vector for the x-axis, e0.
e0
=
e0
.
normalized
()
# make e1 perpendicular to e0.
c
=
e1
.
dot
(
e0
)
e1
=
(
e1
-
c
*
e0
).
normalized
()
# Compute e2 as cross product of e0 and e1.
e2
=
e0
.
cross
(
e1
)
return
cls
(
e0
.
x
,
e1
.
x
,
e2
.
x
,
e0
.
y
,
e1
.
y
,
e2
.
y
,
e0
.
z
,
e1
.
z
,
e2
.
z
)
# pytype: disable=wrong-arg-count # trace-all-classes
@
classmethod
def
from_array
(
cls
,
array
:
jnp
.
ndarray
)
->
Rot3Array
:
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
unstacked
=
utils
.
unstack
(
array
,
axis
=-
2
)
unstacked
=
sum
([
utils
.
unstack
(
x
,
axis
=-
1
)
for
x
in
unstacked
],
[])
return
cls
(
*
unstacked
)
def
to_array
(
self
)
->
jnp
.
ndarray
:
"""Convert Rot3Array to array of shape [..., 3, 3]."""
return
jnp
.
stack
(
[
jnp
.
stack
([
self
.
xx
,
self
.
xy
,
self
.
xz
],
axis
=-
1
),
jnp
.
stack
([
self
.
yx
,
self
.
yy
,
self
.
yz
],
axis
=-
1
),
jnp
.
stack
([
self
.
zx
,
self
.
zy
,
self
.
zz
],
axis
=-
1
)],
axis
=-
2
)
@
classmethod
def
from_quaternion
(
cls
,
w
:
jnp
.
ndarray
,
x
:
jnp
.
ndarray
,
y
:
jnp
.
ndarray
,
z
:
jnp
.
ndarray
,
normalize
:
bool
=
True
,
epsilon
:
float
=
1e-6
)
->
Rot3Array
:
"""Construct Rot3Array from components of quaternion."""
if
normalize
:
inv_norm
=
jax
.
lax
.
rsqrt
(
jnp
.
maximum
(
epsilon
,
w
**
2
+
x
**
2
+
y
**
2
+
z
**
2
))
w
*=
inv_norm
x
*=
inv_norm
y
*=
inv_norm
z
*=
inv_norm
xx
=
1
-
2
*
(
jnp
.
square
(
y
)
+
jnp
.
square
(
z
))
xy
=
2
*
(
x
*
y
-
w
*
z
)
xz
=
2
*
(
x
*
z
+
w
*
y
)
yx
=
2
*
(
x
*
y
+
w
*
z
)
yy
=
1
-
2
*
(
jnp
.
square
(
x
)
+
jnp
.
square
(
z
))
yz
=
2
*
(
y
*
z
-
w
*
x
)
zx
=
2
*
(
x
*
z
-
w
*
y
)
zy
=
2
*
(
y
*
z
+
w
*
x
)
zz
=
1
-
2
*
(
jnp
.
square
(
x
)
+
jnp
.
square
(
y
))
return
cls
(
xx
,
xy
,
xz
,
yx
,
yy
,
yz
,
zx
,
zy
,
zz
)
# pytype: disable=wrong-arg-count # trace-all-classes
@
classmethod
def
random_uniform
(
cls
,
key
,
shape
,
dtype
=
jnp
.
float32
)
->
Rot3Array
:
"""Samples uniform random Rot3Array according to Haar Measure."""
quat_array
=
jax
.
random
.
normal
(
key
,
tuple
(
shape
)
+
(
4
,),
dtype
=
dtype
)
quats
=
utils
.
unstack
(
quat_array
)
return
cls
.
from_quaternion
(
*
quats
)
def
__getstate__
(
self
):
return
(
VERSION
,
[
np
.
asarray
(
getattr
(
self
,
field
))
for
field
in
COMPONENTS
])
def
__setstate__
(
self
,
state
):
version
,
state
=
state
del
version
for
i
,
field
in
enumerate
(
COMPONENTS
):
object
.
__setattr__
(
self
,
field
,
state
[
i
])
alphafold/model/geometry/struct_of_array.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class decorator to represent (nested) struct of arrays."""
import
dataclasses
import
jax
def
get_item
(
instance
,
key
):
sliced
=
{}
for
field
in
get_array_fields
(
instance
):
num_trailing_dims
=
field
.
metadata
.
get
(
'num_trailing_dims'
,
0
)
this_key
=
key
if
isinstance
(
key
,
tuple
)
and
Ellipsis
in
this_key
:
this_key
+=
(
slice
(
None
),)
*
num_trailing_dims
sliced
[
field
.
name
]
=
getattr
(
instance
,
field
.
name
)[
this_key
]
return
dataclasses
.
replace
(
instance
,
**
sliced
)
@
property
def
get_shape
(
instance
):
"""Returns Shape for given instance of dataclass."""
first_field
=
dataclasses
.
fields
(
instance
)[
0
]
num_trailing_dims
=
first_field
.
metadata
.
get
(
'num_trailing_dims'
,
None
)
value
=
getattr
(
instance
,
first_field
.
name
)
if
num_trailing_dims
:
return
value
.
shape
[:
-
num_trailing_dims
]
else
:
return
value
.
shape
def
get_len
(
instance
):
"""Returns length for given instance of dataclass."""
shape
=
instance
.
shape
if
shape
:
return
shape
[
0
]
else
:
raise
TypeError
(
'len() of unsized object'
)
# Match jax.numpy behavior.
@
property
def
get_dtype
(
instance
):
"""Returns Dtype for given instance of dataclass."""
fields
=
dataclasses
.
fields
(
instance
)
sets_dtype
=
[
field
.
name
for
field
in
fields
if
field
.
metadata
.
get
(
'sets_dtype'
,
False
)
]
if
sets_dtype
:
assert
len
(
sets_dtype
)
==
1
,
'at most field can set dtype'
field_value
=
getattr
(
instance
,
sets_dtype
[
0
])
elif
instance
.
same_dtype
:
field_value
=
getattr
(
instance
,
fields
[
0
].
name
)
else
:
# Should this be Value Error?
raise
AttributeError
(
'Trying to access Dtype on Struct of Array without'
'either "same_dtype" or field setting dtype'
)
if
hasattr
(
field_value
,
'dtype'
):
return
field_value
.
dtype
else
:
# Should this be Value Error?
raise
AttributeError
(
f
'field_value
{
field_value
}
does not have dtype'
)
def
replace
(
instance
,
**
kwargs
):
return
dataclasses
.
replace
(
instance
,
**
kwargs
)
def
post_init
(
instance
):
"""Validate instance has same shapes & dtypes."""
array_fields
=
get_array_fields
(
instance
)
arrays
=
list
(
get_array_fields
(
instance
,
return_values
=
True
).
values
())
first_field
=
array_fields
[
0
]
# These slightly weird constructions about checking whether the leaves are
# actual arrays is since e.g. vmap internally relies on being able to
# construct pytree's with object() as leaves, this would break the checking
# as such we are only validating the object when the entries in the dataclass
# Are arrays or other dataclasses of arrays.
try
:
dtype
=
instance
.
dtype
except
AttributeError
:
dtype
=
None
if
dtype
is
not
None
:
first_shape
=
instance
.
shape
for
array
,
field
in
zip
(
arrays
,
array_fields
):
field_shape
=
array
.
shape
num_trailing_dims
=
field
.
metadata
.
get
(
'num_trailing_dims'
,
None
)
if
num_trailing_dims
:
array_shape
=
array
.
shape
field_shape
=
array_shape
[:
-
num_trailing_dims
]
msg
=
(
f
'field
{
field
}
should have number of trailing dims'
' {num_trailing_dims}'
)
assert
len
(
array_shape
)
==
len
(
first_shape
)
+
num_trailing_dims
,
msg
else
:
field_shape
=
array
.
shape
shape_msg
=
(
f
"Stripped Shape
{
field_shape
}
of field
{
field
}
doesn't "
f
"match shape
{
first_shape
}
of field
{
first_field
}
"
)
assert
field_shape
==
first_shape
,
shape_msg
field_dtype
=
array
.
dtype
allowed_metadata_dtypes
=
field
.
metadata
.
get
(
'allowed_dtypes'
,
[])
if
allowed_metadata_dtypes
:
msg
=
f
'Dtype is
{
field_dtype
}
but must be in
{
allowed_metadata_dtypes
}
'
assert
field_dtype
in
allowed_metadata_dtypes
,
msg
if
'dtype'
in
field
.
metadata
:
target_dtype
=
field
.
metadata
[
'dtype'
]
else
:
target_dtype
=
dtype
msg
=
f
'Dtype is
{
field_dtype
}
but must be
{
target_dtype
}
'
assert
field_dtype
==
target_dtype
,
msg
def
flatten
(
instance
):
"""Flatten Struct of Array instance."""
array_likes
=
list
(
get_array_fields
(
instance
,
return_values
=
True
).
values
())
flat_array_likes
=
[]
inner_treedefs
=
[]
num_arrays
=
[]
for
array_like
in
array_likes
:
flat_array_like
,
inner_treedef
=
jax
.
tree_flatten
(
array_like
)
inner_treedefs
.
append
(
inner_treedef
)
flat_array_likes
+=
flat_array_like
num_arrays
.
append
(
len
(
flat_array_like
))
metadata
=
get_metadata_fields
(
instance
,
return_values
=
True
)
metadata
=
type
(
instance
).
metadata_cls
(
**
metadata
)
return
flat_array_likes
,
(
inner_treedefs
,
metadata
,
num_arrays
)
def
make_metadata_class
(
cls
):
metadata_fields
=
get_fields
(
cls
,
lambda
x
:
x
.
metadata
.
get
(
'is_metadata'
,
False
))
metadata_cls
=
dataclasses
.
make_dataclass
(
cls_name
=
'Meta'
+
cls
.
__name__
,
fields
=
[(
field
.
name
,
field
.
type
,
field
)
for
field
in
metadata_fields
],
frozen
=
True
,
eq
=
True
)
return
metadata_cls
def
get_fields
(
cls_or_instance
,
filterfn
,
return_values
=
False
):
fields
=
dataclasses
.
fields
(
cls_or_instance
)
fields
=
[
field
for
field
in
fields
if
filterfn
(
field
)]
if
return_values
:
return
{
field
.
name
:
getattr
(
cls_or_instance
,
field
.
name
)
for
field
in
fields
}
else
:
return
fields
def
get_array_fields
(
cls
,
return_values
=
False
):
return
get_fields
(
cls
,
lambda
x
:
not
x
.
metadata
.
get
(
'is_metadata'
,
False
),
return_values
=
return_values
)
def
get_metadata_fields
(
cls
,
return_values
=
False
):
return
get_fields
(
cls
,
lambda
x
:
x
.
metadata
.
get
(
'is_metadata'
,
False
),
return_values
=
return_values
)
class
StructOfArray
:
"""Class Decorator for Struct Of Arrays."""
def
__init__
(
self
,
same_dtype
=
True
):
self
.
same_dtype
=
same_dtype
def
__call__
(
self
,
cls
):
cls
.
__array_ufunc__
=
None
cls
.
replace
=
replace
cls
.
same_dtype
=
self
.
same_dtype
cls
.
dtype
=
get_dtype
cls
.
shape
=
get_shape
cls
.
__len__
=
get_len
cls
.
__getitem__
=
get_item
cls
.
__post_init__
=
post_init
new_cls
=
dataclasses
.
dataclass
(
cls
,
frozen
=
True
,
eq
=
False
)
# pytype: disable=wrong-keyword-args
# pytree claims to require metadata to be hashable, not sure why,
# But making derived dataclass that can just hold metadata
new_cls
.
metadata_cls
=
make_metadata_class
(
new_cls
)
def
unflatten
(
aux
,
data
):
inner_treedefs
,
metadata
,
num_arrays
=
aux
array_fields
=
[
field
.
name
for
field
in
get_array_fields
(
new_cls
)]
value_dict
=
{}
array_start
=
0
for
num_array
,
inner_treedef
,
array_field
in
zip
(
num_arrays
,
inner_treedefs
,
array_fields
):
value_dict
[
array_field
]
=
jax
.
tree_unflatten
(
inner_treedef
,
data
[
array_start
:
array_start
+
num_array
])
array_start
+=
num_array
metadata_fields
=
get_metadata_fields
(
new_cls
)
for
field
in
metadata_fields
:
value_dict
[
field
.
name
]
=
getattr
(
metadata
,
field
.
name
)
return
new_cls
(
**
value_dict
)
jax
.
tree_util
.
register_pytree_node
(
nodetype
=
new_cls
,
flatten_func
=
flatten
,
unflatten_func
=
unflatten
)
return
new_cls
alphafold/model/geometry/test_utils.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utils for tests."""
import
dataclasses
from
alphafold.model.geometry
import
rigid_matrix_vector
from
alphafold.model.geometry
import
rotation_matrix
from
alphafold.model.geometry
import
vector
import
jax.numpy
as
jnp
import
numpy
as
np
def
assert_rotation_matrix_equal
(
matrix1
:
rotation_matrix
.
Rot3Array
,
matrix2
:
rotation_matrix
.
Rot3Array
):
for
field
in
dataclasses
.
fields
(
rotation_matrix
.
Rot3Array
):
field
=
field
.
name
np
.
testing
.
assert_array_equal
(
getattr
(
matrix1
,
field
),
getattr
(
matrix2
,
field
))
def
assert_rotation_matrix_close
(
mat1
:
rotation_matrix
.
Rot3Array
,
mat2
:
rotation_matrix
.
Rot3Array
):
np
.
testing
.
assert_array_almost_equal
(
mat1
.
to_array
(),
mat2
.
to_array
(),
6
)
def
assert_array_equal_to_rotation_matrix
(
array
:
jnp
.
ndarray
,
matrix
:
rotation_matrix
.
Rot3Array
):
"""Check that array and Matrix match."""
np
.
testing
.
assert_array_equal
(
matrix
.
xx
,
array
[...,
0
,
0
])
np
.
testing
.
assert_array_equal
(
matrix
.
xy
,
array
[...,
0
,
1
])
np
.
testing
.
assert_array_equal
(
matrix
.
xz
,
array
[...,
0
,
2
])
np
.
testing
.
assert_array_equal
(
matrix
.
yx
,
array
[...,
1
,
0
])
np
.
testing
.
assert_array_equal
(
matrix
.
yy
,
array
[...,
1
,
1
])
np
.
testing
.
assert_array_equal
(
matrix
.
yz
,
array
[...,
1
,
2
])
np
.
testing
.
assert_array_equal
(
matrix
.
zx
,
array
[...,
2
,
0
])
np
.
testing
.
assert_array_equal
(
matrix
.
zy
,
array
[...,
2
,
1
])
np
.
testing
.
assert_array_equal
(
matrix
.
zz
,
array
[...,
2
,
2
])
def
assert_array_close_to_rotation_matrix
(
array
:
jnp
.
ndarray
,
matrix
:
rotation_matrix
.
Rot3Array
):
np
.
testing
.
assert_array_almost_equal
(
matrix
.
to_array
(),
array
,
6
)
def
assert_vectors_equal
(
vec1
:
vector
.
Vec3Array
,
vec2
:
vector
.
Vec3Array
):
np
.
testing
.
assert_array_equal
(
vec1
.
x
,
vec2
.
x
)
np
.
testing
.
assert_array_equal
(
vec1
.
y
,
vec2
.
y
)
np
.
testing
.
assert_array_equal
(
vec1
.
z
,
vec2
.
z
)
def
assert_vectors_close
(
vec1
:
vector
.
Vec3Array
,
vec2
:
vector
.
Vec3Array
):
np
.
testing
.
assert_allclose
(
vec1
.
x
,
vec2
.
x
,
atol
=
1e-6
,
rtol
=
0.
)
np
.
testing
.
assert_allclose
(
vec1
.
y
,
vec2
.
y
,
atol
=
1e-6
,
rtol
=
0.
)
np
.
testing
.
assert_allclose
(
vec1
.
z
,
vec2
.
z
,
atol
=
1e-6
,
rtol
=
0.
)
def
assert_array_close_to_vector
(
array
:
jnp
.
ndarray
,
vec
:
vector
.
Vec3Array
):
np
.
testing
.
assert_allclose
(
vec
.
to_array
(),
array
,
atol
=
1e-6
,
rtol
=
0.
)
def
assert_array_equal_to_vector
(
array
:
jnp
.
ndarray
,
vec
:
vector
.
Vec3Array
):
np
.
testing
.
assert_array_equal
(
vec
.
to_array
(),
array
)
def
assert_rigid_equal_to_rigid
(
rigid1
:
rigid_matrix_vector
.
Rigid3Array
,
rigid2
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rot_trans_equal_to_rigid
(
rigid1
.
rotation
,
rigid1
.
translation
,
rigid2
)
def
assert_rigid_close_to_rigid
(
rigid1
:
rigid_matrix_vector
.
Rigid3Array
,
rigid2
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rot_trans_close_to_rigid
(
rigid1
.
rotation
,
rigid1
.
translation
,
rigid2
)
def
assert_rot_trans_equal_to_rigid
(
rot
:
rotation_matrix
.
Rot3Array
,
trans
:
vector
.
Vec3Array
,
rigid
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rotation_matrix_equal
(
rot
,
rigid
.
rotation
)
assert_vectors_equal
(
trans
,
rigid
.
translation
)
def
assert_rot_trans_close_to_rigid
(
rot
:
rotation_matrix
.
Rot3Array
,
trans
:
vector
.
Vec3Array
,
rigid
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rotation_matrix_close
(
rot
,
rigid
.
rotation
)
assert_vectors_close
(
trans
,
rigid
.
translation
)
alphafold/model/geometry/utils.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for geometry library."""
from
typing
import
List
import
jax.numpy
as
jnp
def
unstack
(
value
:
jnp
.
ndarray
,
axis
:
int
=
-
1
)
->
List
[
jnp
.
ndarray
]:
return
[
jnp
.
squeeze
(
v
,
axis
=
axis
)
for
v
in
jnp
.
split
(
value
,
value
.
shape
[
axis
],
axis
=
axis
)]
alphafold/model/geometry/vector.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vec3Array Class."""
from
__future__
import
annotations
import
dataclasses
from
typing
import
Union
from
alphafold.model.geometry
import
struct_of_array
from
alphafold.model.geometry
import
utils
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
Float
=
Union
[
float
,
jnp
.
ndarray
]
VERSION
=
'0.1'
@
struct_of_array
.
StructOfArray
(
same_dtype
=
True
)
class
Vec3Array
:
"""Vec3Array in 3 dimensional Space implemented as struct of arrays.
This is done in order to improve performance and precision.
On TPU small matrix multiplications are very suboptimal and will waste large
compute ressources, furthermore any matrix multiplication on tpu happen in
mixed bfloat16/float32 precision, which is often undesirable when handling
physical coordinates.
In most cases this will also be faster on cpu's/gpu's since it allows for
easier use of vector instructions.
"""
x
:
jnp
.
ndarray
=
dataclasses
.
field
(
metadata
=
{
'dtype'
:
jnp
.
float32
})
y
:
jnp
.
ndarray
z
:
jnp
.
ndarray
def
__post_init__
(
self
):
if
hasattr
(
self
.
x
,
'dtype'
):
assert
self
.
x
.
dtype
==
self
.
y
.
dtype
assert
self
.
x
.
dtype
==
self
.
z
.
dtype
assert
all
([
x
==
y
for
x
,
y
in
zip
(
self
.
x
.
shape
,
self
.
y
.
shape
)])
assert
all
([
x
==
z
for
x
,
z
in
zip
(
self
.
x
.
shape
,
self
.
z
.
shape
)])
def
__add__
(
self
,
other
:
Vec3Array
)
->
Vec3Array
:
return
jax
.
tree_map
(
lambda
x
,
y
:
x
+
y
,
self
,
other
)
def
__sub__
(
self
,
other
:
Vec3Array
)
->
Vec3Array
:
return
jax
.
tree_map
(
lambda
x
,
y
:
x
-
y
,
self
,
other
)
def
__mul__
(
self
,
other
:
Float
)
->
Vec3Array
:
return
jax
.
tree_map
(
lambda
x
:
x
*
other
,
self
)
def
__rmul__
(
self
,
other
:
Float
)
->
Vec3Array
:
return
self
*
other
def
__truediv__
(
self
,
other
:
Float
)
->
Vec3Array
:
return
jax
.
tree_map
(
lambda
x
:
x
/
other
,
self
)
def
__neg__
(
self
)
->
Vec3Array
:
return
jax
.
tree_map
(
lambda
x
:
-
x
,
self
)
def
__pos__
(
self
)
->
Vec3Array
:
return
jax
.
tree_map
(
lambda
x
:
x
,
self
)
def
cross
(
self
,
other
:
Vec3Array
)
->
Vec3Array
:
"""Compute cross product between 'self' and 'other'."""
new_x
=
self
.
y
*
other
.
z
-
self
.
z
*
other
.
y
new_y
=
self
.
z
*
other
.
x
-
self
.
x
*
other
.
z
new_z
=
self
.
x
*
other
.
y
-
self
.
y
*
other
.
x
return
Vec3Array
(
new_x
,
new_y
,
new_z
)
def
dot
(
self
,
other
:
Vec3Array
)
->
Float
:
"""Compute dot product between 'self' and 'other'."""
return
self
.
x
*
other
.
x
+
self
.
y
*
other
.
y
+
self
.
z
*
other
.
z
def
norm
(
self
,
epsilon
:
float
=
1e-6
)
->
Float
:
"""Compute Norm of Vec3Array, clipped to epsilon."""
# To avoid NaN on the backward pass, we must use maximum before the sqrt
norm2
=
self
.
dot
(
self
)
if
epsilon
:
norm2
=
jnp
.
maximum
(
norm2
,
epsilon
**
2
)
return
jnp
.
sqrt
(
norm2
)
def
norm2
(
self
):
return
self
.
dot
(
self
)
def
normalized
(
self
,
epsilon
:
float
=
1e-6
)
->
Vec3Array
:
"""Return unit vector with optional clipping."""
return
self
/
self
.
norm
(
epsilon
)
@
classmethod
def
zeros
(
cls
,
shape
,
dtype
=
jnp
.
float32
):
"""Return Vec3Array corresponding to zeros of given shape."""
return
cls
(
jnp
.
zeros
(
shape
,
dtype
),
jnp
.
zeros
(
shape
,
dtype
),
jnp
.
zeros
(
shape
,
dtype
))
# pytype: disable=wrong-arg-count # trace-all-classes
def
to_array
(
self
)
->
jnp
.
ndarray
:
return
jnp
.
stack
([
self
.
x
,
self
.
y
,
self
.
z
],
axis
=-
1
)
@
classmethod
def
from_array
(
cls
,
array
):
return
cls
(
*
utils
.
unstack
(
array
))
def
__getstate__
(
self
):
return
(
VERSION
,
[
np
.
asarray
(
self
.
x
),
np
.
asarray
(
self
.
y
),
np
.
asarray
(
self
.
z
)])
def
__setstate__
(
self
,
state
):
version
,
state
=
state
del
version
for
i
,
letter
in
enumerate
(
'xyz'
):
object
.
__setattr__
(
self
,
letter
,
state
[
i
])
def
square_euclidean_distance
(
vec1
:
Vec3Array
,
vec2
:
Vec3Array
,
epsilon
:
float
=
1e-6
)
->
Float
:
"""Computes square of euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute distance to
vec2: Vec3Array to compute distance from, should be
broadcast compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of square euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
difference
=
vec1
-
vec2
distance
=
difference
.
dot
(
difference
)
if
epsilon
:
distance
=
jnp
.
maximum
(
distance
,
epsilon
)
return
distance
def
dot
(
vector1
:
Vec3Array
,
vector2
:
Vec3Array
)
->
Float
:
return
vector1
.
dot
(
vector2
)
def
cross
(
vector1
:
Vec3Array
,
vector2
:
Vec3Array
)
->
Float
:
return
vector1
.
cross
(
vector2
)
def
norm
(
vector
:
Vec3Array
,
epsilon
:
float
=
1e-6
)
->
Float
:
return
vector
.
norm
(
epsilon
)
def
normalized
(
vector
:
Vec3Array
,
epsilon
:
float
=
1e-6
)
->
Vec3Array
:
return
vector
.
normalized
(
epsilon
)
def
euclidean_distance
(
vec1
:
Vec3Array
,
vec2
:
Vec3Array
,
epsilon
:
float
=
1e-6
)
->
Float
:
"""Computes euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute euclidean distance to
vec2: Vec3Array to compute euclidean distance from, should be
broadcast compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
distance_sq
=
square_euclidean_distance
(
vec1
,
vec2
,
epsilon
**
2
)
distance
=
jnp
.
sqrt
(
distance_sq
)
return
distance
def
dihedral_angle
(
a
:
Vec3Array
,
b
:
Vec3Array
,
c
:
Vec3Array
,
d
:
Vec3Array
)
->
Float
:
"""Computes torsion angle for a quadruple of points.
For points (a, b, c, d), this is the angle between the planes defined by
points (a, b, c) and (b, c, d). It is also known as the dihedral angle.
Arguments:
a: A Vec3Array of coordinates.
b: A Vec3Array of coordinates.
c: A Vec3Array of coordinates.
d: A Vec3Array of coordinates.
Returns:
A tensor of angles in radians: [-pi, pi].
"""
v1
=
a
-
b
v2
=
b
-
c
v3
=
d
-
c
c1
=
v1
.
cross
(
v2
)
c2
=
v3
.
cross
(
v2
)
c3
=
c2
.
cross
(
c1
)
v2_mag
=
v2
.
norm
()
return
jnp
.
arctan2
(
c3
.
dot
(
v2
),
v2_mag
*
c1
.
dot
(
c2
))
def
random_gaussian_vector
(
shape
,
key
,
dtype
=
jnp
.
float32
):
vec_array
=
jax
.
random
.
normal
(
key
,
shape
+
(
3
,),
dtype
)
return
Vec3Array
.
from_array
(
vec_array
)
alphafold/model/layer_stack.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Function to stack repeats of a layer function without shared parameters."""
import
collections
import
contextlib
import
functools
import
inspect
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
,
Union
import
haiku
as
hk
import
jax
import
jax.numpy
as
jnp
LayerStackCarry
=
collections
.
namedtuple
(
'LayerStackCarry'
,
[
'x'
,
'rng'
])
LayerStackScanned
=
collections
.
namedtuple
(
'LayerStackScanned'
,
[
'i'
,
'args_ys'
])
# WrappedFn should take in arbitrarily nested `jnp.ndarray`, and return the
# exact same type. We cannot express this with `typing`. So we just use it
# to inform the user. In reality, the typing below will accept anything.
NestedArray
=
Any
WrappedFn
=
Callable
[...,
Union
[
NestedArray
,
Tuple
[
NestedArray
]]]
def
_check_no_varargs
(
f
):
if
list
(
inspect
.
signature
(
f
).
parameters
.
values
())[
0
].
kind
==
inspect
.
Parameter
.
VAR_POSITIONAL
:
raise
ValueError
(
'The function `f` should not have any `varargs` (that is *args) '
'argument. Instead, it should only use explicit positional'
'arguments.'
)
@
contextlib
.
contextmanager
def
nullcontext
():
yield
def
maybe_with_rng
(
key
):
if
key
is
not
None
:
return
hk
.
with_rng
(
key
)
else
:
return
nullcontext
()
def
maybe_fold_in
(
key
,
data
):
if
key
is
not
None
:
return
jax
.
random
.
fold_in
(
key
,
data
)
else
:
return
None
class
_LayerStack
(
hk
.
Module
):
"""Module to compose parameterized functions, implemented as a scan."""
def
__init__
(
self
,
count
:
int
,
unroll
:
int
,
name
:
Optional
[
str
]
=
None
):
"""Iterate a function `f` `count` times, with non-shared parameters."""
super
().
__init__
(
name
=
name
)
self
.
_count
=
count
self
.
_unroll
=
unroll
def
__call__
(
self
,
x
,
*
args_ys
):
count
=
self
.
_count
if
hk
.
running_init
():
# At initialization time, we run just one layer but add an extra first
# dimension to every initialized tensor, making sure to use different
# random keys for different slices.
def
creator
(
next_creator
,
shape
,
dtype
,
init
,
context
):
del
context
def
multi_init
(
shape
,
dtype
):
assert
shape
[
0
]
==
count
key
=
hk
.
maybe_next_rng_key
()
def
rng_context_init
(
slice_idx
):
slice_key
=
maybe_fold_in
(
key
,
slice_idx
)
with
maybe_with_rng
(
slice_key
):
return
init
(
shape
[
1
:],
dtype
)
return
jax
.
vmap
(
rng_context_init
)(
jnp
.
arange
(
count
))
return
next_creator
((
count
,)
+
tuple
(
shape
),
dtype
,
multi_init
)
def
getter
(
next_getter
,
value
,
context
):
trailing_dims
=
len
(
context
.
original_shape
)
+
1
sliced_value
=
jax
.
lax
.
index_in_dim
(
value
,
index
=
0
,
axis
=
value
.
ndim
-
trailing_dims
,
keepdims
=
False
)
return
next_getter
(
sliced_value
)
with
hk
.
experimental
.
custom_creator
(
creator
),
hk
.
experimental
.
custom_getter
(
getter
):
if
len
(
args_ys
)
==
1
and
args_ys
[
0
]
is
None
:
args0
=
(
None
,)
else
:
args0
=
[
jax
.
lax
.
dynamic_index_in_dim
(
ys
,
0
,
keepdims
=
False
)
for
ys
in
args_ys
]
x
,
z
=
self
.
_call_wrapped
(
x
,
*
args0
)
if
z
is
None
:
return
x
,
z
# Broadcast state to hold each layer state.
def
broadcast_state
(
layer_state
):
return
jnp
.
broadcast_to
(
layer_state
,
[
count
,]
+
list
(
layer_state
.
shape
))
zs
=
jax
.
tree_util
.
tree_map
(
broadcast_state
,
z
)
return
x
,
zs
else
:
# Use scan during apply, threading through random seed so that it's
# unique for each layer.
def
layer
(
carry
:
LayerStackCarry
,
scanned
:
LayerStackScanned
):
rng
=
carry
.
rng
def
getter
(
next_getter
,
value
,
context
):
# Getter slices the full param at the current loop index.
trailing_dims
=
len
(
context
.
original_shape
)
+
1
assert
value
.
shape
[
value
.
ndim
-
trailing_dims
]
==
count
,
(
f
'Attempting to use a parameter stack of size '
f
'
{
value
.
shape
[
value
.
ndim
-
trailing_dims
]
}
for a LayerStack of '
f
'size
{
count
}
.'
)
sliced_value
=
jax
.
lax
.
dynamic_index_in_dim
(
value
,
scanned
.
i
,
axis
=
value
.
ndim
-
trailing_dims
,
keepdims
=
False
)
return
next_getter
(
sliced_value
)
with
hk
.
experimental
.
custom_getter
(
getter
):
if
rng
is
None
:
out_x
,
z
=
self
.
_call_wrapped
(
carry
.
x
,
*
scanned
.
args_ys
)
else
:
rng
,
rng_
=
jax
.
random
.
split
(
rng
)
with
hk
.
with_rng
(
rng_
):
out_x
,
z
=
self
.
_call_wrapped
(
carry
.
x
,
*
scanned
.
args_ys
)
return
LayerStackCarry
(
x
=
out_x
,
rng
=
rng
),
z
carry
=
LayerStackCarry
(
x
=
x
,
rng
=
hk
.
maybe_next_rng_key
())
scanned
=
LayerStackScanned
(
i
=
jnp
.
arange
(
count
,
dtype
=
jnp
.
int32
),
args_ys
=
args_ys
)
carry
,
zs
=
hk
.
scan
(
layer
,
carry
,
scanned
,
length
=
count
,
unroll
=
self
.
_unroll
)
return
carry
.
x
,
zs
def
_call_wrapped
(
self
,
x
:
jnp
.
ndarray
,
*
args
,
)
->
Tuple
[
jnp
.
ndarray
,
Optional
[
jnp
.
ndarray
]]:
raise
NotImplementedError
()
class
_LayerStackNoState
(
_LayerStack
):
"""_LayerStack impl with no per-layer state provided to the function."""
def
__init__
(
self
,
f
:
WrappedFn
,
count
:
int
,
unroll
:
int
,
name
:
Optional
[
str
]
=
None
):
super
().
__init__
(
count
=
count
,
unroll
=
unroll
,
name
=
name
)
_check_no_varargs
(
f
)
self
.
_f
=
f
@
hk
.
transparent
def
_call_wrapped
(
self
,
args
,
y
):
del
y
ret
=
self
.
_f
(
*
args
)
if
len
(
args
)
==
1
:
# If the function takes a single argument, the wrapped function receives
# a tuple of length 1, and therefore it must return a tuple of length 1.
ret
=
(
ret
,)
return
ret
,
None
class
_LayerStackWithState
(
_LayerStack
):
"""_LayerStack impl with per-layer state provided to the function."""
def
__init__
(
self
,
f
:
WrappedFn
,
count
:
int
,
unroll
:
int
,
name
:
Optional
[
str
]
=
None
):
super
().
__init__
(
count
=
count
,
unroll
=
unroll
,
name
=
name
)
self
.
_f
=
f
@
hk
.
transparent
def
_call_wrapped
(
self
,
x
,
*
args
):
return
self
.
_f
(
x
,
*
args
)
def
layer_stack
(
num_layers
:
int
,
with_state
=
False
,
unroll
:
int
=
1
,
name
:
Optional
[
str
]
=
None
):
"""Utility to wrap a Haiku function and recursively apply it to an input.
A function is valid if it uses only explicit position parameters, and
its return type matches its input type. The position parameters can be
arbitrarily nested structures with `jnp.ndarray` at the leaf nodes. Note
that kwargs are not supported, neither are functions with variable number
of parameters (specified by `*args`).
If `with_state=False` then the new, wrapped function can be understood as
performing the following:
```
for i in range(num_layers):
x = f(x)
return x
```
And if `with_state=True`, assuming `f` takes two arguments on top of `x`:
```
for i in range(num_layers):
x, zs[i] = f(x, ys_0[i], ys_1[i])
return x, zs
```
The code using `layer_stack` for the above function would be:
```
def f(x, y_0, y_1):
...
return new_x, z
x, zs = layer_stack.layer_stack(num_layers,
with_state=True)(f)(x, ys_0, ys_1)
```
Crucially, any parameters created inside `f` will not be shared across
iterations.
Args:
num_layers: The number of times to iterate the wrapped function.
with_state: Whether or not to pass per-layer state to the wrapped function.
unroll: the unroll used by `scan`.
name: Name of the Haiku context.
Returns:
Callable that will produce a layer stack when called with a valid function.
"""
def
iterate
(
f
):
if
with_state
:
@
functools
.
wraps
(
f
)
def
wrapped
(
x
,
*
args
):
for
ys
in
args
:
assert
ys
.
shape
[
0
]
==
num_layers
return
_LayerStackWithState
(
f
,
num_layers
,
unroll
=
unroll
,
name
=
name
)(
x
,
*
args
)
else
:
_check_no_varargs
(
f
)
@
functools
.
wraps
(
f
)
def
wrapped
(
*
args
):
ret
=
_LayerStackNoState
(
f
,
num_layers
,
unroll
=
unroll
,
name
=
name
)(
args
,
None
)[
0
]
if
len
(
args
)
==
1
:
# If the function takes a single argument, we must also return a
# single value, and not a tuple of length 1.
ret
=
ret
[
0
]
return
ret
return
wrapped
return
iterate
alphafold/model/layer_stack_test.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for layer_stack."""
import
functools
from
absl.testing
import
absltest
from
absl.testing
import
parameterized
from
alphafold.model
import
layer_stack
import
haiku
as
hk
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
import
scipy
# Suffixes applied by Haiku for repeated module names.
suffixes
=
[
''
]
+
[
f
'_
{
i
}
'
for
i
in
range
(
1
,
100
)]
def
_slice_layers_params
(
layers_params
):
sliced_layers_params
=
{}
for
k
,
v
in
layers_params
.
items
():
for
inner_k
in
v
:
for
var_slice
,
suffix
in
zip
(
v
[
inner_k
],
suffixes
):
k_new
=
k
.
split
(
'/'
)[
-
1
]
+
suffix
if
k_new
not
in
sliced_layers_params
:
sliced_layers_params
[
k_new
]
=
{}
sliced_layers_params
[
k_new
][
inner_k
]
=
var_slice
return
sliced_layers_params
class
LayerStackTest
(
parameterized
.
TestCase
):
@
parameterized
.
parameters
([
1
,
2
,
4
])
def
test_layer_stack
(
self
,
unroll
):
"""Compare layer_stack to the equivalent unrolled stack.
Tests that the layer_stack application of a Haiku layer function is
equivalent to repeatedly applying the layer function in an unrolled loop.
Args:
unroll: Number of unrolled layers.
"""
num_layers
=
20
def
inner_fn
(
x
):
x
+=
hk
.
Linear
(
100
,
name
=
'linear1'
)(
x
)
x
+=
hk
.
Linear
(
100
,
name
=
'linear2'
)(
x
)
return
x
def
outer_fn_unrolled
(
x
):
for
_
in
range
(
num_layers
):
x
=
inner_fn
(
x
)
return
x
def
outer_fn_layer_stack
(
x
):
stack
=
layer_stack
.
layer_stack
(
num_layers
,
unroll
=
unroll
)(
inner_fn
)
return
stack
(
x
)
unrolled_fn
=
hk
.
transform
(
outer_fn_unrolled
)
layer_stack_fn
=
hk
.
transform
(
outer_fn_layer_stack
)
x
=
jax
.
random
.
uniform
(
jax
.
random
.
PRNGKey
(
0
),
[
10
,
256
,
100
])
rng_init
=
jax
.
random
.
PRNGKey
(
42
)
params
=
layer_stack_fn
.
init
(
rng_init
,
x
)
sliced_params
=
_slice_layers_params
(
params
)
unrolled_pred
=
unrolled_fn
.
apply
(
sliced_params
,
None
,
x
)
layer_stack_pred
=
layer_stack_fn
.
apply
(
params
,
None
,
x
)
np
.
testing
.
assert_allclose
(
unrolled_pred
,
layer_stack_pred
)
def
test_layer_stack_multi_args
(
self
):
"""Compare layer_stack to the equivalent unrolled stack.
Similar to `test_layer_stack`, but use a function that takes more than one
argument.
"""
num_layers
=
20
def
inner_fn
(
x
,
y
):
x_out
=
x
+
hk
.
Linear
(
100
,
name
=
'linear1'
)(
y
)
y_out
=
y
+
hk
.
Linear
(
100
,
name
=
'linear2'
)(
x
)
return
x_out
,
y_out
def
outer_fn_unrolled
(
x
,
y
):
for
_
in
range
(
num_layers
):
x
,
y
=
inner_fn
(
x
,
y
)
return
x
,
y
def
outer_fn_layer_stack
(
x
,
y
):
stack
=
layer_stack
.
layer_stack
(
num_layers
)(
inner_fn
)
return
stack
(
x
,
y
)
unrolled_fn
=
hk
.
transform
(
outer_fn_unrolled
)
layer_stack_fn
=
hk
.
transform
(
outer_fn_layer_stack
)
x
=
jax
.
random
.
uniform
(
jax
.
random
.
PRNGKey
(
0
),
[
10
,
256
,
100
])
y
=
jax
.
random
.
uniform
(
jax
.
random
.
PRNGKey
(
1
),
[
10
,
256
,
100
])
rng_init
=
jax
.
random
.
PRNGKey
(
42
)
params
=
layer_stack_fn
.
init
(
rng_init
,
x
,
y
)
sliced_params
=
_slice_layers_params
(
params
)
unrolled_x
,
unrolled_y
=
unrolled_fn
.
apply
(
sliced_params
,
None
,
x
,
y
)
layer_stack_x
,
layer_stack_y
=
layer_stack_fn
.
apply
(
params
,
None
,
x
,
y
)
np
.
testing
.
assert_allclose
(
unrolled_x
,
layer_stack_x
)
np
.
testing
.
assert_allclose
(
unrolled_y
,
layer_stack_y
)
def
test_layer_stack_no_varargs
(
self
):
"""Test an error is raised when using a function with varargs."""
class
VarArgsModule
(
hk
.
Module
):
"""When used, this module should cause layer_stack to raise an Error."""
def
__call__
(
self
,
*
args
):
return
args
class
NoVarArgsModule
(
hk
.
Module
):
"""This module should be fine to use with layer_stack."""
def
__call__
(
self
,
x
):
return
x
def
build_and_init_stack
(
module_class
):
def
stack_fn
(
x
):
module
=
module_class
()
return
layer_stack
.
layer_stack
(
1
)(
module
)(
x
)
stack
=
hk
.
without_apply_rng
(
hk
.
transform
(
stack_fn
))
stack
.
init
(
jax
.
random
.
PRNGKey
(
1729
),
jnp
.
ones
([
5
]))
build_and_init_stack
(
NoVarArgsModule
)
with
self
.
assertRaisesRegex
(
ValueError
,
'The function `f` should not have any `varargs`'
):
build_and_init_stack
(
VarArgsModule
)
@
parameterized
.
parameters
([
1
,
2
,
4
])
def
test_layer_stack_grads
(
self
,
unroll
):
"""Compare layer_stack gradients to the equivalent unrolled stack.
Tests that the layer_stack application of a Haiku layer function is
equivalent to repeatedly applying the layer function in an unrolled loop.
Args:
unroll: Number of unrolled layers.
"""
num_layers
=
20
def
inner_fn
(
x
):
x
+=
hk
.
Linear
(
100
,
name
=
'linear1'
)(
x
)
x
+=
hk
.
Linear
(
100
,
name
=
'linear2'
)(
x
)
return
x
def
outer_fn_unrolled
(
x
):
for
_
in
range
(
num_layers
):
x
=
inner_fn
(
x
)
return
x
def
outer_fn_layer_stack
(
x
):
stack
=
layer_stack
.
layer_stack
(
num_layers
,
unroll
=
unroll
)(
inner_fn
)
return
stack
(
x
)
unrolled_fn
=
hk
.
transform
(
outer_fn_unrolled
)
layer_stack_fn
=
hk
.
transform
(
outer_fn_layer_stack
)
x
=
jax
.
random
.
uniform
(
jax
.
random
.
PRNGKey
(
0
),
[
10
,
256
,
100
])
rng_init
=
jax
.
random
.
PRNGKey
(
42
)
params
=
layer_stack_fn
.
init
(
rng_init
,
x
)
sliced_params
=
_slice_layers_params
(
params
)
unrolled_grad
=
jax
.
grad
(
lambda
p
,
x
:
jnp
.
mean
(
unrolled_fn
.
apply
(
p
,
None
,
x
)))(
sliced_params
,
x
)
layer_stack_grad
=
jax
.
grad
(
lambda
p
,
x
:
jnp
.
mean
(
layer_stack_fn
.
apply
(
p
,
None
,
x
)))(
params
,
x
)
assert_fn
=
functools
.
partial
(
np
.
testing
.
assert_allclose
,
atol
=
1e-4
,
rtol
=
1e-4
)
jax
.
tree_map
(
assert_fn
,
unrolled_grad
,
_slice_layers_params
(
layer_stack_grad
))
def
test_random
(
self
):
"""Random numbers should be handled correctly."""
n
=
100
@
hk
.
transform
@
layer_stack
.
layer_stack
(
n
)
def
add_random
(
x
):
x
=
x
+
jax
.
random
.
normal
(
hk
.
next_rng_key
())
return
x
# Evaluate a bunch of times
key
,
*
keys
=
jax
.
random
.
split
(
jax
.
random
.
PRNGKey
(
7
),
1024
+
1
)
params
=
add_random
.
init
(
key
,
0.
)
apply_fn
=
jax
.
jit
(
add_random
.
apply
)
values
=
[
apply_fn
(
params
,
key
,
0.
)
for
key
in
keys
]
# Should be roughly N(0, sqrt(n))
cdf
=
scipy
.
stats
.
norm
(
scale
=
np
.
sqrt
(
n
)).
cdf
_
,
p
=
scipy
.
stats
.
kstest
(
values
,
cdf
)
self
.
assertLess
(
0.3
,
p
)
def
test_threading
(
self
):
"""Test @layer_stack when the function gets per-layer state."""
n
=
5
@
layer_stack
.
layer_stack
(
n
,
with_state
=
True
)
def
f
(
x
,
y
):
x
=
x
+
y
*
jax
.
nn
.
one_hot
(
y
,
len
(
x
))
/
10
return
x
,
2
*
y
@
hk
.
without_apply_rng
@
hk
.
transform
def
g
(
x
,
ys
):
x
,
zs
=
f
(
x
,
ys
)
# Check here to catch issues at init time
self
.
assertEqual
(
zs
.
shape
,
(
n
,))
return
x
,
zs
rng
=
jax
.
random
.
PRNGKey
(
7
)
x
=
np
.
zeros
(
n
)
ys
=
np
.
arange
(
n
).
astype
(
np
.
float32
)
params
=
g
.
init
(
rng
,
x
,
ys
)
x
,
zs
=
g
.
apply
(
params
,
x
,
ys
)
self
.
assertTrue
(
np
.
allclose
(
x
,
[
0
,
.
1
,
.
2
,
.
3
,
.
4
]))
self
.
assertTrue
(
np
.
all
(
zs
==
2
*
ys
))
def
test_nested_stacks
(
self
):
def
stack_fn
(
x
):
def
layer_fn
(
x
):
return
hk
.
Linear
(
100
)(
x
)
outer_fn
=
layer_stack
.
layer_stack
(
10
)(
layer_fn
)
layer_outer
=
layer_stack
.
layer_stack
(
20
)(
outer_fn
)
return
layer_outer
(
x
)
hk_mod
=
hk
.
transform
(
stack_fn
)
apply_rng
,
init_rng
=
jax
.
random
.
split
(
jax
.
random
.
PRNGKey
(
0
))
params
=
hk_mod
.
init
(
init_rng
,
jnp
.
zeros
([
10
,
100
]))
hk_mod
.
apply
(
params
,
apply_rng
,
jnp
.
zeros
([
10
,
100
]))
p
,
=
params
.
values
()
assert
p
[
'w'
].
shape
==
(
10
,
20
,
100
,
100
)
assert
p
[
'b'
].
shape
==
(
10
,
20
,
100
)
def
test_with_state_multi_args
(
self
):
"""Test layer_stack with state with multiple arguments."""
width
=
4
batch_size
=
5
stack_height
=
3
def
f_with_multi_args
(
x
,
a
,
b
):
return
hk
.
Linear
(
width
,
w_init
=
hk
.
initializers
.
Constant
(
jnp
.
eye
(
width
)))(
x
)
*
a
+
b
,
None
@
hk
.
without_apply_rng
@
hk
.
transform
def
hk_fn
(
x
):
return
layer_stack
.
layer_stack
(
stack_height
,
with_state
=
True
)(
f_with_multi_args
)(
x
,
jnp
.
full
([
stack_height
],
2.
),
jnp
.
ones
([
stack_height
]))
x
=
jnp
.
zeros
([
batch_size
,
width
])
key_seq
=
hk
.
PRNGSequence
(
19
)
params
=
hk_fn
.
init
(
next
(
key_seq
),
x
)
output
,
z
=
hk_fn
.
apply
(
params
,
x
)
self
.
assertIsNone
(
z
)
self
.
assertEqual
(
output
.
shape
,
(
batch_size
,
width
))
np
.
testing
.
assert_equal
(
output
,
np
.
full
([
batch_size
,
width
],
7.
))
def
test_with_container_state
(
self
):
width
=
2
batch_size
=
2
stack_height
=
3
def
f_with_container_state
(
x
):
hk_layer
=
hk
.
Linear
(
width
,
w_init
=
hk
.
initializers
.
Constant
(
jnp
.
eye
(
width
)))
layer_output
=
hk_layer
(
x
)
layer_state
=
{
'raw_output'
:
layer_output
,
'output_projection'
:
jnp
.
sum
(
layer_output
)
}
return
layer_output
+
jnp
.
ones_like
(
layer_output
),
layer_state
@
hk
.
without_apply_rng
@
hk
.
transform
def
hk_fn
(
x
):
return
layer_stack
.
layer_stack
(
stack_height
,
with_state
=
True
)(
f_with_container_state
)(
x
)
x
=
jnp
.
zeros
([
batch_size
,
width
])
key_seq
=
hk
.
PRNGSequence
(
19
)
params
=
hk_fn
.
init
(
next
(
key_seq
),
x
)
output
,
z
=
hk_fn
.
apply
(
params
,
x
)
self
.
assertEqual
(
z
[
'raw_output'
].
shape
,
(
stack_height
,
batch_size
,
width
))
self
.
assertEqual
(
output
.
shape
,
(
batch_size
,
width
))
self
.
assertEqual
(
z
[
'output_projection'
].
shape
,
(
stack_height
,))
np
.
testing
.
assert_equal
(
np
.
sum
(
z
[
'output_projection'
]),
np
.
array
(
12.
))
np
.
testing
.
assert_equal
(
np
.
all
(
z
[
'raw_output'
]
==
np
.
array
([
0.
,
1.
,
2.
])[...,
None
,
None
]),
np
.
array
(
True
))
if
__name__
==
'__main__'
:
absltest
.
main
()
alphafold/model/lddt.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""lDDT protein distance score."""
import
jax.numpy
as
jnp
def
lddt
(
predicted_points
,
true_points
,
true_points_mask
,
cutoff
=
15.
,
per_residue
=
False
):
"""Measure (approximate) lDDT for a batch of coordinates.
lDDT reference:
Mariani, V., Biasini, M., Barbato, A. & Schwede, T. lDDT: A local
superposition-free score for comparing protein structures and models using
distance difference tests. Bioinformatics 29, 2722–2728 (2013).
lDDT is a measure of the difference between the true distance matrix and the
distance matrix of the predicted points. The difference is computed only on
points closer than cutoff *in the true structure*.
This function does not compute the exact lDDT value that the original paper
describes because it does not include terms for physical feasibility
(e.g. bond length violations). Therefore this is only an approximate
lDDT score.
Args:
predicted_points: (batch, length, 3) array of predicted 3D points
true_points: (batch, length, 3) array of true 3D points
true_points_mask: (batch, length, 1) binary-valued float array. This mask
should be 1 for points that exist in the true points.
cutoff: Maximum distance for a pair of points to be included
per_residue: If true, return score for each residue. Note that the overall
lDDT is not exactly the mean of the per_residue lDDT's because some
residues have more contacts than others.
Returns:
An (approximate, see above) lDDT score in the range 0-1.
"""
assert
len
(
predicted_points
.
shape
)
==
3
assert
predicted_points
.
shape
[
-
1
]
==
3
assert
true_points_mask
.
shape
[
-
1
]
==
1
assert
len
(
true_points_mask
.
shape
)
==
3
# Compute true and predicted distance matrices.
dmat_true
=
jnp
.
sqrt
(
1e-10
+
jnp
.
sum
(
(
true_points
[:,
:,
None
]
-
true_points
[:,
None
,
:])
**
2
,
axis
=-
1
))
dmat_predicted
=
jnp
.
sqrt
(
1e-10
+
jnp
.
sum
(
(
predicted_points
[:,
:,
None
]
-
predicted_points
[:,
None
,
:])
**
2
,
axis
=-
1
))
dists_to_score
=
(
(
dmat_true
<
cutoff
).
astype
(
jnp
.
float32
)
*
true_points_mask
*
jnp
.
transpose
(
true_points_mask
,
[
0
,
2
,
1
])
*
(
1.
-
jnp
.
eye
(
dmat_true
.
shape
[
1
]))
# Exclude self-interaction.
)
# Shift unscored distances to be far away.
dist_l1
=
jnp
.
abs
(
dmat_true
-
dmat_predicted
)
# True lDDT uses a number of fixed bins.
# We ignore the physical plausibility correction to lDDT, though.
score
=
0.25
*
((
dist_l1
<
0.5
).
astype
(
jnp
.
float32
)
+
(
dist_l1
<
1.0
).
astype
(
jnp
.
float32
)
+
(
dist_l1
<
2.0
).
astype
(
jnp
.
float32
)
+
(
dist_l1
<
4.0
).
astype
(
jnp
.
float32
))
# Normalize over the appropriate axes.
reduce_axes
=
(
-
1
,)
if
per_residue
else
(
-
2
,
-
1
)
norm
=
1.
/
(
1e-10
+
jnp
.
sum
(
dists_to_score
,
axis
=
reduce_axes
))
score
=
norm
*
(
1e-10
+
jnp
.
sum
(
dists_to_score
*
score
,
axis
=
reduce_axes
))
return
score
alphafold/model/lddt_test.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for lddt."""
from
absl.testing
import
absltest
from
absl.testing
import
parameterized
from
alphafold.model
import
lddt
import
numpy
as
np
class
LddtTest
(
parameterized
.
TestCase
,
absltest
.
TestCase
):
@
parameterized
.
named_parameters
(
(
'same'
,
[[
0
,
0
,
0
],
[
5
,
0
,
0
],
[
10
,
0
,
0
]],
[[
0
,
0
,
0
],
[
5
,
0
,
0
],
[
10
,
0
,
0
]],
[
1
,
1
,
1
]),
(
'all_shifted'
,
[[
0
,
0
,
0
],
[
5
,
0
,
0
],
[
10
,
0
,
0
]],
[[
-
1
,
0
,
0
],
[
4
,
0
,
0
],
[
9
,
0
,
0
]],
[
1
,
1
,
1
]),
(
'all_rotated'
,
[[
0
,
0
,
0
],
[
5
,
0
,
0
],
[
10
,
0
,
0
]],
[[
0
,
0
,
0
],
[
0
,
5
,
0
],
[
0
,
10
,
0
]],
[
1
,
1
,
1
]),
(
'half_a_dist'
,
[[
0
,
0
,
0
],
[
5
,
0
,
0
]],
[[
0
,
0
,
0
],
[
5.5
-
1e-5
,
0
,
0
]],
[
1
,
1
]),
(
'one_a_dist'
,
[[
0
,
0
,
0
],
[
5
,
0
,
0
]],
[[
0
,
0
,
0
],
[
6
-
1e-5
,
0
,
0
]],
[
0.75
,
0.75
]),
(
'two_a_dist'
,
[[
0
,
0
,
0
],
[
5
,
0
,
0
]],
[[
0
,
0
,
0
],
[
7
-
1e-5
,
0
,
0
]],
[
0.5
,
0.5
]),
(
'four_a_dist'
,
[[
0
,
0
,
0
],
[
5
,
0
,
0
]],
[[
0
,
0
,
0
],
[
9
-
1e-5
,
0
,
0
]],
[
0.25
,
0.25
],),
(
'five_a_dist'
,
[[
0
,
0
,
0
],
[
16
-
1e-5
,
0
,
0
]],
[[
0
,
0
,
0
],
[
11
,
0
,
0
]],
[
0
,
0
]),
(
'no_pairs'
,
[[
0
,
0
,
0
],
[
20
,
0
,
0
]],
[[
0
,
0
,
0
],
[
25
-
1e-5
,
0
,
0
]],
[
1
,
1
]),
)
def
test_lddt
(
self
,
predicted_pos
,
true_pos
,
exp_lddt
):
predicted_pos
=
np
.
array
([
predicted_pos
],
dtype
=
np
.
float32
)
true_points_mask
=
np
.
array
([[[
1
]]
*
len
(
true_pos
)],
dtype
=
np
.
float32
)
true_pos
=
np
.
array
([
true_pos
],
dtype
=
np
.
float32
)
cutoff
=
15.0
per_residue
=
True
result
=
lddt
.
lddt
(
predicted_pos
,
true_pos
,
true_points_mask
,
cutoff
,
per_residue
)
np
.
testing
.
assert_almost_equal
(
result
,
[
exp_lddt
],
decimal
=
4
)
if
__name__
==
'__main__'
:
absltest
.
main
()
alphafold/model/mapping.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Specialized mapping functions."""
import
functools
import
inspect
from
typing
import
Any
,
Callable
,
Optional
,
Sequence
,
Union
import
haiku
as
hk
import
jax
import
jax.numpy
as
jnp
PYTREE
=
Any
PYTREE_JAX_ARRAY
=
Any
partial
=
functools
.
partial
PROXY
=
object
()
def
_maybe_slice
(
array
,
i
,
slice_size
,
axis
):
if
axis
is
PROXY
:
return
array
else
:
return
jax
.
lax
.
dynamic_slice_in_dim
(
array
,
i
,
slice_size
=
slice_size
,
axis
=
axis
)
def
_maybe_get_size
(
array
,
axis
):
if
axis
==
PROXY
:
return
-
1
else
:
return
array
.
shape
[
axis
]
def
_expand_axes
(
axes
,
values
,
name
=
'sharded_apply'
):
values_tree_def
=
jax
.
tree_flatten
(
values
)[
1
]
flat_axes
=
jax
.
api_util
.
flatten_axes
(
name
,
values_tree_def
,
axes
)
# Replace None's with PROXY
flat_axes
=
[
PROXY
if
x
is
None
else
x
for
x
in
flat_axes
]
return
jax
.
tree_unflatten
(
values_tree_def
,
flat_axes
)
def
sharded_map
(
fun
:
Callable
[...,
PYTREE_JAX_ARRAY
],
shard_size
:
Union
[
int
,
None
]
=
1
,
in_axes
:
Union
[
int
,
PYTREE
]
=
0
,
out_axes
:
Union
[
int
,
PYTREE
]
=
0
)
->
Callable
[...,
PYTREE_JAX_ARRAY
]:
"""Sharded vmap.
Maps `fun` over axes, in a way similar to vmap, but does so in shards of
`shard_size`. This allows a smooth trade-off between memory usage
(as in a plain map) vs higher throughput (as in a vmap).
Args:
fun: Function to apply smap transform to.
shard_size: Integer denoting shard size.
in_axes: Either integer or pytree describing which axis to map over for each
input to `fun`, None denotes broadcasting.
out_axes: integer or pytree denoting to what axis in the output the mapped
over axis maps.
Returns:
function with smap applied.
"""
if
'split_rng'
in
inspect
.
signature
(
hk
.
vmap
).
parameters
:
vmapped_fun
=
hk
.
vmap
(
fun
,
in_axes
,
out_axes
,
split_rng
=
False
)
else
:
# TODO(tomhennigan): Remove this when older versions of Haiku aren't used.
vmapped_fun
=
hk
.
vmap
(
fun
,
in_axes
,
out_axes
)
return
sharded_apply
(
vmapped_fun
,
shard_size
,
in_axes
,
out_axes
)
def
sharded_apply
(
fun
:
Callable
[...,
PYTREE_JAX_ARRAY
],
# pylint: disable=g-bare-generic
shard_size
:
Union
[
int
,
None
]
=
1
,
in_axes
:
Union
[
int
,
PYTREE
]
=
0
,
out_axes
:
Union
[
int
,
PYTREE
]
=
0
,
new_out_axes
:
bool
=
False
)
->
Callable
[...,
PYTREE_JAX_ARRAY
]:
"""Sharded apply.
Applies `fun` over shards to axes, in a way similar to vmap,
but does so in shards of `shard_size`. Shards are stacked after.
This allows a smooth trade-off between
memory usage (as in a plain map) vs higher throughput (as in a vmap).
Args:
fun: Function to apply smap transform to.
shard_size: Integer denoting shard size.
in_axes: Either integer or pytree describing which axis to map over for each
input to `fun`, None denotes broadcasting.
out_axes: integer or pytree denoting to what axis in the output the mapped
over axis maps.
new_out_axes: whether to stack outputs on new axes. This assumes that the
output sizes for each shard (including the possible remainder shard) are
the same.
Returns:
function with smap applied.
"""
docstr
=
(
'Mapped version of {fun}. Takes similar arguments to {fun} '
'but with additional array axes over which {fun} is mapped.'
)
if
new_out_axes
:
raise
NotImplementedError
(
'New output axes not yet implemented.'
)
# shard size None denotes no sharding
if
shard_size
is
None
:
return
fun
@
jax
.
util
.
wraps
(
fun
,
docstr
=
docstr
)
def
mapped_fn
(
*
args
):
# Expand in axes and Determine Loop range
in_axes_
=
_expand_axes
(
in_axes
,
args
)
in_sizes
=
jax
.
tree_map
(
_maybe_get_size
,
args
,
in_axes_
)
flat_sizes
=
jax
.
tree_flatten
(
in_sizes
)[
0
]
in_size
=
max
(
flat_sizes
)
assert
all
(
i
in
{
in_size
,
-
1
}
for
i
in
flat_sizes
)
num_extra_shards
=
(
in_size
-
1
)
//
shard_size
# Fix Up if necessary
last_shard_size
=
in_size
%
shard_size
last_shard_size
=
shard_size
if
last_shard_size
==
0
else
last_shard_size
def
apply_fun_to_slice
(
slice_start
,
slice_size
):
input_slice
=
jax
.
tree_map
(
lambda
array
,
axis
:
_maybe_slice
(
array
,
slice_start
,
slice_size
,
axis
),
args
,
in_axes_
)
return
fun
(
*
input_slice
)
remainder_shape_dtype
=
hk
.
eval_shape
(
partial
(
apply_fun_to_slice
,
0
,
last_shard_size
))
out_dtypes
=
jax
.
tree_map
(
lambda
x
:
x
.
dtype
,
remainder_shape_dtype
)
out_shapes
=
jax
.
tree_map
(
lambda
x
:
x
.
shape
,
remainder_shape_dtype
)
out_axes_
=
_expand_axes
(
out_axes
,
remainder_shape_dtype
)
if
num_extra_shards
>
0
:
regular_shard_shape_dtype
=
hk
.
eval_shape
(
partial
(
apply_fun_to_slice
,
0
,
shard_size
))
shard_shapes
=
jax
.
tree_map
(
lambda
x
:
x
.
shape
,
regular_shard_shape_dtype
)
def
make_output_shape
(
axis
,
shard_shape
,
remainder_shape
):
return
shard_shape
[:
axis
]
+
(
shard_shape
[
axis
]
*
num_extra_shards
+
remainder_shape
[
axis
],)
+
shard_shape
[
axis
+
1
:]
out_shapes
=
jax
.
tree_map
(
make_output_shape
,
out_axes_
,
shard_shapes
,
out_shapes
)
# Calls dynamic Update slice with different argument order
# This is here since tree_map only works with positional arguments
def
dynamic_update_slice_in_dim
(
full_array
,
update
,
axis
,
i
):
return
jax
.
lax
.
dynamic_update_slice_in_dim
(
full_array
,
update
,
i
,
axis
)
def
compute_shard
(
outputs
,
slice_start
,
slice_size
):
slice_out
=
apply_fun_to_slice
(
slice_start
,
slice_size
)
update_slice
=
partial
(
dynamic_update_slice_in_dim
,
i
=
slice_start
)
return
jax
.
tree_map
(
update_slice
,
outputs
,
slice_out
,
out_axes_
)
def
scan_iteration
(
outputs
,
i
):
new_outputs
=
compute_shard
(
outputs
,
i
,
shard_size
)
return
new_outputs
,
()
slice_starts
=
jnp
.
arange
(
0
,
in_size
-
shard_size
+
1
,
shard_size
)
def
allocate_buffer
(
dtype
,
shape
):
return
jnp
.
zeros
(
shape
,
dtype
=
dtype
)
outputs
=
jax
.
tree_map
(
allocate_buffer
,
out_dtypes
,
out_shapes
)
if
slice_starts
.
shape
[
0
]
>
0
:
outputs
,
_
=
hk
.
scan
(
scan_iteration
,
outputs
,
slice_starts
)
if
last_shard_size
!=
shard_size
:
remainder_start
=
in_size
-
last_shard_size
outputs
=
compute_shard
(
outputs
,
remainder_start
,
last_shard_size
)
return
outputs
return
mapped_fn
def
inference_subbatch
(
module
:
Callable
[...,
PYTREE_JAX_ARRAY
],
subbatch_size
:
int
,
batched_args
:
Sequence
[
PYTREE_JAX_ARRAY
],
nonbatched_args
:
Sequence
[
PYTREE_JAX_ARRAY
],
low_memory
:
bool
=
True
,
input_subbatch_dim
:
int
=
0
,
output_subbatch_dim
:
Optional
[
int
]
=
None
)
->
PYTREE_JAX_ARRAY
:
"""Run through subbatches (like batch apply but with split and concat)."""
assert
len
(
batched_args
)
>
0
# pylint: disable=g-explicit-length-test
if
not
low_memory
:
args
=
list
(
batched_args
)
+
list
(
nonbatched_args
)
return
module
(
*
args
)
if
output_subbatch_dim
is
None
:
output_subbatch_dim
=
input_subbatch_dim
def
run_module
(
*
batched_args
):
args
=
list
(
batched_args
)
+
list
(
nonbatched_args
)
return
module
(
*
args
)
sharded_module
=
sharded_apply
(
run_module
,
shard_size
=
subbatch_size
,
in_axes
=
input_subbatch_dim
,
out_axes
=
output_subbatch_dim
)
return
sharded_module
(
*
batched_args
)
alphafold/model/model.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code for constructing the model."""
from
typing
import
Any
,
Mapping
,
Optional
,
Union
from
absl
import
logging
from
alphafold.common
import
confidence
from
alphafold.model
import
features
from
alphafold.model
import
modules
from
alphafold.model
import
modules_multimer
import
haiku
as
hk
import
jax
import
ml_collections
import
numpy
as
np
import
tensorflow.compat.v1
as
tf
import
tree
def
get_confidence_metrics
(
prediction_result
:
Mapping
[
str
,
Any
],
multimer_mode
:
bool
)
->
Mapping
[
str
,
Any
]:
"""Post processes prediction_result to get confidence metrics."""
confidence_metrics
=
{}
confidence_metrics
[
'plddt'
]
=
confidence
.
compute_plddt
(
prediction_result
[
'predicted_lddt'
][
'logits'
])
if
'predicted_aligned_error'
in
prediction_result
:
confidence_metrics
.
update
(
confidence
.
compute_predicted_aligned_error
(
logits
=
prediction_result
[
'predicted_aligned_error'
][
'logits'
],
breaks
=
prediction_result
[
'predicted_aligned_error'
][
'breaks'
]))
confidence_metrics
[
'ptm'
]
=
confidence
.
predicted_tm_score
(
logits
=
prediction_result
[
'predicted_aligned_error'
][
'logits'
],
breaks
=
prediction_result
[
'predicted_aligned_error'
][
'breaks'
],
asym_id
=
None
)
if
multimer_mode
:
# Compute the ipTM only for the multimer model.
confidence_metrics
[
'iptm'
]
=
confidence
.
predicted_tm_score
(
logits
=
prediction_result
[
'predicted_aligned_error'
][
'logits'
],
breaks
=
prediction_result
[
'predicted_aligned_error'
][
'breaks'
],
asym_id
=
prediction_result
[
'predicted_aligned_error'
][
'asym_id'
],
interface
=
True
)
confidence_metrics
[
'ranking_confidence'
]
=
(
0.8
*
confidence_metrics
[
'iptm'
]
+
0.2
*
confidence_metrics
[
'ptm'
])
if
not
multimer_mode
:
# Monomer models use mean pLDDT for model ranking.
confidence_metrics
[
'ranking_confidence'
]
=
np
.
mean
(
confidence_metrics
[
'plddt'
])
return
confidence_metrics
class
RunModel
:
"""Container for JAX model."""
def
__init__
(
self
,
config
:
ml_collections
.
ConfigDict
,
params
:
Optional
[
Mapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]]]
=
None
):
self
.
config
=
config
self
.
params
=
params
self
.
multimer_mode
=
config
.
model
.
global_config
.
multimer_mode
if
self
.
multimer_mode
:
def
_forward_fn
(
batch
):
model
=
modules_multimer
.
AlphaFold
(
self
.
config
.
model
)
return
model
(
batch
,
is_training
=
False
)
else
:
def
_forward_fn
(
batch
):
model
=
modules
.
AlphaFold
(
self
.
config
.
model
)
return
model
(
batch
,
is_training
=
False
,
compute_loss
=
False
,
ensemble_representations
=
True
)
self
.
apply
=
jax
.
jit
(
hk
.
transform
(
_forward_fn
).
apply
)
self
.
init
=
jax
.
jit
(
hk
.
transform
(
_forward_fn
).
init
)
def
init_params
(
self
,
feat
:
features
.
FeatureDict
,
random_seed
:
int
=
0
):
"""Initializes the model parameters.
If none were provided when this class was instantiated then the parameters
are randomly initialized.
Args:
feat: A dictionary of NumPy feature arrays as output by
RunModel.process_features.
random_seed: A random seed to use to initialize the parameters if none
were set when this class was initialized.
"""
if
not
self
.
params
:
# Init params randomly.
rng
=
jax
.
random
.
PRNGKey
(
random_seed
)
self
.
params
=
hk
.
data_structures
.
to_mutable_dict
(
self
.
init
(
rng
,
feat
))
logging
.
warning
(
'Initialized parameters randomly'
)
def
process_features
(
self
,
raw_features
:
Union
[
tf
.
train
.
Example
,
features
.
FeatureDict
],
random_seed
:
int
)
->
features
.
FeatureDict
:
"""Processes features to prepare for feeding them into the model.
Args:
raw_features: The output of the data pipeline either as a dict of NumPy
arrays or as a tf.train.Example.
random_seed: The random seed to use when processing the features.
Returns:
A dict of NumPy feature arrays suitable for feeding into the model.
"""
if
self
.
multimer_mode
:
return
raw_features
# Single-chain mode.
if
isinstance
(
raw_features
,
dict
):
return
features
.
np_example_to_features
(
np_example
=
raw_features
,
config
=
self
.
config
,
random_seed
=
random_seed
)
else
:
return
features
.
tf_example_to_features
(
tf_example
=
raw_features
,
config
=
self
.
config
,
random_seed
=
random_seed
)
def
eval_shape
(
self
,
feat
:
features
.
FeatureDict
)
->
jax
.
ShapeDtypeStruct
:
self
.
init_params
(
feat
)
logging
.
info
(
'Running eval_shape with shape(feat) = %s'
,
tree
.
map_structure
(
lambda
x
:
x
.
shape
,
feat
))
shape
=
jax
.
eval_shape
(
self
.
apply
,
self
.
params
,
jax
.
random
.
PRNGKey
(
0
),
feat
)
logging
.
info
(
'Output shape was %s'
,
shape
)
return
shape
def
predict
(
self
,
feat
:
features
.
FeatureDict
,
random_seed
:
int
,
)
->
Mapping
[
str
,
Any
]:
"""Makes a prediction by inferencing the model on the provided features.
Args:
feat: A dictionary of NumPy feature arrays as output by
RunModel.process_features.
random_seed: The random seed to use when running the model. In the
multimer model this controls the MSA sampling.
Returns:
A dictionary of model outputs.
"""
self
.
init_params
(
feat
)
logging
.
info
(
'Running predict with shape(feat) = %s'
,
tree
.
map_structure
(
lambda
x
:
x
.
shape
,
feat
))
result
=
self
.
apply
(
self
.
params
,
jax
.
random
.
PRNGKey
(
random_seed
),
feat
)
# This block is to ensure benchmark timings are accurate. Some blocking is
# already happening when computing get_confidence_metrics, and this ensures
# all outputs are blocked on.
jax
.
tree_map
(
lambda
x
:
x
.
block_until_ready
(),
result
)
result
.
update
(
get_confidence_metrics
(
result
,
multimer_mode
=
self
.
multimer_mode
))
logging
.
info
(
'Output shape was %s'
,
tree
.
map_structure
(
lambda
x
:
x
.
shape
,
result
))
return
result
alphafold/model/modules.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modules and code used in the core part of AlphaFold.
The structure generation code is in 'folding.py'.
"""
import
functools
from
alphafold.common
import
residue_constants
from
alphafold.model
import
all_atom
from
alphafold.model
import
common_modules
from
alphafold.model
import
folding
from
alphafold.model
import
layer_stack
from
alphafold.model
import
lddt
from
alphafold.model
import
mapping
from
alphafold.model
import
prng
from
alphafold.model
import
quat_affine
from
alphafold.model
import
utils
import
haiku
as
hk
import
jax
import
jax.numpy
as
jnp
def
softmax_cross_entropy
(
logits
,
labels
):
"""Computes softmax cross entropy given logits and one-hot class labels."""
loss
=
-
jnp
.
sum
(
labels
*
jax
.
nn
.
log_softmax
(
logits
),
axis
=-
1
)
return
jnp
.
asarray
(
loss
)
def
sigmoid_cross_entropy
(
logits
,
labels
):
"""Computes sigmoid cross entropy given logits and multiple class labels."""
log_p
=
jax
.
nn
.
log_sigmoid
(
logits
)
# log(1 - sigmoid(x)) = log_sigmoid(-x), the latter is more numerically stable
log_not_p
=
jax
.
nn
.
log_sigmoid
(
-
logits
)
loss
=
-
labels
*
log_p
-
(
1.
-
labels
)
*
log_not_p
return
jnp
.
asarray
(
loss
)
def
apply_dropout
(
*
,
tensor
,
safe_key
,
rate
,
is_training
,
broadcast_dim
=
None
):
"""Applies dropout to a tensor."""
if
is_training
and
rate
!=
0.0
:
shape
=
list
(
tensor
.
shape
)
if
broadcast_dim
is
not
None
:
shape
[
broadcast_dim
]
=
1
keep_rate
=
1.0
-
rate
keep
=
jax
.
random
.
bernoulli
(
safe_key
.
get
(),
keep_rate
,
shape
=
shape
)
return
keep
*
tensor
/
keep_rate
else
:
return
tensor
def
dropout_wrapper
(
module
,
input_act
,
mask
,
safe_key
,
global_config
,
output_act
=
None
,
is_training
=
True
,
**
kwargs
):
"""Applies module + dropout + residual update."""
if
output_act
is
None
:
output_act
=
input_act
gc
=
global_config
residual
=
module
(
input_act
,
mask
,
is_training
=
is_training
,
**
kwargs
)
dropout_rate
=
0.0
if
gc
.
deterministic
else
module
.
config
.
dropout_rate
if
module
.
config
.
shared_dropout
:
if
module
.
config
.
orientation
==
'per_row'
:
broadcast_dim
=
0
else
:
broadcast_dim
=
1
else
:
broadcast_dim
=
None
residual
=
apply_dropout
(
tensor
=
residual
,
safe_key
=
safe_key
,
rate
=
dropout_rate
,
is_training
=
is_training
,
broadcast_dim
=
broadcast_dim
)
new_act
=
output_act
+
residual
return
new_act
def
create_extra_msa_feature
(
batch
):
"""Expand extra_msa into 1hot and concat with other extra msa features.
We do this as late as possible as the one_hot extra msa can be very large.
Arguments:
batch: a dictionary with the following keys:
* 'extra_msa': [N_extra_seq, N_res] MSA that wasn't selected as a cluster
centre. Note, that this is not one-hot encoded.
* 'extra_has_deletion': [N_extra_seq, N_res] Whether there is a deletion to
the left of each position in the extra MSA.
* 'extra_deletion_value': [N_extra_seq, N_res] The number of deletions to
the left of each position in the extra MSA.
Returns:
Concatenated tensor of extra MSA features.
"""
# 23 = 20 amino acids + 'X' for unknown + gap + bert mask
msa_1hot
=
jax
.
nn
.
one_hot
(
batch
[
'extra_msa'
],
23
)
msa_feat
=
[
msa_1hot
,
jnp
.
expand_dims
(
batch
[
'extra_has_deletion'
],
axis
=-
1
),
jnp
.
expand_dims
(
batch
[
'extra_deletion_value'
],
axis
=-
1
)]
return
jnp
.
concatenate
(
msa_feat
,
axis
=-
1
)
class
AlphaFoldIteration
(
hk
.
Module
):
"""A single recycling iteration of AlphaFold architecture.
Computes ensembled (averaged) representations from the provided features.
These representations are then passed to the various heads
that have been requested by the configuration file. Each head also returns a
loss which is combined as a weighted sum to produce the total loss.
Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 3-22
"""
def
__init__
(
self
,
config
,
global_config
,
name
=
'alphafold_iteration'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
ensembled_batch
,
non_ensembled_batch
,
is_training
,
compute_loss
=
False
,
ensemble_representations
=
False
,
return_representations
=
False
):
num_ensemble
=
jnp
.
asarray
(
ensembled_batch
[
'seq_length'
].
shape
[
0
])
if
not
ensemble_representations
:
assert
ensembled_batch
[
'seq_length'
].
shape
[
0
]
==
1
def
slice_batch
(
i
):
b
=
{
k
:
v
[
i
]
for
k
,
v
in
ensembled_batch
.
items
()}
b
.
update
(
non_ensembled_batch
)
return
b
# Compute representations for each batch element and average.
evoformer_module
=
EmbeddingsAndEvoformer
(
self
.
config
.
embeddings_and_evoformer
,
self
.
global_config
)
batch0
=
slice_batch
(
0
)
representations
=
evoformer_module
(
batch0
,
is_training
)
# MSA representations are not ensembled so
# we don't pass tensor into the loop.
msa_representation
=
representations
[
'msa'
]
del
representations
[
'msa'
]
# Average the representations (except MSA) over the batch dimension.
if
ensemble_representations
:
def
body
(
x
):
"""Add one element to the representations ensemble."""
i
,
current_representations
=
x
feats
=
slice_batch
(
i
)
representations_update
=
evoformer_module
(
feats
,
is_training
)
new_representations
=
{}
for
k
in
current_representations
:
new_representations
[
k
]
=
(
current_representations
[
k
]
+
representations_update
[
k
])
return
i
+
1
,
new_representations
if
hk
.
running_init
():
# When initializing the Haiku module, run one iteration of the
# while_loop to initialize the Haiku modules used in `body`.
_
,
representations
=
body
((
1
,
representations
))
else
:
_
,
representations
=
hk
.
while_loop
(
lambda
x
:
x
[
0
]
<
num_ensemble
,
body
,
(
1
,
representations
))
for
k
in
representations
:
if
k
!=
'msa'
:
representations
[
k
]
/=
num_ensemble
.
astype
(
representations
[
k
].
dtype
)
representations
[
'msa'
]
=
msa_representation
batch
=
batch0
# We are not ensembled from here on.
heads
=
{}
for
head_name
,
head_config
in
sorted
(
self
.
config
.
heads
.
items
()):
if
not
head_config
.
weight
:
continue
# Do not instantiate zero-weight heads.
head_factory
=
{
'masked_msa'
:
MaskedMsaHead
,
'distogram'
:
DistogramHead
,
'structure_module'
:
functools
.
partial
(
folding
.
StructureModule
,
compute_loss
=
compute_loss
),
'predicted_lddt'
:
PredictedLDDTHead
,
'predicted_aligned_error'
:
PredictedAlignedErrorHead
,
'experimentally_resolved'
:
ExperimentallyResolvedHead
,
}[
head_name
]
heads
[
head_name
]
=
(
head_config
,
head_factory
(
head_config
,
self
.
global_config
))
total_loss
=
0.
ret
=
{}
ret
[
'representations'
]
=
representations
def
loss
(
module
,
head_config
,
ret
,
name
,
filter_ret
=
True
):
if
filter_ret
:
value
=
ret
[
name
]
else
:
value
=
ret
loss_output
=
module
.
loss
(
value
,
batch
)
ret
[
name
].
update
(
loss_output
)
loss
=
head_config
.
weight
*
ret
[
name
][
'loss'
]
return
loss
for
name
,
(
head_config
,
module
)
in
heads
.
items
():
# Skip PredictedLDDTHead and PredictedAlignedErrorHead until
# StructureModule is executed.
if
name
in
(
'predicted_lddt'
,
'predicted_aligned_error'
):
continue
else
:
ret
[
name
]
=
module
(
representations
,
batch
,
is_training
)
if
'representations'
in
ret
[
name
]:
# Extra representations from the head. Used by the structure module
# to provide activations for the PredictedLDDTHead.
representations
.
update
(
ret
[
name
].
pop
(
'representations'
))
if
compute_loss
:
total_loss
+=
loss
(
module
,
head_config
,
ret
,
name
)
if
self
.
config
.
heads
.
get
(
'predicted_lddt.weight'
,
0.0
):
# Add PredictedLDDTHead after StructureModule executes.
name
=
'predicted_lddt'
# Feed all previous results to give access to structure_module result.
head_config
,
module
=
heads
[
name
]
ret
[
name
]
=
module
(
representations
,
batch
,
is_training
)
if
compute_loss
:
total_loss
+=
loss
(
module
,
head_config
,
ret
,
name
,
filter_ret
=
False
)
if
(
'predicted_aligned_error'
in
self
.
config
.
heads
and
self
.
config
.
heads
.
get
(
'predicted_aligned_error.weight'
,
0.0
)):
# Add PredictedAlignedErrorHead after StructureModule executes.
name
=
'predicted_aligned_error'
# Feed all previous results to give access to structure_module result.
head_config
,
module
=
heads
[
name
]
ret
[
name
]
=
module
(
representations
,
batch
,
is_training
)
if
compute_loss
:
total_loss
+=
loss
(
module
,
head_config
,
ret
,
name
,
filter_ret
=
False
)
if
compute_loss
:
return
ret
,
total_loss
else
:
return
ret
class
AlphaFold
(
hk
.
Module
):
"""AlphaFold model with recycling.
Jumper et al. (2021) Suppl. Alg. 2 "Inference"
"""
def
__init__
(
self
,
config
,
name
=
'alphafold'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
config
.
global_config
def
__call__
(
self
,
batch
,
is_training
,
compute_loss
=
False
,
ensemble_representations
=
False
,
return_representations
=
False
):
"""Run the AlphaFold model.
Arguments:
batch: Dictionary with inputs to the AlphaFold model.
is_training: Whether the system is in training or inference mode.
compute_loss: Whether to compute losses (requires extra features
to be present in the batch and knowing the true structure).
ensemble_representations: Whether to use ensembling of representations.
return_representations: Whether to also return the intermediate
representations.
Returns:
When compute_loss is True:
a tuple of loss and output of AlphaFoldIteration.
When compute_loss is False:
just output of AlphaFoldIteration.
The output of AlphaFoldIteration is a nested dictionary containing
predictions from the various heads.
"""
impl
=
AlphaFoldIteration
(
self
.
config
,
self
.
global_config
)
batch_size
,
num_residues
=
batch
[
'aatype'
].
shape
def
get_prev
(
ret
):
new_prev
=
{
'prev_pos'
:
ret
[
'structure_module'
][
'final_atom_positions'
],
'prev_msa_first_row'
:
ret
[
'representations'
][
'msa_first_row'
],
'prev_pair'
:
ret
[
'representations'
][
'pair'
],
}
return
jax
.
tree_map
(
jax
.
lax
.
stop_gradient
,
new_prev
)
def
do_call
(
prev
,
recycle_idx
,
compute_loss
=
compute_loss
):
if
self
.
config
.
resample_msa_in_recycling
:
num_ensemble
=
batch_size
//
(
self
.
config
.
num_recycle
+
1
)
def
slice_recycle_idx
(
x
):
start
=
recycle_idx
*
num_ensemble
size
=
num_ensemble
return
jax
.
lax
.
dynamic_slice_in_dim
(
x
,
start
,
size
,
axis
=
0
)
ensembled_batch
=
jax
.
tree_map
(
slice_recycle_idx
,
batch
)
else
:
num_ensemble
=
batch_size
ensembled_batch
=
batch
non_ensembled_batch
=
jax
.
tree_map
(
lambda
x
:
x
,
prev
)
return
impl
(
ensembled_batch
=
ensembled_batch
,
non_ensembled_batch
=
non_ensembled_batch
,
is_training
=
is_training
,
compute_loss
=
compute_loss
,
ensemble_representations
=
ensemble_representations
)
prev
=
{}
emb_config
=
self
.
config
.
embeddings_and_evoformer
if
emb_config
.
recycle_pos
:
prev
[
'prev_pos'
]
=
jnp
.
zeros
(
[
num_residues
,
residue_constants
.
atom_type_num
,
3
])
if
emb_config
.
recycle_features
:
prev
[
'prev_msa_first_row'
]
=
jnp
.
zeros
(
[
num_residues
,
emb_config
.
msa_channel
])
prev
[
'prev_pair'
]
=
jnp
.
zeros
(
[
num_residues
,
num_residues
,
emb_config
.
pair_channel
])
if
self
.
config
.
num_recycle
:
if
'num_iter_recycling'
in
batch
:
# Training time: num_iter_recycling is in batch.
# The value for each ensemble batch is the same, so arbitrarily taking
# 0-th.
num_iter
=
batch
[
'num_iter_recycling'
][
0
]
# Add insurance that we will not run more
# recyclings than the model is configured to run.
num_iter
=
jnp
.
minimum
(
num_iter
,
self
.
config
.
num_recycle
)
else
:
# Eval mode or tests: use the maximum number of iterations.
num_iter
=
self
.
config
.
num_recycle
body
=
lambda
x
:
(
x
[
0
]
+
1
,
# pylint: disable=g-long-lambda
get_prev
(
do_call
(
x
[
1
],
recycle_idx
=
x
[
0
],
compute_loss
=
False
)))
if
hk
.
running_init
():
# When initializing the Haiku module, run one iteration of the
# while_loop to initialize the Haiku modules used in `body`.
_
,
prev
=
body
((
0
,
prev
))
else
:
_
,
prev
=
hk
.
while_loop
(
lambda
x
:
x
[
0
]
<
num_iter
,
body
,
(
0
,
prev
))
else
:
num_iter
=
0
ret
=
do_call
(
prev
=
prev
,
recycle_idx
=
num_iter
)
if
compute_loss
:
ret
=
ret
[
0
],
[
ret
[
1
]]
if
not
return_representations
:
del
(
ret
[
0
]
if
compute_loss
else
ret
)[
'representations'
]
# pytype: disable=unsupported-operands
return
ret
class
TemplatePairStack
(
hk
.
Module
):
"""Pair stack for the templates.
Jumper et al. (2021) Suppl. Alg. 16 "TemplatePairStack"
"""
def
__init__
(
self
,
config
,
global_config
,
name
=
'template_pair_stack'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
pair_act
,
pair_mask
,
is_training
,
safe_key
=
None
):
"""Builds TemplatePairStack module.
Arguments:
pair_act: Pair activations for single template, shape [N_res, N_res, c_t].
pair_mask: Pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode.
safe_key: Safe key object encapsulating the random number generation key.
Returns:
Updated pair_act, shape [N_res, N_res, c_t].
"""
if
safe_key
is
None
:
safe_key
=
prng
.
SafeKey
(
hk
.
next_rng_key
())
gc
=
self
.
global_config
c
=
self
.
config
if
not
c
.
num_block
:
return
pair_act
def
block
(
x
):
"""One block of the template pair stack."""
pair_act
,
safe_key
=
x
dropout_wrapper_fn
=
functools
.
partial
(
dropout_wrapper
,
is_training
=
is_training
,
global_config
=
gc
)
safe_key
,
*
sub_keys
=
safe_key
.
split
(
6
)
sub_keys
=
iter
(
sub_keys
)
pair_act
=
dropout_wrapper_fn
(
TriangleAttention
(
c
.
triangle_attention_starting_node
,
gc
,
name
=
'triangle_attention_starting_node'
),
pair_act
,
pair_mask
,
next
(
sub_keys
))
pair_act
=
dropout_wrapper_fn
(
TriangleAttention
(
c
.
triangle_attention_ending_node
,
gc
,
name
=
'triangle_attention_ending_node'
),
pair_act
,
pair_mask
,
next
(
sub_keys
))
pair_act
=
dropout_wrapper_fn
(
TriangleMultiplication
(
c
.
triangle_multiplication_outgoing
,
gc
,
name
=
'triangle_multiplication_outgoing'
),
pair_act
,
pair_mask
,
next
(
sub_keys
))
pair_act
=
dropout_wrapper_fn
(
TriangleMultiplication
(
c
.
triangle_multiplication_incoming
,
gc
,
name
=
'triangle_multiplication_incoming'
),
pair_act
,
pair_mask
,
next
(
sub_keys
))
pair_act
=
dropout_wrapper_fn
(
Transition
(
c
.
pair_transition
,
gc
,
name
=
'pair_transition'
),
pair_act
,
pair_mask
,
next
(
sub_keys
))
return
pair_act
,
safe_key
if
gc
.
use_remat
:
block
=
hk
.
remat
(
block
)
res_stack
=
layer_stack
.
layer_stack
(
c
.
num_block
)(
block
)
pair_act
,
safe_key
=
res_stack
((
pair_act
,
safe_key
))
return
pair_act
class
Transition
(
hk
.
Module
):
"""Transition layer.
Jumper et al. (2021) Suppl. Alg. 9 "MSATransition"
Jumper et al. (2021) Suppl. Alg. 15 "PairTransition"
"""
def
__init__
(
self
,
config
,
global_config
,
name
=
'transition_block'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
act
,
mask
,
is_training
=
True
):
"""Builds Transition module.
Arguments:
act: A tensor of queries of size [batch_size, N_res, N_channel].
mask: A tensor denoting the mask of size [batch_size, N_res].
is_training: Whether the module is in training mode.
Returns:
A float32 tensor of size [batch_size, N_res, N_channel].
"""
_
,
_
,
nc
=
act
.
shape
num_intermediate
=
int
(
nc
*
self
.
config
.
num_intermediate_factor
)
mask
=
jnp
.
expand_dims
(
mask
,
axis
=-
1
)
act
=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'input_layer_norm'
)(
act
)
transition_module
=
hk
.
Sequential
([
common_modules
.
Linear
(
num_intermediate
,
initializer
=
'relu'
,
name
=
'transition1'
),
jax
.
nn
.
relu
,
common_modules
.
Linear
(
nc
,
initializer
=
utils
.
final_init
(
self
.
global_config
),
name
=
'transition2'
)
])
act
=
mapping
.
inference_subbatch
(
transition_module
,
self
.
global_config
.
subbatch_size
,
batched_args
=
[
act
],
nonbatched_args
=
[],
low_memory
=
not
is_training
)
return
act
def
glorot_uniform
():
return
hk
.
initializers
.
VarianceScaling
(
scale
=
1.0
,
mode
=
'fan_avg'
,
distribution
=
'uniform'
)
class
Attention
(
hk
.
Module
):
"""Multihead attention."""
def
__init__
(
self
,
config
,
global_config
,
output_dim
,
name
=
'attention'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
self
.
output_dim
=
output_dim
def
__call__
(
self
,
q_data
,
m_data
,
bias
,
nonbatched_bias
=
None
):
"""Builds Attention module.
Arguments:
q_data: A tensor of queries, shape [batch_size, N_queries, q_channels].
m_data: A tensor of memories from which the keys and values are
projected, shape [batch_size, N_keys, m_channels].
bias: A bias for the attention, shape [batch_size, N_queries, N_keys].
nonbatched_bias: Shared bias, shape [N_queries, N_keys].
Returns:
A float32 tensor of shape [batch_size, N_queries, output_dim].
"""
# Sensible default for when the config keys are missing
key_dim
=
self
.
config
.
get
(
'key_dim'
,
int
(
q_data
.
shape
[
-
1
]))
value_dim
=
self
.
config
.
get
(
'value_dim'
,
int
(
m_data
.
shape
[
-
1
]))
num_head
=
self
.
config
.
num_head
assert
key_dim
%
num_head
==
0
assert
value_dim
%
num_head
==
0
key_dim
=
key_dim
//
num_head
value_dim
=
value_dim
//
num_head
q_weights
=
hk
.
get_parameter
(
'query_w'
,
shape
=
(
q_data
.
shape
[
-
1
],
num_head
,
key_dim
),
init
=
glorot_uniform
())
k_weights
=
hk
.
get_parameter
(
'key_w'
,
shape
=
(
m_data
.
shape
[
-
1
],
num_head
,
key_dim
),
init
=
glorot_uniform
())
v_weights
=
hk
.
get_parameter
(
'value_w'
,
shape
=
(
m_data
.
shape
[
-
1
],
num_head
,
value_dim
),
init
=
glorot_uniform
())
q
=
jnp
.
einsum
(
'bqa,ahc->bqhc'
,
q_data
,
q_weights
)
*
key_dim
**
(
-
0.5
)
k
=
jnp
.
einsum
(
'bka,ahc->bkhc'
,
m_data
,
k_weights
)
v
=
jnp
.
einsum
(
'bka,ahc->bkhc'
,
m_data
,
v_weights
)
logits
=
jnp
.
einsum
(
'bqhc,bkhc->bhqk'
,
q
,
k
)
+
bias
if
nonbatched_bias
is
not
None
:
logits
+=
jnp
.
expand_dims
(
nonbatched_bias
,
axis
=
0
)
weights
=
jax
.
nn
.
softmax
(
logits
)
weighted_avg
=
jnp
.
einsum
(
'bhqk,bkhc->bqhc'
,
weights
,
v
)
if
self
.
global_config
.
zero_init
:
init
=
hk
.
initializers
.
Constant
(
0.0
)
else
:
init
=
glorot_uniform
()
if
self
.
config
.
gating
:
gating_weights
=
hk
.
get_parameter
(
'gating_w'
,
shape
=
(
q_data
.
shape
[
-
1
],
num_head
,
value_dim
),
init
=
hk
.
initializers
.
Constant
(
0.0
))
gating_bias
=
hk
.
get_parameter
(
'gating_b'
,
shape
=
(
num_head
,
value_dim
),
init
=
hk
.
initializers
.
Constant
(
1.0
))
gate_values
=
jnp
.
einsum
(
'bqc, chv->bqhv'
,
q_data
,
gating_weights
)
+
gating_bias
gate_values
=
jax
.
nn
.
sigmoid
(
gate_values
)
weighted_avg
*=
gate_values
o_weights
=
hk
.
get_parameter
(
'output_w'
,
shape
=
(
num_head
,
value_dim
,
self
.
output_dim
),
init
=
init
)
o_bias
=
hk
.
get_parameter
(
'output_b'
,
shape
=
(
self
.
output_dim
,),
init
=
hk
.
initializers
.
Constant
(
0.0
))
output
=
jnp
.
einsum
(
'bqhc,hco->bqo'
,
weighted_avg
,
o_weights
)
+
o_bias
return
output
class
GlobalAttention
(
hk
.
Module
):
"""Global attention.
Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention" lines 2-7
"""
def
__init__
(
self
,
config
,
global_config
,
output_dim
,
name
=
'attention'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
self
.
output_dim
=
output_dim
def
__call__
(
self
,
q_data
,
m_data
,
q_mask
):
"""Builds GlobalAttention module.
Arguments:
q_data: A tensor of queries with size [batch_size, N_queries,
q_channels]
m_data: A tensor of memories from which the keys and values
projected. Size [batch_size, N_keys, m_channels]
q_mask: A binary mask for q_data with zeros in the padded sequence
elements and ones otherwise. Size [batch_size, N_queries, q_channels]
(or broadcastable to this shape).
Returns:
A float32 tensor of size [batch_size, N_queries, output_dim].
"""
# Sensible default for when the config keys are missing
key_dim
=
self
.
config
.
get
(
'key_dim'
,
int
(
q_data
.
shape
[
-
1
]))
value_dim
=
self
.
config
.
get
(
'value_dim'
,
int
(
m_data
.
shape
[
-
1
]))
num_head
=
self
.
config
.
num_head
assert
key_dim
%
num_head
==
0
assert
value_dim
%
num_head
==
0
key_dim
=
key_dim
//
num_head
value_dim
=
value_dim
//
num_head
q_weights
=
hk
.
get_parameter
(
'query_w'
,
shape
=
(
q_data
.
shape
[
-
1
],
num_head
,
key_dim
),
init
=
glorot_uniform
())
k_weights
=
hk
.
get_parameter
(
'key_w'
,
shape
=
(
m_data
.
shape
[
-
1
],
key_dim
),
init
=
glorot_uniform
())
v_weights
=
hk
.
get_parameter
(
'value_w'
,
shape
=
(
m_data
.
shape
[
-
1
],
value_dim
),
init
=
glorot_uniform
())
v
=
jnp
.
einsum
(
'bka,ac->bkc'
,
m_data
,
v_weights
)
q_avg
=
utils
.
mask_mean
(
q_mask
,
q_data
,
axis
=
1
)
q
=
jnp
.
einsum
(
'ba,ahc->bhc'
,
q_avg
,
q_weights
)
*
key_dim
**
(
-
0.5
)
k
=
jnp
.
einsum
(
'bka,ac->bkc'
,
m_data
,
k_weights
)
bias
=
(
1e9
*
(
q_mask
[:,
None
,
:,
0
]
-
1.
))
logits
=
jnp
.
einsum
(
'bhc,bkc->bhk'
,
q
,
k
)
+
bias
weights
=
jax
.
nn
.
softmax
(
logits
)
weighted_avg
=
jnp
.
einsum
(
'bhk,bkc->bhc'
,
weights
,
v
)
if
self
.
global_config
.
zero_init
:
init
=
hk
.
initializers
.
Constant
(
0.0
)
else
:
init
=
glorot_uniform
()
o_weights
=
hk
.
get_parameter
(
'output_w'
,
shape
=
(
num_head
,
value_dim
,
self
.
output_dim
),
init
=
init
)
o_bias
=
hk
.
get_parameter
(
'output_b'
,
shape
=
(
self
.
output_dim
,),
init
=
hk
.
initializers
.
Constant
(
0.0
))
if
self
.
config
.
gating
:
gating_weights
=
hk
.
get_parameter
(
'gating_w'
,
shape
=
(
q_data
.
shape
[
-
1
],
num_head
,
value_dim
),
init
=
hk
.
initializers
.
Constant
(
0.0
))
gating_bias
=
hk
.
get_parameter
(
'gating_b'
,
shape
=
(
num_head
,
value_dim
),
init
=
hk
.
initializers
.
Constant
(
1.0
))
gate_values
=
jnp
.
einsum
(
'bqc, chv->bqhv'
,
q_data
,
gating_weights
)
gate_values
=
jax
.
nn
.
sigmoid
(
gate_values
+
gating_bias
)
weighted_avg
=
weighted_avg
[:,
None
]
*
gate_values
output
=
jnp
.
einsum
(
'bqhc,hco->bqo'
,
weighted_avg
,
o_weights
)
+
o_bias
else
:
output
=
jnp
.
einsum
(
'bhc,hco->bo'
,
weighted_avg
,
o_weights
)
+
o_bias
output
=
output
[:,
None
]
return
output
class
MSARowAttentionWithPairBias
(
hk
.
Module
):
"""MSA per-row attention biased by the pair representation.
Jumper et al. (2021) Suppl. Alg. 7 "MSARowAttentionWithPairBias"
"""
def
__init__
(
self
,
config
,
global_config
,
name
=
'msa_row_attention_with_pair_bias'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
msa_act
,
msa_mask
,
pair_act
,
is_training
=
False
):
"""Builds MSARowAttentionWithPairBias module.
Arguments:
msa_act: [N_seq, N_res, c_m] MSA representation.
msa_mask: [N_seq, N_res] mask of non-padded regions.
pair_act: [N_res, N_res, c_z] pair representation.
is_training: Whether the module is in training mode.
Returns:
Update to msa_act, shape [N_seq, N_res, c_m].
"""
c
=
self
.
config
assert
len
(
msa_act
.
shape
)
==
3
assert
len
(
msa_mask
.
shape
)
==
2
assert
c
.
orientation
==
'per_row'
bias
=
(
1e9
*
(
msa_mask
-
1.
))[:,
None
,
None
,
:]
assert
len
(
bias
.
shape
)
==
4
msa_act
=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'query_norm'
)(
msa_act
)
pair_act
=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'feat_2d_norm'
)(
pair_act
)
init_factor
=
1.
/
jnp
.
sqrt
(
int
(
pair_act
.
shape
[
-
1
]))
weights
=
hk
.
get_parameter
(
'feat_2d_weights'
,
shape
=
(
pair_act
.
shape
[
-
1
],
c
.
num_head
),
init
=
hk
.
initializers
.
RandomNormal
(
stddev
=
init_factor
))
nonbatched_bias
=
jnp
.
einsum
(
'qkc,ch->hqk'
,
pair_act
,
weights
)
attn_mod
=
Attention
(
c
,
self
.
global_config
,
msa_act
.
shape
[
-
1
])
msa_act
=
mapping
.
inference_subbatch
(
attn_mod
,
self
.
global_config
.
subbatch_size
,
batched_args
=
[
msa_act
,
msa_act
,
bias
],
nonbatched_args
=
[
nonbatched_bias
],
low_memory
=
not
is_training
)
return
msa_act
class
MSAColumnAttention
(
hk
.
Module
):
"""MSA per-column attention.
Jumper et al. (2021) Suppl. Alg. 8 "MSAColumnAttention"
"""
def
__init__
(
self
,
config
,
global_config
,
name
=
'msa_column_attention'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
msa_act
,
msa_mask
,
is_training
=
False
):
"""Builds MSAColumnAttention module.
Arguments:
msa_act: [N_seq, N_res, c_m] MSA representation.
msa_mask: [N_seq, N_res] mask of non-padded regions.
is_training: Whether the module is in training mode.
Returns:
Update to msa_act, shape [N_seq, N_res, c_m]
"""
c
=
self
.
config
assert
len
(
msa_act
.
shape
)
==
3
assert
len
(
msa_mask
.
shape
)
==
2
assert
c
.
orientation
==
'per_column'
msa_act
=
jnp
.
swapaxes
(
msa_act
,
-
2
,
-
3
)
msa_mask
=
jnp
.
swapaxes
(
msa_mask
,
-
1
,
-
2
)
bias
=
(
1e9
*
(
msa_mask
-
1.
))[:,
None
,
None
,
:]
assert
len
(
bias
.
shape
)
==
4
msa_act
=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'query_norm'
)(
msa_act
)
attn_mod
=
Attention
(
c
,
self
.
global_config
,
msa_act
.
shape
[
-
1
])
msa_act
=
mapping
.
inference_subbatch
(
attn_mod
,
self
.
global_config
.
subbatch_size
,
batched_args
=
[
msa_act
,
msa_act
,
bias
],
nonbatched_args
=
[],
low_memory
=
not
is_training
)
msa_act
=
jnp
.
swapaxes
(
msa_act
,
-
2
,
-
3
)
return
msa_act
class
MSAColumnGlobalAttention
(
hk
.
Module
):
"""MSA per-column global attention.
Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention"
"""
def
__init__
(
self
,
config
,
global_config
,
name
=
'msa_column_global_attention'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
msa_act
,
msa_mask
,
is_training
=
False
):
"""Builds MSAColumnGlobalAttention module.
Arguments:
msa_act: [N_seq, N_res, c_m] MSA representation.
msa_mask: [N_seq, N_res] mask of non-padded regions.
is_training: Whether the module is in training mode.
Returns:
Update to msa_act, shape [N_seq, N_res, c_m].
"""
c
=
self
.
config
assert
len
(
msa_act
.
shape
)
==
3
assert
len
(
msa_mask
.
shape
)
==
2
assert
c
.
orientation
==
'per_column'
msa_act
=
jnp
.
swapaxes
(
msa_act
,
-
2
,
-
3
)
msa_mask
=
jnp
.
swapaxes
(
msa_mask
,
-
1
,
-
2
)
bias
=
(
1e9
*
(
msa_mask
-
1.
))[:,
None
,
None
,
:]
assert
len
(
bias
.
shape
)
==
4
msa_act
=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'query_norm'
)(
msa_act
)
attn_mod
=
GlobalAttention
(
c
,
self
.
global_config
,
msa_act
.
shape
[
-
1
],
name
=
'attention'
)
# [N_seq, N_res, 1]
msa_mask
=
jnp
.
expand_dims
(
msa_mask
,
axis
=-
1
)
msa_act
=
mapping
.
inference_subbatch
(
attn_mod
,
self
.
global_config
.
subbatch_size
,
batched_args
=
[
msa_act
,
msa_act
,
msa_mask
],
nonbatched_args
=
[],
low_memory
=
not
is_training
)
msa_act
=
jnp
.
swapaxes
(
msa_act
,
-
2
,
-
3
)
return
msa_act
class
TriangleAttention
(
hk
.
Module
):
"""Triangle Attention.
Jumper et al. (2021) Suppl. Alg. 13 "TriangleAttentionStartingNode"
Jumper et al. (2021) Suppl. Alg. 14 "TriangleAttentionEndingNode"
"""
def
__init__
(
self
,
config
,
global_config
,
name
=
'triangle_attention'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
pair_act
,
pair_mask
,
is_training
=
False
):
"""Builds TriangleAttention module.
Arguments:
pair_act: [N_res, N_res, c_z] pair activations tensor
pair_mask: [N_res, N_res] mask of non-padded regions in the tensor.
is_training: Whether the module is in training mode.
Returns:
Update to pair_act, shape [N_res, N_res, c_z].
"""
c
=
self
.
config
assert
len
(
pair_act
.
shape
)
==
3
assert
len
(
pair_mask
.
shape
)
==
2
assert
c
.
orientation
in
[
'per_row'
,
'per_column'
]
if
c
.
orientation
==
'per_column'
:
pair_act
=
jnp
.
swapaxes
(
pair_act
,
-
2
,
-
3
)
pair_mask
=
jnp
.
swapaxes
(
pair_mask
,
-
1
,
-
2
)
bias
=
(
1e9
*
(
pair_mask
-
1.
))[:,
None
,
None
,
:]
assert
len
(
bias
.
shape
)
==
4
pair_act
=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'query_norm'
)(
pair_act
)
init_factor
=
1.
/
jnp
.
sqrt
(
int
(
pair_act
.
shape
[
-
1
]))
weights
=
hk
.
get_parameter
(
'feat_2d_weights'
,
shape
=
(
pair_act
.
shape
[
-
1
],
c
.
num_head
),
init
=
hk
.
initializers
.
RandomNormal
(
stddev
=
init_factor
))
nonbatched_bias
=
jnp
.
einsum
(
'qkc,ch->hqk'
,
pair_act
,
weights
)
attn_mod
=
Attention
(
c
,
self
.
global_config
,
pair_act
.
shape
[
-
1
])
pair_act
=
mapping
.
inference_subbatch
(
attn_mod
,
self
.
global_config
.
subbatch_size
,
batched_args
=
[
pair_act
,
pair_act
,
bias
],
nonbatched_args
=
[
nonbatched_bias
],
low_memory
=
not
is_training
)
if
c
.
orientation
==
'per_column'
:
pair_act
=
jnp
.
swapaxes
(
pair_act
,
-
2
,
-
3
)
return
pair_act
class
MaskedMsaHead
(
hk
.
Module
):
"""Head to predict MSA at the masked locations.
The MaskedMsaHead employs a BERT-style objective to reconstruct a masked
version of the full MSA, based on a linear projection of
the MSA representation.
Jumper et al. (2021) Suppl. Sec. 1.9.9 "Masked MSA prediction"
"""
def
__init__
(
self
,
config
,
global_config
,
name
=
'masked_msa_head'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
if
global_config
.
multimer_mode
:
self
.
num_output
=
len
(
residue_constants
.
restypes_with_x_and_gap
)
else
:
self
.
num_output
=
config
.
num_output
def
__call__
(
self
,
representations
,
batch
,
is_training
):
"""Builds MaskedMsaHead module.
Arguments:
representations: Dictionary of representations, must contain:
* 'msa': MSA representation, shape [N_seq, N_res, c_m].
batch: Batch, unused.
is_training: Whether the module is in training mode.
Returns:
Dictionary containing:
* 'logits': logits of shape [N_seq, N_res, N_aatype] with
(unnormalized) log probabilies of predicted aatype at position.
"""
del
batch
logits
=
common_modules
.
Linear
(
self
.
num_output
,
initializer
=
utils
.
final_init
(
self
.
global_config
),
name
=
'logits'
)(
representations
[
'msa'
])
return
dict
(
logits
=
logits
)
def
loss
(
self
,
value
,
batch
):
errors
=
softmax_cross_entropy
(
labels
=
jax
.
nn
.
one_hot
(
batch
[
'true_msa'
],
num_classes
=
self
.
num_output
),
logits
=
value
[
'logits'
])
loss
=
(
jnp
.
sum
(
errors
*
batch
[
'bert_mask'
],
axis
=
(
-
2
,
-
1
))
/
(
1e-8
+
jnp
.
sum
(
batch
[
'bert_mask'
],
axis
=
(
-
2
,
-
1
))))
return
{
'loss'
:
loss
}
class
PredictedLDDTHead
(
hk
.
Module
):
"""Head to predict the per-residue LDDT to be used as a confidence measure.
Jumper et al. (2021) Suppl. Sec. 1.9.6 "Model confidence prediction (pLDDT)"
Jumper et al. (2021) Suppl. Alg. 29 "predictPerResidueLDDT_Ca"
"""
def
__init__
(
self
,
config
,
global_config
,
name
=
'predicted_lddt_head'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
representations
,
batch
,
is_training
):
"""Builds PredictedLDDTHead module.
Arguments:
representations: Dictionary of representations, must contain:
* 'structure_module': Single representation from the structure module,
shape [N_res, c_s].
batch: Batch, unused.
is_training: Whether the module is in training mode.
Returns:
Dictionary containing :
* 'logits': logits of shape [N_res, N_bins] with
(unnormalized) log probabilies of binned predicted lDDT.
"""
act
=
representations
[
'structure_module'
]
act
=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'input_layer_norm'
)(
act
)
act
=
common_modules
.
Linear
(
self
.
config
.
num_channels
,
initializer
=
'relu'
,
name
=
'act_0'
)(
act
)
act
=
jax
.
nn
.
relu
(
act
)
act
=
common_modules
.
Linear
(
self
.
config
.
num_channels
,
initializer
=
'relu'
,
name
=
'act_1'
)(
act
)
act
=
jax
.
nn
.
relu
(
act
)
logits
=
common_modules
.
Linear
(
self
.
config
.
num_bins
,
initializer
=
utils
.
final_init
(
self
.
global_config
),
name
=
'logits'
)(
act
)
# Shape (batch_size, num_res, num_bins)
return
dict
(
logits
=
logits
)
def
loss
(
self
,
value
,
batch
):
# Shape (num_res, 37, 3)
pred_all_atom_pos
=
value
[
'structure_module'
][
'final_atom_positions'
]
# Shape (num_res, 37, 3)
true_all_atom_pos
=
batch
[
'all_atom_positions'
]
# Shape (num_res, 37)
all_atom_mask
=
batch
[
'all_atom_mask'
]
# Shape (num_res,)
lddt_ca
=
lddt
.
lddt
(
# Shape (batch_size, num_res, 3)
predicted_points
=
pred_all_atom_pos
[
None
,
:,
1
,
:],
# Shape (batch_size, num_res, 3)
true_points
=
true_all_atom_pos
[
None
,
:,
1
,
:],
# Shape (batch_size, num_res, 1)
true_points_mask
=
all_atom_mask
[
None
,
:,
1
:
2
].
astype
(
jnp
.
float32
),
cutoff
=
15.
,
per_residue
=
True
)
lddt_ca
=
jax
.
lax
.
stop_gradient
(
lddt_ca
)
num_bins
=
self
.
config
.
num_bins
bin_index
=
jnp
.
floor
(
lddt_ca
*
num_bins
).
astype
(
jnp
.
int32
)
# protect against out of range for lddt_ca == 1
bin_index
=
jnp
.
minimum
(
bin_index
,
num_bins
-
1
)
lddt_ca_one_hot
=
jax
.
nn
.
one_hot
(
bin_index
,
num_classes
=
num_bins
)
# Shape (num_res, num_channel)
logits
=
value
[
'predicted_lddt'
][
'logits'
]
errors
=
softmax_cross_entropy
(
labels
=
lddt_ca_one_hot
,
logits
=
logits
)
# Shape (num_res,)
mask_ca
=
all_atom_mask
[:,
residue_constants
.
atom_order
[
'CA'
]]
mask_ca
=
mask_ca
.
astype
(
jnp
.
float32
)
loss
=
jnp
.
sum
(
errors
*
mask_ca
)
/
(
jnp
.
sum
(
mask_ca
)
+
1e-8
)
if
self
.
config
.
filter_by_resolution
:
# NMR & distillation have resolution = 0
loss
*=
((
batch
[
'resolution'
]
>=
self
.
config
.
min_resolution
)
&
(
batch
[
'resolution'
]
<=
self
.
config
.
max_resolution
)).
astype
(
jnp
.
float32
)
output
=
{
'loss'
:
loss
}
return
output
class
PredictedAlignedErrorHead
(
hk
.
Module
):
"""Head to predict the distance errors in the backbone alignment frames.
Can be used to compute predicted TM-Score.
Jumper et al. (2021) Suppl. Sec. 1.9.7 "TM-score prediction"
"""
def
__init__
(
self
,
config
,
global_config
,
name
=
'predicted_aligned_error_head'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
representations
,
batch
,
is_training
):
"""Builds PredictedAlignedErrorHead module.
Arguments:
representations: Dictionary of representations, must contain:
* 'pair': pair representation, shape [N_res, N_res, c_z].
batch: Batch, unused.
is_training: Whether the module is in training mode.
Returns:
Dictionary containing:
* logits: logits for aligned error, shape [N_res, N_res, N_bins].
* bin_breaks: array containing bin breaks, shape [N_bins - 1].
"""
act
=
representations
[
'pair'
]
# Shape (num_res, num_res, num_bins)
logits
=
common_modules
.
Linear
(
self
.
config
.
num_bins
,
initializer
=
utils
.
final_init
(
self
.
global_config
),
name
=
'logits'
)(
act
)
# Shape (num_bins,)
breaks
=
jnp
.
linspace
(
0.
,
self
.
config
.
max_error_bin
,
self
.
config
.
num_bins
-
1
)
return
dict
(
logits
=
logits
,
breaks
=
breaks
)
def
loss
(
self
,
value
,
batch
):
# Shape (num_res, 7)
predicted_affine
=
quat_affine
.
QuatAffine
.
from_tensor
(
value
[
'structure_module'
][
'final_affines'
])
# Shape (num_res, 7)
true_affine
=
quat_affine
.
QuatAffine
.
from_tensor
(
batch
[
'backbone_affine_tensor'
])
# Shape (num_res)
mask
=
batch
[
'backbone_affine_mask'
]
# Shape (num_res, num_res)
square_mask
=
mask
[:,
None
]
*
mask
[
None
,
:]
num_bins
=
self
.
config
.
num_bins
# (1, num_bins - 1)
breaks
=
value
[
'predicted_aligned_error'
][
'breaks'
]
# (1, num_bins)
logits
=
value
[
'predicted_aligned_error'
][
'logits'
]
# Compute the squared error for each alignment.
def
_local_frame_points
(
affine
):
points
=
[
jnp
.
expand_dims
(
x
,
axis
=-
2
)
for
x
in
affine
.
translation
]
return
affine
.
invert_point
(
points
,
extra_dims
=
1
)
error_dist2_xyz
=
[
jnp
.
square
(
a
-
b
)
for
a
,
b
in
zip
(
_local_frame_points
(
predicted_affine
),
_local_frame_points
(
true_affine
))]
error_dist2
=
sum
(
error_dist2_xyz
)
# Shape (num_res, num_res)
# First num_res are alignment frames, second num_res are the residues.
error_dist2
=
jax
.
lax
.
stop_gradient
(
error_dist2
)
sq_breaks
=
jnp
.
square
(
breaks
)
true_bins
=
jnp
.
sum
((
error_dist2
[...,
None
]
>
sq_breaks
).
astype
(
jnp
.
int32
),
axis
=-
1
)
errors
=
softmax_cross_entropy
(
labels
=
jax
.
nn
.
one_hot
(
true_bins
,
num_bins
,
axis
=-
1
),
logits
=
logits
)
loss
=
(
jnp
.
sum
(
errors
*
square_mask
,
axis
=
(
-
2
,
-
1
))
/
(
1e-8
+
jnp
.
sum
(
square_mask
,
axis
=
(
-
2
,
-
1
))))
if
self
.
config
.
filter_by_resolution
:
# NMR & distillation have resolution = 0
loss
*=
((
batch
[
'resolution'
]
>=
self
.
config
.
min_resolution
)
&
(
batch
[
'resolution'
]
<=
self
.
config
.
max_resolution
)).
astype
(
jnp
.
float32
)
output
=
{
'loss'
:
loss
}
return
output
class
ExperimentallyResolvedHead
(
hk
.
Module
):
"""Predicts if an atom is experimentally resolved in a high-res structure.
Only trained on high-resolution X-ray crystals & cryo-EM.
Jumper et al. (2021) Suppl. Sec. 1.9.10 '"Experimentally resolved" prediction'
"""
def
__init__
(
self
,
config
,
global_config
,
name
=
'experimentally_resolved_head'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
representations
,
batch
,
is_training
):
"""Builds ExperimentallyResolvedHead module.
Arguments:
representations: Dictionary of representations, must contain:
* 'single': Single representation, shape [N_res, c_s].
batch: Batch, unused.
is_training: Whether the module is in training mode.
Returns:
Dictionary containing:
* 'logits': logits of shape [N_res, 37],
log probability that an atom is resolved in atom37 representation,
can be converted to probability by applying sigmoid.
"""
logits
=
common_modules
.
Linear
(
37
,
# atom_exists.shape[-1]
initializer
=
utils
.
final_init
(
self
.
global_config
),
name
=
'logits'
)(
representations
[
'single'
])
return
dict
(
logits
=
logits
)
def
loss
(
self
,
value
,
batch
):
logits
=
value
[
'logits'
]
assert
len
(
logits
.
shape
)
==
2
# Does the atom appear in the amino acid?
atom_exists
=
batch
[
'atom37_atom_exists'
]
# Is the atom resolved in the experiment? Subset of atom_exists,
# *except for OXT*
all_atom_mask
=
batch
[
'all_atom_mask'
].
astype
(
jnp
.
float32
)
xent
=
sigmoid_cross_entropy
(
labels
=
all_atom_mask
,
logits
=
logits
)
loss
=
jnp
.
sum
(
xent
*
atom_exists
)
/
(
1e-8
+
jnp
.
sum
(
atom_exists
))
if
self
.
config
.
filter_by_resolution
:
# NMR & distillation examples have resolution = 0.
loss
*=
((
batch
[
'resolution'
]
>=
self
.
config
.
min_resolution
)
&
(
batch
[
'resolution'
]
<=
self
.
config
.
max_resolution
)).
astype
(
jnp
.
float32
)
output
=
{
'loss'
:
loss
}
return
output
class
TriangleMultiplication
(
hk
.
Module
):
"""Triangle multiplication layer ("outgoing" or "incoming").
Jumper et al. (2021) Suppl. Alg. 11 "TriangleMultiplicationOutgoing"
Jumper et al. (2021) Suppl. Alg. 12 "TriangleMultiplicationIncoming"
"""
def
__init__
(
self
,
config
,
global_config
,
name
=
'triangle_multiplication'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
act
,
mask
,
is_training
=
True
):
"""Builds TriangleMultiplication module.
Arguments:
act: Pair activations, shape [N_res, N_res, c_z]
mask: Pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode.
Returns:
Outputs, same shape/type as act.
"""
del
is_training
c
=
self
.
config
gc
=
self
.
global_config
mask
=
mask
[...,
None
]
act
=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'layer_norm_input'
)(
act
)
input_act
=
act
left_projection
=
common_modules
.
Linear
(
c
.
num_intermediate_channel
,
name
=
'left_projection'
)
left_proj_act
=
mask
*
left_projection
(
act
)
right_projection
=
common_modules
.
Linear
(
c
.
num_intermediate_channel
,
name
=
'right_projection'
)
right_proj_act
=
mask
*
right_projection
(
act
)
left_gate_values
=
jax
.
nn
.
sigmoid
(
common_modules
.
Linear
(
c
.
num_intermediate_channel
,
bias_init
=
1.
,
initializer
=
utils
.
final_init
(
gc
),
name
=
'left_gate'
)(
act
))
right_gate_values
=
jax
.
nn
.
sigmoid
(
common_modules
.
Linear
(
c
.
num_intermediate_channel
,
bias_init
=
1.
,
initializer
=
utils
.
final_init
(
gc
),
name
=
'right_gate'
)(
act
))
left_proj_act
*=
left_gate_values
right_proj_act
*=
right_gate_values
# "Outgoing" edges equation: 'ikc,jkc->ijc'
# "Incoming" edges equation: 'kjc,kic->ijc'
# Note on the Suppl. Alg. 11 & 12 notation:
# For the "outgoing" edges, a = left_proj_act and b = right_proj_act
# For the "incoming" edges, it's swapped:
# b = left_proj_act and a = right_proj_act
act
=
jnp
.
einsum
(
c
.
equation
,
left_proj_act
,
right_proj_act
)
act
=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'center_layer_norm'
)(
act
)
output_channel
=
int
(
input_act
.
shape
[
-
1
])
act
=
common_modules
.
Linear
(
output_channel
,
initializer
=
utils
.
final_init
(
gc
),
name
=
'output_projection'
)(
act
)
gate_values
=
jax
.
nn
.
sigmoid
(
common_modules
.
Linear
(
output_channel
,
bias_init
=
1.
,
initializer
=
utils
.
final_init
(
gc
),
name
=
'gating_linear'
)(
input_act
))
act
*=
gate_values
return
act
class
DistogramHead
(
hk
.
Module
):
"""Head to predict a distogram.
Jumper et al. (2021) Suppl. Sec. 1.9.8 "Distogram prediction"
"""
def
__init__
(
self
,
config
,
global_config
,
name
=
'distogram_head'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
representations
,
batch
,
is_training
):
"""Builds DistogramHead module.
Arguments:
representations: Dictionary of representations, must contain:
* 'pair': pair representation, shape [N_res, N_res, c_z].
batch: Batch, unused.
is_training: Whether the module is in training mode.
Returns:
Dictionary containing:
* logits: logits for distogram, shape [N_res, N_res, N_bins].
* bin_breaks: array containing bin breaks, shape [N_bins - 1,].
"""
half_logits
=
common_modules
.
Linear
(
self
.
config
.
num_bins
,
initializer
=
utils
.
final_init
(
self
.
global_config
),
name
=
'half_logits'
)(
representations
[
'pair'
])
logits
=
half_logits
+
jnp
.
swapaxes
(
half_logits
,
-
2
,
-
3
)
breaks
=
jnp
.
linspace
(
self
.
config
.
first_break
,
self
.
config
.
last_break
,
self
.
config
.
num_bins
-
1
)
return
dict
(
logits
=
logits
,
bin_edges
=
breaks
)
def
loss
(
self
,
value
,
batch
):
return
_distogram_log_loss
(
value
[
'logits'
],
value
[
'bin_edges'
],
batch
,
self
.
config
.
num_bins
)
def
_distogram_log_loss
(
logits
,
bin_edges
,
batch
,
num_bins
):
"""Log loss of a distogram."""
assert
len
(
logits
.
shape
)
==
3
positions
=
batch
[
'pseudo_beta'
]
mask
=
batch
[
'pseudo_beta_mask'
]
assert
positions
.
shape
[
-
1
]
==
3
sq_breaks
=
jnp
.
square
(
bin_edges
)
dist2
=
jnp
.
sum
(
jnp
.
square
(
jnp
.
expand_dims
(
positions
,
axis
=-
2
)
-
jnp
.
expand_dims
(
positions
,
axis
=-
3
)),
axis
=-
1
,
keepdims
=
True
)
true_bins
=
jnp
.
sum
(
dist2
>
sq_breaks
,
axis
=-
1
)
errors
=
softmax_cross_entropy
(
labels
=
jax
.
nn
.
one_hot
(
true_bins
,
num_bins
),
logits
=
logits
)
square_mask
=
jnp
.
expand_dims
(
mask
,
axis
=-
2
)
*
jnp
.
expand_dims
(
mask
,
axis
=-
1
)
avg_error
=
(
jnp
.
sum
(
errors
*
square_mask
,
axis
=
(
-
2
,
-
1
))
/
(
1e-6
+
jnp
.
sum
(
square_mask
,
axis
=
(
-
2
,
-
1
))))
dist2
=
dist2
[...,
0
]
return
dict
(
loss
=
avg_error
,
true_dist
=
jnp
.
sqrt
(
1e-6
+
dist2
))
class
OuterProductMean
(
hk
.
Module
):
"""Computes mean outer product.
Jumper et al. (2021) Suppl. Alg. 10 "OuterProductMean"
"""
def
__init__
(
self
,
config
,
global_config
,
num_output_channel
,
name
=
'outer_product_mean'
):
super
().
__init__
(
name
=
name
)
self
.
global_config
=
global_config
self
.
config
=
config
self
.
num_output_channel
=
num_output_channel
def
__call__
(
self
,
act
,
mask
,
is_training
=
True
):
"""Builds OuterProductMean module.
Arguments:
act: MSA representation, shape [N_seq, N_res, c_m].
mask: MSA mask, shape [N_seq, N_res].
is_training: Whether the module is in training mode.
Returns:
Update to pair representation, shape [N_res, N_res, c_z].
"""
gc
=
self
.
global_config
c
=
self
.
config
mask
=
mask
[...,
None
]
act
=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
'layer_norm_input'
)(
act
)
left_act
=
mask
*
common_modules
.
Linear
(
c
.
num_outer_channel
,
initializer
=
'linear'
,
name
=
'left_projection'
)(
act
)
right_act
=
mask
*
common_modules
.
Linear
(
c
.
num_outer_channel
,
initializer
=
'linear'
,
name
=
'right_projection'
)(
act
)
if
gc
.
zero_init
:
init_w
=
hk
.
initializers
.
Constant
(
0.0
)
else
:
init_w
=
hk
.
initializers
.
VarianceScaling
(
scale
=
2.
,
mode
=
'fan_in'
)
output_w
=
hk
.
get_parameter
(
'output_w'
,
shape
=
(
c
.
num_outer_channel
,
c
.
num_outer_channel
,
self
.
num_output_channel
),
init
=
init_w
)
output_b
=
hk
.
get_parameter
(
'output_b'
,
shape
=
(
self
.
num_output_channel
,),
init
=
hk
.
initializers
.
Constant
(
0.0
))
def
compute_chunk
(
left_act
):
# This is equivalent to
#
# act = jnp.einsum('abc,ade->dceb', left_act, right_act)
# act = jnp.einsum('dceb,cef->bdf', act, output_w) + output_b
#
# but faster.
left_act
=
jnp
.
transpose
(
left_act
,
[
0
,
2
,
1
])
act
=
jnp
.
einsum
(
'acb,ade->dceb'
,
left_act
,
right_act
)
act
=
jnp
.
einsum
(
'dceb,cef->dbf'
,
act
,
output_w
)
+
output_b
return
jnp
.
transpose
(
act
,
[
1
,
0
,
2
])
act
=
mapping
.
inference_subbatch
(
compute_chunk
,
c
.
chunk_size
,
batched_args
=
[
left_act
],
nonbatched_args
=
[],
low_memory
=
True
,
input_subbatch_dim
=
1
,
output_subbatch_dim
=
0
)
epsilon
=
1e-3
norm
=
jnp
.
einsum
(
'abc,adc->bdc'
,
mask
,
mask
)
act
/=
epsilon
+
norm
return
act
def
dgram_from_positions
(
positions
,
num_bins
,
min_bin
,
max_bin
):
"""Compute distogram from amino acid positions.
Arguments:
positions: [N_res, 3] Position coordinates.
num_bins: The number of bins in the distogram.
min_bin: The left edge of the first bin.
max_bin: The left edge of the final bin. The final bin catches
everything larger than `max_bin`.
Returns:
Distogram with the specified number of bins.
"""
def
squared_difference
(
x
,
y
):
return
jnp
.
square
(
x
-
y
)
lower_breaks
=
jnp
.
linspace
(
min_bin
,
max_bin
,
num_bins
)
lower_breaks
=
jnp
.
square
(
lower_breaks
)
upper_breaks
=
jnp
.
concatenate
([
lower_breaks
[
1
:],
jnp
.
array
([
1e8
],
dtype
=
jnp
.
float32
)],
axis
=-
1
)
dist2
=
jnp
.
sum
(
squared_difference
(
jnp
.
expand_dims
(
positions
,
axis
=-
2
),
jnp
.
expand_dims
(
positions
,
axis
=-
3
)),
axis
=-
1
,
keepdims
=
True
)
dgram
=
((
dist2
>
lower_breaks
).
astype
(
jnp
.
float32
)
*
(
dist2
<
upper_breaks
).
astype
(
jnp
.
float32
))
return
dgram
def
pseudo_beta_fn
(
aatype
,
all_atom_positions
,
all_atom_masks
):
"""Create pseudo beta features."""
is_gly
=
jnp
.
equal
(
aatype
,
residue_constants
.
restype_order
[
'G'
])
ca_idx
=
residue_constants
.
atom_order
[
'CA'
]
cb_idx
=
residue_constants
.
atom_order
[
'CB'
]
pseudo_beta
=
jnp
.
where
(
jnp
.
tile
(
is_gly
[...,
None
],
[
1
]
*
len
(
is_gly
.
shape
)
+
[
3
]),
all_atom_positions
[...,
ca_idx
,
:],
all_atom_positions
[...,
cb_idx
,
:])
if
all_atom_masks
is
not
None
:
pseudo_beta_mask
=
jnp
.
where
(
is_gly
,
all_atom_masks
[...,
ca_idx
],
all_atom_masks
[...,
cb_idx
])
pseudo_beta_mask
=
pseudo_beta_mask
.
astype
(
jnp
.
float32
)
return
pseudo_beta
,
pseudo_beta_mask
else
:
return
pseudo_beta
class
EvoformerIteration
(
hk
.
Module
):
"""Single iteration (block) of Evoformer stack.
Jumper et al. (2021) Suppl. Alg. 6 "EvoformerStack" lines 2-10
"""
def
__init__
(
self
,
config
,
global_config
,
is_extra_msa
,
name
=
'evoformer_iteration'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
self
.
is_extra_msa
=
is_extra_msa
def
__call__
(
self
,
activations
,
masks
,
is_training
=
True
,
safe_key
=
None
):
"""Builds EvoformerIteration module.
Arguments:
activations: Dictionary containing activations:
* 'msa': MSA activations, shape [N_seq, N_res, c_m].
* 'pair': pair activations, shape [N_res, N_res, c_z].
masks: Dictionary of masks:
* 'msa': MSA mask, shape [N_seq, N_res].
* 'pair': pair mask, shape [N_res, N_res].
is_training: Whether the module is in training mode.
safe_key: prng.SafeKey encapsulating rng key.
Returns:
Outputs, same shape/type as act.
"""
c
=
self
.
config
gc
=
self
.
global_config
msa_act
,
pair_act
=
activations
[
'msa'
],
activations
[
'pair'
]
if
safe_key
is
None
:
safe_key
=
prng
.
SafeKey
(
hk
.
next_rng_key
())
msa_mask
,
pair_mask
=
masks
[
'msa'
],
masks
[
'pair'
]
dropout_wrapper_fn
=
functools
.
partial
(
dropout_wrapper
,
is_training
=
is_training
,
global_config
=
gc
)
safe_key
,
*
sub_keys
=
safe_key
.
split
(
10
)
sub_keys
=
iter
(
sub_keys
)
outer_module
=
OuterProductMean
(
config
=
c
.
outer_product_mean
,
global_config
=
self
.
global_config
,
num_output_channel
=
int
(
pair_act
.
shape
[
-
1
]),
name
=
'outer_product_mean'
)
if
c
.
outer_product_mean
.
first
:
pair_act
=
dropout_wrapper_fn
(
outer_module
,
msa_act
,
msa_mask
,
safe_key
=
next
(
sub_keys
),
output_act
=
pair_act
)
msa_act
=
dropout_wrapper_fn
(
MSARowAttentionWithPairBias
(
c
.
msa_row_attention_with_pair_bias
,
gc
,
name
=
'msa_row_attention_with_pair_bias'
),
msa_act
,
msa_mask
,
safe_key
=
next
(
sub_keys
),
pair_act
=
pair_act
)
if
not
self
.
is_extra_msa
:
attn_mod
=
MSAColumnAttention
(
c
.
msa_column_attention
,
gc
,
name
=
'msa_column_attention'
)
else
:
attn_mod
=
MSAColumnGlobalAttention
(
c
.
msa_column_attention
,
gc
,
name
=
'msa_column_global_attention'
)
msa_act
=
dropout_wrapper_fn
(
attn_mod
,
msa_act
,
msa_mask
,
safe_key
=
next
(
sub_keys
))
msa_act
=
dropout_wrapper_fn
(
Transition
(
c
.
msa_transition
,
gc
,
name
=
'msa_transition'
),
msa_act
,
msa_mask
,
safe_key
=
next
(
sub_keys
))
if
not
c
.
outer_product_mean
.
first
:
pair_act
=
dropout_wrapper_fn
(
outer_module
,
msa_act
,
msa_mask
,
safe_key
=
next
(
sub_keys
),
output_act
=
pair_act
)
pair_act
=
dropout_wrapper_fn
(
TriangleMultiplication
(
c
.
triangle_multiplication_outgoing
,
gc
,
name
=
'triangle_multiplication_outgoing'
),
pair_act
,
pair_mask
,
safe_key
=
next
(
sub_keys
))
pair_act
=
dropout_wrapper_fn
(
TriangleMultiplication
(
c
.
triangle_multiplication_incoming
,
gc
,
name
=
'triangle_multiplication_incoming'
),
pair_act
,
pair_mask
,
safe_key
=
next
(
sub_keys
))
pair_act
=
dropout_wrapper_fn
(
TriangleAttention
(
c
.
triangle_attention_starting_node
,
gc
,
name
=
'triangle_attention_starting_node'
),
pair_act
,
pair_mask
,
safe_key
=
next
(
sub_keys
))
pair_act
=
dropout_wrapper_fn
(
TriangleAttention
(
c
.
triangle_attention_ending_node
,
gc
,
name
=
'triangle_attention_ending_node'
),
pair_act
,
pair_mask
,
safe_key
=
next
(
sub_keys
))
pair_act
=
dropout_wrapper_fn
(
Transition
(
c
.
pair_transition
,
gc
,
name
=
'pair_transition'
),
pair_act
,
pair_mask
,
safe_key
=
next
(
sub_keys
))
return
{
'msa'
:
msa_act
,
'pair'
:
pair_act
}
class
EmbeddingsAndEvoformer
(
hk
.
Module
):
"""Embeds the input data and runs Evoformer.
Produces the MSA, single and pair representations.
Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5-18
"""
def
__init__
(
self
,
config
,
global_config
,
name
=
'evoformer'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
batch
,
is_training
,
safe_key
=
None
):
c
=
self
.
config
gc
=
self
.
global_config
if
safe_key
is
None
:
safe_key
=
prng
.
SafeKey
(
hk
.
next_rng_key
())
# Embed clustered MSA.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5
# Jumper et al. (2021) Suppl. Alg. 3 "InputEmbedder"
preprocess_1d
=
common_modules
.
Linear
(
c
.
msa_channel
,
name
=
'preprocess_1d'
)(
batch
[
'target_feat'
])
preprocess_msa
=
common_modules
.
Linear
(
c
.
msa_channel
,
name
=
'preprocess_msa'
)(
batch
[
'msa_feat'
])
msa_activations
=
jnp
.
expand_dims
(
preprocess_1d
,
axis
=
0
)
+
preprocess_msa
left_single
=
common_modules
.
Linear
(
c
.
pair_channel
,
name
=
'left_single'
)(
batch
[
'target_feat'
])
right_single
=
common_modules
.
Linear
(
c
.
pair_channel
,
name
=
'right_single'
)(
batch
[
'target_feat'
])
pair_activations
=
left_single
[:,
None
]
+
right_single
[
None
]
mask_2d
=
batch
[
'seq_mask'
][:,
None
]
*
batch
[
'seq_mask'
][
None
,
:]
# Inject previous outputs for recycling.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6
# Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder"
if
c
.
recycle_pos
:
prev_pseudo_beta
=
pseudo_beta_fn
(
batch
[
'aatype'
],
batch
[
'prev_pos'
],
None
)
dgram
=
dgram_from_positions
(
prev_pseudo_beta
,
**
self
.
config
.
prev_pos
)
pair_activations
+=
common_modules
.
Linear
(
c
.
pair_channel
,
name
=
'prev_pos_linear'
)(
dgram
)
if
c
.
recycle_features
:
if
'prev_msa_first_row'
in
batch
:
prev_msa_first_row
=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
'prev_msa_first_row_norm'
)(
batch
[
'prev_msa_first_row'
])
msa_activations
=
msa_activations
.
at
[
0
].
add
(
prev_msa_first_row
)
if
'prev_pair'
in
batch
:
pair_activations
+=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
'prev_pair_norm'
)(
batch
[
'prev_pair'
])
# Relative position encoding.
# Jumper et al. (2021) Suppl. Alg. 4 "relpos"
# Jumper et al. (2021) Suppl. Alg. 5 "one_hot"
if
c
.
max_relative_feature
:
# Add one-hot-encoded clipped residue distances to the pair activations.
pos
=
batch
[
'residue_index'
]
offset
=
pos
[:,
None
]
-
pos
[
None
,
:]
rel_pos
=
jax
.
nn
.
one_hot
(
jnp
.
clip
(
offset
+
c
.
max_relative_feature
,
a_min
=
0
,
a_max
=
2
*
c
.
max_relative_feature
),
2
*
c
.
max_relative_feature
+
1
)
pair_activations
+=
common_modules
.
Linear
(
c
.
pair_channel
,
name
=
'pair_activiations'
)(
rel_pos
)
# Embed templates into the pair activations.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-13
if
c
.
template
.
enabled
:
template_batch
=
{
k
:
batch
[
k
]
for
k
in
batch
if
k
.
startswith
(
'template_'
)}
template_pair_representation
=
TemplateEmbedding
(
c
.
template
,
gc
)(
pair_activations
,
template_batch
,
mask_2d
,
is_training
=
is_training
)
pair_activations
+=
template_pair_representation
# Embed extra MSA features.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 14-16
extra_msa_feat
=
create_extra_msa_feature
(
batch
)
extra_msa_activations
=
common_modules
.
Linear
(
c
.
extra_msa_channel
,
name
=
'extra_msa_activations'
)(
extra_msa_feat
)
# Extra MSA Stack.
# Jumper et al. (2021) Suppl. Alg. 18 "ExtraMsaStack"
extra_msa_stack_input
=
{
'msa'
:
extra_msa_activations
,
'pair'
:
pair_activations
,
}
extra_msa_stack_iteration
=
EvoformerIteration
(
c
.
evoformer
,
gc
,
is_extra_msa
=
True
,
name
=
'extra_msa_stack'
)
def
extra_msa_stack_fn
(
x
):
act
,
safe_key
=
x
safe_key
,
safe_subkey
=
safe_key
.
split
()
extra_evoformer_output
=
extra_msa_stack_iteration
(
activations
=
act
,
masks
=
{
'msa'
:
batch
[
'extra_msa_mask'
],
'pair'
:
mask_2d
},
is_training
=
is_training
,
safe_key
=
safe_subkey
)
return
(
extra_evoformer_output
,
safe_key
)
if
gc
.
use_remat
:
extra_msa_stack_fn
=
hk
.
remat
(
extra_msa_stack_fn
)
extra_msa_stack
=
layer_stack
.
layer_stack
(
c
.
extra_msa_stack_num_block
)(
extra_msa_stack_fn
)
extra_msa_output
,
safe_key
=
extra_msa_stack
(
(
extra_msa_stack_input
,
safe_key
))
pair_activations
=
extra_msa_output
[
'pair'
]
evoformer_input
=
{
'msa'
:
msa_activations
,
'pair'
:
pair_activations
,
}
evoformer_masks
=
{
'msa'
:
batch
[
'msa_mask'
],
'pair'
:
mask_2d
}
# Append num_templ rows to msa_activations with template embeddings.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 7-8
if
c
.
template
.
enabled
and
c
.
template
.
embed_torsion_angles
:
num_templ
,
num_res
=
batch
[
'template_aatype'
].
shape
# Embed the templates aatypes.
aatype_one_hot
=
jax
.
nn
.
one_hot
(
batch
[
'template_aatype'
],
22
,
axis
=-
1
)
# Embed the templates aatype, torsion angles and masks.
# Shape (templates, residues, msa_channels)
ret
=
all_atom
.
atom37_to_torsion_angles
(
aatype
=
batch
[
'template_aatype'
],
all_atom_pos
=
batch
[
'template_all_atom_positions'
],
all_atom_mask
=
batch
[
'template_all_atom_masks'
],
# Ensure consistent behaviour during testing:
placeholder_for_undefined
=
not
gc
.
zero_init
)
template_features
=
jnp
.
concatenate
([
aatype_one_hot
,
jnp
.
reshape
(
ret
[
'torsion_angles_sin_cos'
],
[
num_templ
,
num_res
,
14
]),
jnp
.
reshape
(
ret
[
'alt_torsion_angles_sin_cos'
],
[
num_templ
,
num_res
,
14
]),
ret
[
'torsion_angles_mask'
]],
axis
=-
1
)
template_activations
=
common_modules
.
Linear
(
c
.
msa_channel
,
initializer
=
'relu'
,
name
=
'template_single_embedding'
)(
template_features
)
template_activations
=
jax
.
nn
.
relu
(
template_activations
)
template_activations
=
common_modules
.
Linear
(
c
.
msa_channel
,
initializer
=
'relu'
,
name
=
'template_projection'
)(
template_activations
)
# Concatenate the templates to the msa.
evoformer_input
[
'msa'
]
=
jnp
.
concatenate
(
[
evoformer_input
[
'msa'
],
template_activations
],
axis
=
0
)
# Concatenate templates masks to the msa masks.
# Use mask from the psi angle, as it only depends on the backbone atoms
# from a single residue.
torsion_angle_mask
=
ret
[
'torsion_angles_mask'
][:,
:,
2
]
torsion_angle_mask
=
torsion_angle_mask
.
astype
(
evoformer_masks
[
'msa'
].
dtype
)
evoformer_masks
[
'msa'
]
=
jnp
.
concatenate
(
[
evoformer_masks
[
'msa'
],
torsion_angle_mask
],
axis
=
0
)
# Main trunk of the network
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 17-18
evoformer_iteration
=
EvoformerIteration
(
c
.
evoformer
,
gc
,
is_extra_msa
=
False
,
name
=
'evoformer_iteration'
)
def
evoformer_fn
(
x
):
act
,
safe_key
=
x
safe_key
,
safe_subkey
=
safe_key
.
split
()
evoformer_output
=
evoformer_iteration
(
activations
=
act
,
masks
=
evoformer_masks
,
is_training
=
is_training
,
safe_key
=
safe_subkey
)
return
(
evoformer_output
,
safe_key
)
if
gc
.
use_remat
:
evoformer_fn
=
hk
.
remat
(
evoformer_fn
)
evoformer_stack
=
layer_stack
.
layer_stack
(
c
.
evoformer_num_block
)(
evoformer_fn
)
evoformer_output
,
safe_key
=
evoformer_stack
(
(
evoformer_input
,
safe_key
))
msa_activations
=
evoformer_output
[
'msa'
]
pair_activations
=
evoformer_output
[
'pair'
]
single_activations
=
common_modules
.
Linear
(
c
.
seq_channel
,
name
=
'single_activations'
)(
msa_activations
[
0
])
num_sequences
=
batch
[
'msa_feat'
].
shape
[
0
]
output
=
{
'single'
:
single_activations
,
'pair'
:
pair_activations
,
# Crop away template rows such that they are not used in MaskedMsaHead.
'msa'
:
msa_activations
[:
num_sequences
,
:,
:],
'msa_first_row'
:
msa_activations
[
0
],
}
return
output
class
SingleTemplateEmbedding
(
hk
.
Module
):
"""Embeds a single template.
Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9+11
"""
def
__init__
(
self
,
config
,
global_config
,
name
=
'single_template_embedding'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
query_embedding
,
batch
,
mask_2d
,
is_training
):
"""Build the single template embedding.
Arguments:
query_embedding: Query pair representation, shape [N_res, N_res, c_z].
batch: A batch of template features (note the template dimension has been
stripped out as this module only runs over a single template).
mask_2d: Padding mask (Note: this doesn't care if a template exists,
unlike the template_pseudo_beta_mask).
is_training: Whether the module is in training mode.
Returns:
A template embedding [N_res, N_res, c_z].
"""
assert
mask_2d
.
dtype
==
query_embedding
.
dtype
dtype
=
query_embedding
.
dtype
num_res
=
batch
[
'template_aatype'
].
shape
[
0
]
num_channels
=
(
self
.
config
.
template_pair_stack
.
triangle_attention_ending_node
.
value_dim
)
template_mask
=
batch
[
'template_pseudo_beta_mask'
]
template_mask_2d
=
template_mask
[:,
None
]
*
template_mask
[
None
,
:]
template_mask_2d
=
template_mask_2d
.
astype
(
dtype
)
template_dgram
=
dgram_from_positions
(
batch
[
'template_pseudo_beta'
],
**
self
.
config
.
dgram_features
)
template_dgram
=
template_dgram
.
astype
(
dtype
)
to_concat
=
[
template_dgram
,
template_mask_2d
[:,
:,
None
]]
aatype
=
jax
.
nn
.
one_hot
(
batch
[
'template_aatype'
],
22
,
axis
=-
1
,
dtype
=
dtype
)
to_concat
.
append
(
jnp
.
tile
(
aatype
[
None
,
:,
:],
[
num_res
,
1
,
1
]))
to_concat
.
append
(
jnp
.
tile
(
aatype
[:,
None
,
:],
[
1
,
num_res
,
1
]))
n
,
ca
,
c
=
[
residue_constants
.
atom_order
[
a
]
for
a
in
(
'N'
,
'CA'
,
'C'
)]
rot
,
trans
=
quat_affine
.
make_transform_from_reference
(
n_xyz
=
batch
[
'template_all_atom_positions'
][:,
n
],
ca_xyz
=
batch
[
'template_all_atom_positions'
][:,
ca
],
c_xyz
=
batch
[
'template_all_atom_positions'
][:,
c
])
affines
=
quat_affine
.
QuatAffine
(
quaternion
=
quat_affine
.
rot_to_quat
(
rot
,
unstack_inputs
=
True
),
translation
=
trans
,
rotation
=
rot
,
unstack_inputs
=
True
)
points
=
[
jnp
.
expand_dims
(
x
,
axis
=-
2
)
for
x
in
affines
.
translation
]
affine_vec
=
affines
.
invert_point
(
points
,
extra_dims
=
1
)
inv_distance_scalar
=
jax
.
lax
.
rsqrt
(
1e-6
+
sum
([
jnp
.
square
(
x
)
for
x
in
affine_vec
]))
# Backbone affine mask: whether the residue has C, CA, N
# (the template mask defined above only considers pseudo CB).
template_mask
=
(
batch
[
'template_all_atom_masks'
][...,
n
]
*
batch
[
'template_all_atom_masks'
][...,
ca
]
*
batch
[
'template_all_atom_masks'
][...,
c
])
template_mask_2d
=
template_mask
[:,
None
]
*
template_mask
[
None
,
:]
inv_distance_scalar
*=
template_mask_2d
.
astype
(
inv_distance_scalar
.
dtype
)
unit_vector
=
[(
x
*
inv_distance_scalar
)[...,
None
]
for
x
in
affine_vec
]
unit_vector
=
[
x
.
astype
(
dtype
)
for
x
in
unit_vector
]
template_mask_2d
=
template_mask_2d
.
astype
(
dtype
)
if
not
self
.
config
.
use_template_unit_vector
:
unit_vector
=
[
jnp
.
zeros_like
(
x
)
for
x
in
unit_vector
]
to_concat
.
extend
(
unit_vector
)
to_concat
.
append
(
template_mask_2d
[...,
None
])
act
=
jnp
.
concatenate
(
to_concat
,
axis
=-
1
)
# Mask out non-template regions so we don't get arbitrary values in the
# distogram for these regions.
act
*=
template_mask_2d
[...,
None
]
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 9
act
=
common_modules
.
Linear
(
num_channels
,
initializer
=
'relu'
,
name
=
'embedding2d'
)(
act
)
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 11
act
=
TemplatePairStack
(
self
.
config
.
template_pair_stack
,
self
.
global_config
)(
act
,
mask_2d
,
is_training
)
act
=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
'output_layer_norm'
)(
act
)
return
act
class
TemplateEmbedding
(
hk
.
Module
):
"""Embeds a set of templates.
Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12
Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention"
"""
def
__init__
(
self
,
config
,
global_config
,
name
=
'template_embedding'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
query_embedding
,
template_batch
,
mask_2d
,
is_training
):
"""Build TemplateEmbedding module.
Arguments:
query_embedding: Query pair representation, shape [N_res, N_res, c_z].
template_batch: A batch of template features.
mask_2d: Padding mask (Note: this doesn't care if a template exists,
unlike the template_pseudo_beta_mask).
is_training: Whether the module is in training mode.
Returns:
A template embedding [N_res, N_res, c_z].
"""
num_templates
=
template_batch
[
'template_mask'
].
shape
[
0
]
num_channels
=
(
self
.
config
.
template_pair_stack
.
triangle_attention_ending_node
.
value_dim
)
num_res
=
query_embedding
.
shape
[
0
]
dtype
=
query_embedding
.
dtype
template_mask
=
template_batch
[
'template_mask'
]
template_mask
=
template_mask
.
astype
(
dtype
)
query_num_channels
=
query_embedding
.
shape
[
-
1
]
# Make sure the weights are shared across templates by constructing the
# embedder here.
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12
template_embedder
=
SingleTemplateEmbedding
(
self
.
config
,
self
.
global_config
)
def
map_fn
(
batch
):
return
template_embedder
(
query_embedding
,
batch
,
mask_2d
,
is_training
)
template_pair_representation
=
mapping
.
sharded_map
(
map_fn
,
in_axes
=
0
)(
template_batch
)
# Cross attend from the query to the templates along the residue
# dimension by flattening everything else into the batch dimension.
# Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention"
flat_query
=
jnp
.
reshape
(
query_embedding
,
[
num_res
*
num_res
,
1
,
query_num_channels
])
flat_templates
=
jnp
.
reshape
(
jnp
.
transpose
(
template_pair_representation
,
[
1
,
2
,
0
,
3
]),
[
num_res
*
num_res
,
num_templates
,
num_channels
])
bias
=
(
1e9
*
(
template_mask
[
None
,
None
,
None
,
:]
-
1.
))
template_pointwise_attention_module
=
Attention
(
self
.
config
.
attention
,
self
.
global_config
,
query_num_channels
)
nonbatched_args
=
[
bias
]
batched_args
=
[
flat_query
,
flat_templates
]
embedding
=
mapping
.
inference_subbatch
(
template_pointwise_attention_module
,
self
.
config
.
subbatch_size
,
batched_args
=
batched_args
,
nonbatched_args
=
nonbatched_args
,
low_memory
=
not
is_training
)
embedding
=
jnp
.
reshape
(
embedding
,
[
num_res
,
num_res
,
query_num_channels
])
# No gradients if no templates.
embedding
*=
(
jnp
.
sum
(
template_mask
)
>
0.
).
astype
(
embedding
.
dtype
)
return
embedding
alphafold/model/modules_multimer.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Core modules, which have been refactored in AlphaFold-Multimer.
The main difference is that MSA sampling pipeline is moved inside the JAX model
for easier implementation of recycling and ensembling.
Lower-level modules up to EvoformerIteration are reused from modules.py.
"""
import
functools
from
typing
import
Sequence
from
alphafold.common
import
residue_constants
from
alphafold.model
import
all_atom_multimer
from
alphafold.model
import
common_modules
from
alphafold.model
import
folding_multimer
from
alphafold.model
import
geometry
from
alphafold.model
import
layer_stack
from
alphafold.model
import
modules
from
alphafold.model
import
prng
from
alphafold.model
import
utils
import
haiku
as
hk
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
def
reduce_fn
(
x
,
mode
):
if
mode
==
'none'
or
mode
is
None
:
return
jnp
.
asarray
(
x
)
elif
mode
==
'sum'
:
return
jnp
.
asarray
(
x
).
sum
()
elif
mode
==
'mean'
:
return
jnp
.
mean
(
jnp
.
asarray
(
x
))
else
:
raise
ValueError
(
'Unsupported reduction option.'
)
def
gumbel_noise
(
key
:
jnp
.
ndarray
,
shape
:
Sequence
[
int
])
->
jnp
.
ndarray
:
"""Generate Gumbel Noise of given Shape.
This generates samples from Gumbel(0, 1).
Args:
key: Jax random number key.
shape: Shape of noise to return.
Returns:
Gumbel noise of given shape.
"""
epsilon
=
1e-6
uniform
=
utils
.
padding_consistent_rng
(
jax
.
random
.
uniform
)
uniform_noise
=
uniform
(
key
,
shape
=
shape
,
dtype
=
jnp
.
float32
,
minval
=
0.
,
maxval
=
1.
)
gumbel
=
-
jnp
.
log
(
-
jnp
.
log
(
uniform_noise
+
epsilon
)
+
epsilon
)
return
gumbel
def
gumbel_max_sample
(
key
:
jnp
.
ndarray
,
logits
:
jnp
.
ndarray
)
->
jnp
.
ndarray
:
"""Samples from a probability distribution given by 'logits'.
This uses Gumbel-max trick to implement the sampling in an efficient manner.
Args:
key: prng key.
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z
=
gumbel_noise
(
key
,
logits
.
shape
)
return
jax
.
nn
.
one_hot
(
jnp
.
argmax
(
logits
+
z
,
axis
=-
1
),
logits
.
shape
[
-
1
],
dtype
=
logits
.
dtype
)
def
gumbel_argsort_sample_idx
(
key
:
jnp
.
ndarray
,
logits
:
jnp
.
ndarray
)
->
jnp
.
ndarray
:
"""Samples with replacement from a distribution given by 'logits'.
This uses Gumbel trick to implement the sampling an efficient manner. For a
distribution over k items this samples k times without replacement, so this
is effectively sampling a random permutation with probabilities over the
permutations derived from the logprobs.
Args:
key: prng key.
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z
=
gumbel_noise
(
key
,
logits
.
shape
)
# This construction is equivalent to jnp.argsort, but using a non stable sort,
# since stable sort's aren't supported by jax2tf.
axis
=
len
(
logits
.
shape
)
-
1
iota
=
jax
.
lax
.
broadcasted_iota
(
jnp
.
int64
,
logits
.
shape
,
axis
)
_
,
perm
=
jax
.
lax
.
sort_key_val
(
logits
+
z
,
iota
,
dimension
=-
1
,
is_stable
=
False
)
return
perm
[::
-
1
]
def
make_masked_msa
(
batch
,
key
,
config
,
epsilon
=
1e-6
):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
random_aa
=
jnp
.
array
([
0.05
]
*
20
+
[
0.
,
0.
],
dtype
=
jnp
.
float32
)
categorical_probs
=
(
config
.
uniform_prob
*
random_aa
+
config
.
profile_prob
*
batch
[
'msa_profile'
]
+
config
.
same_prob
*
jax
.
nn
.
one_hot
(
batch
[
'msa'
],
22
))
# Put all remaining probability on [MASK] which is a new column.
pad_shapes
=
[[
0
,
0
]
for
_
in
range
(
len
(
categorical_probs
.
shape
))]
pad_shapes
[
-
1
][
1
]
=
1
mask_prob
=
1.
-
config
.
profile_prob
-
config
.
same_prob
-
config
.
uniform_prob
assert
mask_prob
>=
0.
categorical_probs
=
jnp
.
pad
(
categorical_probs
,
pad_shapes
,
constant_values
=
mask_prob
)
sh
=
batch
[
'msa'
].
shape
key
,
mask_subkey
,
gumbel_subkey
=
key
.
split
(
3
)
uniform
=
utils
.
padding_consistent_rng
(
jax
.
random
.
uniform
)
mask_position
=
uniform
(
mask_subkey
.
get
(),
sh
)
<
config
.
replace_fraction
mask_position
*=
batch
[
'msa_mask'
]
logits
=
jnp
.
log
(
categorical_probs
+
epsilon
)
bert_msa
=
gumbel_max_sample
(
gumbel_subkey
.
get
(),
logits
)
bert_msa
=
jnp
.
where
(
mask_position
,
jnp
.
argmax
(
bert_msa
,
axis
=-
1
),
batch
[
'msa'
])
bert_msa
*=
batch
[
'msa_mask'
]
# Mix real and masked MSA.
if
'bert_mask'
in
batch
:
batch
[
'bert_mask'
]
*=
mask_position
.
astype
(
jnp
.
float32
)
else
:
batch
[
'bert_mask'
]
=
mask_position
.
astype
(
jnp
.
float32
)
batch
[
'true_msa'
]
=
batch
[
'msa'
]
batch
[
'msa'
]
=
bert_msa
return
batch
def
nearest_neighbor_clusters
(
batch
,
gap_agreement_weight
=
0.
):
"""Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
# Determine how much weight we assign to each agreement. In theory, we could
# use a full blosum matrix here, but right now let's just down-weight gap
# agreement because it could be spurious.
# Never put weight on agreeing on BERT mask.
weights
=
jnp
.
array
(
[
1.
]
*
21
+
[
gap_agreement_weight
]
+
[
0.
],
dtype
=
jnp
.
float32
)
msa_mask
=
batch
[
'msa_mask'
]
msa_one_hot
=
jax
.
nn
.
one_hot
(
batch
[
'msa'
],
23
)
extra_mask
=
batch
[
'extra_msa_mask'
]
extra_one_hot
=
jax
.
nn
.
one_hot
(
batch
[
'extra_msa'
],
23
)
msa_one_hot_masked
=
msa_mask
[:,
:,
None
]
*
msa_one_hot
extra_one_hot_masked
=
extra_mask
[:,
:,
None
]
*
extra_one_hot
agreement
=
jnp
.
einsum
(
'mrc, nrc->nm'
,
extra_one_hot_masked
,
weights
*
msa_one_hot_masked
)
cluster_assignment
=
jax
.
nn
.
softmax
(
1e3
*
agreement
,
axis
=
0
)
cluster_assignment
*=
jnp
.
einsum
(
'mr, nr->mn'
,
msa_mask
,
extra_mask
)
cluster_count
=
jnp
.
sum
(
cluster_assignment
,
axis
=-
1
)
cluster_count
+=
1.
# We always include the sequence itself.
msa_sum
=
jnp
.
einsum
(
'nm, mrc->nrc'
,
cluster_assignment
,
extra_one_hot_masked
)
msa_sum
+=
msa_one_hot_masked
cluster_profile
=
msa_sum
/
cluster_count
[:,
None
,
None
]
extra_deletion_matrix
=
batch
[
'extra_deletion_matrix'
]
deletion_matrix
=
batch
[
'deletion_matrix'
]
del_sum
=
jnp
.
einsum
(
'nm, mc->nc'
,
cluster_assignment
,
extra_mask
*
extra_deletion_matrix
)
del_sum
+=
deletion_matrix
# Original sequence.
cluster_deletion_mean
=
del_sum
/
cluster_count
[:,
None
]
return
cluster_profile
,
cluster_deletion_mean
def
create_msa_feat
(
batch
):
"""Create and concatenate MSA features."""
msa_1hot
=
jax
.
nn
.
one_hot
(
batch
[
'msa'
],
23
)
deletion_matrix
=
batch
[
'deletion_matrix'
]
has_deletion
=
jnp
.
clip
(
deletion_matrix
,
0.
,
1.
)[...,
None
]
deletion_value
=
(
jnp
.
arctan
(
deletion_matrix
/
3.
)
*
(
2.
/
jnp
.
pi
))[...,
None
]
deletion_mean_value
=
(
jnp
.
arctan
(
batch
[
'cluster_deletion_mean'
]
/
3.
)
*
(
2.
/
jnp
.
pi
))[...,
None
]
msa_feat
=
[
msa_1hot
,
has_deletion
,
deletion_value
,
batch
[
'cluster_profile'
],
deletion_mean_value
]
return
jnp
.
concatenate
(
msa_feat
,
axis
=-
1
)
def
create_extra_msa_feature
(
batch
,
num_extra_msa
):
"""Expand extra_msa into 1hot and concat with other extra msa features.
We do this as late as possible as the one_hot extra msa can be very large.
Args:
batch: a dictionary with the following keys:
* 'extra_msa': [num_seq, num_res] MSA that wasn't selected as a cluster
centre. Note - This isn't one-hotted.
* 'extra_deletion_matrix': [num_seq, num_res] Number of deletions at given
position.
num_extra_msa: Number of extra msa to use.
Returns:
Concatenated tensor of extra MSA features.
"""
# 23 = 20 amino acids + 'X' for unknown + gap + bert mask
extra_msa
=
batch
[
'extra_msa'
][:
num_extra_msa
]
deletion_matrix
=
batch
[
'extra_deletion_matrix'
][:
num_extra_msa
]
msa_1hot
=
jax
.
nn
.
one_hot
(
extra_msa
,
23
)
has_deletion
=
jnp
.
clip
(
deletion_matrix
,
0.
,
1.
)[...,
None
]
deletion_value
=
(
jnp
.
arctan
(
deletion_matrix
/
3.
)
*
(
2.
/
jnp
.
pi
))[...,
None
]
extra_msa_mask
=
batch
[
'extra_msa_mask'
][:
num_extra_msa
]
return
jnp
.
concatenate
([
msa_1hot
,
has_deletion
,
deletion_value
],
axis
=-
1
),
extra_msa_mask
def
sample_msa
(
key
,
batch
,
max_seq
):
"""Sample MSA randomly, remaining sequences are stored as `extra_*`.
Args:
key: safe key for random number generation.
batch: batch to sample msa from.
max_seq: number of sequences to sample.
Returns:
Protein with sampled msa.
"""
# Sample uniformly among sequences with at least one non-masked position.
logits
=
(
jnp
.
clip
(
jnp
.
sum
(
batch
[
'msa_mask'
],
axis
=-
1
),
0.
,
1.
)
-
1.
)
*
1e6
# The cluster_bias_mask can be used to preserve the first row (target
# sequence) for each chain, for example.
if
'cluster_bias_mask'
not
in
batch
:
cluster_bias_mask
=
jnp
.
pad
(
jnp
.
zeros
(
batch
[
'msa'
].
shape
[
0
]
-
1
),
(
1
,
0
),
constant_values
=
1.
)
else
:
cluster_bias_mask
=
batch
[
'cluster_bias_mask'
]
logits
+=
cluster_bias_mask
*
1e6
index_order
=
gumbel_argsort_sample_idx
(
key
.
get
(),
logits
)
sel_idx
=
index_order
[:
max_seq
]
extra_idx
=
index_order
[
max_seq
:]
for
k
in
[
'msa'
,
'deletion_matrix'
,
'msa_mask'
,
'bert_mask'
]:
if
k
in
batch
:
batch
[
'extra_'
+
k
]
=
batch
[
k
][
extra_idx
]
batch
[
k
]
=
batch
[
k
][
sel_idx
]
return
batch
def
make_msa_profile
(
batch
):
"""Compute the MSA profile."""
# Compute the profile for every residue (over all MSA sequences).
return
utils
.
mask_mean
(
batch
[
'msa_mask'
][:,
:,
None
],
jax
.
nn
.
one_hot
(
batch
[
'msa'
],
22
),
axis
=
0
)
class
AlphaFoldIteration
(
hk
.
Module
):
"""A single recycling iteration of AlphaFold architecture.
Computes ensembled (averaged) representations from the provided features.
These representations are then passed to the various heads
that have been requested by the configuration file.
"""
def
__init__
(
self
,
config
,
global_config
,
name
=
'alphafold_iteration'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
batch
,
is_training
,
return_representations
=
False
,
safe_key
=
None
):
if
is_training
:
num_ensemble
=
np
.
asarray
(
self
.
config
.
num_ensemble_train
)
else
:
num_ensemble
=
np
.
asarray
(
self
.
config
.
num_ensemble_eval
)
# Compute representations for each MSA sample and average.
embedding_module
=
EmbeddingsAndEvoformer
(
self
.
config
.
embeddings_and_evoformer
,
self
.
global_config
)
repr_shape
=
hk
.
eval_shape
(
lambda
:
embedding_module
(
batch
,
is_training
))
representations
=
{
k
:
jnp
.
zeros
(
v
.
shape
,
v
.
dtype
)
for
(
k
,
v
)
in
repr_shape
.
items
()
}
def
ensemble_body
(
x
,
unused_y
):
"""Add into representations ensemble."""
del
unused_y
representations
,
safe_key
=
x
safe_key
,
safe_subkey
=
safe_key
.
split
()
representations_update
=
embedding_module
(
batch
,
is_training
,
safe_key
=
safe_subkey
)
for
k
in
representations
:
if
k
not
in
{
'msa'
,
'true_msa'
,
'bert_mask'
}:
representations
[
k
]
+=
representations_update
[
k
]
*
(
1.
/
num_ensemble
).
astype
(
representations
[
k
].
dtype
)
else
:
representations
[
k
]
=
representations_update
[
k
]
return
(
representations
,
safe_key
),
None
(
representations
,
_
),
_
=
hk
.
scan
(
ensemble_body
,
(
representations
,
safe_key
),
None
,
length
=
num_ensemble
)
self
.
representations
=
representations
self
.
batch
=
batch
self
.
heads
=
{}
for
head_name
,
head_config
in
sorted
(
self
.
config
.
heads
.
items
()):
if
not
head_config
.
weight
:
continue
# Do not instantiate zero-weight heads.
head_factory
=
{
'masked_msa'
:
modules
.
MaskedMsaHead
,
'distogram'
:
modules
.
DistogramHead
,
'structure_module'
:
folding_multimer
.
StructureModule
,
'predicted_aligned_error'
:
modules
.
PredictedAlignedErrorHead
,
'predicted_lddt'
:
modules
.
PredictedLDDTHead
,
'experimentally_resolved'
:
modules
.
ExperimentallyResolvedHead
,
}[
head_name
]
self
.
heads
[
head_name
]
=
(
head_config
,
head_factory
(
head_config
,
self
.
global_config
))
structure_module_output
=
None
if
'entity_id'
in
batch
and
'all_atom_positions'
in
batch
:
_
,
fold_module
=
self
.
heads
[
'structure_module'
]
structure_module_output
=
fold_module
(
representations
,
batch
,
is_training
)
ret
=
{}
ret
[
'representations'
]
=
representations
for
name
,
(
head_config
,
module
)
in
self
.
heads
.
items
():
if
name
==
'structure_module'
and
structure_module_output
is
not
None
:
ret
[
name
]
=
structure_module_output
representations
[
'structure_module'
]
=
structure_module_output
.
pop
(
'act'
)
# Skip confidence heads until StructureModule is executed.
elif
name
in
{
'predicted_lddt'
,
'predicted_aligned_error'
,
'experimentally_resolved'
}:
continue
else
:
ret
[
name
]
=
module
(
representations
,
batch
,
is_training
)
# Add confidence heads after StructureModule is executed.
if
self
.
config
.
heads
.
get
(
'predicted_lddt.weight'
,
0.0
):
name
=
'predicted_lddt'
head_config
,
module
=
self
.
heads
[
name
]
ret
[
name
]
=
module
(
representations
,
batch
,
is_training
)
if
self
.
config
.
heads
.
experimentally_resolved
.
weight
:
name
=
'experimentally_resolved'
head_config
,
module
=
self
.
heads
[
name
]
ret
[
name
]
=
module
(
representations
,
batch
,
is_training
)
if
self
.
config
.
heads
.
get
(
'predicted_aligned_error.weight'
,
0.0
):
name
=
'predicted_aligned_error'
head_config
,
module
=
self
.
heads
[
name
]
ret
[
name
]
=
module
(
representations
,
batch
,
is_training
)
# Will be used for ipTM computation.
ret
[
name
][
'asym_id'
]
=
batch
[
'asym_id'
]
return
ret
class
AlphaFold
(
hk
.
Module
):
"""AlphaFold-Multimer model with recycling.
"""
def
__init__
(
self
,
config
,
name
=
'alphafold'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
config
.
global_config
def
__call__
(
self
,
batch
,
is_training
,
return_representations
=
False
,
safe_key
=
None
):
c
=
self
.
config
impl
=
AlphaFoldIteration
(
c
,
self
.
global_config
)
if
safe_key
is
None
:
safe_key
=
prng
.
SafeKey
(
hk
.
next_rng_key
())
elif
isinstance
(
safe_key
,
jnp
.
ndarray
):
safe_key
=
prng
.
SafeKey
(
safe_key
)
assert
isinstance
(
batch
,
dict
)
num_res
=
batch
[
'aatype'
].
shape
[
0
]
def
get_prev
(
ret
):
new_prev
=
{
'prev_pos'
:
ret
[
'structure_module'
][
'final_atom_positions'
],
'prev_msa_first_row'
:
ret
[
'representations'
][
'msa_first_row'
],
'prev_pair'
:
ret
[
'representations'
][
'pair'
],
}
return
jax
.
tree_map
(
jax
.
lax
.
stop_gradient
,
new_prev
)
def
apply_network
(
prev
,
safe_key
):
recycled_batch
=
{
**
batch
,
**
prev
}
return
impl
(
batch
=
recycled_batch
,
is_training
=
is_training
,
safe_key
=
safe_key
)
prev
=
{}
emb_config
=
self
.
config
.
embeddings_and_evoformer
if
emb_config
.
recycle_pos
:
prev
[
'prev_pos'
]
=
jnp
.
zeros
(
[
num_res
,
residue_constants
.
atom_type_num
,
3
])
if
emb_config
.
recycle_features
:
prev
[
'prev_msa_first_row'
]
=
jnp
.
zeros
(
[
num_res
,
emb_config
.
msa_channel
])
prev
[
'prev_pair'
]
=
jnp
.
zeros
(
[
num_res
,
num_res
,
emb_config
.
pair_channel
])
if
self
.
config
.
num_recycle
:
if
'num_iter_recycling'
in
batch
:
# Training time: num_iter_recycling is in batch.
# Value for each ensemble batch is the same, so arbitrarily taking 0-th.
num_iter
=
batch
[
'num_iter_recycling'
][
0
]
# Add insurance that even when ensembling, we will not run more
# recyclings than the model is configured to run.
num_iter
=
jnp
.
minimum
(
num_iter
,
c
.
num_recycle
)
else
:
# Eval mode or tests: use the maximum number of iterations.
num_iter
=
c
.
num_recycle
def
recycle_body
(
i
,
x
):
del
i
prev
,
safe_key
=
x
safe_key1
,
safe_key2
=
safe_key
.
split
()
if
c
.
resample_msa_in_recycling
else
safe_key
.
duplicate
()
# pylint: disable=line-too-long
ret
=
apply_network
(
prev
=
prev
,
safe_key
=
safe_key2
)
return
get_prev
(
ret
),
safe_key1
prev
,
safe_key
=
hk
.
fori_loop
(
0
,
num_iter
,
recycle_body
,
(
prev
,
safe_key
))
# Run extra iteration.
ret
=
apply_network
(
prev
=
prev
,
safe_key
=
safe_key
)
if
not
return_representations
:
del
ret
[
'representations'
]
return
ret
class
EmbeddingsAndEvoformer
(
hk
.
Module
):
"""Embeds the input data and runs Evoformer.
Produces the MSA, single and pair representations.
"""
def
__init__
(
self
,
config
,
global_config
,
name
=
'evoformer'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
_relative_encoding
(
self
,
batch
):
"""Add relative position encodings.
For position (i, j), the value is (i-j) clipped to [-k, k] and one-hotted.
When not using 'use_chain_relative' the residue indices are used as is, e.g.
for heteromers relative positions will be computed using the positions in
the corresponding chains.
When using 'use_chain_relative' we add an extra bin that denotes
'different chain'. Furthermore we also provide the relative chain index
(i.e. sym_id) clipped and one-hotted to the network. And an extra feature
which denotes whether they belong to the same chain type, i.e. it's 0 if
they are in different heteromer chains and 1 otherwise.
Args:
batch: batch.
Returns:
Feature embedding using the features as described before.
"""
c
=
self
.
config
rel_feats
=
[]
pos
=
batch
[
'residue_index'
]
asym_id
=
batch
[
'asym_id'
]
asym_id_same
=
jnp
.
equal
(
asym_id
[:,
None
],
asym_id
[
None
,
:])
offset
=
pos
[:,
None
]
-
pos
[
None
,
:]
clipped_offset
=
jnp
.
clip
(
offset
+
c
.
max_relative_idx
,
a_min
=
0
,
a_max
=
2
*
c
.
max_relative_idx
)
if
c
.
use_chain_relative
:
final_offset
=
jnp
.
where
(
asym_id_same
,
clipped_offset
,
(
2
*
c
.
max_relative_idx
+
1
)
*
jnp
.
ones_like
(
clipped_offset
))
rel_pos
=
jax
.
nn
.
one_hot
(
final_offset
,
2
*
c
.
max_relative_idx
+
2
)
rel_feats
.
append
(
rel_pos
)
entity_id
=
batch
[
'entity_id'
]
entity_id_same
=
jnp
.
equal
(
entity_id
[:,
None
],
entity_id
[
None
,
:])
rel_feats
.
append
(
entity_id_same
.
astype
(
rel_pos
.
dtype
)[...,
None
])
sym_id
=
batch
[
'sym_id'
]
rel_sym_id
=
sym_id
[:,
None
]
-
sym_id
[
None
,
:]
max_rel_chain
=
c
.
max_relative_chain
clipped_rel_chain
=
jnp
.
clip
(
rel_sym_id
+
max_rel_chain
,
a_min
=
0
,
a_max
=
2
*
max_rel_chain
)
final_rel_chain
=
jnp
.
where
(
entity_id_same
,
clipped_rel_chain
,
(
2
*
max_rel_chain
+
1
)
*
jnp
.
ones_like
(
clipped_rel_chain
))
rel_chain
=
jax
.
nn
.
one_hot
(
final_rel_chain
,
2
*
c
.
max_relative_chain
+
2
)
rel_feats
.
append
(
rel_chain
)
else
:
rel_pos
=
jax
.
nn
.
one_hot
(
clipped_offset
,
2
*
c
.
max_relative_idx
+
1
)
rel_feats
.
append
(
rel_pos
)
rel_feat
=
jnp
.
concatenate
(
rel_feats
,
axis
=-
1
)
return
common_modules
.
Linear
(
c
.
pair_channel
,
name
=
'position_activations'
)(
rel_feat
)
def
__call__
(
self
,
batch
,
is_training
,
safe_key
=
None
):
c
=
self
.
config
gc
=
self
.
global_config
batch
=
dict
(
batch
)
if
safe_key
is
None
:
safe_key
=
prng
.
SafeKey
(
hk
.
next_rng_key
())
output
=
{}
batch
[
'msa_profile'
]
=
make_msa_profile
(
batch
)
target_feat
=
jax
.
nn
.
one_hot
(
batch
[
'aatype'
],
21
)
preprocess_1d
=
common_modules
.
Linear
(
c
.
msa_channel
,
name
=
'preprocess_1d'
)(
target_feat
)
safe_key
,
sample_key
,
mask_key
=
safe_key
.
split
(
3
)
batch
=
sample_msa
(
sample_key
,
batch
,
c
.
num_msa
)
batch
=
make_masked_msa
(
batch
,
mask_key
,
c
.
masked_msa
)
(
batch
[
'cluster_profile'
],
batch
[
'cluster_deletion_mean'
])
=
nearest_neighbor_clusters
(
batch
)
msa_feat
=
create_msa_feat
(
batch
)
preprocess_msa
=
common_modules
.
Linear
(
c
.
msa_channel
,
name
=
'preprocess_msa'
)(
msa_feat
)
msa_activations
=
jnp
.
expand_dims
(
preprocess_1d
,
axis
=
0
)
+
preprocess_msa
left_single
=
common_modules
.
Linear
(
c
.
pair_channel
,
name
=
'left_single'
)(
target_feat
)
right_single
=
common_modules
.
Linear
(
c
.
pair_channel
,
name
=
'right_single'
)(
target_feat
)
pair_activations
=
left_single
[:,
None
]
+
right_single
[
None
]
mask_2d
=
batch
[
'seq_mask'
][:,
None
]
*
batch
[
'seq_mask'
][
None
,
:]
mask_2d
=
mask_2d
.
astype
(
jnp
.
float32
)
if
c
.
recycle_pos
:
prev_pseudo_beta
=
modules
.
pseudo_beta_fn
(
batch
[
'aatype'
],
batch
[
'prev_pos'
],
None
)
dgram
=
modules
.
dgram_from_positions
(
prev_pseudo_beta
,
**
self
.
config
.
prev_pos
)
pair_activations
+=
common_modules
.
Linear
(
c
.
pair_channel
,
name
=
'prev_pos_linear'
)(
dgram
)
if
c
.
recycle_features
:
prev_msa_first_row
=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'prev_msa_first_row_norm'
)(
batch
[
'prev_msa_first_row'
])
msa_activations
=
msa_activations
.
at
[
0
].
add
(
prev_msa_first_row
)
pair_activations
+=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'prev_pair_norm'
)(
batch
[
'prev_pair'
])
if
c
.
max_relative_idx
:
pair_activations
+=
self
.
_relative_encoding
(
batch
)
if
c
.
template
.
enabled
:
template_module
=
TemplateEmbedding
(
c
.
template
,
gc
)
template_batch
=
{
'template_aatype'
:
batch
[
'template_aatype'
],
'template_all_atom_positions'
:
batch
[
'template_all_atom_positions'
],
'template_all_atom_mask'
:
batch
[
'template_all_atom_mask'
]
}
# Construct a mask such that only intra-chain template features are
# computed, since all templates are for each chain individually.
multichain_mask
=
batch
[
'asym_id'
][:,
None
]
==
batch
[
'asym_id'
][
None
,
:]
safe_key
,
safe_subkey
=
safe_key
.
split
()
template_act
=
template_module
(
query_embedding
=
pair_activations
,
template_batch
=
template_batch
,
padding_mask_2d
=
mask_2d
,
multichain_mask_2d
=
multichain_mask
,
is_training
=
is_training
,
safe_key
=
safe_subkey
)
pair_activations
+=
template_act
# Extra MSA stack.
(
extra_msa_feat
,
extra_msa_mask
)
=
create_extra_msa_feature
(
batch
,
c
.
num_extra_msa
)
extra_msa_activations
=
common_modules
.
Linear
(
c
.
extra_msa_channel
,
name
=
'extra_msa_activations'
)(
extra_msa_feat
)
extra_msa_mask
=
extra_msa_mask
.
astype
(
jnp
.
float32
)
extra_evoformer_input
=
{
'msa'
:
extra_msa_activations
,
'pair'
:
pair_activations
,
}
extra_masks
=
{
'msa'
:
extra_msa_mask
,
'pair'
:
mask_2d
}
extra_evoformer_iteration
=
modules
.
EvoformerIteration
(
c
.
evoformer
,
gc
,
is_extra_msa
=
True
,
name
=
'extra_msa_stack'
)
def
extra_evoformer_fn
(
x
):
act
,
safe_key
=
x
safe_key
,
safe_subkey
=
safe_key
.
split
()
extra_evoformer_output
=
extra_evoformer_iteration
(
activations
=
act
,
masks
=
extra_masks
,
is_training
=
is_training
,
safe_key
=
safe_subkey
)
return
(
extra_evoformer_output
,
safe_key
)
if
gc
.
use_remat
:
extra_evoformer_fn
=
hk
.
remat
(
extra_evoformer_fn
)
safe_key
,
safe_subkey
=
safe_key
.
split
()
extra_evoformer_stack
=
layer_stack
.
layer_stack
(
c
.
extra_msa_stack_num_block
)(
extra_evoformer_fn
)
extra_evoformer_output
,
safe_key
=
extra_evoformer_stack
(
(
extra_evoformer_input
,
safe_subkey
))
pair_activations
=
extra_evoformer_output
[
'pair'
]
# Get the size of the MSA before potentially adding templates, so we
# can crop out the templates later.
num_msa_sequences
=
msa_activations
.
shape
[
0
]
evoformer_input
=
{
'msa'
:
msa_activations
,
'pair'
:
pair_activations
,
}
evoformer_masks
=
{
'msa'
:
batch
[
'msa_mask'
].
astype
(
jnp
.
float32
),
'pair'
:
mask_2d
}
if
c
.
template
.
enabled
:
template_features
,
template_masks
=
(
template_embedding_1d
(
batch
=
batch
,
num_channel
=
c
.
msa_channel
))
evoformer_input
[
'msa'
]
=
jnp
.
concatenate
(
[
evoformer_input
[
'msa'
],
template_features
],
axis
=
0
)
evoformer_masks
[
'msa'
]
=
jnp
.
concatenate
(
[
evoformer_masks
[
'msa'
],
template_masks
],
axis
=
0
)
evoformer_iteration
=
modules
.
EvoformerIteration
(
c
.
evoformer
,
gc
,
is_extra_msa
=
False
,
name
=
'evoformer_iteration'
)
def
evoformer_fn
(
x
):
act
,
safe_key
=
x
safe_key
,
safe_subkey
=
safe_key
.
split
()
evoformer_output
=
evoformer_iteration
(
activations
=
act
,
masks
=
evoformer_masks
,
is_training
=
is_training
,
safe_key
=
safe_subkey
)
return
(
evoformer_output
,
safe_key
)
if
gc
.
use_remat
:
evoformer_fn
=
hk
.
remat
(
evoformer_fn
)
safe_key
,
safe_subkey
=
safe_key
.
split
()
evoformer_stack
=
layer_stack
.
layer_stack
(
c
.
evoformer_num_block
)(
evoformer_fn
)
def
run_evoformer
(
evoformer_input
):
evoformer_output
,
_
=
evoformer_stack
((
evoformer_input
,
safe_subkey
))
return
evoformer_output
evoformer_output
=
run_evoformer
(
evoformer_input
)
msa_activations
=
evoformer_output
[
'msa'
]
pair_activations
=
evoformer_output
[
'pair'
]
single_activations
=
common_modules
.
Linear
(
c
.
seq_channel
,
name
=
'single_activations'
)(
msa_activations
[
0
])
output
.
update
({
'single'
:
single_activations
,
'pair'
:
pair_activations
,
# Crop away template rows such that they are not used in MaskedMsaHead.
'msa'
:
msa_activations
[:
num_msa_sequences
,
:,
:],
'msa_first_row'
:
msa_activations
[
0
],
})
return
output
class
TemplateEmbedding
(
hk
.
Module
):
"""Embed a set of templates."""
def
__init__
(
self
,
config
,
global_config
,
name
=
'template_embedding'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
query_embedding
,
template_batch
,
padding_mask_2d
,
multichain_mask_2d
,
is_training
,
safe_key
=
None
):
"""Generate an embedding for a set of templates.
Args:
query_embedding: [num_res, num_res, num_channel] a query tensor that will
be used to attend over the templates to remove the num_templates
dimension.
template_batch: A dictionary containing:
`template_aatype`: [num_templates, num_res] aatype for each template.
`template_all_atom_positions`: [num_templates, num_res, 37, 3] atom
positions for all templates.
`template_all_atom_mask`: [num_templates, num_res, 37] mask for each
template.
padding_mask_2d: [num_res, num_res] Pair mask for attention operations.
multichain_mask_2d: [num_res, num_res] Mask indicating which residue pairs
are intra-chain, used to mask out residue distance based features
between chains.
is_training: bool indicating where we are running in training mode.
safe_key: random key generator.
Returns:
An embedding of size [num_res, num_res, num_channels]
"""
c
=
self
.
config
if
safe_key
is
None
:
safe_key
=
prng
.
SafeKey
(
hk
.
next_rng_key
())
num_templates
=
template_batch
[
'template_aatype'
].
shape
[
0
]
num_res
,
_
,
query_num_channels
=
query_embedding
.
shape
# Embed each template separately.
template_embedder
=
SingleTemplateEmbedding
(
self
.
config
,
self
.
global_config
)
def
partial_template_embedder
(
template_aatype
,
template_all_atom_positions
,
template_all_atom_mask
,
unsafe_key
):
safe_key
=
prng
.
SafeKey
(
unsafe_key
)
return
template_embedder
(
query_embedding
,
template_aatype
,
template_all_atom_positions
,
template_all_atom_mask
,
padding_mask_2d
,
multichain_mask_2d
,
is_training
,
safe_key
)
safe_key
,
unsafe_key
=
safe_key
.
split
()
unsafe_keys
=
jax
.
random
.
split
(
unsafe_key
.
_key
,
num_templates
)
def
scan_fn
(
carry
,
x
):
return
carry
+
partial_template_embedder
(
*
x
),
None
scan_init
=
jnp
.
zeros
((
num_res
,
num_res
,
c
.
num_channels
),
dtype
=
query_embedding
.
dtype
)
summed_template_embeddings
,
_
=
hk
.
scan
(
scan_fn
,
scan_init
,
(
template_batch
[
'template_aatype'
],
template_batch
[
'template_all_atom_positions'
],
template_batch
[
'template_all_atom_mask'
],
unsafe_keys
))
embedding
=
summed_template_embeddings
/
num_templates
embedding
=
jax
.
nn
.
relu
(
embedding
)
embedding
=
common_modules
.
Linear
(
query_num_channels
,
initializer
=
'relu'
,
name
=
'output_linear'
)(
embedding
)
return
embedding
class
SingleTemplateEmbedding
(
hk
.
Module
):
"""Embed a single template."""
def
__init__
(
self
,
config
,
global_config
,
name
=
'single_template_embedding'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
query_embedding
,
template_aatype
,
template_all_atom_positions
,
template_all_atom_mask
,
padding_mask_2d
,
multichain_mask_2d
,
is_training
,
safe_key
):
"""Build the single template embedding graph.
Args:
query_embedding: (num_res, num_res, num_channels) - embedding of the
query sequence/msa.
template_aatype: [num_res] aatype for each template.
template_all_atom_positions: [num_res, 37, 3] atom positions for all
templates.
template_all_atom_mask: [num_res, 37] mask for each template.
padding_mask_2d: Padding mask (Note: this doesn't care if a template
exists, unlike the template_pseudo_beta_mask).
multichain_mask_2d: A mask indicating intra-chain residue pairs, used
to mask out between chain distances/features when templates are for
single chains.
is_training: Are we in training mode.
safe_key: Random key generator.
Returns:
A template embedding (num_res, num_res, num_channels).
"""
gc
=
self
.
global_config
c
=
self
.
config
assert
padding_mask_2d
.
dtype
==
query_embedding
.
dtype
dtype
=
query_embedding
.
dtype
num_channels
=
self
.
config
.
num_channels
def
construct_input
(
query_embedding
,
template_aatype
,
template_all_atom_positions
,
template_all_atom_mask
,
multichain_mask_2d
):
# Compute distogram feature for the template.
template_positions
,
pseudo_beta_mask
=
modules
.
pseudo_beta_fn
(
template_aatype
,
template_all_atom_positions
,
template_all_atom_mask
)
pseudo_beta_mask_2d
=
(
pseudo_beta_mask
[:,
None
]
*
pseudo_beta_mask
[
None
,
:])
pseudo_beta_mask_2d
*=
multichain_mask_2d
template_dgram
=
modules
.
dgram_from_positions
(
template_positions
,
**
self
.
config
.
dgram_features
)
template_dgram
*=
pseudo_beta_mask_2d
[...,
None
]
template_dgram
=
template_dgram
.
astype
(
dtype
)
pseudo_beta_mask_2d
=
pseudo_beta_mask_2d
.
astype
(
dtype
)
to_concat
=
[(
template_dgram
,
1
),
(
pseudo_beta_mask_2d
,
0
)]
aatype
=
jax
.
nn
.
one_hot
(
template_aatype
,
22
,
axis
=-
1
,
dtype
=
dtype
)
to_concat
.
append
((
aatype
[
None
,
:,
:],
1
))
to_concat
.
append
((
aatype
[:,
None
,
:],
1
))
# Compute a feature representing the normalized vector between each
# backbone affine - i.e. in each residues local frame, what direction are
# each of the other residues.
raw_atom_pos
=
template_all_atom_positions
atom_pos
=
geometry
.
Vec3Array
.
from_array
(
raw_atom_pos
)
rigid
,
backbone_mask
=
folding_multimer
.
make_backbone_affine
(
atom_pos
,
template_all_atom_mask
,
template_aatype
)
points
=
rigid
.
translation
rigid_vec
=
rigid
[:,
None
].
inverse
().
apply_to_point
(
points
)
unit_vector
=
rigid_vec
.
normalized
()
unit_vector
=
[
unit_vector
.
x
,
unit_vector
.
y
,
unit_vector
.
z
]
backbone_mask_2d
=
backbone_mask
[:,
None
]
*
backbone_mask
[
None
,
:]
backbone_mask_2d
*=
multichain_mask_2d
unit_vector
=
[
x
*
backbone_mask_2d
for
x
in
unit_vector
]
# Note that the backbone_mask takes into account C, CA and N (unlike
# pseudo beta mask which just needs CB) so we add both masks as features.
to_concat
.
extend
([(
x
,
0
)
for
x
in
unit_vector
])
to_concat
.
append
((
backbone_mask_2d
,
0
))
query_embedding
=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'query_embedding_norm'
)(
query_embedding
)
# Allow the template embedder to see the query embedding. Note this
# contains the position relative feature, so this is how the network knows
# which residues are next to each other.
to_concat
.
append
((
query_embedding
,
1
))
act
=
0
for
i
,
(
x
,
n_input_dims
)
in
enumerate
(
to_concat
):
act
+=
common_modules
.
Linear
(
num_channels
,
num_input_dims
=
n_input_dims
,
initializer
=
'relu'
,
name
=
f
'template_pair_embedding_
{
i
}
'
)(
x
)
return
act
act
=
construct_input
(
query_embedding
,
template_aatype
,
template_all_atom_positions
,
template_all_atom_mask
,
multichain_mask_2d
)
template_iteration
=
TemplateEmbeddingIteration
(
c
.
template_pair_stack
,
gc
,
name
=
'template_embedding_iteration'
)
def
template_iteration_fn
(
x
):
act
,
safe_key
=
x
safe_key
,
safe_subkey
=
safe_key
.
split
()
act
=
template_iteration
(
act
=
act
,
pair_mask
=
padding_mask_2d
,
is_training
=
is_training
,
safe_key
=
safe_subkey
)
return
(
act
,
safe_key
)
if
gc
.
use_remat
:
template_iteration_fn
=
hk
.
remat
(
template_iteration_fn
)
safe_key
,
safe_subkey
=
safe_key
.
split
()
template_stack
=
layer_stack
.
layer_stack
(
c
.
template_pair_stack
.
num_block
)(
template_iteration_fn
)
act
,
safe_key
=
template_stack
((
act
,
safe_subkey
))
act
=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'output_layer_norm'
)(
act
)
return
act
class
TemplateEmbeddingIteration
(
hk
.
Module
):
"""Single Iteration of Template Embedding."""
def
__init__
(
self
,
config
,
global_config
,
name
=
'template_embedding_iteration'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
act
,
pair_mask
,
is_training
=
True
,
safe_key
=
None
):
"""Build a single iteration of the template embedder.
Args:
act: [num_res, num_res, num_channel] Input pairwise activations.
pair_mask: [num_res, num_res] padding mask.
is_training: Whether to run in training mode.
safe_key: Safe pseudo-random generator key.
Returns:
[num_res, num_res, num_channel] tensor of activations.
"""
c
=
self
.
config
gc
=
self
.
global_config
if
safe_key
is
None
:
safe_key
=
prng
.
SafeKey
(
hk
.
next_rng_key
())
dropout_wrapper_fn
=
functools
.
partial
(
modules
.
dropout_wrapper
,
is_training
=
is_training
,
global_config
=
gc
)
safe_key
,
*
sub_keys
=
safe_key
.
split
(
20
)
sub_keys
=
iter
(
sub_keys
)
act
=
dropout_wrapper_fn
(
modules
.
TriangleMultiplication
(
c
.
triangle_multiplication_outgoing
,
gc
,
name
=
'triangle_multiplication_outgoing'
),
act
,
pair_mask
,
safe_key
=
next
(
sub_keys
))
act
=
dropout_wrapper_fn
(
modules
.
TriangleMultiplication
(
c
.
triangle_multiplication_incoming
,
gc
,
name
=
'triangle_multiplication_incoming'
),
act
,
pair_mask
,
safe_key
=
next
(
sub_keys
))
act
=
dropout_wrapper_fn
(
modules
.
TriangleAttention
(
c
.
triangle_attention_starting_node
,
gc
,
name
=
'triangle_attention_starting_node'
),
act
,
pair_mask
,
safe_key
=
next
(
sub_keys
))
act
=
dropout_wrapper_fn
(
modules
.
TriangleAttention
(
c
.
triangle_attention_ending_node
,
gc
,
name
=
'triangle_attention_ending_node'
),
act
,
pair_mask
,
safe_key
=
next
(
sub_keys
))
act
=
dropout_wrapper_fn
(
modules
.
Transition
(
c
.
pair_transition
,
gc
,
name
=
'pair_transition'
),
act
,
pair_mask
,
safe_key
=
next
(
sub_keys
))
return
act
def
template_embedding_1d
(
batch
,
num_channel
):
"""Embed templates into an (num_res, num_templates, num_channels) embedding.
Args:
batch: A batch containing:
template_aatype, (num_templates, num_res) aatype for the templates.
template_all_atom_positions, (num_templates, num_residues, 37, 3) atom
positions for the templates.
template_all_atom_mask, (num_templates, num_residues, 37) atom mask for
each template.
num_channel: The number of channels in the output.
Returns:
An embedding of shape (num_templates, num_res, num_channels) and a mask of
shape (num_templates, num_res).
"""
# Embed the templates aatypes.
aatype_one_hot
=
jax
.
nn
.
one_hot
(
batch
[
'template_aatype'
],
22
,
axis
=-
1
)
num_templates
=
batch
[
'template_aatype'
].
shape
[
0
]
all_chi_angles
=
[]
all_chi_masks
=
[]
for
i
in
range
(
num_templates
):
atom_pos
=
geometry
.
Vec3Array
.
from_array
(
batch
[
'template_all_atom_positions'
][
i
,
:,
:,
:])
template_chi_angles
,
template_chi_mask
=
all_atom_multimer
.
compute_chi_angles
(
atom_pos
,
batch
[
'template_all_atom_mask'
][
i
,
:,
:],
batch
[
'template_aatype'
][
i
,
:])
all_chi_angles
.
append
(
template_chi_angles
)
all_chi_masks
.
append
(
template_chi_mask
)
chi_angles
=
jnp
.
stack
(
all_chi_angles
,
axis
=
0
)
chi_mask
=
jnp
.
stack
(
all_chi_masks
,
axis
=
0
)
template_features
=
jnp
.
concatenate
([
aatype_one_hot
,
jnp
.
sin
(
chi_angles
)
*
chi_mask
,
jnp
.
cos
(
chi_angles
)
*
chi_mask
,
chi_mask
],
axis
=-
1
)
template_mask
=
chi_mask
[:,
:,
0
]
template_activations
=
common_modules
.
Linear
(
num_channel
,
initializer
=
'relu'
,
name
=
'template_single_embedding'
)(
template_features
)
template_activations
=
jax
.
nn
.
relu
(
template_activations
)
template_activations
=
common_modules
.
Linear
(
num_channel
,
initializer
=
'relu'
,
name
=
'template_projection'
)(
template_activations
)
return
template_activations
,
template_mask
alphafold/model/prng.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A collection of utilities surrounding PRNG usage in protein folding."""
import
haiku
as
hk
import
jax
def
safe_dropout
(
*
,
tensor
,
safe_key
,
rate
,
is_deterministic
,
is_training
):
if
is_training
and
rate
!=
0.0
and
not
is_deterministic
:
return
hk
.
dropout
(
safe_key
.
get
(),
rate
,
tensor
)
else
:
return
tensor
class
SafeKey
:
"""Safety wrapper for PRNG keys."""
def
__init__
(
self
,
key
):
self
.
_key
=
key
self
.
_used
=
False
def
_assert_not_used
(
self
):
if
self
.
_used
:
raise
RuntimeError
(
'Random key has been used previously.'
)
def
get
(
self
):
self
.
_assert_not_used
()
self
.
_used
=
True
return
self
.
_key
def
split
(
self
,
num_keys
=
2
):
self
.
_assert_not_used
()
self
.
_used
=
True
new_keys
=
jax
.
random
.
split
(
self
.
_key
,
num_keys
)
return
jax
.
tree_map
(
SafeKey
,
tuple
(
new_keys
))
def
duplicate
(
self
,
num_keys
=
2
):
self
.
_assert_not_used
()
self
.
_used
=
True
return
tuple
(
SafeKey
(
self
.
_key
)
for
_
in
range
(
num_keys
))
def
_safe_key_flatten
(
safe_key
):
# Flatten transfers "ownership" to the tree
return
(
safe_key
.
_key
,),
safe_key
.
_used
# pylint: disable=protected-access
def
_safe_key_unflatten
(
aux_data
,
children
):
ret
=
SafeKey
(
children
[
0
])
ret
.
_used
=
aux_data
# pylint: disable=protected-access
return
ret
jax
.
tree_util
.
register_pytree_node
(
SafeKey
,
_safe_key_flatten
,
_safe_key_unflatten
)
alphafold/model/prng_test.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for prng."""
from
absl.testing
import
absltest
from
alphafold.model
import
prng
import
jax
class
PrngTest
(
absltest
.
TestCase
):
def
test_key_reuse
(
self
):
init_key
=
jax
.
random
.
PRNGKey
(
42
)
safe_key
=
prng
.
SafeKey
(
init_key
)
_
,
safe_key
=
safe_key
.
split
()
raw_key
=
safe_key
.
get
()
self
.
assertNotEqual
(
raw_key
[
0
],
init_key
[
0
])
self
.
assertNotEqual
(
raw_key
[
1
],
init_key
[
1
])
with
self
.
assertRaises
(
RuntimeError
):
safe_key
.
get
()
with
self
.
assertRaises
(
RuntimeError
):
safe_key
.
split
()
with
self
.
assertRaises
(
RuntimeError
):
safe_key
.
duplicate
()
if
__name__
==
'__main__'
:
absltest
.
main
()
alphafold/model/quat_affine.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Quaternion geometry modules.
This introduces a representation of coordinate frames that is based around a
‘QuatAffine’ object. This object describes an array of coordinate frames.
It consists of vectors corresponding to the
origin of the frames as well as orientations which are stored in two
ways, as unit quaternions as well as a rotation matrices.
The rotation matrices are derived from the unit quaternions and the two are kept
in sync.
For an explanation of the relation between unit quaternions and rotations see
https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation
This representation is used in the model for the backbone frames.
One important thing to note here, is that while we update both representations
the jit compiler is going to ensure that only the parts that are
actually used are executed.
"""
import
functools
from
typing
import
Tuple
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
# pylint: disable=bad-whitespace
QUAT_TO_ROT
=
np
.
zeros
((
4
,
4
,
3
,
3
),
dtype
=
np
.
float32
)
QUAT_TO_ROT
[
0
,
0
]
=
[[
1
,
0
,
0
],
[
0
,
1
,
0
],
[
0
,
0
,
1
]]
# rr
QUAT_TO_ROT
[
1
,
1
]
=
[[
1
,
0
,
0
],
[
0
,
-
1
,
0
],
[
0
,
0
,
-
1
]]
# ii
QUAT_TO_ROT
[
2
,
2
]
=
[[
-
1
,
0
,
0
],
[
0
,
1
,
0
],
[
0
,
0
,
-
1
]]
# jj
QUAT_TO_ROT
[
3
,
3
]
=
[[
-
1
,
0
,
0
],
[
0
,
-
1
,
0
],
[
0
,
0
,
1
]]
# kk
QUAT_TO_ROT
[
1
,
2
]
=
[[
0
,
2
,
0
],
[
2
,
0
,
0
],
[
0
,
0
,
0
]]
# ij
QUAT_TO_ROT
[
1
,
3
]
=
[[
0
,
0
,
2
],
[
0
,
0
,
0
],
[
2
,
0
,
0
]]
# ik
QUAT_TO_ROT
[
2
,
3
]
=
[[
0
,
0
,
0
],
[
0
,
0
,
2
],
[
0
,
2
,
0
]]
# jk
QUAT_TO_ROT
[
0
,
1
]
=
[[
0
,
0
,
0
],
[
0
,
0
,
-
2
],
[
0
,
2
,
0
]]
# ir
QUAT_TO_ROT
[
0
,
2
]
=
[[
0
,
0
,
2
],
[
0
,
0
,
0
],
[
-
2
,
0
,
0
]]
# jr
QUAT_TO_ROT
[
0
,
3
]
=
[[
0
,
-
2
,
0
],
[
2
,
0
,
0
],
[
0
,
0
,
0
]]
# kr
QUAT_MULTIPLY
=
np
.
zeros
((
4
,
4
,
4
),
dtype
=
np
.
float32
)
QUAT_MULTIPLY
[:,
:,
0
]
=
[[
1
,
0
,
0
,
0
],
[
0
,
-
1
,
0
,
0
],
[
0
,
0
,
-
1
,
0
],
[
0
,
0
,
0
,
-
1
]]
QUAT_MULTIPLY
[:,
:,
1
]
=
[[
0
,
1
,
0
,
0
],
[
1
,
0
,
0
,
0
],
[
0
,
0
,
0
,
1
],
[
0
,
0
,
-
1
,
0
]]
QUAT_MULTIPLY
[:,
:,
2
]
=
[[
0
,
0
,
1
,
0
],
[
0
,
0
,
0
,
-
1
],
[
1
,
0
,
0
,
0
],
[
0
,
1
,
0
,
0
]]
QUAT_MULTIPLY
[:,
:,
3
]
=
[[
0
,
0
,
0
,
1
],
[
0
,
0
,
1
,
0
],
[
0
,
-
1
,
0
,
0
],
[
1
,
0
,
0
,
0
]]
QUAT_MULTIPLY_BY_VEC
=
QUAT_MULTIPLY
[:,
1
:,
:]
# pylint: enable=bad-whitespace
def
rot_to_quat
(
rot
,
unstack_inputs
=
False
):
"""Convert rotation matrix to quaternion.
Note that this function calls self_adjoint_eig which is extremely expensive on
the GPU. If at all possible, this function should run on the CPU.
Args:
rot: rotation matrix (see below for format).
unstack_inputs: If true, rotation matrix should be shape (..., 3, 3)
otherwise the rotation matrix should be a list of lists of tensors.
Returns:
Quaternion as (..., 4) tensor.
"""
if
unstack_inputs
:
rot
=
[
jnp
.
moveaxis
(
x
,
-
1
,
0
)
for
x
in
jnp
.
moveaxis
(
rot
,
-
2
,
0
)]
[[
xx
,
xy
,
xz
],
[
yx
,
yy
,
yz
],
[
zx
,
zy
,
zz
]]
=
rot
# pylint: disable=bad-whitespace
k
=
[[
xx
+
yy
+
zz
,
zy
-
yz
,
xz
-
zx
,
yx
-
xy
,],
[
zy
-
yz
,
xx
-
yy
-
zz
,
xy
+
yx
,
xz
+
zx
,],
[
xz
-
zx
,
xy
+
yx
,
yy
-
xx
-
zz
,
yz
+
zy
,],
[
yx
-
xy
,
xz
+
zx
,
yz
+
zy
,
zz
-
xx
-
yy
,]]
# pylint: enable=bad-whitespace
k
=
(
1.
/
3.
)
*
jnp
.
stack
([
jnp
.
stack
(
x
,
axis
=-
1
)
for
x
in
k
],
axis
=-
2
)
# Get eigenvalues in non-decreasing order and associated.
_
,
qs
=
jnp
.
linalg
.
eigh
(
k
)
return
qs
[...,
-
1
]
def
rot_list_to_tensor
(
rot_list
):
"""Convert list of lists to rotation tensor."""
return
jnp
.
stack
(
[
jnp
.
stack
(
rot_list
[
0
],
axis
=-
1
),
jnp
.
stack
(
rot_list
[
1
],
axis
=-
1
),
jnp
.
stack
(
rot_list
[
2
],
axis
=-
1
)],
axis
=-
2
)
def
vec_list_to_tensor
(
vec_list
):
"""Convert list to vector tensor."""
return
jnp
.
stack
(
vec_list
,
axis
=-
1
)
def
quat_to_rot
(
normalized_quat
):
"""Convert a normalized quaternion to a rotation matrix."""
rot_tensor
=
jnp
.
sum
(
np
.
reshape
(
QUAT_TO_ROT
,
(
4
,
4
,
9
))
*
normalized_quat
[...,
:,
None
,
None
]
*
normalized_quat
[...,
None
,
:,
None
],
axis
=
(
-
3
,
-
2
))
rot
=
jnp
.
moveaxis
(
rot_tensor
,
-
1
,
0
)
# Unstack.
return
[[
rot
[
0
],
rot
[
1
],
rot
[
2
]],
[
rot
[
3
],
rot
[
4
],
rot
[
5
]],
[
rot
[
6
],
rot
[
7
],
rot
[
8
]]]
def
quat_multiply_by_vec
(
quat
,
vec
):
"""Multiply a quaternion by a pure-vector quaternion."""
return
jnp
.
sum
(
QUAT_MULTIPLY_BY_VEC
*
quat
[...,
:,
None
,
None
]
*
vec
[...,
None
,
:,
None
],
axis
=
(
-
3
,
-
2
))
def
quat_multiply
(
quat1
,
quat2
):
"""Multiply a quaternion by another quaternion."""
return
jnp
.
sum
(
QUAT_MULTIPLY
*
quat1
[...,
:,
None
,
None
]
*
quat2
[...,
None
,
:,
None
],
axis
=
(
-
3
,
-
2
))
def
apply_rot_to_vec
(
rot
,
vec
,
unstack
=
False
):
"""Multiply rotation matrix by a vector."""
if
unstack
:
x
,
y
,
z
=
[
vec
[:,
i
]
for
i
in
range
(
3
)]
else
:
x
,
y
,
z
=
vec
return
[
rot
[
0
][
0
]
*
x
+
rot
[
0
][
1
]
*
y
+
rot
[
0
][
2
]
*
z
,
rot
[
1
][
0
]
*
x
+
rot
[
1
][
1
]
*
y
+
rot
[
1
][
2
]
*
z
,
rot
[
2
][
0
]
*
x
+
rot
[
2
][
1
]
*
y
+
rot
[
2
][
2
]
*
z
]
def
apply_inverse_rot_to_vec
(
rot
,
vec
):
"""Multiply the inverse of a rotation matrix by a vector."""
# Inverse rotation is just transpose
return
[
rot
[
0
][
0
]
*
vec
[
0
]
+
rot
[
1
][
0
]
*
vec
[
1
]
+
rot
[
2
][
0
]
*
vec
[
2
],
rot
[
0
][
1
]
*
vec
[
0
]
+
rot
[
1
][
1
]
*
vec
[
1
]
+
rot
[
2
][
1
]
*
vec
[
2
],
rot
[
0
][
2
]
*
vec
[
0
]
+
rot
[
1
][
2
]
*
vec
[
1
]
+
rot
[
2
][
2
]
*
vec
[
2
]]
class
QuatAffine
(
object
):
"""Affine transformation represented by quaternion and vector."""
def
__init__
(
self
,
quaternion
,
translation
,
rotation
=
None
,
normalize
=
True
,
unstack_inputs
=
False
):
"""Initialize from quaternion and translation.
Args:
quaternion: Rotation represented by a quaternion, to be applied
before translation. Must be a unit quaternion unless normalize==True.
translation: Translation represented as a vector.
rotation: Same rotation as the quaternion, represented as a (..., 3, 3)
tensor. If None, rotation will be calculated from the quaternion.
normalize: If True, l2 normalize the quaternion on input.
unstack_inputs: If True, translation is a vector with last component 3
"""
if
quaternion
is
not
None
:
assert
quaternion
.
shape
[
-
1
]
==
4
if
unstack_inputs
:
if
rotation
is
not
None
:
rotation
=
[
jnp
.
moveaxis
(
x
,
-
1
,
0
)
# Unstack.
for
x
in
jnp
.
moveaxis
(
rotation
,
-
2
,
0
)]
# Unstack.
translation
=
jnp
.
moveaxis
(
translation
,
-
1
,
0
)
# Unstack.
if
normalize
and
quaternion
is
not
None
:
quaternion
=
quaternion
/
jnp
.
linalg
.
norm
(
quaternion
,
axis
=-
1
,
keepdims
=
True
)
if
rotation
is
None
:
rotation
=
quat_to_rot
(
quaternion
)
self
.
quaternion
=
quaternion
self
.
rotation
=
[
list
(
row
)
for
row
in
rotation
]
self
.
translation
=
list
(
translation
)
assert
all
(
len
(
row
)
==
3
for
row
in
self
.
rotation
)
assert
len
(
self
.
translation
)
==
3
def
to_tensor
(
self
):
return
jnp
.
concatenate
(
[
self
.
quaternion
]
+
[
jnp
.
expand_dims
(
x
,
axis
=-
1
)
for
x
in
self
.
translation
],
axis
=-
1
)
def
apply_tensor_fn
(
self
,
tensor_fn
):
"""Return a new QuatAffine with tensor_fn applied (e.g. stop_gradient)."""
return
QuatAffine
(
tensor_fn
(
self
.
quaternion
),
[
tensor_fn
(
x
)
for
x
in
self
.
translation
],
rotation
=
[[
tensor_fn
(
x
)
for
x
in
row
]
for
row
in
self
.
rotation
],
normalize
=
False
)
def
apply_rotation_tensor_fn
(
self
,
tensor_fn
):
"""Return a new QuatAffine with tensor_fn applied to the rotation part."""
return
QuatAffine
(
tensor_fn
(
self
.
quaternion
),
[
x
for
x
in
self
.
translation
],
rotation
=
[[
tensor_fn
(
x
)
for
x
in
row
]
for
row
in
self
.
rotation
],
normalize
=
False
)
def
scale_translation
(
self
,
position_scale
):
"""Return a new quat affine with a different scale for translation."""
return
QuatAffine
(
self
.
quaternion
,
[
x
*
position_scale
for
x
in
self
.
translation
],
rotation
=
[[
x
for
x
in
row
]
for
row
in
self
.
rotation
],
normalize
=
False
)
@
classmethod
def
from_tensor
(
cls
,
tensor
,
normalize
=
False
):
quaternion
,
tx
,
ty
,
tz
=
jnp
.
split
(
tensor
,
[
4
,
5
,
6
],
axis
=-
1
)
return
cls
(
quaternion
,
[
tx
[...,
0
],
ty
[...,
0
],
tz
[...,
0
]],
normalize
=
normalize
)
def
pre_compose
(
self
,
update
):
"""Return a new QuatAffine which applies the transformation update first.
Args:
update: Length-6 vector. 3-vector of x, y, and z such that the quaternion
update is (1, x, y, z) and zero for the 3-vector is the identity
quaternion. 3-vector for translation concatenated.
Returns:
New QuatAffine object.
"""
vector_quaternion_update
,
x
,
y
,
z
=
jnp
.
split
(
update
,
[
3
,
4
,
5
],
axis
=-
1
)
trans_update
=
[
jnp
.
squeeze
(
x
,
axis
=-
1
),
jnp
.
squeeze
(
y
,
axis
=-
1
),
jnp
.
squeeze
(
z
,
axis
=-
1
)]
new_quaternion
=
(
self
.
quaternion
+
quat_multiply_by_vec
(
self
.
quaternion
,
vector_quaternion_update
))
trans_update
=
apply_rot_to_vec
(
self
.
rotation
,
trans_update
)
new_translation
=
[
self
.
translation
[
0
]
+
trans_update
[
0
],
self
.
translation
[
1
]
+
trans_update
[
1
],
self
.
translation
[
2
]
+
trans_update
[
2
]]
return
QuatAffine
(
new_quaternion
,
new_translation
)
def
apply_to_point
(
self
,
point
,
extra_dims
=
0
):
"""Apply affine to a point.
Args:
point: List of 3 tensors to apply affine.
extra_dims: Number of dimensions at the end of the transformed_point
shape that are not present in the rotation and translation. The most
common use is rotation N points at once with extra_dims=1 for use in a
network.
Returns:
Transformed point after applying affine.
"""
rotation
=
self
.
rotation
translation
=
self
.
translation
for
_
in
range
(
extra_dims
):
expand_fn
=
functools
.
partial
(
jnp
.
expand_dims
,
axis
=-
1
)
rotation
=
jax
.
tree_map
(
expand_fn
,
rotation
)
translation
=
jax
.
tree_map
(
expand_fn
,
translation
)
rot_point
=
apply_rot_to_vec
(
rotation
,
point
)
return
[
rot_point
[
0
]
+
translation
[
0
],
rot_point
[
1
]
+
translation
[
1
],
rot_point
[
2
]
+
translation
[
2
]]
def
invert_point
(
self
,
transformed_point
,
extra_dims
=
0
):
"""Apply inverse of transformation to a point.
Args:
transformed_point: List of 3 tensors to apply affine
extra_dims: Number of dimensions at the end of the transformed_point
shape that are not present in the rotation and translation. The most
common use is rotation N points at once with extra_dims=1 for use in a
network.
Returns:
Transformed point after applying affine.
"""
rotation
=
self
.
rotation
translation
=
self
.
translation
for
_
in
range
(
extra_dims
):
expand_fn
=
functools
.
partial
(
jnp
.
expand_dims
,
axis
=-
1
)
rotation
=
jax
.
tree_map
(
expand_fn
,
rotation
)
translation
=
jax
.
tree_map
(
expand_fn
,
translation
)
rot_point
=
[
transformed_point
[
0
]
-
translation
[
0
],
transformed_point
[
1
]
-
translation
[
1
],
transformed_point
[
2
]
-
translation
[
2
]]
return
apply_inverse_rot_to_vec
(
rotation
,
rot_point
)
def
__repr__
(
self
):
return
'QuatAffine(%r, %r)'
%
(
self
.
quaternion
,
self
.
translation
)
def
_multiply
(
a
,
b
):
return
jnp
.
stack
([
jnp
.
array
([
a
[
0
][
0
]
*
b
[
0
][
0
]
+
a
[
0
][
1
]
*
b
[
1
][
0
]
+
a
[
0
][
2
]
*
b
[
2
][
0
],
a
[
0
][
0
]
*
b
[
0
][
1
]
+
a
[
0
][
1
]
*
b
[
1
][
1
]
+
a
[
0
][
2
]
*
b
[
2
][
1
],
a
[
0
][
0
]
*
b
[
0
][
2
]
+
a
[
0
][
1
]
*
b
[
1
][
2
]
+
a
[
0
][
2
]
*
b
[
2
][
2
]]),
jnp
.
array
([
a
[
1
][
0
]
*
b
[
0
][
0
]
+
a
[
1
][
1
]
*
b
[
1
][
0
]
+
a
[
1
][
2
]
*
b
[
2
][
0
],
a
[
1
][
0
]
*
b
[
0
][
1
]
+
a
[
1
][
1
]
*
b
[
1
][
1
]
+
a
[
1
][
2
]
*
b
[
2
][
1
],
a
[
1
][
0
]
*
b
[
0
][
2
]
+
a
[
1
][
1
]
*
b
[
1
][
2
]
+
a
[
1
][
2
]
*
b
[
2
][
2
]]),
jnp
.
array
([
a
[
2
][
0
]
*
b
[
0
][
0
]
+
a
[
2
][
1
]
*
b
[
1
][
0
]
+
a
[
2
][
2
]
*
b
[
2
][
0
],
a
[
2
][
0
]
*
b
[
0
][
1
]
+
a
[
2
][
1
]
*
b
[
1
][
1
]
+
a
[
2
][
2
]
*
b
[
2
][
1
],
a
[
2
][
0
]
*
b
[
0
][
2
]
+
a
[
2
][
1
]
*
b
[
1
][
2
]
+
a
[
2
][
2
]
*
b
[
2
][
2
]])])
def
make_canonical_transform
(
n_xyz
:
jnp
.
ndarray
,
ca_xyz
:
jnp
.
ndarray
,
c_xyz
:
jnp
.
ndarray
)
->
Tuple
[
jnp
.
ndarray
,
jnp
.
ndarray
]:
"""Returns translation and rotation matrices to canonicalize residue atoms.
Note that this method does not take care of symmetries. If you provide the
atom positions in the non-standard way, the N atom will end up not at
[-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
need to take care of such cases in your code.
Args:
n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates.
ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates.
c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.
Returns:
A tuple (translation, rotation) where:
translation is an array of shape [batch, 3] defining the translation.
rotation is an array of shape [batch, 3, 3] defining the rotation.
After applying the translation and rotation to all atoms in a residue:
* All atoms will be shifted so that CA is at the origin,
* All atoms will be rotated so that C is at the x-axis,
* All atoms will be shifted so that N is in the xy plane.
"""
assert
len
(
n_xyz
.
shape
)
==
2
,
n_xyz
.
shape
assert
n_xyz
.
shape
[
-
1
]
==
3
,
n_xyz
.
shape
assert
n_xyz
.
shape
==
ca_xyz
.
shape
==
c_xyz
.
shape
,
(
n_xyz
.
shape
,
ca_xyz
.
shape
,
c_xyz
.
shape
)
# Place CA at the origin.
translation
=
-
ca_xyz
n_xyz
=
n_xyz
+
translation
c_xyz
=
c_xyz
+
translation
# Place C on the x-axis.
c_x
,
c_y
,
c_z
=
[
c_xyz
[:,
i
]
for
i
in
range
(
3
)]
# Rotate by angle c1 in the x-y plane (around the z-axis).
sin_c1
=
-
c_y
/
jnp
.
sqrt
(
1e-20
+
c_x
**
2
+
c_y
**
2
)
cos_c1
=
c_x
/
jnp
.
sqrt
(
1e-20
+
c_x
**
2
+
c_y
**
2
)
zeros
=
jnp
.
zeros_like
(
sin_c1
)
ones
=
jnp
.
ones_like
(
sin_c1
)
# pylint: disable=bad-whitespace
c1_rot_matrix
=
jnp
.
stack
([
jnp
.
array
([
cos_c1
,
-
sin_c1
,
zeros
]),
jnp
.
array
([
sin_c1
,
cos_c1
,
zeros
]),
jnp
.
array
([
zeros
,
zeros
,
ones
])])
# Rotate by angle c2 in the x-z plane (around the y-axis).
sin_c2
=
c_z
/
jnp
.
sqrt
(
1e-20
+
c_x
**
2
+
c_y
**
2
+
c_z
**
2
)
cos_c2
=
jnp
.
sqrt
(
c_x
**
2
+
c_y
**
2
)
/
jnp
.
sqrt
(
1e-20
+
c_x
**
2
+
c_y
**
2
+
c_z
**
2
)
c2_rot_matrix
=
jnp
.
stack
([
jnp
.
array
([
cos_c2
,
zeros
,
sin_c2
]),
jnp
.
array
([
zeros
,
ones
,
zeros
]),
jnp
.
array
([
-
sin_c2
,
zeros
,
cos_c2
])])
c_rot_matrix
=
_multiply
(
c2_rot_matrix
,
c1_rot_matrix
)
n_xyz
=
jnp
.
stack
(
apply_rot_to_vec
(
c_rot_matrix
,
n_xyz
,
unstack
=
True
)).
T
# Place N in the x-y plane.
_
,
n_y
,
n_z
=
[
n_xyz
[:,
i
]
for
i
in
range
(
3
)]
# Rotate by angle alpha in the y-z plane (around the x-axis).
sin_n
=
-
n_z
/
jnp
.
sqrt
(
1e-20
+
n_y
**
2
+
n_z
**
2
)
cos_n
=
n_y
/
jnp
.
sqrt
(
1e-20
+
n_y
**
2
+
n_z
**
2
)
n_rot_matrix
=
jnp
.
stack
([
jnp
.
array
([
ones
,
zeros
,
zeros
]),
jnp
.
array
([
zeros
,
cos_n
,
-
sin_n
]),
jnp
.
array
([
zeros
,
sin_n
,
cos_n
])])
# pylint: enable=bad-whitespace
return
(
translation
,
jnp
.
transpose
(
_multiply
(
n_rot_matrix
,
c_rot_matrix
),
[
2
,
0
,
1
]))
def
make_transform_from_reference
(
n_xyz
:
jnp
.
ndarray
,
ca_xyz
:
jnp
.
ndarray
,
c_xyz
:
jnp
.
ndarray
)
->
Tuple
[
jnp
.
ndarray
,
jnp
.
ndarray
]:
"""Returns rotation and translation matrices to convert from reference.
Note that this method does not take care of symmetries. If you provide the
atom positions in the non-standard way, the N atom will end up not at
[-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
need to take care of such cases in your code.
Args:
n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates.
ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates.
c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.
Returns:
A tuple (rotation, translation) where:
rotation is an array of shape [batch, 3, 3] defining the rotation.
translation is an array of shape [batch, 3] defining the translation.
After applying the translation and rotation to the reference backbone,
the coordinates will approximately equal to the input coordinates.
The order of translation and rotation differs from make_canonical_transform
because the rotation from this function should be applied before the
translation, unlike make_canonical_transform.
"""
translation
,
rotation
=
make_canonical_transform
(
n_xyz
,
ca_xyz
,
c_xyz
)
return
np
.
transpose
(
rotation
,
(
0
,
2
,
1
)),
-
translation
alphafold/model/quat_affine_test.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for quat_affine."""
from
absl
import
logging
from
absl.testing
import
absltest
from
alphafold.model
import
quat_affine
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
VERBOSE
=
False
np
.
set_printoptions
(
precision
=
3
,
suppress
=
True
)
r2t
=
quat_affine
.
rot_list_to_tensor
v2t
=
quat_affine
.
vec_list_to_tensor
q2r
=
lambda
q
:
r2t
(
quat_affine
.
quat_to_rot
(
q
))
class
QuatAffineTest
(
absltest
.
TestCase
):
def
_assert_check
(
self
,
to_check
,
tol
=
1e-5
):
for
k
,
(
correct
,
generated
)
in
to_check
.
items
():
if
VERBOSE
:
logging
.
info
(
k
)
logging
.
info
(
'Correct %s'
,
correct
)
logging
.
info
(
'Predicted %s'
,
generated
)
self
.
assertLess
(
np
.
max
(
np
.
abs
(
correct
-
generated
)),
tol
)
def
test_conversion
(
self
):
quat
=
jnp
.
array
([
-
2.
,
5.
,
-
1.
,
4.
])
rotation
=
jnp
.
array
([
[
0.26087
,
0.130435
,
0.956522
],
[
-
0.565217
,
-
0.782609
,
0.26087
],
[
0.782609
,
-
0.608696
,
-
0.130435
]])
translation
=
jnp
.
array
([
1.
,
-
3.
,
4.
])
point
=
jnp
.
array
([
0.7
,
3.2
,
-
2.9
])
a
=
quat_affine
.
QuatAffine
(
quat
,
translation
,
unstack_inputs
=
True
)
true_new_point
=
jnp
.
matmul
(
rotation
,
point
[:,
None
])[:,
0
]
+
translation
self
.
_assert_check
({
'rot'
:
(
rotation
,
r2t
(
a
.
rotation
)),
'trans'
:
(
translation
,
v2t
(
a
.
translation
)),
'point'
:
(
true_new_point
,
v2t
(
a
.
apply_to_point
(
jnp
.
moveaxis
(
point
,
-
1
,
0
)))),
# Because of the double cover, we must be careful and compare rotations
'quat'
:
(
q2r
(
a
.
quaternion
),
q2r
(
quat_affine
.
rot_to_quat
(
a
.
rotation
))),
})
def
test_double_cover
(
self
):
"""Test that -q is the same rotation as q."""
rng
=
jax
.
random
.
PRNGKey
(
42
)
keys
=
jax
.
random
.
split
(
rng
)
q
=
jax
.
random
.
normal
(
keys
[
0
],
(
2
,
4
))
trans
=
jax
.
random
.
normal
(
keys
[
1
],
(
2
,
3
))
a1
=
quat_affine
.
QuatAffine
(
q
,
trans
,
unstack_inputs
=
True
)
a2
=
quat_affine
.
QuatAffine
(
-
q
,
trans
,
unstack_inputs
=
True
)
self
.
_assert_check
({
'rot'
:
(
r2t
(
a1
.
rotation
),
r2t
(
a2
.
rotation
)),
'trans'
:
(
v2t
(
a1
.
translation
),
v2t
(
a2
.
translation
)),
})
def
test_homomorphism
(
self
):
rng
=
jax
.
random
.
PRNGKey
(
42
)
keys
=
jax
.
random
.
split
(
rng
,
4
)
vec_q1
=
jax
.
random
.
normal
(
keys
[
0
],
(
2
,
3
))
q1
=
jnp
.
concatenate
([
jnp
.
ones_like
(
vec_q1
)[:,
:
1
],
vec_q1
],
axis
=-
1
)
q2
=
jax
.
random
.
normal
(
keys
[
1
],
(
2
,
4
))
t1
=
jax
.
random
.
normal
(
keys
[
2
],
(
2
,
3
))
t2
=
jax
.
random
.
normal
(
keys
[
3
],
(
2
,
3
))
a1
=
quat_affine
.
QuatAffine
(
q1
,
t1
,
unstack_inputs
=
True
)
a2
=
quat_affine
.
QuatAffine
(
q2
,
t2
,
unstack_inputs
=
True
)
a21
=
a2
.
pre_compose
(
jnp
.
concatenate
([
vec_q1
,
t1
],
axis
=-
1
))
rng
,
key
=
jax
.
random
.
split
(
rng
)
x
=
jax
.
random
.
normal
(
key
,
(
2
,
3
))
new_x
=
a21
.
apply_to_point
(
jnp
.
moveaxis
(
x
,
-
1
,
0
))
new_x_apply2
=
a2
.
apply_to_point
(
a1
.
apply_to_point
(
jnp
.
moveaxis
(
x
,
-
1
,
0
)))
self
.
_assert_check
({
'quat'
:
(
q2r
(
quat_affine
.
quat_multiply
(
a2
.
quaternion
,
a1
.
quaternion
)),
q2r
(
a21
.
quaternion
)),
'rot'
:
(
jnp
.
matmul
(
r2t
(
a2
.
rotation
),
r2t
(
a1
.
rotation
)),
r2t
(
a21
.
rotation
)),
'point'
:
(
v2t
(
new_x_apply2
),
v2t
(
new_x
)),
'inverse'
:
(
x
,
v2t
(
a21
.
invert_point
(
new_x
))),
})
def
test_batching
(
self
):
"""Test that affine applies batchwise."""
rng
=
jax
.
random
.
PRNGKey
(
42
)
keys
=
jax
.
random
.
split
(
rng
,
3
)
q
=
jax
.
random
.
uniform
(
keys
[
0
],
(
5
,
2
,
4
))
t
=
jax
.
random
.
uniform
(
keys
[
1
],
(
2
,
3
))
x
=
jax
.
random
.
uniform
(
keys
[
2
],
(
5
,
1
,
3
))
a
=
quat_affine
.
QuatAffine
(
q
,
t
,
unstack_inputs
=
True
)
y
=
v2t
(
a
.
apply_to_point
(
jnp
.
moveaxis
(
x
,
-
1
,
0
)))
y_list
=
[]
for
i
in
range
(
5
):
for
j
in
range
(
2
):
a_local
=
quat_affine
.
QuatAffine
(
q
[
i
,
j
],
t
[
j
],
unstack_inputs
=
True
)
y_local
=
v2t
(
a_local
.
apply_to_point
(
jnp
.
moveaxis
(
x
[
i
,
0
],
-
1
,
0
)))
y_list
.
append
(
y_local
)
y_combine
=
jnp
.
reshape
(
jnp
.
stack
(
y_list
,
axis
=
0
),
(
5
,
2
,
3
))
self
.
_assert_check
({
'batch'
:
(
y_combine
,
y
),
'quat'
:
(
q2r
(
a
.
quaternion
),
q2r
(
quat_affine
.
rot_to_quat
(
a
.
rotation
))),
})
def
assertAllClose
(
self
,
a
,
b
,
rtol
=
1e-06
,
atol
=
1e-06
):
self
.
assertTrue
(
np
.
allclose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
))
def
assertAllEqual
(
self
,
a
,
b
):
self
.
assertTrue
(
np
.
all
(
np
.
array
(
a
)
==
np
.
array
(
b
)))
if
__name__
==
'__main__'
:
absltest
.
main
()
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment