Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
bb3f51e5
Unverified
Commit
bb3f51e5
authored
Feb 07, 2024
by
Christina Floristean
Committed by
GitHub
Feb 07, 2024
Browse files
Merge pull request #405 from aqlaboratory/multimer
Full multimer merge
parents
ce211367
c33a0bd6
Changes
106
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2150 additions
and
309 deletions
+2150
-309
openfold/utils/argparse_utils.py
openfold/utils/argparse_utils.py
+0
-0
openfold/utils/feats.py
openfold/utils/feats.py
+32
-17
openfold/utils/geometry/__init__.py
openfold/utils/geometry/__init__.py
+28
-0
openfold/utils/geometry/quat_rigid.py
openfold/utils/geometry/quat_rigid.py
+38
-0
openfold/utils/geometry/rigid_matrix_vector.py
openfold/utils/geometry/rigid_matrix_vector.py
+181
-0
openfold/utils/geometry/rotation_matrix.py
openfold/utils/geometry/rotation_matrix.py
+208
-0
openfold/utils/geometry/test_utils.py
openfold/utils/geometry/test_utils.py
+97
-0
openfold/utils/geometry/utils.py
openfold/utils/geometry/utils.py
+22
-0
openfold/utils/geometry/vector.py
openfold/utils/geometry/vector.py
+261
-0
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+396
-142
openfold/utils/loss.py
openfold/utils/loss.py
+273
-97
openfold/utils/multi_chain_permutation.py
openfold/utils/multi_chain_permutation.py
+421
-0
openfold/utils/rigid_utils.py
openfold/utils/rigid_utils.py
+10
-0
openfold/utils/script_utils.py
openfold/utils/script_utils.py
+5
-4
run_pretrained_openfold.py
run_pretrained_openfold.py
+92
-44
scripts/__init__.py
scripts/__init__.py
+0
-0
scripts/convert_of_weights_to_jax.py
scripts/convert_of_weights_to_jax.py
+2
-1
scripts/data_dir_to_fasta.py
scripts/data_dir_to_fasta.py
+22
-2
scripts/deepspeed_inference_test.py
scripts/deepspeed_inference_test.py
+54
-0
scripts/download_alphafold_dbs.sh
scripts/download_alphafold_dbs.sh
+8
-2
No files found.
openfold/utils/argparse.py
→
openfold/utils/argparse
_utils
.py
View file @
bb3f51e5
File moved
openfold/utils/feats.py
View file @
bb3f51e5
...
...
@@ -18,10 +18,11 @@ import math
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
typing
import
Dict
from
typing
import
Dict
,
Union
from
openfold.np
import
protein
import
openfold.np.residue_constants
as
rc
from
openfold.utils.geometry
import
rigid_matrix_vector
,
rotation_matrix
,
vector
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
batched_gather
,
...
...
@@ -89,6 +90,23 @@ def build_template_angle_feat(template_feats):
return
template_angle_feat
def
dgram_from_positions
(
pos
:
torch
.
Tensor
,
min_bin
:
float
=
3.25
,
max_bin
:
float
=
50.75
,
no_bins
:
float
=
39
,
inf
:
float
=
1e8
,
):
dgram
=
torch
.
sum
(
(
pos
[...,
None
,
:]
-
pos
[...,
None
,
:,
:])
**
2
,
dim
=-
1
,
keepdim
=
True
)
lower
=
torch
.
linspace
(
min_bin
,
max_bin
,
no_bins
,
device
=
pos
.
device
)
**
2
upper
=
torch
.
cat
([
lower
[
1
:],
lower
.
new_tensor
([
inf
])],
dim
=-
1
)
dgram
=
((
dgram
>
lower
)
*
(
dgram
<
upper
)).
type
(
dgram
.
dtype
)
return
dgram
def
build_template_pair_feat
(
batch
,
min_bin
,
max_bin
,
no_bins
,
...
...
@@ -100,12 +118,7 @@ def build_template_pair_feat(
# Compute distogram (this seems to differ slightly from Alg. 5)
tpb
=
batch
[
"template_pseudo_beta"
]
dgram
=
torch
.
sum
(
(
tpb
[...,
None
,
:]
-
tpb
[...,
None
,
:,
:])
**
2
,
dim
=-
1
,
keepdim
=
True
)
lower
=
torch
.
linspace
(
min_bin
,
max_bin
,
no_bins
,
device
=
tpb
.
device
)
**
2
upper
=
torch
.
cat
([
lower
[
1
:],
lower
.
new_tensor
([
inf
])],
dim
=-
1
)
dgram
=
((
dgram
>
lower
)
*
(
dgram
<
upper
)).
type
(
dgram
.
dtype
)
dgram
=
dgram_from_positions
(
tpb
,
min_bin
,
max_bin
,
no_bins
,
inf
)
to_concat
=
[
dgram
,
template_mask_2d
[...,
None
]]
...
...
@@ -170,18 +183,21 @@ def build_extra_msa_feat(batch):
def
torsion_angles_to_frames
(
r
:
Rigid
,
r
:
Union
[
Rigid
,
rigid_matrix_vector
.
Rigid3Array
]
,
alpha
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
rrgdf
:
torch
.
Tensor
,
):
rigid_type
=
type
(
r
)
# [*, N, 8, 4, 4]
default_4x4
=
rrgdf
[
aatype
,
...]
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
default_r
=
r
.
from_tensor_4x4
(
default_4x4
)
default_r
=
r
igid_type
.
from_tensor_4x4
(
default_4x4
)
bb_rot
=
alpha
.
new_zeros
((
*
((
1
,)
*
len
(
alpha
.
shape
[:
-
1
])),
2
))
bb_rot
[...,
1
]
=
1
...
...
@@ -201,14 +217,13 @@ def torsion_angles_to_frames(
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots
=
alpha
.
new_zeros
(
default_r
.
get_rots
().
get_rot_mats
().
shape
)
all_rots
=
alpha
.
new_zeros
(
default_r
.
shape
+
(
4
,
4
)
)
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
=
Rigid
(
Rotation
(
rot_mats
=
all_rots
),
None
)
all_rots
[...,
2
,
1
:
3
]
=
alpha
all_rots
=
rigid_type
.
from_tensor_4x4
(
all_rots
)
all_frames
=
default_r
.
compose
(
all_rots
)
chi2_frame_to_frame
=
all_frames
[...,
5
]
...
...
@@ -220,7 +235,7 @@ def torsion_angles_to_frames(
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
all_frames_to_bb
=
R
igid
.
cat
(
all_frames_to_bb
=
r
igid
_type
.
cat
(
[
all_frames
[...,
:
5
],
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
...
...
@@ -236,7 +251,7 @@ def torsion_angles_to_frames(
def
frames_and_literature_positions_to_atom14_pos
(
r
:
Rigid
,
r
:
Union
[
Rigid
,
rigid_matrix_vector
.
Rigid3Array
]
,
aatype
:
torch
.
Tensor
,
default_frames
,
group_idx
,
...
...
@@ -263,7 +278,7 @@ def frames_and_literature_positions_to_atom14_pos(
lambda
x
:
torch
.
sum
(
x
,
dim
=-
1
)
)
# [*, N, 14
, 1
]
# [*, N, 14]
atom_mask
=
atom_mask
[
aatype
,
...].
unsqueeze
(
-
1
)
# [*, N, 14, 3]
...
...
openfold/utils/geometry/__init__.py
0 → 100644
View file @
bb3f51e5
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Geometry Module."""
from
openfold.utils.geometry
import
rigid_matrix_vector
from
openfold.utils.geometry
import
rotation_matrix
from
openfold.utils.geometry
import
vector
Rot3Array
=
rotation_matrix
.
Rot3Array
Rigid3Array
=
rigid_matrix_vector
.
Rigid3Array
Vec3Array
=
vector
.
Vec3Array
square_euclidean_distance
=
vector
.
square_euclidean_distance
euclidean_distance
=
vector
.
euclidean_distance
dihedral_angle
=
vector
.
dihedral_angle
dot
=
vector
.
dot
cross
=
vector
.
cross
openfold/utils/geometry/quat_rigid.py
0 → 100644
View file @
bb3f51e5
import
torch
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
openfold.utils.geometry.rotation_matrix
import
Rot3Array
from
openfold.utils.geometry.vector
import
Vec3Array
class
QuatRigid
(
nn
.
Module
):
def
__init__
(
self
,
c_hidden
,
full_quat
):
super
().
__init__
()
self
.
full_quat
=
full_quat
if
self
.
full_quat
:
rigid_dim
=
7
else
:
rigid_dim
=
6
self
.
linear
=
Linear
(
c_hidden
,
rigid_dim
,
init
=
"final"
,
precision
=
torch
.
float32
)
def
forward
(
self
,
activations
:
torch
.
Tensor
)
->
Rigid3Array
:
# NOTE: During training, this needs to be run in higher precision
rigid_flat
=
self
.
linear
(
activations
)
rigid_flat
=
torch
.
unbind
(
rigid_flat
,
dim
=-
1
)
if
(
self
.
full_quat
):
qw
,
qx
,
qy
,
qz
=
rigid_flat
[:
4
]
translation
=
rigid_flat
[
4
:]
else
:
qx
,
qy
,
qz
=
rigid_flat
[:
3
]
qw
=
torch
.
ones_like
(
qx
)
translation
=
rigid_flat
[
3
:]
rotation
=
Rot3Array
.
from_quaternion
(
qw
,
qx
,
qy
,
qz
,
normalize
=
True
,
)
translation
=
Vec3Array
(
*
translation
)
return
Rigid3Array
(
rotation
,
translation
)
openfold/utils/geometry/rigid_matrix_vector.py
0 → 100644
View file @
bb3f51e5
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rigid3Array Transformations represented by a Matrix and a Vector."""
from
__future__
import
annotations
import
dataclasses
from
typing
import
Union
,
List
import
torch
from
openfold.utils.geometry
import
rotation_matrix
from
openfold.utils.geometry
import
vector
Float
=
Union
[
float
,
torch
.
Tensor
]
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Rigid3Array
:
"""Rigid Transformation, i.e. element of special euclidean group."""
rotation
:
rotation_matrix
.
Rot3Array
translation
:
vector
.
Vec3Array
def
__matmul__
(
self
,
other
:
Rigid3Array
)
->
Rigid3Array
:
new_rotation
=
self
.
rotation
@
other
.
rotation
# __matmul__
new_translation
=
self
.
apply_to_point
(
other
.
translation
)
return
Rigid3Array
(
new_rotation
,
new_translation
)
def
__getitem__
(
self
,
index
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
[
index
],
self
.
translation
[
index
],
)
def
__mul__
(
self
,
other
:
torch
.
Tensor
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
*
other
,
self
.
translation
*
other
,
)
def
map_tensor_fn
(
self
,
fn
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
.
map_tensor_fn
(
fn
),
self
.
translation
.
map_tensor_fn
(
fn
),
)
def
inverse
(
self
)
->
Rigid3Array
:
"""Return Rigid3Array corresponding to inverse transform."""
inv_rotation
=
self
.
rotation
.
inverse
()
inv_translation
=
inv_rotation
.
apply_to_point
(
-
self
.
translation
)
return
Rigid3Array
(
inv_rotation
,
inv_translation
)
def
apply_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Apply Rigid3Array transform to point."""
return
self
.
rotation
.
apply_to_point
(
point
)
+
self
.
translation
def
apply
(
self
,
point
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
apply_to_point
(
vector
.
Vec3Array
.
from_array
(
point
)).
to_tensor
()
def
apply_inverse_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Apply inverse Rigid3Array transform to point."""
new_point
=
point
-
self
.
translation
return
self
.
rotation
.
apply_inverse_to_point
(
new_point
)
def
invert_apply
(
self
,
point
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
apply_inverse_to_point
(
vector
.
Vec3Array
.
from_array
(
point
)).
to_tensor
()
def
compose_rotation
(
self
,
other_rotation
):
rot
=
self
.
rotation
@
other_rotation
return
Rigid3Array
(
rot
,
self
.
translation
.
clone
())
def
compose
(
self
,
other_rigid
):
return
self
@
other_rigid
def
unsqueeze
(
self
,
dim
:
int
):
return
Rigid3Array
(
self
.
rotation
.
unsqueeze
(
dim
),
self
.
translation
.
unsqueeze
(
dim
),
)
@
property
def
shape
(
self
)
->
torch
.
Size
:
return
self
.
rotation
.
xx
.
shape
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
return
self
.
rotation
.
xx
.
dtype
@
property
def
device
(
self
)
->
torch
.
device
:
return
self
.
rotation
.
xx
.
device
@
classmethod
def
identity
(
cls
,
shape
,
device
)
->
Rigid3Array
:
"""Return identity Rigid3Array of given shape."""
return
cls
(
rotation_matrix
.
Rot3Array
.
identity
(
shape
,
device
),
vector
.
Vec3Array
.
zeros
(
shape
,
device
)
)
@
classmethod
def
cat
(
cls
,
rigids
:
List
[
Rigid3Array
],
dim
:
int
)
->
Rigid3Array
:
return
cls
(
rotation_matrix
.
Rot3Array
.
cat
(
[
r
.
rotation
for
r
in
rigids
],
dim
=
dim
),
vector
.
Vec3Array
.
cat
(
[
r
.
translation
for
r
in
rigids
],
dim
=
dim
),
)
def
scale_translation
(
self
,
factor
:
Float
)
->
Rigid3Array
:
"""Scale translation in Rigid3Array by 'factor'."""
return
Rigid3Array
(
self
.
rotation
,
self
.
translation
*
factor
)
def
to_tensor
(
self
)
->
torch
.
Tensor
:
rot_array
=
self
.
rotation
.
to_tensor
()
vec_array
=
self
.
translation
.
to_tensor
()
array
=
torch
.
zeros
(
rot_array
.
shape
[:
-
2
]
+
(
4
,
4
),
device
=
rot_array
.
device
,
dtype
=
rot_array
.
dtype
)
array
[...,
:
3
,
:
3
]
=
rot_array
array
[...,
:
3
,
3
]
=
vec_array
array
[...,
3
,
3
]
=
1.
return
array
def
to_tensor_4x4
(
self
)
->
torch
.
Tensor
:
return
self
.
to_tensor
()
def
reshape
(
self
,
new_shape
)
->
Rigid3Array
:
rots
=
self
.
rotation
.
reshape
(
new_shape
)
trans
=
self
.
translation
.
reshape
(
new_shape
)
return
Rigid3Array
(
rots
,
trans
)
def
stop_rot_gradient
(
self
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
.
stop_gradient
(),
self
.
translation
,
)
@
classmethod
def
from_array
(
cls
,
array
):
rot
=
rotation_matrix
.
Rot3Array
.
from_array
(
array
[...,
:
3
,
:
3
],
)
vec
=
vector
.
Vec3Array
.
from_array
(
array
[...,
:
3
,
3
])
return
cls
(
rot
,
vec
)
@
classmethod
def
from_tensor_4x4
(
cls
,
array
):
return
cls
.
from_array
(
array
)
@
classmethod
def
from_array4x4
(
cls
,
array
:
torch
.
tensor
)
->
Rigid3Array
:
"""Construct Rigid3Array from homogeneous 4x4 array."""
rotation
=
rotation_matrix
.
Rot3Array
(
array
[...,
0
,
0
],
array
[...,
0
,
1
],
array
[...,
0
,
2
],
array
[...,
1
,
0
],
array
[...,
1
,
1
],
array
[...,
1
,
2
],
array
[...,
2
,
0
],
array
[...,
2
,
1
],
array
[...,
2
,
2
]
)
translation
=
vector
.
Vec3Array
(
array
[...,
0
,
3
],
array
[...,
1
,
3
],
array
[...,
2
,
3
]
)
return
cls
(
rotation
,
translation
)
def
cuda
(
self
)
->
Rigid3Array
:
return
Rigid3Array
.
from_tensor_4x4
(
self
.
to_tensor_4x4
().
cuda
())
openfold/utils/geometry/rotation_matrix.py
0 → 100644
View file @
bb3f51e5
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rot3Array Matrix Class."""
from
__future__
import
annotations
import
dataclasses
from
typing
import
List
import
torch
from
openfold.utils.geometry
import
utils
from
openfold.utils.geometry
import
vector
from
openfold.utils.tensor_utils
import
tensor_tree_map
COMPONENTS
=
[
'xx'
,
'xy'
,
'xz'
,
'yx'
,
'yy'
,
'yz'
,
'zx'
,
'zy'
,
'zz'
]
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Rot3Array
:
"""Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays."""
xx
:
torch
.
Tensor
=
dataclasses
.
field
(
metadata
=
{
'dtype'
:
torch
.
float32
})
xy
:
torch
.
Tensor
xz
:
torch
.
Tensor
yx
:
torch
.
Tensor
yy
:
torch
.
Tensor
yz
:
torch
.
Tensor
zx
:
torch
.
Tensor
zy
:
torch
.
Tensor
zz
:
torch
.
Tensor
__array_ufunc__
=
None
def
__getitem__
(
self
,
index
):
field_names
=
utils
.
get_field_names
(
Rot3Array
)
return
Rot3Array
(
**
{
name
:
getattr
(
self
,
name
)[
index
]
for
name
in
field_names
}
)
def
__mul__
(
self
,
other
:
torch
.
Tensor
):
field_names
=
utils
.
get_field_names
(
Rot3Array
)
return
Rot3Array
(
**
{
name
:
getattr
(
self
,
name
)
*
other
for
name
in
field_names
}
)
def
__matmul__
(
self
,
other
:
Rot3Array
)
->
Rot3Array
:
"""Composes two Rot3Arrays."""
c0
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xx
,
other
.
yx
,
other
.
zx
))
c1
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xy
,
other
.
yy
,
other
.
zy
))
c2
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xz
,
other
.
yz
,
other
.
zz
))
return
Rot3Array
(
c0
.
x
,
c1
.
x
,
c2
.
x
,
c0
.
y
,
c1
.
y
,
c2
.
y
,
c0
.
z
,
c1
.
z
,
c2
.
z
)
def
map_tensor_fn
(
self
,
fn
)
->
Rot3Array
:
field_names
=
utils
.
get_field_names
(
Rot3Array
)
return
Rot3Array
(
**
{
name
:
fn
(
getattr
(
self
,
name
))
for
name
in
field_names
}
)
def
inverse
(
self
)
->
Rot3Array
:
"""Returns inverse of Rot3Array."""
return
Rot3Array
(
self
.
xx
,
self
.
yx
,
self
.
zx
,
self
.
xy
,
self
.
yy
,
self
.
zy
,
self
.
xz
,
self
.
yz
,
self
.
zz
)
def
apply_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Applies Rot3Array to point."""
return
vector
.
Vec3Array
(
self
.
xx
*
point
.
x
+
self
.
xy
*
point
.
y
+
self
.
xz
*
point
.
z
,
self
.
yx
*
point
.
x
+
self
.
yy
*
point
.
y
+
self
.
yz
*
point
.
z
,
self
.
zx
*
point
.
x
+
self
.
zy
*
point
.
y
+
self
.
zz
*
point
.
z
)
def
apply_inverse_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Applies inverse Rot3Array to point."""
return
self
.
inverse
().
apply_to_point
(
point
)
def
unsqueeze
(
self
,
dim
:
int
):
return
Rot3Array
(
*
tensor_tree_map
(
lambda
t
:
t
.
unsqueeze
(
dim
),
[
getattr
(
self
,
c
)
for
c
in
COMPONENTS
]
)
)
def
stop_gradient
(
self
)
->
Rot3Array
:
return
Rot3Array
(
*
[
getattr
(
self
,
c
).
detach
()
for
c
in
COMPONENTS
]
)
@
classmethod
def
identity
(
cls
,
shape
,
device
)
->
Rot3Array
:
"""Returns identity of given shape."""
ones
=
torch
.
ones
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
)
zeros
=
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
)
return
cls
(
ones
,
zeros
,
zeros
,
zeros
,
ones
,
zeros
,
zeros
,
zeros
,
ones
)
@
classmethod
def
from_two_vectors
(
cls
,
e0
:
vector
.
Vec3Array
,
e1
:
vector
.
Vec3Array
)
->
Rot3Array
:
"""Construct Rot3Array from two Vectors.
Rot3Array is constructed such that in the corresponding frame 'e0' lies on
the positive x-Axis and 'e1' lies in the xy plane with positive sign of y.
Args:
e0: Vector
e1: Vector
Returns:
Rot3Array
"""
# Normalize the unit vector for the x-axis, e0.
e0
=
e0
.
normalized
()
# make e1 perpendicular to e0.
c
=
e1
.
dot
(
e0
)
e1
=
(
e1
-
c
*
e0
).
normalized
()
# Compute e2 as cross product of e0 and e1.
e2
=
e0
.
cross
(
e1
)
return
cls
(
e0
.
x
,
e1
.
x
,
e2
.
x
,
e0
.
y
,
e1
.
y
,
e2
.
y
,
e0
.
z
,
e1
.
z
,
e2
.
z
)
@
classmethod
def
from_array
(
cls
,
array
:
torch
.
Tensor
)
->
Rot3Array
:
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
rows
=
torch
.
unbind
(
array
,
dim
=-
2
)
rc
=
[
torch
.
unbind
(
e
,
dim
=-
1
)
for
e
in
rows
]
return
cls
(
*
[
e
for
row
in
rc
for
e
in
row
])
def
to_tensor
(
self
)
->
torch
.
Tensor
:
"""Convert Rot3Array to array of shape [..., 3, 3]."""
return
torch
.
stack
(
[
torch
.
stack
([
self
.
xx
,
self
.
xy
,
self
.
xz
],
dim
=-
1
),
torch
.
stack
([
self
.
yx
,
self
.
yy
,
self
.
yz
],
dim
=-
1
),
torch
.
stack
([
self
.
zx
,
self
.
zy
,
self
.
zz
],
dim
=-
1
)
],
dim
=-
2
)
@
classmethod
def
from_quaternion
(
cls
,
w
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
normalize
:
bool
=
True
,
eps
:
float
=
1e-6
)
->
Rot3Array
:
"""Construct Rot3Array from components of quaternion."""
if
normalize
:
inv_norm
=
torch
.
rsqrt
(
torch
.
clamp
(
w
**
2
+
x
**
2
+
y
**
2
+
z
**
2
,
min
=
eps
))
w
=
w
*
inv_norm
x
=
x
*
inv_norm
y
=
y
*
inv_norm
z
=
z
*
inv_norm
xx
=
1.0
-
2.0
*
(
y
**
2
+
z
**
2
)
xy
=
2.0
*
(
x
*
y
-
w
*
z
)
xz
=
2.0
*
(
x
*
z
+
w
*
y
)
yx
=
2.0
*
(
x
*
y
+
w
*
z
)
yy
=
1.0
-
2.0
*
(
x
**
2
+
z
**
2
)
yz
=
2.0
*
(
y
*
z
-
w
*
x
)
zx
=
2.0
*
(
x
*
z
-
w
*
y
)
zy
=
2.0
*
(
y
*
z
+
w
*
x
)
zz
=
1.0
-
2.0
*
(
x
**
2
+
y
**
2
)
return
cls
(
xx
,
xy
,
xz
,
yx
,
yy
,
yz
,
zx
,
zy
,
zz
)
def
reshape
(
self
,
new_shape
):
field_names
=
utils
.
get_field_names
(
Rot3Array
)
reshape_fn
=
lambda
t
:
t
.
reshape
(
new_shape
)
return
Rot3Array
(
**
{
name
:
reshape_fn
(
getattr
(
self
,
name
))
for
name
in
field_names
}
)
@
classmethod
def
cat
(
cls
,
rots
:
List
[
Rot3Array
],
dim
:
int
)
->
Rot3Array
:
field_names
=
utils
.
get_field_names
(
Rot3Array
)
cat_fn
=
lambda
l
:
torch
.
cat
(
l
,
dim
=
dim
)
return
cls
(
**
{
name
:
cat_fn
([
getattr
(
r
,
name
)
for
r
in
rots
])
for
name
in
field_names
}
)
openfold/utils/geometry/test_utils.py
0 → 100644
View file @
bb3f51e5
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utils for tests."""
import
dataclasses
import
torch
from
openfold.utils.geometry
import
rigid_matrix_vector
from
openfold.utils.geometry
import
rotation_matrix
from
openfold.utils.geometry
import
vector
def
assert_rotation_matrix_equal
(
matrix1
:
rotation_matrix
.
Rot3Array
,
matrix2
:
rotation_matrix
.
Rot3Array
):
for
field
in
dataclasses
.
fields
(
rotation_matrix
.
Rot3Array
):
field
=
field
.
name
assert
torch
.
equal
(
getattr
(
matrix1
,
field
),
getattr
(
matrix2
,
field
))
def
assert_rotation_matrix_close
(
mat1
:
rotation_matrix
.
Rot3Array
,
mat2
:
rotation_matrix
.
Rot3Array
):
assert
torch
.
allclose
(
mat1
.
to_tensor
(),
mat2
.
to_tensor
(),
atol
=
1e-6
)
def
assert_array_equal_to_rotation_matrix
(
array
:
torch
.
Tensor
,
matrix
:
rotation_matrix
.
Rot3Array
):
"""Check that array and Matrix match."""
assert
torch
.
equal
(
matrix
.
xx
,
array
[...,
0
,
0
])
assert
torch
.
equal
(
matrix
.
xy
,
array
[...,
0
,
1
])
assert
torch
.
equal
(
matrix
.
xz
,
array
[...,
0
,
2
])
assert
torch
.
equal
(
matrix
.
yx
,
array
[...,
1
,
0
])
assert
torch
.
equal
(
matrix
.
yy
,
array
[...,
1
,
1
])
assert
torch
.
equal
(
matrix
.
yz
,
array
[...,
1
,
2
])
assert
torch
.
equal
(
matrix
.
zx
,
array
[...,
2
,
0
])
assert
torch
.
equal
(
matrix
.
zy
,
array
[...,
2
,
1
])
assert
torch
.
equal
(
matrix
.
zz
,
array
[...,
2
,
2
])
def
assert_array_close_to_rotation_matrix
(
array
:
torch
.
Tensor
,
matrix
:
rotation_matrix
.
Rot3Array
):
assert
torch
.
allclose
(
matrix
.
to_tensor
(),
array
,
atol
=
1e-6
)
def
assert_vectors_equal
(
vec1
:
vector
.
Vec3Array
,
vec2
:
vector
.
Vec3Array
):
assert
torch
.
equal
(
vec1
.
x
,
vec2
.
x
)
assert
torch
.
equal
(
vec1
.
y
,
vec2
.
y
)
assert
torch
.
equal
(
vec1
.
z
,
vec2
.
z
)
def
assert_vectors_close
(
vec1
:
vector
.
Vec3Array
,
vec2
:
vector
.
Vec3Array
):
assert
torch
.
allclose
(
vec1
.
x
,
vec2
.
x
,
atol
=
1e-6
,
rtol
=
0.
)
assert
torch
.
allclose
(
vec1
.
y
,
vec2
.
y
,
atol
=
1e-6
,
rtol
=
0.
)
assert
torch
.
allclose
(
vec1
.
z
,
vec2
.
z
,
atol
=
1e-6
,
rtol
=
0.
)
def
assert_array_close_to_vector
(
array
:
torch
.
Tensor
,
vec
:
vector
.
Vec3Array
):
assert
torch
.
allclose
(
vec
.
to_tensor
(),
array
,
atol
=
1e-6
,
rtol
=
0.
)
def
assert_array_equal_to_vector
(
array
:
torch
.
Tensor
,
vec
:
vector
.
Vec3Array
):
assert
torch
.
equal
(
vec
.
to_tensor
(),
array
)
def
assert_rigid_equal_to_rigid
(
rigid1
:
rigid_matrix_vector
.
Rigid3Array
,
rigid2
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rot_trans_equal_to_rigid
(
rigid1
.
rotation
,
rigid1
.
translation
,
rigid2
)
def
assert_rigid_close_to_rigid
(
rigid1
:
rigid_matrix_vector
.
Rigid3Array
,
rigid2
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rot_trans_close_to_rigid
(
rigid1
.
rotation
,
rigid1
.
translation
,
rigid2
)
def
assert_rot_trans_equal_to_rigid
(
rot
:
rotation_matrix
.
Rot3Array
,
trans
:
vector
.
Vec3Array
,
rigid
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rotation_matrix_equal
(
rot
,
rigid
.
rotation
)
assert_vectors_equal
(
trans
,
rigid
.
translation
)
def
assert_rot_trans_close_to_rigid
(
rot
:
rotation_matrix
.
Rot3Array
,
trans
:
vector
.
Vec3Array
,
rigid
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rotation_matrix_close
(
rot
,
rigid
.
rotation
)
assert_vectors_close
(
trans
,
rigid
.
translation
)
openfold/utils/geometry/utils.py
0 → 100644
View file @
bb3f51e5
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for geometry library."""
import
dataclasses
def
get_field_names
(
cls
):
fields
=
dataclasses
.
fields
(
cls
)
field_names
=
[
f
.
name
for
f
in
fields
]
return
field_names
openfold/utils/geometry/vector.py
0 → 100644
View file @
bb3f51e5
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vec3Array Class."""
from
__future__
import
annotations
import
dataclasses
from
typing
import
Union
,
List
import
torch
Float
=
Union
[
float
,
torch
.
Tensor
]
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Vec3Array
:
x
:
torch
.
Tensor
=
dataclasses
.
field
(
metadata
=
{
'dtype'
:
torch
.
float32
})
y
:
torch
.
Tensor
z
:
torch
.
Tensor
def
__post_init__
(
self
):
if
hasattr
(
self
.
x
,
'dtype'
):
assert
self
.
x
.
dtype
==
self
.
y
.
dtype
assert
self
.
x
.
dtype
==
self
.
z
.
dtype
assert
all
([
x
==
y
for
x
,
y
in
zip
(
self
.
x
.
shape
,
self
.
y
.
shape
)])
assert
all
([
x
==
z
for
x
,
z
in
zip
(
self
.
x
.
shape
,
self
.
z
.
shape
)])
def
__add__
(
self
,
other
:
Vec3Array
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
+
other
.
x
,
self
.
y
+
other
.
y
,
self
.
z
+
other
.
z
,
)
def
__sub__
(
self
,
other
:
Vec3Array
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
-
other
.
x
,
self
.
y
-
other
.
y
,
self
.
z
-
other
.
z
,
)
def
__mul__
(
self
,
other
:
Float
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
*
other
,
self
.
y
*
other
,
self
.
z
*
other
,
)
def
__rmul__
(
self
,
other
:
Float
)
->
Vec3Array
:
return
self
*
other
def
__truediv__
(
self
,
other
:
Float
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
/
other
,
self
.
y
/
other
,
self
.
z
/
other
,
)
def
__neg__
(
self
)
->
Vec3Array
:
return
self
*
-
1
def
__pos__
(
self
)
->
Vec3Array
:
return
self
*
1
def
__getitem__
(
self
,
index
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
[
index
],
self
.
y
[
index
],
self
.
z
[
index
],
)
def
__iter__
(
self
):
return
iter
((
self
.
x
,
self
.
y
,
self
.
z
))
@
property
def
shape
(
self
):
return
self
.
x
.
shape
def
map_tensor_fn
(
self
,
fn
)
->
Vec3Array
:
return
Vec3Array
(
fn
(
self
.
x
),
fn
(
self
.
y
),
fn
(
self
.
z
),
)
def
cross
(
self
,
other
:
Vec3Array
)
->
Vec3Array
:
"""Compute cross product between 'self' and 'other'."""
new_x
=
self
.
y
*
other
.
z
-
self
.
z
*
other
.
y
new_y
=
self
.
z
*
other
.
x
-
self
.
x
*
other
.
z
new_z
=
self
.
x
*
other
.
y
-
self
.
y
*
other
.
x
return
Vec3Array
(
new_x
,
new_y
,
new_z
)
def
dot
(
self
,
other
:
Vec3Array
)
->
Float
:
"""Compute dot product between 'self' and 'other'."""
return
self
.
x
*
other
.
x
+
self
.
y
*
other
.
y
+
self
.
z
*
other
.
z
def
norm
(
self
,
epsilon
:
float
=
1e-6
)
->
Float
:
"""Compute Norm of Vec3Array, clipped to epsilon."""
# To avoid NaN on the backward pass, we must use maximum before the sqrt
norm2
=
self
.
dot
(
self
)
if
epsilon
:
norm2
=
torch
.
clamp
(
norm2
,
min
=
epsilon
**
2
)
return
torch
.
sqrt
(
norm2
)
def
norm2
(
self
):
return
self
.
dot
(
self
)
def
normalized
(
self
,
epsilon
:
float
=
1e-6
)
->
Vec3Array
:
"""Return unit vector with optional clipping."""
return
self
/
self
.
norm
(
epsilon
)
def
clone
(
self
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
.
clone
(),
self
.
y
.
clone
(),
self
.
z
.
clone
(),
)
def
reshape
(
self
,
new_shape
)
->
Vec3Array
:
x
=
self
.
x
.
reshape
(
new_shape
)
y
=
self
.
y
.
reshape
(
new_shape
)
z
=
self
.
z
.
reshape
(
new_shape
)
return
Vec3Array
(
x
,
y
,
z
)
def
sum
(
self
,
dim
:
int
)
->
Vec3Array
:
return
Vec3Array
(
torch
.
sum
(
self
.
x
,
dim
=
dim
),
torch
.
sum
(
self
.
y
,
dim
=
dim
),
torch
.
sum
(
self
.
z
,
dim
=
dim
),
)
def
unsqueeze
(
self
,
dim
:
int
):
return
Vec3Array
(
self
.
x
.
unsqueeze
(
dim
),
self
.
y
.
unsqueeze
(
dim
),
self
.
z
.
unsqueeze
(
dim
),
)
@
classmethod
def
zeros
(
cls
,
shape
,
device
=
"cpu"
):
"""Return Vec3Array corresponding to zeros of given shape."""
return
cls
(
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
),
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
),
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
)
)
def
to_tensor
(
self
)
->
torch
.
Tensor
:
return
torch
.
stack
([
self
.
x
,
self
.
y
,
self
.
z
],
dim
=-
1
)
@
classmethod
def
from_array
(
cls
,
tensor
):
return
cls
(
*
torch
.
unbind
(
tensor
,
dim
=-
1
))
@
classmethod
def
cat
(
cls
,
vecs
:
List
[
Vec3Array
],
dim
:
int
)
->
Vec3Array
:
return
cls
(
torch
.
cat
([
v
.
x
for
v
in
vecs
],
dim
=
dim
),
torch
.
cat
([
v
.
y
for
v
in
vecs
],
dim
=
dim
),
torch
.
cat
([
v
.
z
for
v
in
vecs
],
dim
=
dim
),
)
def
square_euclidean_distance
(
vec1
:
Vec3Array
,
vec2
:
Vec3Array
,
epsilon
:
float
=
1e-6
)
->
Float
:
"""Computes square of euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute distance to
vec2: Vec3Array to compute distance from, should be
broadcast compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of square euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
difference
=
vec1
-
vec2
distance
=
difference
.
dot
(
difference
)
if
epsilon
:
distance
=
torch
.
clamp
(
distance
,
min
=
epsilon
)
return
distance
def
dot
(
vector1
:
Vec3Array
,
vector2
:
Vec3Array
)
->
Float
:
return
vector1
.
dot
(
vector2
)
def
cross
(
vector1
:
Vec3Array
,
vector2
:
Vec3Array
)
->
Float
:
return
vector1
.
cross
(
vector2
)
def
norm
(
vector
:
Vec3Array
,
epsilon
:
float
=
1e-6
)
->
Float
:
return
vector
.
norm
(
epsilon
)
def
normalized
(
vector
:
Vec3Array
,
epsilon
:
float
=
1e-6
)
->
Vec3Array
:
return
vector
.
normalized
(
epsilon
)
def
euclidean_distance
(
vec1
:
Vec3Array
,
vec2
:
Vec3Array
,
epsilon
:
float
=
1e-6
)
->
Float
:
"""Computes euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute euclidean distance to
vec2: Vec3Array to compute euclidean distance from, should be
broadcast compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
distance_sq
=
square_euclidean_distance
(
vec1
,
vec2
,
epsilon
**
2
)
distance
=
torch
.
sqrt
(
distance_sq
)
return
distance
def
dihedral_angle
(
a
:
Vec3Array
,
b
:
Vec3Array
,
c
:
Vec3Array
,
d
:
Vec3Array
)
->
Float
:
"""Computes torsion angle for a quadruple of points.
For points (a, b, c, d), this is the angle between the planes defined by
points (a, b, c) and (b, c, d). It is also known as the dihedral angle.
Arguments:
a: A Vec3Array of coordinates.
b: A Vec3Array of coordinates.
c: A Vec3Array of coordinates.
d: A Vec3Array of coordinates.
Returns:
A tensor of angles in radians: [-pi, pi].
"""
v1
=
a
-
b
v2
=
b
-
c
v3
=
d
-
c
c1
=
v1
.
cross
(
v2
)
c2
=
v3
.
cross
(
v2
)
c3
=
c2
.
cross
(
c1
)
v2_mag
=
v2
.
norm
()
return
torch
.
atan2
(
c3
.
dot
(
v2
),
v2_mag
*
c1
.
dot
(
c2
))
openfold/utils/import_weights.py
View file @
bb3f51e5
...
...
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
re
from
enum
import
Enum
from
dataclasses
import
dataclass
from
functools
import
partial
...
...
@@ -27,7 +28,7 @@ _NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/"
# With Param, a poor man's enum with attributes (Rust-style)
class
ParamType
(
Enum
):
LinearWeight
=
partial
(
# hack: partial prevents fns from becoming methods
lambda
w
:
w
.
transpose
(
-
1
,
-
2
)
lambda
w
:
w
.
unsqueeze
(
-
1
)
if
len
(
w
.
shape
)
==
1
else
w
.
transpose
(
-
1
,
-
2
)
)
LinearWeightMHA
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
2
],
-
1
).
transpose
(
-
1
,
-
2
)
...
...
@@ -39,6 +40,13 @@ class ParamType(Enum):
LinearWeightOPM
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
3
],
-
1
,
w
.
shape
[
-
1
]).
transpose
(
-
1
,
-
2
)
)
LinearWeightMultimer
=
partial
(
lambda
w
:
w
.
unsqueeze
(
-
1
)
if
len
(
w
.
shape
)
==
1
else
w
.
reshape
(
w
.
shape
[
0
],
-
1
).
transpose
(
-
1
,
-
2
)
)
LinearBiasMultimer
=
partial
(
lambda
w
:
w
.
reshape
(
-
1
)
)
Other
=
partial
(
lambda
w
:
w
)
def
__init__
(
self
,
fn
):
...
...
@@ -50,6 +58,7 @@ class Param:
param
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
param_type
:
ParamType
=
ParamType
.
Other
stacked
:
bool
=
False
swap
:
bool
=
False
def
process_translation_dict
(
d
,
top_layer
=
True
):
...
...
@@ -93,6 +102,7 @@ def stacked(param_dict_list, out=None):
param
=
[
param
.
param
for
param
in
v
],
param_type
=
v
[
0
].
param_type
,
stacked
=
True
,
swap
=
v
[
0
].
swap
)
out
[
k
]
=
stacked_param
...
...
@@ -114,6 +124,11 @@ def assign(translation_dict, orig_weights):
try
:
weights
=
list
(
map
(
param_type
.
transformation
,
weights
))
for
p
,
w
in
zip
(
ref
,
weights
):
if
param
.
swap
:
index
=
p
.
shape
[
0
]
//
2
p
[:
index
].
copy_
(
w
[
index
:])
p
[
index
:].
copy_
(
w
[:
index
])
else
:
p
.
copy_
(
w
)
except
:
print
(
k
)
...
...
@@ -122,26 +137,44 @@ def assign(translation_dict, orig_weights):
raise
def
generate_translation_dict
(
model
,
version
):
def
generate_translation_dict
(
model
,
version
,
is_multimer
=
False
):
#######################
# Some templates
#######################
LinearWeight
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeight
))
LinearBias
=
lambda
l
:
(
Param
(
l
))
LinearWeightMHA
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightMHA
))
LinearBiasMHA
=
lambda
b
:
(
Param
(
b
,
param_type
=
ParamType
.
LinearBiasMHA
))
LinearWeightOPM
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightOPM
))
LinearWeightMultimer
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightMultimer
)
)
LinearBiasMultimer
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearBiasMultimer
)
)
LinearWeightSwap
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeight
,
swap
=
True
))
LinearBiasSwap
=
lambda
l
:
(
Param
(
l
,
swap
=
True
))
LinearParams
=
lambda
l
:
{
"weights"
:
LinearWeight
(
l
.
weight
),
"bias"
:
LinearBias
(
l
.
bias
),
}
LinearParamsMHA
=
lambda
l
:
{
"weights"
:
LinearWeightMHA
(
l
.
weight
),
"bias"
:
LinearBiasMHA
(
l
.
bias
),
}
LinearParamsSwap
=
lambda
l
:
{
"weights"
:
LinearWeightSwap
(
l
.
weight
),
"bias"
:
LinearBiasSwap
(
l
.
bias
),
}
LinearParamsMultimer
=
lambda
l
:
{
"weights"
:
LinearWeightMultimer
(
l
.
weight
),
"bias"
:
LinearBiasMultimer
(
l
.
bias
),
}
LayerNormParams
=
lambda
l
:
{
"scale"
:
Param
(
l
.
weight
),
"offset"
:
Param
(
l
.
bias
),
...
...
@@ -178,31 +211,48 @@ def generate_translation_dict(model, version):
"attention"
:
AttentionGatedParams
(
tri_att
.
mha
),
}
TriMulOutParams
=
lambda
tri_mul
:
{
"layer_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"left_projection"
:
LinearParams
(
tri_mul
.
linear_a_p
),
"right_projection"
:
LinearParams
(
tri_mul
.
linear_b_p
),
"left_gate"
:
LinearParams
(
tri_mul
.
linear_a_g
),
"right_gate"
:
LinearParams
(
tri_mul
.
linear_b_g
),
"center_layer_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
"output_projection"
:
LinearParams
(
tri_mul
.
linear_z
),
"gating_linear"
:
LinearParams
(
tri_mul
.
linear_g
),
def
TriMulOutParams
(
tri_mul
,
outgoing
=
True
):
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
version
):
lin_param_type
=
LinearParams
if
outgoing
else
LinearParamsSwap
d
=
{
"left_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"projection"
:
lin_param_type
(
tri_mul
.
linear_ab_p
),
"gate"
:
lin_param_type
(
tri_mul
.
linear_ab_g
),
"center_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
}
else
:
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
TriMulInParams
=
lambda
tri_mul
:
{
if
outgoing
:
left_projection
=
LinearParams
(
tri_mul
.
linear_a_p
)
right_projection
=
LinearParams
(
tri_mul
.
linear_b_p
)
left_gate
=
LinearParams
(
tri_mul
.
linear_a_g
)
right_gate
=
LinearParams
(
tri_mul
.
linear_b_g
)
else
:
left_projection
=
LinearParams
(
tri_mul
.
linear_b_p
)
right_projection
=
LinearParams
(
tri_mul
.
linear_a_p
)
left_gate
=
LinearParams
(
tri_mul
.
linear_b_g
)
right_gate
=
LinearParams
(
tri_mul
.
linear_a_g
)
d
=
{
"layer_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"left_projection"
:
LinearParams
(
tri_mul
.
linear_b_p
)
,
"right_projection"
:
LinearParams
(
tri_mul
.
linear_a_p
)
,
"left_gate"
:
LinearParams
(
tri_mul
.
linear_b_g
)
,
"right_gate"
:
LinearParams
(
tri_mul
.
linear_a_g
)
,
"left_projection"
:
left_projection
,
"right_projection"
:
right_projection
,
"left_gate"
:
left_gate
,
"right_gate"
:
right_gate
,
"center_layer_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
}
d
.
update
({
"output_projection"
:
LinearParams
(
tri_mul
.
linear_z
),
"gating_linear"
:
LinearParams
(
tri_mul
.
linear_g
),
}
})
return
d
TriMulInParams
=
partial
(
TriMulOutParams
,
outgoing
=
False
)
PairTransitionParams
=
lambda
pt
:
{
"input_layer_norm"
:
LayerNormParams
(
pt
.
layer_norm
),
...
...
@@ -236,8 +286,46 @@ def generate_translation_dict(model, version):
IPAParams
=
lambda
ipa
:
{
"q_scalar"
:
LinearParams
(
ipa
.
linear_q
),
"kv_scalar"
:
LinearParams
(
ipa
.
linear_kv
),
"q_point_local"
:
LinearParams
(
ipa
.
linear_q_points
),
"kv_point_local"
:
LinearParams
(
ipa
.
linear_kv_points
),
"q_point_local"
:
LinearParams
(
ipa
.
linear_q_points
.
linear
),
"kv_point_local"
:
LinearParams
(
ipa
.
linear_kv_points
.
linear
),
"trainable_point_weights"
:
Param
(
param
=
ipa
.
head_weights
,
param_type
=
ParamType
.
Other
),
"attention_2d"
:
LinearParams
(
ipa
.
linear_b
),
"output_projection"
:
LinearParams
(
ipa
.
linear_out
),
}
PointProjectionParams
=
lambda
pp
:
{
"point_projection"
:
LinearParamsMHA
(
pp
.
linear
,
),
}
IPAParamsMultimer
=
lambda
ipa
:
{
"q_scalar_projection"
:
{
"weights"
:
LinearWeightMHA
(
ipa
.
linear_q
.
weight
,
),
},
"k_scalar_projection"
:
{
"weights"
:
LinearWeightMHA
(
ipa
.
linear_k
.
weight
,
),
},
"v_scalar_projection"
:
{
"weights"
:
LinearWeightMHA
(
ipa
.
linear_v
.
weight
,
),
},
"q_point_projection"
:
PointProjectionParams
(
ipa
.
linear_q_points
),
"k_point_projection"
:
PointProjectionParams
(
ipa
.
linear_k_points
),
"v_point_projection"
:
PointProjectionParams
(
ipa
.
linear_v_points
),
"trainable_point_weights"
:
Param
(
param
=
ipa
.
head_weights
,
param_type
=
ParamType
.
Other
),
...
...
@@ -280,27 +368,29 @@ def generate_translation_dict(model, version):
b
.
msa_att_row
),
col_att_name
:
msa_col_att_params
,
"msa_transition"
:
MSATransitionParams
(
b
.
core
.
msa_transition
),
"msa_transition"
:
MSATransitionParams
(
b
.
msa_transition
),
"outer_product_mean"
:
OuterProductMeanParams
(
b
.
core
.
outer_product_mean
),
OuterProductMeanParams
(
b
.
outer_product_mean
),
"triangle_multiplication_outgoing"
:
TriMulOutParams
(
b
.
core
.
tri_mul_out
),
TriMulOutParams
(
b
.
pair_stack
.
tri_mul_out
),
"triangle_multiplication_incoming"
:
TriMulInParams
(
b
.
core
.
tri_mul_in
),
TriMulInParams
(
b
.
pair_stack
.
tri_mul_in
),
"triangle_attention_starting_node"
:
TriAttParams
(
b
.
core
.
tri_att_start
),
TriAttParams
(
b
.
pair_stack
.
tri_att_start
),
"triangle_attention_ending_node"
:
TriAttParams
(
b
.
core
.
tri_att_end
),
TriAttParams
(
b
.
pair_stack
.
tri_att_end
),
"pair_transition"
:
PairTransitionParams
(
b
.
core
.
pair_transition
),
PairTransitionParams
(
b
.
pair_stack
.
pair_transition
),
}
return
d
ExtraMSABlockParams
=
partial
(
EvoformerBlockParams
,
is_extra_msa
=
True
)
FoldIterationParams
=
lambda
sm
:
{
"invariant_point_attention"
:
IPAParams
(
sm
.
ipa
),
def
FoldIterationParams
(
sm
):
d
=
{
"invariant_point_attention"
:
IPAParamsMultimer
(
sm
.
ipa
)
if
is_multimer
else
IPAParams
(
sm
.
ipa
),
"attention_layer_norm"
:
LayerNormParams
(
sm
.
layer_norm_ipa
),
"transition"
:
LinearParams
(
sm
.
transition
.
layers
[
0
].
linear_1
),
"transition_1"
:
LinearParams
(
sm
.
transition
.
layers
[
0
].
linear_2
),
...
...
@@ -309,25 +399,39 @@ def generate_translation_dict(model, version):
"affine_update"
:
LinearParams
(
sm
.
bb_update
.
linear
),
"rigid_sidechain"
:
{
"input_projection"
:
LinearParams
(
sm
.
angle_resnet
.
linear_in
),
"input_projection_1"
:
LinearParams
(
sm
.
angle_resnet
.
linear_initial
),
"input_projection_1"
:
LinearParams
(
sm
.
angle_resnet
.
linear_initial
),
"resblock1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
0
].
linear_1
),
"resblock2"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
0
].
linear_2
),
"resblock1_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_1
),
"resblock2_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_2
),
"unnormalized_angles"
:
LinearParams
(
sm
.
angle_resnet
.
linear_out
),
"resblock1_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_1
),
"resblock2_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_2
),
"unnormalized_angles"
:
LinearParams
(
sm
.
angle_resnet
.
linear_out
),
},
}
if
(
is_multimer
):
d
.
pop
(
"affine_update"
)
d
[
"quat_rigid"
]
=
{
"rigid"
:
LinearParams
(
sm
.
bb_update
.
linear
)
}
return
d
############################
# translations dict overflow
############################
ems_blocks
=
model
.
extra_msa_stack
.
blocks
ems_blocks_params
=
stacked
([
ExtraMSABlockParams
(
b
)
for
b
in
ems_blocks
])
evo_blocks
=
model
.
evoformer
.
blocks
evo_blocks_params
=
stacked
([
EvoformerBlockParams
(
b
)
for
b
in
evo_blocks
])
if
(
not
is_multimer
):
translations
=
{
"evoformer"
:
{
"preprocess_1d"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_m
),
...
...
@@ -383,6 +487,64 @@ def generate_translation_dict(model, version):
"logits"
:
LinearParams
(
model
.
aux_heads
.
masked_msa
.
linear
),
},
}
else
:
translations
=
{
"evoformer"
:
{
"preprocess_1d"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_m
),
"preprocess_msa"
:
LinearParams
(
model
.
input_embedder
.
linear_msa_m
),
"left_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_i
),
"right_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_j
),
"prev_pos_linear"
:
LinearParams
(
model
.
recycling_embedder
.
linear
),
"prev_msa_first_row_norm"
:
LayerNormParams
(
model
.
recycling_embedder
.
layer_norm_m
),
"prev_pair_norm"
:
LayerNormParams
(
model
.
recycling_embedder
.
layer_norm_z
),
"~_relative_encoding"
:
{
"position_activations"
:
LinearParams
(
model
.
input_embedder
.
linear_relpos
),
},
"extra_msa_activations"
:
LinearParams
(
model
.
extra_msa_embedder
.
linear
),
"extra_msa_stack"
:
ems_blocks_params
,
"evoformer_iteration"
:
evo_blocks_params
,
"single_activations"
:
LinearParams
(
model
.
evoformer
.
linear
),
},
"structure_module"
:
{
"single_layer_norm"
:
LayerNormParams
(
model
.
structure_module
.
layer_norm_s
),
"initial_projection"
:
LinearParams
(
model
.
structure_module
.
linear_in
),
"pair_layer_norm"
:
LayerNormParams
(
model
.
structure_module
.
layer_norm_z
),
"fold_iteration"
:
FoldIterationParams
(
model
.
structure_module
),
},
"predicted_lddt_head"
:
{
"input_layer_norm"
:
LayerNormParams
(
model
.
aux_heads
.
plddt
.
layer_norm
),
"act_0"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_1
),
"act_1"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_2
),
"logits"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_3
),
},
"distogram_head"
:
{
"half_logits"
:
LinearParams
(
model
.
aux_heads
.
distogram
.
linear
),
},
"experimentally_resolved_head"
:
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
experimentally_resolved
.
linear
),
},
"masked_msa_head"
:
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
masked_msa
.
linear
),
},
}
no_templ
=
[
"model_3"
,
...
...
@@ -394,48 +556,98 @@ def generate_translation_dict(model, version):
]
if
version
not
in
no_templ
:
tps_blocks
=
model
.
template_pair_stack
.
blocks
tps_blocks
=
model
.
template_
embedder
.
template_
pair_stack
.
blocks
tps_blocks_params
=
stacked
(
[
TemplatePairBlockParams
(
b
)
for
b
in
tps_blocks
]
)
if
(
not
is_multimer
):
template_param_dict
=
{
"template_embedding"
:
{
"single_template_embedding"
:
{
"embedding2d"
:
LinearParams
(
model
.
template_pair_embedder
.
linear
model
.
template_
embedder
.
template_
pair_embedder
.
linear
),
"template_pair_stack"
:
{
"__layer_stack_no_state"
:
tps_blocks_params
,
},
"output_layer_norm"
:
LayerNormParams
(
model
.
template_pair_stack
.
layer_norm
model
.
template_
embedder
.
template_
pair_stack
.
layer_norm
),
},
"attention"
:
AttentionParams
(
model
.
template_pointwise_att
.
mha
),
"attention"
:
AttentionParams
(
model
.
template_
embedder
.
template_
pointwise_att
.
mha
),
},
"template_single_embedding"
:
LinearParams
(
model
.
template_
a
ngle_embedder
.
linear_1
model
.
template_
embedder
.
template_si
ngle_embedder
.
linear_1
),
"template_projection"
:
LinearParams
(
model
.
template_angle_embedder
.
linear_2
model
.
template_embedder
.
template_single_embedder
.
linear_2
),
}
else
:
temp_embedder
=
model
.
template_embedder
template_param_dict
=
{
"template_embedding"
:
{
"single_template_embedding"
:
{
"query_embedding_norm"
:
LayerNormParams
(
temp_embedder
.
template_pair_embedder
.
query_embedding_layer_norm
),
"template_pair_embedding_0"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
dgram_linear
),
"template_pair_embedding_1"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
pseudo_beta_mask_linear
),
"template_pair_embedding_2"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
aatype_linear_1
),
"template_pair_embedding_3"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
aatype_linear_2
),
"template_pair_embedding_4"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
x_linear
),
"template_pair_embedding_5"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
y_linear
),
"template_pair_embedding_6"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
z_linear
),
"template_pair_embedding_7"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
backbone_mask_linear
),
"template_pair_embedding_8"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
query_embedding_linear
),
"template_embedding_iteration"
:
tps_blocks_params
,
"output_layer_norm"
:
LayerNormParams
(
temp_embedder
.
template_pair_stack
.
layer_norm
),
},
"output_linear"
:
LinearParams
(
temp_embedder
.
linear_t
),
},
"template_projection"
:
LinearParams
(
temp_embedder
.
template_single_embedder
.
template_projector
,
),
"template_single_embedding"
:
LinearParams
(
temp_embedder
.
template_single_embedder
.
template_single_embedder
,
),
}
translations
[
"evoformer"
].
update
(
template_param_dict
)
if
"_ptm"
in
version
:
if
is_multimer
or
"_ptm"
in
version
:
translations
[
"predicted_aligned_error_head"
]
=
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
}
return
translations
def
import_jax_weights_
(
model
,
npz_path
,
version
=
"model_1"
):
data
=
np
.
load
(
npz_path
)
translations
=
generate_translation_dict
(
model
,
version
)
translations
=
generate_translation_dict
(
model
,
version
,
is_multimer
=
(
"multimer"
in
version
))
# Flatten keys and insert missing key prefixes
flat
=
process_translation_dict
(
translations
)
...
...
@@ -453,3 +665,45 @@ def import_jax_weights_(model, npz_path, version="model_1"):
# Set weights
assign
(
flat
,
data
)
def
convert_deprecated_v1_keys
(
state_dict
):
"""Update older OpenFold model weight names to match the current model code."""
replacements
=
{
'template_angle_embedder'
:
'template_single_embedder'
,
'core.msa_transition'
:
'msa_transition'
,
'core.outer_product_mean'
:
'outer_product_mean'
,
'core.tri_'
:
'pair_stack.tri_'
,
'core.pair_transition'
:
'pair_stack.pair_transition'
,
'ipa.linear_q_points'
:
'ipa.linear_q_points.linear'
,
'ipa.linear_kv_points'
:
'ipa.linear_kv_points.linear'
}
convert_key_re
=
re
.
compile
(
"(%s)"
%
"|"
.
join
(
map
(
re
.
escape
,
replacements
.
keys
())))
converted_state_dict
=
{}
for
key
,
value
in
state_dict
.
items
():
# For each match, look-up replacement value in the dictionary
new_key
=
convert_key_re
.
sub
(
lambda
m
:
replacements
[
m
.
group
()],
key
)
# Add prefix for template modules
if
new_key
.
startswith
(
'template'
):
new_key
=
f
'template_embedder.
{
new_key
}
'
converted_state_dict
[
new_key
]
=
value
return
converted_state_dict
def
import_openfold_weights_
(
model
,
state_dict
):
"""
Import model weights. Several parts of the model were refactored in the process
of adding support for Multimer. The state dicts of older models are translated
to match the refactored model code.
"""
try
:
model
.
load_state_dict
(
state_dict
)
except
RuntimeError
:
converted_state_dict
=
convert_deprecated_v1_keys
(
state_dict
)
model
.
load_state_dict
(
converted_state_dict
)
openfold/utils/loss.py
View file @
bb3f51e5
...
...
@@ -13,25 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
logging
import
ml_collections
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
torch.distributions.bernoulli
import
Bernoulli
from
typing
import
Dict
,
Optional
,
Tuple
from
openfold.np
import
residue_constants
from
openfold.utils
import
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.geometry.vector
import
Vec3Array
,
euclidean_distance
from
openfold.utils.all_atom_multimer
import
get_rc_tensor
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
masked_mean
,
permute_final_dims
,
batched_gather
,
)
import
logging
from
openfold.utils.tensor_utils
import
tensor_tree_map
logger
=
logging
.
getLogger
(
__name__
)
def
softmax_cross_entropy
(
logits
,
labels
):
...
...
@@ -87,6 +87,7 @@ def compute_fape(
target_positions
:
torch
.
Tensor
,
positions_mask
:
torch
.
Tensor
,
length_scale
:
float
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
l1_clamp_distance
:
Optional
[
float
]
=
None
,
eps
=
1e-8
,
)
->
torch
.
Tensor
:
...
...
@@ -108,6 +109,9 @@ def compute_fape(
[*, N_pts] positions mask
length_scale:
Length scale by which the loss is divided
pair_mask:
[*, N_frames, N_pts] mask to use for
separating intra- from inter-chain losses.
l1_clamp_distance:
Cutoff above which distance errors are disregarded
eps:
...
...
@@ -134,6 +138,15 @@ def compute_fape(
normed_error
=
normed_error
*
frames_mask
[...,
None
]
normed_error
=
normed_error
*
positions_mask
[...,
None
,
:]
if
pair_mask
is
not
None
:
normed_error
=
normed_error
*
pair_mask
normed_error
=
torch
.
sum
(
normed_error
,
dim
=
(
-
1
,
-
2
))
mask
=
frames_mask
[...,
None
]
*
positions_mask
[...,
None
,
:]
*
pair_mask
norm_factor
=
torch
.
sum
(
mask
,
dim
=
(
-
2
,
-
1
))
normed_error
=
normed_error
/
(
eps
+
norm_factor
)
else
:
# FP16-friendly averaging. Roughly equivalent to:
#
# norm_factor = (
...
...
@@ -157,13 +170,19 @@ def backbone_loss(
backbone_rigid_tensor
:
torch
.
Tensor
,
backbone_rigid_mask
:
torch
.
Tensor
,
traj
:
torch
.
Tensor
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
use_clamped_fape
:
Optional
[
torch
.
Tensor
]
=
None
,
clamp_distance
:
float
=
10.0
,
loss_unit_distance
:
float
=
10.0
,
eps
:
float
=
1e-4
,
**
kwargs
,
)
->
torch
.
Tensor
:
### need to check if the traj belongs to 4*4 matrix or a tensor_7
if
traj
.
shape
[
-
1
]
==
7
:
pred_aff
=
Rigid
.
from_tensor_7
(
traj
)
elif
traj
.
shape
[
-
1
]
==
4
:
pred_aff
=
Rigid
.
from_tensor_4x4
(
traj
)
pred_aff
=
Rigid
(
Rotation
(
rot_mats
=
pred_aff
.
get_rots
().
get_rot_mats
(),
quats
=
None
),
pred_aff
.
get_trans
(),
...
...
@@ -184,6 +203,7 @@ def backbone_loss(
pred_aff
.
get_trans
(),
gt_aff
[
None
].
get_trans
(),
backbone_rigid_mask
[
None
],
pair_mask
=
pair_mask
,
l1_clamp_distance
=
clamp_distance
,
length_scale
=
loss_unit_distance
,
eps
=
eps
,
...
...
@@ -196,6 +216,7 @@ def backbone_loss(
pred_aff
.
get_trans
(),
gt_aff
[
None
].
get_trans
(),
backbone_rigid_mask
[
None
],
pair_mask
=
pair_mask
,
l1_clamp_distance
=
None
,
length_scale
=
loss_unit_distance
,
eps
=
eps
,
...
...
@@ -253,6 +274,7 @@ def sidechain_loss(
sidechain_atom_pos
,
renamed_atom14_gt_positions
,
renamed_atom14_gt_exists
,
pair_mask
=
None
,
l1_clamp_distance
=
clamp_distance
,
length_scale
=
length_scale
,
eps
=
eps
,
...
...
@@ -266,10 +288,28 @@ def fape_loss(
batch
:
Dict
[
str
,
torch
.
Tensor
],
config
:
ml_collections
.
ConfigDict
,
)
->
torch
.
Tensor
:
traj
=
out
[
"sm"
][
"frames"
]
asym_id
=
batch
.
get
(
"asym_id"
)
if
asym_id
is
not
None
:
intra_chain_mask
=
(
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]).
to
(
dtype
=
traj
.
dtype
)
intra_chain_bb_loss
=
backbone_loss
(
traj
=
traj
,
pair_mask
=
intra_chain_mask
,
**
{
**
batch
,
**
config
.
intra_chain_backbone
},
)
interface_bb_loss
=
backbone_loss
(
traj
=
traj
,
pair_mask
=
1.
-
intra_chain_mask
,
**
{
**
batch
,
**
config
.
interface_backbone
},
)
weighted_bb_loss
=
(
intra_chain_bb_loss
*
config
.
intra_chain_backbone
.
weight
+
interface_bb_loss
*
config
.
interface_backbone
.
weight
)
else
:
bb_loss
=
backbone_loss
(
traj
=
out
[
"sm"
][
"frames"
]
,
traj
=
traj
,
**
{
**
batch
,
**
config
.
backbone
},
)
weighted_bb_loss
=
bb_loss
*
config
.
backbone
.
weight
sc_loss
=
sidechain_loss
(
out
[
"sm"
][
"sidechain_frames"
],
...
...
@@ -277,7 +317,7 @@ def fape_loss(
**
{
**
batch
,
**
config
.
sidechain
},
)
loss
=
config
.
backbone
.
weight
*
bb_loss
+
config
.
sidechain
.
weight
*
sc_loss
loss
=
weight
ed_
bb_loss
+
config
.
sidechain
.
weight
*
sc_loss
# Average over the batch dimension
loss
=
torch
.
mean
(
loss
)
...
...
@@ -452,7 +492,7 @@ def lddt_ca(
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
return
lddt
(
all_atom_pred_pos
,
...
...
@@ -482,7 +522,7 @@ def lddt_loss(
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
score
=
lddt
(
all_atom_pred_pos
,
...
...
@@ -492,8 +532,11 @@ def lddt_loss(
eps
=
eps
)
score
=
score
.
detach
()
# TODO: Remove after initial pipeline testing
score
=
torch
.
nan_to_num
(
score
,
nan
=
torch
.
nanmean
(
score
))
score
[
score
<
0
]
=
0
score
=
score
.
detach
()
bin_index
=
torch
.
floor
(
score
*
no_bins
).
long
()
bin_index
=
torch
.
clamp
(
bin_index
,
max
=
(
no_bins
-
1
))
lddt_ca_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
...
...
@@ -627,6 +670,8 @@ def compute_predicted_aligned_error(
def
compute_tm
(
logits
:
torch
.
Tensor
,
residue_weights
:
Optional
[
torch
.
Tensor
]
=
None
,
asym_id
:
Optional
[
torch
.
Tensor
]
=
None
,
interface
:
bool
=
False
,
max_bin
:
int
=
31
,
no_bins
:
int
=
64
,
eps
:
float
=
1e-8
,
...
...
@@ -649,7 +694,22 @@ def compute_tm(
tm_per_bin
=
1.0
/
(
1
+
(
bin_centers
**
2
)
/
(
d0
**
2
))
predicted_tm_term
=
torch
.
sum
(
probs
*
tm_per_bin
,
dim
=-
1
)
normed_residue_mask
=
residue_weights
/
(
eps
+
residue_weights
.
sum
())
n
=
residue_weights
.
shape
[
-
1
]
pair_mask
=
residue_weights
.
new_ones
((
n
,
n
),
dtype
=
torch
.
int32
)
if
interface
and
(
asym_id
is
not
None
):
if
len
(
asym_id
.
shape
)
>
1
:
assert
len
(
asym_id
.
shape
)
<=
2
batch_size
=
asym_id
.
shape
[
0
]
pair_mask
=
residue_weights
.
new_ones
((
batch_size
,
n
,
n
),
dtype
=
torch
.
int32
)
pair_mask
*=
(
asym_id
[...,
None
]
!=
asym_id
[...,
None
,
:]).
to
(
dtype
=
pair_mask
.
dtype
)
predicted_tm_term
*=
pair_mask
pair_residue_weights
=
pair_mask
*
(
residue_weights
[...,
None
,
:]
*
residue_weights
[...,
:,
None
]
)
denom
=
eps
+
torch
.
sum
(
pair_residue_weights
,
dim
=-
1
,
keepdims
=
True
)
normed_residue_mask
=
pair_residue_weights
/
denom
per_alignment
=
torch
.
sum
(
predicted_tm_term
*
normed_residue_mask
,
dim
=-
1
)
weighted
=
per_alignment
*
residue_weights
...
...
@@ -671,7 +731,11 @@ def tm_loss(
eps
=
1e-8
,
**
kwargs
,
):
# first check whether this is a tensor_7 or tensor_4*4
if
final_affine_tensor
.
shape
[
-
1
]
==
7
:
pred_affine
=
Rigid
.
from_tensor_7
(
final_affine_tensor
)
elif
final_affine_tensor
.
shape
[
-
1
]
==
4
:
pred_affine
=
Rigid
.
from_tensor_4x4
(
final_affine_tensor
)
backbone_rigid
=
Rigid
.
from_tensor_4x4
(
backbone_rigid_tensor
)
def
_points
(
affine
):
...
...
@@ -709,7 +773,7 @@ def tm_loss(
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
)
# Average over the
loss
dimension
# Average over the
batch
dimension
loss
=
torch
.
mean
(
loss
)
return
loss
...
...
@@ -784,6 +848,7 @@ def between_residue_bond_loss(
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
1
]
c_n_bond_length_error
=
torch
.
sqrt
(
eps
+
(
c_n_bond_length
-
gt_length
)
**
2
)
c_n_loss_per_residue
=
torch
.
nn
.
functional
.
relu
(
c_n_bond_length_error
-
tolerance_factor_soft
*
gt_stddev
...
...
@@ -879,6 +944,7 @@ def between_residue_clash_loss(
atom14_atom_exists
:
torch
.
Tensor
,
atom14_atom_radius
:
torch
.
Tensor
,
residue_index
:
torch
.
Tensor
,
asym_id
:
Optional
[
torch
.
Tensor
]
=
None
,
overlap_tolerance_soft
=
1.5
,
overlap_tolerance_hard
=
1.5
,
eps
=
1e-10
,
...
...
@@ -908,7 +974,6 @@ def between_residue_clash_loss(
shape (N, 14)
"""
fp_type
=
atom14_pred_positions
.
dtype
# Create the distance matrix.
# (N, N, 14, 14)
dists
=
torch
.
sqrt
(
...
...
@@ -954,9 +1019,13 @@ def between_residue_clash_loss(
)
n_one_hot
=
n_one_hot
.
type
(
fp_type
)
neighbour_mask
=
(
residue_index
[...,
:,
None
,
None
,
None
]
+
1
)
==
residue_index
[...,
None
,
:,
None
,
None
]
neighbour_mask
=
(
residue_index
[...,
:,
None
]
+
1
)
==
residue_index
[...,
None
,
:]
if
asym_id
is
not
None
:
neighbour_mask
=
neighbour_mask
&
(
asym_id
[...,
:,
None
]
==
asym_id
[...,
None
,
:])
neighbour_mask
=
neighbour_mask
[...,
None
,
None
]
c_n_bonds
=
(
neighbour_mask
*
c_one_hot
[...,
None
,
None
,
:,
None
]
...
...
@@ -998,7 +1067,7 @@ def between_residue_clash_loss(
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum
=
torch
.
sum
(
dists_to_low_error
,
dim
=
(
-
4
,
-
2
))
+
torch
.
sum
(
dists_to_low_error
,
axis
=
(
-
3
,
-
1
)
dists_to_low_error
,
dim
=
(
-
3
,
-
1
)
)
# Compute the hard clash mask.
...
...
@@ -1007,17 +1076,20 @@ def between_residue_clash_loss(
dists
<
(
dists_lower_bound
-
overlap_tolerance_hard
)
)
per_atom_num_clash
=
torch
.
sum
(
clash_mask
,
dim
=
(
-
4
,
-
2
))
+
torch
.
sum
(
clash_mask
,
dim
=
(
-
3
,
-
1
))
# Compute the per atom clash.
# shape (N, 14)
per_atom_clash_mask
=
torch
.
maximum
(
torch
.
amax
(
clash_mask
,
axis
=
(
-
4
,
-
2
)),
torch
.
amax
(
clash_mask
,
axis
=
(
-
3
,
-
1
)),
torch
.
amax
(
clash_mask
,
dim
=
(
-
4
,
-
2
)),
torch
.
amax
(
clash_mask
,
dim
=
(
-
3
,
-
1
)),
)
return
{
"mean_loss"
:
mean_loss
,
# shape ()
"per_atom_loss_sum"
:
per_atom_loss_sum
,
# shape (N, 14)
"per_atom_clash_mask"
:
per_atom_clash_mask
,
# shape (N, 14)
"per_atom_num_clash"
:
per_atom_num_clash
# shape (N, 14)
}
...
...
@@ -1097,6 +1169,8 @@ def within_residue_violations(
(
dists
<
atom14_dists_lower_bound
)
|
(
dists
>
atom14_dists_upper_bound
)
)
per_atom_num_clash
=
torch
.
sum
(
violations
,
dim
=-
2
)
+
torch
.
sum
(
violations
,
dim
=-
1
)
# Compute the per atom violations.
per_atom_violations
=
torch
.
maximum
(
torch
.
max
(
violations
,
dim
=-
2
)[
0
],
torch
.
max
(
violations
,
axis
=-
1
)[
0
]
...
...
@@ -1105,6 +1179,7 @@ def within_residue_violations(
return
{
"per_atom_loss_sum"
:
per_atom_loss_sum
,
"per_atom_violations"
:
per_atom_violations
,
"per_atom_num_clash"
:
per_atom_num_clash
}
...
...
@@ -1134,7 +1209,20 @@ def find_structural_violations(
residue_constants
.
van_der_waals_radius
[
name
[
0
]]
for
name
in
residue_constants
.
atom_types
]
atomtype_radius
=
atom14_pred_positions
.
new_tensor
(
atomtype_radius
)
# TODO: Consolidate monomer/multimer modes
asym_id
=
batch
.
get
(
"asym_id"
)
if
asym_id
is
not
None
:
residx_atom14_to_atom37
=
get_rc_tensor
(
residue_constants
.
RESTYPE_ATOM14_TO_ATOM37
,
batch
[
"aatype"
]
)
atom14_atom_radius
=
(
batch
[
"atom14_atom_exists"
]
*
atomtype_radius
[
residx_atom14_to_atom37
.
long
()]
)
else
:
atom14_atom_radius
=
(
batch
[
"atom14_atom_exists"
]
*
atomtype_radius
[
batch
[
"residx_atom14_to_atom37"
]]
...
...
@@ -1146,6 +1234,7 @@ def find_structural_violations(
atom14_atom_exists
=
batch
[
"atom14_atom_exists"
],
atom14_atom_radius
=
atom14_atom_radius
,
residue_index
=
batch
[
"residue_index"
],
asym_id
=
asym_id
,
overlap_tolerance_soft
=
clash_overlap_tolerance
,
overlap_tolerance_hard
=
clash_overlap_tolerance
,
)
...
...
@@ -1208,6 +1297,9 @@ def find_structural_violations(
"clashes_per_atom_clash_mask"
:
between_residue_clashes
[
"per_atom_clash_mask"
],
# (N, 14)
"clashes_per_atom_num_clash"
:
between_residue_clashes
[
"per_atom_num_clash"
],
# (N, 14)
},
"within_residues"
:
{
"per_atom_loss_sum"
:
residue_violations
[
...
...
@@ -1216,6 +1308,9 @@ def find_structural_violations(
"per_atom_violations"
:
residue_violations
[
"per_atom_violations"
],
# (N, 14),
"per_atom_num_clash"
:
residue_violations
[
"per_atom_num_clash"
],
# (N, 14)
},
"total_per_residue_violations_mask"
:
per_residue_violations_mask
,
# (N)
}
...
...
@@ -1337,15 +1432,21 @@ def compute_violation_metrics_np(
def
violation_loss
(
violations
:
Dict
[
str
,
torch
.
Tensor
],
atom14_atom_exists
:
torch
.
Tensor
,
average_clashes
:
bool
=
False
,
eps
=
1e-6
,
**
kwargs
,
)
->
torch
.
Tensor
:
num_atoms
=
torch
.
sum
(
atom14_atom_exists
)
l_clash
=
torch
.
sum
(
violations
[
"between_residues"
][
"clashes_per_atom_loss_sum"
]
+
violations
[
"within_residues"
][
"per_atom_loss_sum"
]
)
l_clash
=
l_clash
/
(
eps
+
num_atoms
)
per_atom_clash
=
(
violations
[
"between_residues"
][
"clashes_per_atom_loss_sum"
]
+
violations
[
"within_residues"
][
"per_atom_loss_sum"
])
if
average_clashes
:
num_clash
=
(
violations
[
"between_residues"
][
"clashes_per_atom_num_clash"
]
+
violations
[
"within_residues"
][
"per_atom_num_clash"
])
per_atom_clash
=
per_atom_clash
/
(
num_clash
+
eps
)
l_clash
=
torch
.
sum
(
per_atom_clash
)
/
(
eps
+
num_atoms
)
loss
=
(
violations
[
"between_residues"
][
"bonds_c_n_loss_mean"
]
+
violations
[
"between_residues"
][
"angles_ca_c_n_loss_mean"
]
...
...
@@ -1491,7 +1592,7 @@ def experimentally_resolved_loss(
return
loss
def
masked_msa_loss
(
logits
,
true_msa
,
bert_mask
,
eps
=
1e-8
,
**
kwargs
):
def
masked_msa_loss
(
logits
,
true_msa
,
bert_mask
,
num_classes
,
eps
=
1e-8
,
**
kwargs
):
"""
Computes BERT-style masked MSA loss. Implements subsection 1.9.9.
...
...
@@ -1503,7 +1604,7 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
Masked MSA loss
"""
errors
=
softmax_cross_entropy
(
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_msa
,
num_classes
=
23
)
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_msa
,
num_classes
=
num_classes
)
)
# FP16-friendly averaging. Equivalent to:
...
...
@@ -1524,13 +1625,75 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
return
loss
def
chain_center_of_mass_loss
(
all_atom_pred_pos
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
asym_id
:
torch
.
Tensor
,
clamp_distance
:
float
=
-
4.0
,
weight
:
float
=
0.05
,
eps
:
float
=
1e-10
,
**
kwargs
)
->
torch
.
Tensor
:
"""
Computes chain centre-of-mass loss. Implements section 2.5, eqn 1 in the Multimer paper.
Args:
all_atom_pred_pos:
[*, N_pts, 37, 3] All-atom predicted atom positions
all_atom_positions:
[*, N_pts, 37, 3] Ground truth all-atom positions
all_atom_mask:
[*, N_pts, 37] All-atom positions mask
asym_id:
[*, N_pts] Chain asym IDs
clamp_distance:
Cutoff above which distance errors are disregarded
weight:
Weight for loss
eps:
Small value used to regularize denominators
Returns:
[*] loss tensor
"""
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
one_hot
=
torch
.
nn
.
functional
.
one_hot
(
asym_id
.
long
()).
to
(
dtype
=
all_atom_mask
.
dtype
)
one_hot
=
one_hot
*
all_atom_mask
chain_pos_mask
=
one_hot
.
transpose
(
-
2
,
-
1
)
chain_exists
=
torch
.
any
(
chain_pos_mask
,
dim
=-
1
).
to
(
dtype
=
all_atom_positions
.
dtype
)
def
get_chain_center_of_mass
(
pos
):
center_sum
=
(
chain_pos_mask
[...,
None
]
*
pos
[...,
None
,
:,
:]).
sum
(
dim
=-
2
)
centers
=
center_sum
/
(
torch
.
sum
(
chain_pos_mask
,
dim
=-
1
,
keepdim
=
True
)
+
eps
)
return
Vec3Array
.
from_array
(
centers
)
pred_centers
=
get_chain_center_of_mass
(
all_atom_pred_pos
)
# [B, NC, 3]
true_centers
=
get_chain_center_of_mass
(
all_atom_positions
)
# [B, NC, 3]
pred_dists
=
euclidean_distance
(
pred_centers
[...,
None
,
:],
pred_centers
[...,
:,
None
],
epsilon
=
eps
)
true_dists
=
euclidean_distance
(
true_centers
[...,
None
,
:],
true_centers
[...,
:,
None
],
epsilon
=
eps
)
losses
=
torch
.
clamp
((
weight
*
(
pred_dists
-
true_dists
-
clamp_distance
)),
max
=
0
)
**
2
loss_mask
=
chain_exists
[...,
:,
None
]
*
chain_exists
[...,
None
,
:]
loss
=
masked_mean
(
loss_mask
,
losses
,
dim
=
(
-
1
,
-
2
))
return
loss
class
AlphaFoldLoss
(
nn
.
Module
):
"""Aggregation of the various losses described in the supplement"""
def
__init__
(
self
,
config
):
super
(
AlphaFoldLoss
,
self
).
__init__
()
self
.
config
=
config
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
):
def
loss
(
self
,
out
,
batch
,
_return_breakdown
=
False
):
"""
Rename previous forward() as loss()
so that can be reused in the subclass
"""
if
"violation"
not
in
out
.
keys
():
out
[
"violation"
]
=
find_structural_violations
(
batch
,
...
...
@@ -1576,31 +1739,36 @@ class AlphaFoldLoss(nn.Module):
),
"violation"
:
lambda
:
violation_loss
(
out
[
"violation"
],
**
batch
,
**
{
**
batch
,
**
self
.
config
.
violation
},
),
}
if
(
self
.
config
.
tm
.
enabled
)
:
if
self
.
config
.
tm
.
enabled
:
loss_fns
[
"tm"
]
=
lambda
:
tm_loss
(
logits
=
out
[
"tm_logits"
],
**
{
**
batch
,
**
out
,
**
self
.
config
.
tm
},
)
if
self
.
config
.
chain_center_of_mass
.
enabled
:
loss_fns
[
"chain_center_of_mass"
]
=
lambda
:
chain_center_of_mass_loss
(
all_atom_pred_pos
=
out
[
"final_atom_positions"
],
**
{
**
batch
,
**
self
.
config
.
chain_center_of_mass
},
)
cum_loss
=
0.
losses
=
{}
for
loss_name
,
loss_fn
in
loss_fns
.
items
():
weight
=
self
.
config
[
loss_name
].
weight
loss
=
loss_fn
()
if
(
torch
.
isnan
(
loss
)
or
torch
.
isinf
(
loss
)
)
:
#for k,v in batch.items():
# if
(
torch.any(torch.isnan(v)) or torch.any(torch.isinf(v))
)
:
if
torch
.
isnan
(
loss
)
or
torch
.
isinf
(
loss
):
#
for k,v in batch.items():
# if
torch.any(torch.isnan(v)) or torch.any(torch.isinf(v)):
# logging.warning(f"{k}: is nan")
#logging.warning(f"{loss_name}: {loss}")
#
logging.warning(f"{loss_name}: {loss}")
logging
.
warning
(
f
"
{
loss_name
}
loss is NaN. Skipping..."
)
loss
=
loss
.
new_tensor
(
0.
,
requires_grad
=
True
)
cum_loss
=
cum_loss
+
weight
*
loss
losses
[
loss_name
]
=
loss
.
detach
().
clone
()
losses
[
"unscaled_loss"
]
=
cum_loss
.
detach
().
clone
()
# Scale the loss by the square root of the minimum of the crop size and
...
...
@@ -1611,7 +1779,15 @@ class AlphaFoldLoss(nn.Module):
losses
[
"loss"
]
=
cum_loss
.
detach
().
clone
()
if
(
not
_return_breakdown
)
:
if
not
_return_breakdown
:
return
cum_loss
return
cum_loss
,
losses
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
):
if
not
_return_breakdown
:
cum_loss
=
self
.
loss
(
out
,
batch
,
_return_breakdown
)
return
cum_loss
else
:
cum_loss
,
losses
=
self
.
loss
(
out
,
batch
,
_return_breakdown
)
return
cum_loss
,
losses
openfold/utils/multi_chain_permutation.py
0 → 100644
View file @
bb3f51e5
import
logging
import
random
import
torch
from
openfold.np
import
residue_constants
as
rc
logger
=
logging
.
getLogger
(
__name__
)
def
compute_rmsd
(
true_atom_pos
:
torch
.
Tensor
,
pred_atom_pos
:
torch
.
Tensor
,
atom_mask
:
torch
.
Tensor
=
None
,
eps
:
float
=
1e-6
,
)
->
torch
.
Tensor
:
sq_diff
=
torch
.
square
(
true_atom_pos
-
pred_atom_pos
).
sum
(
dim
=-
1
,
keepdim
=
False
)
if
atom_mask
is
not
None
:
sq_diff
=
torch
.
masked_select
(
sq_diff
,
atom_mask
.
to
(
sq_diff
.
device
))
msd
=
torch
.
mean
(
sq_diff
)
msd
=
torch
.
nan_to_num
(
msd
,
nan
=
1e8
)
return
torch
.
sqrt
(
msd
+
eps
)
# prevent sqrt 0
def
kabsch_rotation
(
P
,
Q
):
"""
Calculate the best rotation that minimises the RMSD between P and Q.
The optimal rotation matrix was calculated using Kabsch algorithm:
https://en.wikipedia.org/wiki/Kabsch_algorithm
Args:
P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates
Q: [N * 3] the same dimension as P
return:
A 3*3 rotation matrix
"""
assert
P
.
shape
==
torch
.
Size
([
Q
.
shape
[
0
],
Q
.
shape
[
1
]])
# Firstly, compute SVD of P.T * Q
u
,
_
,
vt
=
torch
.
linalg
.
svd
(
torch
.
matmul
(
P
.
to
(
torch
.
float32
).
T
,
Q
.
to
(
torch
.
float32
)))
# Then construct s matrix
s
=
torch
.
eye
(
P
.
shape
[
1
],
device
=
P
.
device
)
# correct the rotation matrix to ensure a right-handed coordinate
s
[
-
1
,
-
1
]
=
torch
.
sign
(
torch
.
linalg
.
det
(
torch
.
matmul
(
u
,
vt
)))
# finally compute the rotation matrix
r_opt
=
torch
.
matmul
(
torch
.
matmul
(
u
,
s
),
vt
)
assert
r_opt
.
shape
==
torch
.
Size
([
3
,
3
])
return
r_opt
.
to
(
device
=
P
.
device
,
dtype
=
P
.
dtype
)
def
get_optimal_transform
(
src_atoms
:
torch
.
Tensor
,
tgt_atoms
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
=
None
,
):
"""
src_atoms: predicted CA positions, shape:[num_res,3]
tgt_atoms: ground-truth CA positions, shape:[num_res,3]
mask: a vector of boolean values, shape:[num_res]
"""
assert
src_atoms
.
shape
==
tgt_atoms
.
shape
,
(
src_atoms
.
shape
,
tgt_atoms
.
shape
)
assert
src_atoms
.
shape
[
-
1
]
==
3
if
mask
is
not
None
:
assert
len
(
mask
.
shape
)
==
1
,
"mask should have the shape of [num_res]"
if
torch
.
isnan
(
src_atoms
).
any
()
or
torch
.
isinf
(
src_atoms
).
any
():
#
# sometimes using fake test inputs generates NaN in the predicted atom positions
# #
logging
.
warning
(
f
"src_atom has nan or inf"
)
src_atoms
=
torch
.
nan_to_num
(
src_atoms
,
nan
=
0.0
,
posinf
=
1.0
,
neginf
=
1.0
)
if
mask
is
not
None
:
assert
mask
.
dtype
==
torch
.
bool
assert
mask
.
shape
[
-
1
]
==
src_atoms
.
shape
[
-
2
]
if
mask
.
sum
()
==
0
:
src_atoms
=
torch
.
zeros
((
1
,
3
),
device
=
src_atoms
.
device
,
dtype
=
src_atoms
.
dtype
)
tgt_atoms
=
src_atoms
else
:
src_atoms
=
src_atoms
[
mask
,
:]
tgt_atoms
=
tgt_atoms
[
mask
,
:]
src_center
=
src_atoms
.
mean
(
-
2
,
keepdim
=
True
,
dtype
=
src_atoms
.
dtype
)
tgt_center
=
tgt_atoms
.
mean
(
-
2
,
keepdim
=
True
,
dtype
=
src_atoms
.
dtype
)
r
=
kabsch_rotation
(
src_atoms
,
tgt_atoms
)
x
=
tgt_center
-
src_center
@
r
return
r
,
x
def
get_least_asym_entity_or_longest_length
(
batch
,
input_asym_id
):
"""
First check how many subunit(s) one sequence has. Select the subunit that is less
common, e.g. if the protein was AABBB then select one of the A as anchor
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor
Args:
batch: in this function batch is the full ground truth features
input_asym_id: A list of asym_ids that are in the cropped input features
Return:
anchor_gt_asym_id: Tensor(int) selected ground truth asym_id
anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates
"""
entity_2_asym_list
=
get_entity_2_asym_list
(
batch
)
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
entity_asym_count
=
{}
entity_length
=
{}
for
entity_id
in
unique_entity_ids
:
asym_ids
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
entity_id
])
# Make sure some asym IDs associated with ground truth entity ID exist in cropped prediction
asym_ids_in_pred
=
[
a
for
a
in
asym_ids
if
a
in
input_asym_id
]
if
not
asym_ids_in_pred
:
continue
entity_asym_count
[
int
(
entity_id
)]
=
len
(
asym_ids
)
# Calculate entity length
entity_mask
=
(
batch
[
"entity_id"
]
==
entity_id
)
entity_length
[
int
(
entity_id
)]
=
entity_mask
.
sum
().
item
()
min_asym_count
=
min
(
entity_asym_count
.
values
())
least_asym_entities
=
[
entity
for
entity
,
count
in
entity_asym_count
.
items
()
if
count
==
min_asym_count
]
# If multiple entities have the least asym_id count, return those with the longest length
if
len
(
least_asym_entities
)
>
1
:
max_length
=
max
([
entity_length
[
entity
]
for
entity
in
least_asym_entities
])
least_asym_entities
=
[
entity
for
entity
in
least_asym_entities
if
entity_length
[
entity
]
==
max_length
]
# If still multiple entities, return a random one
if
len
(
least_asym_entities
)
>
1
:
least_asym_entities
=
[
random
.
choice
(
least_asym_entities
)]
assert
len
(
least_asym_entities
)
==
1
least_asym_entities
=
least_asym_entities
[
0
]
anchor_gt_asym_id
=
random
.
choice
(
entity_2_asym_list
[
least_asym_entities
])
anchor_pred_asym_ids
=
[
asym_id
for
asym_id
in
entity_2_asym_list
[
least_asym_entities
]
if
asym_id
in
input_asym_id
]
return
anchor_gt_asym_id
,
anchor_pred_asym_ids
def
greedy_align
(
batch
,
per_asym_residue_index
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_mask
,
true_ca_poses
,
true_ca_masks
,
):
"""
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
"""
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
align
=
[]
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
for
cur_asym_id
in
unique_asym_ids
:
i
=
int
(
cur_asym_id
-
1
)
asym_mask
=
batch
[
"asym_id"
]
==
cur_asym_id
cur_entity_ids
=
batch
[
"entity_id"
][
asym_mask
][
0
]
best_rmsd
=
torch
.
inf
best_idx
=
None
cur_asym_list
=
entity_2_asym_list
[
int
(
cur_entity_ids
)]
cur_residue_index
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
cur_pred_pos
=
pred_ca_pos
[
asym_mask
]
cur_pred_mask
=
pred_ca_mask
[
asym_mask
]
for
next_asym_id
in
cur_asym_list
:
j
=
int
(
next_asym_id
-
1
)
if
not
used
[
j
]:
# possible candidate
cropped_pos
=
torch
.
index_select
(
true_ca_poses
[
j
],
1
,
cur_residue_index
)
mask
=
torch
.
index_select
(
true_ca_masks
[
j
],
1
,
cur_residue_index
)
rmsd
=
compute_rmsd
(
torch
.
squeeze
(
cropped_pos
,
0
),
torch
.
squeeze
(
cur_pred_pos
,
0
),
(
cur_pred_mask
*
mask
).
bool
()
)
if
rmsd
is
not
None
and
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
best_idx
=
j
assert
best_idx
is
not
None
used
[
best_idx
]
=
True
align
.
append
((
i
,
best_idx
))
return
align
def
pad_features
(
feature_tensor
,
nres_pad
,
pad_dim
):
"""Pad input feature tensor"""
pad_shape
=
list
(
feature_tensor
.
shape
)
pad_shape
[
pad_dim
]
=
nres_pad
padding_tensor
=
feature_tensor
.
new_zeros
(
pad_shape
,
device
=
feature_tensor
.
device
)
return
torch
.
concat
((
feature_tensor
,
padding_tensor
),
dim
=
pad_dim
)
def
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
):
"""
Merge ground truth labels according to the permutation results
labels: list of original ground truth feats
align: list of tuples, each entry specify the corresponding label of the asym.
modified based on UniFold:
https://github.com/dptech-corp/Uni-Fold/blob/b1c89a2cebd4e4ee4c47b4e443f92beeb9138fbb/unifold/losses/chain_align.py#L176C1-L176C1
"""
outs
=
{}
for
k
,
v
in
labels
[
0
].
items
():
cur_out
=
{}
for
i
,
j
in
align
:
label
=
labels
[
j
][
k
]
# to 1-based
cur_residue_index
=
per_asym_residue_index
[
i
+
1
]
if
len
(
v
.
shape
)
<=
1
or
"template"
in
k
or
"row_mask"
in
k
:
continue
else
:
dimension_to_merge
=
1
cur_out
[
i
]
=
label
.
index_select
(
dimension_to_merge
,
cur_residue_index
)
cur_out
=
[
x
[
1
]
for
x
in
sorted
(
cur_out
.
items
())]
if
len
(
cur_out
)
>
0
:
new_v
=
torch
.
concat
(
cur_out
,
dim
=
dimension_to_merge
)
# below check whether padding is needed
if
new_v
.
shape
[
dimension_to_merge
]
!=
original_nres
:
nres_pad
=
original_nres
-
new_v
.
shape
[
dimension_to_merge
]
new_v
=
pad_features
(
new_v
,
nres_pad
,
pad_dim
=
dimension_to_merge
)
outs
[
k
]
=
new_v
return
outs
def
split_ground_truth_labels
(
gt_features
):
"""
Splits ground truth features according to chains
Returns:
a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation
"""
unique_asym_ids
,
asym_id_counts
=
torch
.
unique
(
gt_features
[
"asym_id"
],
sorted
=
True
,
return_counts
=
True
)
n_res
=
gt_features
[
"asym_id"
].
shape
[
-
1
]
def
split_dim
(
shape
):
return
next
(
iter
(
i
for
i
,
size
in
enumerate
(
shape
)
if
size
==
n_res
),
None
)
labels
=
list
(
map
(
dict
,
zip
(
*
[[(
k
,
v
)
for
v
in
torch
.
split
(
v_all
,
asym_id_counts
.
tolist
(),
dim
=
split_dim
(
v_all
.
shape
))]
for
k
,
v_all
in
gt_features
.
items
()
if
n_res
in
v_all
.
shape
])))
return
labels
def
get_per_asym_residue_index
(
features
):
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
features
[
"asym_id"
])
if
i
!=
0
]
per_asym_residue_index
=
{}
for
cur_asym_id
in
unique_asym_ids
:
asym_mask
=
(
features
[
"asym_id"
]
==
cur_asym_id
).
bool
()
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
torch
.
masked_select
(
features
[
"residue_index"
],
asym_mask
)
return
per_asym_residue_index
def
get_entity_2_asym_list
(
batch
):
"""
Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity.
Args:
batch (dict): A dictionary containing data batches, including "entity_id" and "asym_id" tensors.
Returns:
entity_2_asym_list (dict): A dictionary where keys are unique entity IDs, and values are lists of unique asymmetry IDs
associated with each entity.
"""
entity_2_asym_list
=
{}
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
for
cur_ent_id
in
unique_entity_ids
:
ent_mask
=
batch
[
"entity_id"
]
==
cur_ent_id
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
return
entity_2_asym_list
def
calculate_input_mask
(
true_ca_masks
,
anchor_gt_idx
,
anchor_gt_residue
,
asym_mask
,
pred_ca_mask
):
"""
Calculate an input mask for downstream optimal transformation computation
Args:
true_ca_masks (Tensor): ca mask from ground truth.
anchor_gt_idx (Tensor): The index of selected ground truth anchor.
asym_mask (Tensor): Boolean tensor indicating which regions are selected predicted anchor.
pred_ca_mask (Tensor): ca mask from predicted structure.
Returns:
input_mask (Tensor): A boolean mask
"""
pred_ca_mask
=
torch
.
squeeze
(
pred_ca_mask
,
0
)
asym_mask
=
torch
.
squeeze
(
asym_mask
,
0
)
anchor_pred_mask
=
pred_ca_mask
[
asym_mask
]
anchor_true_mask
=
torch
.
index_select
(
true_ca_masks
[
anchor_gt_idx
],
1
,
anchor_gt_residue
)
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
return
input_mask
def
calculate_optimal_transform
(
true_ca_poses
,
anchor_gt_idx
,
anchor_gt_residue
,
true_ca_masks
,
pred_ca_mask
,
asym_mask
,
pred_ca_pos
):
input_mask
=
calculate_input_mask
(
true_ca_masks
,
anchor_gt_idx
,
anchor_gt_residue
,
asym_mask
,
pred_ca_mask
)
input_mask
=
torch
.
squeeze
(
input_mask
,
0
)
pred_ca_pos
=
torch
.
squeeze
(
pred_ca_pos
,
0
)
asym_mask
=
torch
.
squeeze
(
asym_mask
,
0
)
anchor_true_pos
=
torch
.
index_select
(
true_ca_poses
[
anchor_gt_idx
],
1
,
anchor_gt_residue
)
anchor_pred_pos
=
pred_ca_pos
[
asym_mask
]
r
,
x
=
get_optimal_transform
(
anchor_pred_pos
,
torch
.
squeeze
(
anchor_true_pos
,
0
),
mask
=
input_mask
)
return
r
,
x
def
compute_permutation_alignment
(
out
,
features
,
ground_truth
):
"""
A class method that first permutate chains in ground truth first
before calculating the loss.
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
unique_asym_ids
=
set
(
torch
.
unique
(
features
[
'asym_id'
]).
tolist
())
unique_asym_ids
.
discard
(
0
)
# Remove padding asym_id
is_monomer
=
len
(
unique_asym_ids
)
==
1
per_asym_residue_index
=
get_per_asym_residue_index
(
features
)
if
is_monomer
:
best_align
=
list
(
enumerate
(
range
(
len
(
per_asym_residue_index
))))
return
best_align
,
per_asym_residue_index
best_rmsd
=
float
(
'inf'
)
best_align
=
None
# First select anchors from predicted structures and ground truths
anchor_gt_asym
,
anchor_pred_asym_ids
=
get_least_asym_entity_or_longest_length
(
ground_truth
,
features
[
'asym_id'
])
entity_2_asym_list
=
get_entity_2_asym_list
(
ground_truth
)
labels
=
split_ground_truth_labels
(
ground_truth
)
assert
isinstance
(
labels
,
list
)
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
# Then calculate optimal transform by aligning anchors
ca_idx
=
rc
.
atom_order
[
"CA"
]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
pred_ca_mask
=
out
[
"final_atom_mask"
][...,
ca_idx
].
to
(
dtype
=
pred_ca_pos
.
dtype
)
# [bsz, nres]
true_ca_poses
=
[
l
[
"all_atom_positions"
][...,
ca_idx
,
:]
for
l
in
labels
]
# list([nres, 3])
true_ca_masks
=
[
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
]
# list([nres,])
for
candidate_pred_anchor
in
anchor_pred_asym_ids
:
asym_mask
=
(
features
[
"asym_id"
]
==
candidate_pred_anchor
).
bool
()
anchor_gt_residue
=
per_asym_residue_index
[
candidate_pred_anchor
.
item
()]
r
,
x
=
calculate_optimal_transform
(
true_ca_poses
,
anchor_gt_idx
,
anchor_gt_residue
,
true_ca_masks
,
pred_ca_mask
,
asym_mask
,
pred_ca_pos
)
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
align
=
greedy_align
(
features
,
per_asym_residue_index
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_mask
,
aligned_true_ca_poses
,
true_ca_masks
,
)
merged_labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
=
features
[
'aatype'
].
shape
[
-
1
])
rmsd
=
compute_rmsd
(
true_atom_pos
=
merged_labels
[
'all_atom_positions'
][...,
ca_idx
,
:].
to
(
r
.
dtype
)
@
r
+
x
,
pred_atom_pos
=
pred_ca_pos
,
atom_mask
=
(
pred_ca_mask
*
merged_labels
[
'all_atom_mask'
][...,
ca_idx
].
long
()).
bool
())
if
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
best_align
=
align
return
best_align
,
per_asym_residue_index
def
multi_chain_permutation_align
(
out
,
features
,
ground_truth
):
"""Compute multi-chain permutation alignment.
Args:
out: The output of model.forward()
features: Input features
ground_truth: Ground truth features
"""
labels
=
split_ground_truth_labels
(
ground_truth
)
# Then permute ground truth chains before calculating the loss
align
,
per_asym_residue_index
=
compute_permutation_alignment
(
out
=
out
,
features
=
features
,
ground_truth
=
ground_truth
)
# reorder ground truth labels according to permutation results
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
=
features
[
'aatype'
].
shape
[
-
1
])
features
.
update
(
labels
)
return
features
openfold/utils/rigid_utils.py
View file @
bb3f51e5
...
...
@@ -978,6 +978,16 @@ class Rigid:
"""
return
self
.
_trans
.
device
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
"""
Returns the dtype of the Rigid tensors.
Returns:
The dtype of the Rigid tensors
"""
return
self
.
_rots
.
dtype
def
get_rots
(
self
)
->
Rotation
:
"""
Getter for the rotation.
...
...
openfold/utils/script_utils.py
View file @
bb3f51e5
...
...
@@ -12,6 +12,7 @@ from openfold.np import residue_constants, protein
from
openfold.np.relax
import
relax
from
openfold.utils.import_weights
import
(
import_jax_weights_
,
import_openfold_weights_
)
from
pytorch_lightning.utilities.deepspeed
import
(
...
...
@@ -90,7 +91,7 @@ def load_models_from_command_line(config, model_device, openfold_checkpoint_path
ckpt_path
,
)
d
=
torch
.
load
(
ckpt_path
)
model
.
load_
state_dict
(
d
[
"ema"
][
"params"
])
import_openfold_weights_
(
model
=
model
,
state_dict
=
d
[
"ema"
][
"params"
])
else
:
ckpt_path
=
path
d
=
torch
.
load
(
ckpt_path
)
...
...
@@ -98,7 +99,7 @@ def load_models_from_command_line(config, model_device, openfold_checkpoint_path
if
"ema"
in
d
:
# The public weights have had this done to them already
d
=
d
[
"ema"
][
"params"
]
model
.
load_
state_dict
(
d
)
import_openfold_weights_
(
model
=
model
,
state_dict
=
d
)
model
=
model
.
to
(
model_device
)
logger
.
info
(
...
...
@@ -122,7 +123,7 @@ def parse_fasta(data):
][
1
:]
tags
,
seqs
=
lines
[::
2
],
lines
[
1
::
2
]
tags
=
[
t
.
split
()[
0
]
for
t
in
tags
]
tags
=
[
re
.
split
(
'\W| \|'
,
t
)[
0
]
for
t
in
tags
]
return
tags
,
seqs
...
...
@@ -219,7 +220,7 @@ def prep_output(out, batch, feature_dict, feature_processor, config_preset, mult
features
=
batch
,
result
=
out
,
b_factors
=
plddt_b_factors
,
chain_index
=
chain_index
,
remove_leading_feature_dimension
=
False
,
remark
=
remark
,
parents
=
template_domain_names
,
parents_chain_index
=
template_chain_index
,
...
...
run_pretrained_openfold.py
View file @
bb3f51e5
...
...
@@ -17,24 +17,19 @@ import logging
import
math
import
numpy
as
np
import
os
from
openfold.utils.script_utils
import
load_models_from_command_line
,
parse_fasta
,
run_model
,
prep_output
,
\
update_timings
,
relax_protein
import
pickle
import
random
import
time
logging
.
basicConfig
()
logger
=
logging
.
getLogger
(
__file__
)
logger
.
setLevel
(
level
=
logging
.
INFO
)
import
pickle
import
random
import
time
import
torch
torch_versions
=
torch
.
__version__
.
split
(
"."
)
torch_major_version
=
int
(
torch_versions
[
0
])
torch_minor_version
=
int
(
torch_versions
[
1
])
if
(
if
(
torch_major_version
>
1
or
(
torch_major_version
==
1
and
torch_minor_version
>=
12
)
):
...
...
@@ -45,16 +40,16 @@ torch.set_grad_enabled(False)
from
openfold.config
import
model_config
from
openfold.data
import
templates
,
feature_pipeline
,
data_pipeline
from
openfold.np
import
residue_constants
,
protein
import
openfold.np.relax.relax
as
relax
from
openfold.utils.tensor_utils
import
(
tensor_tree_map
,
)
from
openfold.data.tools
import
hhsearch
,
hmmsearch
from
openfold.np
import
protein
from
openfold.utils.script_utils
import
(
load_models_from_command_line
,
parse_fasta
,
run_model
,
prep_output
,
relax_protein
)
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.trace_utils
import
(
pad_feature_dict_seq
,
trace_model_
,
)
from
scripts.precompute_embeddings
import
EmbeddingGenerator
from
scripts.utils
import
add_data_args
...
...
@@ -69,18 +64,30 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
if
(
args
.
use_precomputed_alignments
is
None
and
not
os
.
path
.
isdir
(
local_alignment_dir
)):
if
args
.
use_precomputed_alignments
is
None
:
logger
.
info
(
f
"Generating alignments for
{
tag
}
..."
)
os
.
makedirs
(
local_alignment_dir
)
os
.
makedirs
(
local_alignment_dir
,
exist_ok
=
True
)
if
"multimer"
in
args
.
config_preset
:
template_searcher
=
hmmsearch
.
Hmmsearch
(
binary_path
=
args
.
hmmsearch_binary_path
,
hmmbuild_binary_path
=
args
.
hmmbuild_binary_path
,
database_path
=
args
.
pdb_seqres_database_path
,
)
else
:
template_searcher
=
hhsearch
.
HHSearch
(
binary_path
=
args
.
hhsearch_binary_path
,
databases
=
[
args
.
pdb70_database_path
],
)
# In seqemb mode, use AlignmentRunner only to generate templates
if
args
.
use_single_seq_mode
:
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
template_searcher
=
template_searcher
,
no_cpus
=
args
.
cpus
,
)
embedding_generator
=
EmbeddingGenerator
()
...
...
@@ -89,14 +96,17 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniref30_database_path
=
args
.
uniref30_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
no_cpus
=
args
.
cpus
,
uniprot_database_path
=
args
.
uniprot_database_path
,
template_searcher
=
template_searcher
,
use_small_bfd
=
args
.
bfd_database_path
is
None
,
no_cpus
=
args
.
cpus
)
alignment_runner
.
run
(
tmp_fasta_path
,
local_alignment_dir
)
...
...
@@ -133,6 +143,14 @@ def generate_feature_dict(
alignment_dir
=
local_alignment_dir
,
seqemb_mode
=
args
.
use_single_seq_mode
,
)
elif
"multimer"
in
args
.
config_preset
:
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
'
\n
'
.
join
([
f
">
{
tag
}
\n
{
seq
}
"
for
tag
,
seq
in
zip
(
tags
,
seqs
)])
)
feature_dict
=
data_processor
.
process_fasta
(
fasta_path
=
tmp_fasta_path
,
alignment_dir
=
alignment_dir
,
)
else
:
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
...
...
@@ -147,6 +165,7 @@ def generate_feature_dict(
return
feature_dict
def
list_files_with_extensions
(
dir
,
extensions
):
return
[
f
for
f
in
os
.
listdir
(
dir
)
if
f
.
endswith
(
extensions
)]
...
...
@@ -157,15 +176,28 @@ def main(args):
if
args
.
config_preset
.
startswith
(
"seq"
):
args
.
use_single_seq_mode
=
True
config
=
model_config
(
args
.
config_preset
,
long_sequence_inference
=
args
.
long_sequence_inference
)
if
(
args
.
trace_model
)
:
if
(
not
config
.
data
.
predict
.
fixed_size
)
:
if
args
.
trace_model
:
if
not
config
.
data
.
predict
.
fixed_size
:
raise
ValueError
(
"Tracing requires that fixed_size mode be enabled in the config"
)
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
is_multimer
=
"multimer"
in
args
.
config_preset
if
is_multimer
:
template_featurizer
=
templates
.
HmmsearchHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
max_hits
=
config
.
data
.
predict
.
max_templates
,
kalign_binary_path
=
args
.
kalign_binary_path
,
release_dates_path
=
args
.
release_dates_path
,
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
)
else
:
template_featurizer
=
templates
.
HhsearchHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
max_hits
=
config
.
data
.
predict
.
max_templates
,
...
...
@@ -178,10 +210,15 @@ def main(args):
template_featurizer
=
template_featurizer
,
)
if
is_multimer
:
data_processor
=
data_pipeline
.
DataPipelineMultimer
(
monomer_data_pipeline
=
data_processor
,
)
output_dir_base
=
args
.
output_dir
random_seed
=
args
.
data_random_seed
if
random_seed
is
None
:
random_seed
=
random
.
randrange
(
2
**
32
)
random_seed
=
random
.
randrange
(
2
**
32
)
np
.
random
.
seed
(
random_seed
)
torch
.
manual_seed
(
random_seed
+
1
)
...
...
@@ -198,10 +235,19 @@ def main(args):
seq_list
=
[]
for
fasta_file
in
list_files_with_extensions
(
args
.
fasta_dir
,
(
".fasta"
,
".fa"
)):
# Gather input sequences
with
open
(
os
.
path
.
join
(
args
.
fasta_dir
,
fasta_file
),
"r"
)
as
fp
:
fasta_path
=
os
.
path
.
join
(
args
.
fasta_dir
,
fasta_file
)
with
open
(
fasta_path
,
"r"
)
as
fp
:
data
=
fp
.
read
()
tags
,
seqs
=
parse_fasta
(
data
)
if
not
is_multimer
and
len
(
tags
)
!=
1
:
print
(
f
"
{
fasta_path
}
contains more than one sequence but "
f
"multimer mode is not enabled. Skipping..."
)
continue
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag
=
'-'
.
join
(
tags
)
...
...
@@ -217,6 +263,7 @@ def main(args):
args
.
openfold_checkpoint_path
,
args
.
jax_param_path
,
args
.
output_dir
)
for
model
,
output_directory
in
model_generator
:
cur_tracing_interval
=
0
for
(
tag
,
tags
),
seqs
in
sorted_targets
:
...
...
@@ -228,7 +275,7 @@ def main(args):
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
)
feature_dict
=
feature_dicts
.
get
(
tag
,
None
)
if
(
feature_dict
is
None
)
:
if
feature_dict
is
None
:
feature_dict
=
generate_feature_dict
(
tags
,
seqs
,
...
...
@@ -237,7 +284,7 @@ def main(args):
args
,
)
if
(
args
.
trace_model
)
:
if
args
.
trace_model
:
n
=
feature_dict
[
"aatype"
].
shape
[
-
2
]
rounded_seqlen
=
round_up_seqlen
(
n
)
feature_dict
=
pad_feature_dict_seq
(
...
...
@@ -247,16 +294,16 @@ def main(args):
feature_dicts
[
tag
]
=
feature_dict
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
feature_dict
,
mode
=
'predict'
,
is_multimer
=
is_multimer
)
processed_feature_dict
=
{
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
model_device
)
for
k
,
v
in
processed_feature_dict
.
items
()
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
model_device
)
for
k
,
v
in
processed_feature_dict
.
items
()
}
if
(
args
.
trace_model
)
:
if
(
rounded_seqlen
>
cur_tracing_interval
)
:
if
args
.
trace_model
:
if
rounded_seqlen
>
cur_tracing_interval
:
logger
.
info
(
f
"Tracing model at
{
rounded_seqlen
}
residues..."
)
...
...
@@ -305,7 +352,8 @@ def main(args):
if
not
args
.
skip_relaxation
:
# Relax the prediction.
logger
.
info
(
f
"Running relaxation on
{
unrelaxed_output_path
}
..."
)
relax_protein
(
config
,
args
.
model_device
,
unrelaxed_protein
,
output_directory
,
output_name
,
args
.
cif_output
)
relax_protein
(
config
,
args
.
model_device
,
unrelaxed_protein
,
output_directory
,
output_name
,
args
.
cif_output
)
if
args
.
save_outputs
:
output_dict_path
=
os
.
path
.
join
(
...
...
@@ -407,13 +455,13 @@ if __name__ == "__main__":
add_data_args
(
parser
)
args
=
parser
.
parse_args
()
if
(
args
.
jax_param_path
is
None
and
args
.
openfold_checkpoint_path
is
None
)
:
if
args
.
jax_param_path
is
None
and
args
.
openfold_checkpoint_path
is
None
:
args
.
jax_param_path
=
os
.
path
.
join
(
"openfold"
,
"resources"
,
"params"
,
"params_"
+
args
.
config_preset
+
".npz"
)
if
(
args
.
model_device
==
"cpu"
and
torch
.
cuda
.
is_available
()
)
:
if
args
.
model_device
==
"cpu"
and
torch
.
cuda
.
is_available
():
logging
.
warning
(
"""The model is being run on CPU. Consider specifying
--model_device for better performance"""
...
...
scripts/__init__.py
0 → 100644
View file @
bb3f51e5
scripts/convert_of_weights_to_jax.py
View file @
bb3f51e5
...
...
@@ -26,6 +26,7 @@ from openfold.utils.import_weights import (
ParamType
,
generate_translation_dict
,
process_translation_dict
,
import_openfold_weights_
)
from
openfold.utils.tensor_utils
import
tree_map
...
...
@@ -63,7 +64,7 @@ def main(args):
config
=
model_config
(
args
.
config_preset
)
model
=
AlphaFold
(
config
)
model
.
load_
state_dict
(
d
)
import_openfold_weights_
(
model
=
model
,
state_dict
=
d
)
translation
=
generate_translation_dict
(
model
,
args
.
config_preset
)
translation
=
process_translation_dict
(
translation
)
...
...
scripts/data_dir_to_fasta.py
View file @
bb3f51e5
import
argparse
import
logging
import
os
import
string
from
collections
import
defaultdict
from
openfold.data
import
mmcif_parsing
from
openfold.np
import
protein
,
residue_constants
...
...
@@ -22,7 +23,7 @@ def main(args):
if
(
mmcif
.
mmcif_object
is
None
):
logging
.
warning
(
f
'Failed to parse
{
fname
}
...'
)
if
(
args
.
raise_errors
):
raise
list
(
mmcif
.
errors
.
values
())[
0
]
raise
Exception
(
list
(
mmcif
.
errors
.
values
())[
0
]
)
else
:
continue
...
...
@@ -31,6 +32,25 @@ def main(args):
chain_id
=
'_'
.
join
([
basename
,
chain
])
fasta
.
append
(
f
">
{
chain_id
}
"
)
fasta
.
append
(
seq
)
elif
(
ext
==
".pdb"
):
with
open
(
fpath
,
'r'
)
as
fp
:
pdb_str
=
fp
.
read
()
protein_object
=
protein
.
from_pdb_string
(
pdb_str
)
aatype
=
protein_object
.
aatype
chain_index
=
protein_object
.
chain_index
last_chain_index
=
chain_index
[
0
]
chain_dict
=
defaultdict
(
list
)
for
i
in
range
(
aatype
.
shape
[
0
]):
chain_dict
[
chain_index
[
i
]].
append
(
residue_constants
.
restypes_with_x
[
aatype
[
i
]])
chain_tags
=
string
.
ascii_uppercase
for
chain
,
seq
in
chain_dict
.
items
():
chain_id
=
'_'
.
join
([
basename
,
chain_tags
[
chain
]])
fasta
.
append
(
f
">
{
chain_id
}
"
)
fasta
.
append
(
''
.
join
(
seq
))
elif
(
ext
==
".core"
):
with
open
(
fpath
,
'r'
)
as
fp
:
core_str
=
fp
.
read
()
...
...
scripts/deepspeed_inference_test.py
0 → 100644
View file @
bb3f51e5
import
copy
import
os
import
torch
import
deepspeed
local_rank
=
int
(
os
.
getenv
(
'LOCAL_RANK'
,
'0'
))
world_size
=
int
(
os
.
getenv
(
'WORLD_SIZE'
,
'1'
))
class
Model
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
ml
=
torch
.
nn
.
ModuleList
()
for
_
in
range
(
4000
):
self
.
ml
.
append
(
torch
.
nn
.
Linear
(
500
,
500
))
def
forward
(
self
,
batch
):
for
i
,
l
in
enumerate
(
self
.
ml
):
# print(f"{i}: {l.weight.device}")
batch
=
l
(
batch
)
return
batch
class
DummyDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
):
self
.
batch
=
torch
.
rand
(
500
,
500
)
def
__getitem__
(
self
,
idx
):
return
copy
.
deepcopy
(
self
.
batch
)
def
__len__
(
self
):
return
1000
dd
=
DummyDataset
()
dl
=
torch
.
utils
.
data
.
DataLoader
(
dd
)
example
=
next
(
iter
(
dl
)).
to
(
f
"cuda:
{
local_rank
}
"
)
model
=
Model
()
model
=
model
.
to
(
f
"cuda:
{
local_rank
}
"
)
model
=
deepspeed
.
init_inference
(
model
,
mp_size
=
world_size
,
checkpoint
=
None
,
replace_method
=
None
,
#replace_method="auto"
)
out
=
model
(
example
)
#if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
# print(out)
scripts/download_alphafold_dbs.sh
View file @
bb3f51e5
...
...
@@ -56,10 +56,16 @@ bash "${SCRIPT_DIR}/download_pdb70.sh" "${DOWNLOAD_DIR}"
echo
"Downloading PDB mmCIF files..."
bash
"
${
SCRIPT_DIR
}
/download_pdb_mmcif.sh"
"
${
DOWNLOAD_DIR
}
"
echo
"Downloading Uni
clust
30..."
bash
"
${
SCRIPT_DIR
}
/download_uni
clust
30.sh"
"
${
DOWNLOAD_DIR
}
"
echo
"Downloading Uni
ref
30..."
bash
"
${
SCRIPT_DIR
}
/download_uni
ref
30.sh"
"
${
DOWNLOAD_DIR
}
"
echo
"Downloading Uniref90..."
bash
"
${
SCRIPT_DIR
}
/download_uniref90.sh"
"
${
DOWNLOAD_DIR
}
"
echo
"Downloading UniProt..."
bash
"
${
SCRIPT_DIR
}
/download_uniprot.sh"
"
${
DOWNLOAD_DIR
}
"
echo
"Downloading PDB SeqRes..."
bash
"
${
SCRIPT_DIR
}
/download_pdb_seqres.sh"
"
${
DOWNLOAD_DIR
}
"
echo
"All data downloaded."
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment