Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
56d5e39c
Commit
56d5e39c
authored
Jun 17, 2023
by
Geoffrey Yu
Browse files
Merge remote-tracking branch 'upstream/multimer' into multimer
parents
56b86074
51556d52
Changes
80
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1187 additions
and
240 deletions
+1187
-240
openfold/utils/geometry/vector.py
openfold/utils/geometry/vector.py
+263
-0
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+332
-140
openfold/utils/loss.py
openfold/utils/loss.py
+194
-43
openfold/utils/rigid_utils.py
openfold/utils/rigid_utils.py
+10
-0
openfold/utils/script_utils.py
openfold/utils/script_utils.py
+1
-1
run_pretrained_openfold.py
run_pretrained_openfold.py
+94
-18
scripts/__init__.py
scripts/__init__.py
+0
-0
scripts/data_dir_to_fasta.py
scripts/data_dir_to_fasta.py
+22
-2
scripts/deepspeed_inference_test.py
scripts/deepspeed_inference_test.py
+54
-0
scripts/download_alphafold_dbs.sh
scripts/download_alphafold_dbs.sh
+8
-2
scripts/download_alphafold_params.sh
scripts/download_alphafold_params.sh
+1
-1
scripts/download_mgnify.sh
scripts/download_mgnify.sh
+2
-2
scripts/download_pdb_seqres.sh
scripts/download_pdb_seqres.sh
+42
-0
scripts/download_uniprot.sh
scripts/download_uniprot.sh
+55
-0
scripts/download_uniref30.sh
scripts/download_uniref30.sh
+8
-3
scripts/flatten_roda.sh
scripts/flatten_roda.sh
+6
-1
scripts/generate_alphafold_feature_dict.py
scripts/generate_alphafold_feature_dict.py
+39
-11
scripts/generate_chain_data_cache.py
scripts/generate_chain_data_cache.py
+21
-13
scripts/precompute_alignments.py
scripts/precompute_alignments.py
+20
-3
scripts/utils.py
scripts/utils.py
+15
-0
No files found.
openfold/utils/geometry/vector.py
0 → 100644
View file @
56d5e39c
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vec3Array Class."""
from
__future__
import
annotations
import
dataclasses
from
typing
import
Union
,
List
import
torch
from
openfold.utils.geometry
import
utils
Float
=
Union
[
float
,
torch
.
Tensor
]
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Vec3Array
:
x
:
torch
.
Tensor
=
dataclasses
.
field
(
metadata
=
{
'dtype'
:
torch
.
float32
})
y
:
torch
.
Tensor
z
:
torch
.
Tensor
def
__post_init__
(
self
):
if
hasattr
(
self
.
x
,
'dtype'
):
assert
self
.
x
.
dtype
==
self
.
y
.
dtype
assert
self
.
x
.
dtype
==
self
.
z
.
dtype
assert
all
([
x
==
y
for
x
,
y
in
zip
(
self
.
x
.
shape
,
self
.
y
.
shape
)])
assert
all
([
x
==
z
for
x
,
z
in
zip
(
self
.
x
.
shape
,
self
.
z
.
shape
)])
def
__add__
(
self
,
other
:
Vec3Array
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
+
other
.
x
,
self
.
y
+
other
.
y
,
self
.
z
+
other
.
z
,
)
def
__sub__
(
self
,
other
:
Vec3Array
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
-
other
.
x
,
self
.
y
-
other
.
y
,
self
.
z
-
other
.
z
,
)
def
__mul__
(
self
,
other
:
Float
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
*
other
,
self
.
y
*
other
,
self
.
z
*
other
,
)
def
__rmul__
(
self
,
other
:
Float
)
->
Vec3Array
:
return
self
*
other
def
__truediv__
(
self
,
other
:
Float
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
/
other
,
self
.
y
/
other
,
self
.
z
/
other
,
)
def
__neg__
(
self
)
->
Vec3Array
:
return
self
*
-
1
def
__pos__
(
self
)
->
Vec3Array
:
return
self
*
1
def
__getitem__
(
self
,
index
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
[
index
],
self
.
y
[
index
],
self
.
z
[
index
],
)
def
__iter__
(
self
):
return
iter
((
self
.
x
,
self
.
y
,
self
.
z
))
@
property
def
shape
(
self
):
return
self
.
x
.
shape
def
map_tensor_fn
(
self
,
fn
)
->
Vec3Array
:
return
Vec3Array
(
fn
(
self
.
x
),
fn
(
self
.
y
),
fn
(
self
.
z
),
)
def
cross
(
self
,
other
:
Vec3Array
)
->
Vec3Array
:
"""Compute cross product between 'self' and 'other'."""
new_x
=
self
.
y
*
other
.
z
-
self
.
z
*
other
.
y
new_y
=
self
.
z
*
other
.
x
-
self
.
x
*
other
.
z
new_z
=
self
.
x
*
other
.
y
-
self
.
y
*
other
.
x
return
Vec3Array
(
new_x
,
new_y
,
new_z
)
def
dot
(
self
,
other
:
Vec3Array
)
->
Float
:
"""Compute dot product between 'self' and 'other'."""
return
self
.
x
*
other
.
x
+
self
.
y
*
other
.
y
+
self
.
z
*
other
.
z
def
norm
(
self
,
epsilon
:
float
=
1e-6
)
->
Float
:
"""Compute Norm of Vec3Array, clipped to epsilon."""
# To avoid NaN on the backward pass, we must use maximum before the sqrt
norm2
=
self
.
dot
(
self
)
if
epsilon
:
norm2
=
torch
.
clamp
(
norm2
,
min
=
epsilon
**
2
)
return
torch
.
sqrt
(
norm2
)
def
norm2
(
self
):
return
self
.
dot
(
self
)
def
normalized
(
self
,
epsilon
:
float
=
1e-6
)
->
Vec3Array
:
"""Return unit vector with optional clipping."""
return
self
/
self
.
norm
(
epsilon
)
def
clone
(
self
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
.
clone
(),
self
.
y
.
clone
(),
self
.
z
.
clone
(),
)
def
reshape
(
self
,
new_shape
)
->
Vec3Array
:
x
=
self
.
x
.
reshape
(
new_shape
)
y
=
self
.
y
.
reshape
(
new_shape
)
z
=
self
.
z
.
reshape
(
new_shape
)
return
Vec3Array
(
x
,
y
,
z
)
def
sum
(
self
,
dim
:
int
)
->
Vec3Array
:
return
Vec3Array
(
torch
.
sum
(
self
.
x
,
dim
=
dim
),
torch
.
sum
(
self
.
y
,
dim
=
dim
),
torch
.
sum
(
self
.
z
,
dim
=
dim
),
)
def
unsqueeze
(
self
,
dim
:
int
):
return
Vec3Array
(
self
.
x
.
unsqueeze
(
dim
),
self
.
y
.
unsqueeze
(
dim
),
self
.
z
.
unsqueeze
(
dim
),
)
@
classmethod
def
zeros
(
cls
,
shape
,
device
=
"cpu"
):
"""Return Vec3Array corresponding to zeros of given shape."""
return
cls
(
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
),
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
),
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
)
)
def
to_tensor
(
self
)
->
torch
.
Tensor
:
return
torch
.
stack
([
self
.
x
,
self
.
y
,
self
.
z
],
dim
=-
1
)
@
classmethod
def
from_array
(
cls
,
tensor
):
return
cls
(
*
torch
.
unbind
(
tensor
,
dim
=-
1
))
@
classmethod
def
cat
(
cls
,
vecs
:
List
[
Vec3Array
],
dim
:
int
)
->
Vec3Array
:
return
cls
(
torch
.
cat
([
v
.
x
for
v
in
vecs
],
dim
=
dim
),
torch
.
cat
([
v
.
y
for
v
in
vecs
],
dim
=
dim
),
torch
.
cat
([
v
.
z
for
v
in
vecs
],
dim
=
dim
),
)
def
square_euclidean_distance
(
vec1
:
Vec3Array
,
vec2
:
Vec3Array
,
epsilon
:
float
=
1e-6
)
->
Float
:
"""Computes square of euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute distance to
vec2: Vec3Array to compute distance from, should be
broadcast compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of square euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
difference
=
vec1
-
vec2
distance
=
difference
.
dot
(
difference
)
if
epsilon
:
distance
=
torch
.
clamp
(
distance
,
min
=
epsilon
)
return
distance
def
dot
(
vector1
:
Vec3Array
,
vector2
:
Vec3Array
)
->
Float
:
return
vector1
.
dot
(
vector2
)
def
cross
(
vector1
:
Vec3Array
,
vector2
:
Vec3Array
)
->
Float
:
return
vector1
.
cross
(
vector2
)
def
norm
(
vector
:
Vec3Array
,
epsilon
:
float
=
1e-6
)
->
Float
:
return
vector
.
norm
(
epsilon
)
def
normalized
(
vector
:
Vec3Array
,
epsilon
:
float
=
1e-6
)
->
Vec3Array
:
return
vector
.
normalized
(
epsilon
)
def
euclidean_distance
(
vec1
:
Vec3Array
,
vec2
:
Vec3Array
,
epsilon
:
float
=
1e-6
)
->
Float
:
"""Computes euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute euclidean distance to
vec2: Vec3Array to compute euclidean distance from, should be
broadcast compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
distance_sq
=
square_euclidean_distance
(
vec1
,
vec2
,
epsilon
**
2
)
distance
=
torch
.
sqrt
(
distance_sq
)
return
distance
def
dihedral_angle
(
a
:
Vec3Array
,
b
:
Vec3Array
,
c
:
Vec3Array
,
d
:
Vec3Array
)
->
Float
:
"""Computes torsion angle for a quadruple of points.
For points (a, b, c, d), this is the angle between the planes defined by
points (a, b, c) and (b, c, d). It is also known as the dihedral angle.
Arguments:
a: A Vec3Array of coordinates.
b: A Vec3Array of coordinates.
c: A Vec3Array of coordinates.
d: A Vec3Array of coordinates.
Returns:
A tensor of angles in radians: [-pi, pi].
"""
v1
=
a
-
b
v2
=
b
-
c
v3
=
d
-
c
c1
=
v1
.
cross
(
v2
)
c2
=
v3
.
cross
(
v2
)
c3
=
c2
.
cross
(
c1
)
v2_mag
=
v2
.
norm
()
return
torch
.
atan2
(
c3
.
dot
(
v2
),
v2_mag
*
c1
.
dot
(
c2
))
openfold/utils/import_weights.py
View file @
56d5e39c
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
re
from
enum
import
Enum
from
enum
import
Enum
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
partial
from
functools
import
partial
...
@@ -39,6 +40,13 @@ class ParamType(Enum):
...
@@ -39,6 +40,13 @@ class ParamType(Enum):
LinearWeightOPM
=
partial
(
LinearWeightOPM
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
3
],
-
1
,
w
.
shape
[
-
1
]).
transpose
(
-
1
,
-
2
)
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
3
],
-
1
,
w
.
shape
[
-
1
]).
transpose
(
-
1
,
-
2
)
)
)
LinearWeightMultimer
=
partial
(
lambda
w
:
w
.
unsqueeze
(
-
1
)
if
len
(
w
.
shape
)
==
1
else
w
.
reshape
(
w
.
shape
[
0
],
-
1
).
transpose
(
-
1
,
-
2
)
)
LinearBiasMultimer
=
partial
(
lambda
w
:
w
.
reshape
(
-
1
)
)
Other
=
partial
(
lambda
w
:
w
)
Other
=
partial
(
lambda
w
:
w
)
def
__init__
(
self
,
fn
):
def
__init__
(
self
,
fn
):
...
@@ -122,26 +130,32 @@ def assign(translation_dict, orig_weights):
...
@@ -122,26 +130,32 @@ def assign(translation_dict, orig_weights):
raise
raise
def
generate_translation_dict
(
model
,
version
):
def
generate_translation_dict
(
model
,
version
,
is_multimer
=
False
):
#######################
#######################
# Some templates
# Some templates
#######################
#######################
LinearWeight
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeight
))
LinearWeight
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeight
))
LinearBias
=
lambda
l
:
(
Param
(
l
))
LinearBias
=
lambda
l
:
(
Param
(
l
))
LinearWeightMHA
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightMHA
))
LinearWeightMHA
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightMHA
))
LinearBiasMHA
=
lambda
b
:
(
Param
(
b
,
param_type
=
ParamType
.
LinearBiasMHA
))
LinearBiasMHA
=
lambda
b
:
(
Param
(
b
,
param_type
=
ParamType
.
LinearBiasMHA
))
LinearWeightOPM
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightOPM
))
LinearWeightOPM
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightOPM
))
LinearWeightMultimer
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightMultimer
)
)
LinearBiasMultimer
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearBiasMultimer
)
)
LinearParams
=
lambda
l
:
{
LinearParams
=
lambda
l
:
{
"weights"
:
LinearWeight
(
l
.
weight
),
"weights"
:
LinearWeight
(
l
.
weight
),
"bias"
:
LinearBias
(
l
.
bias
),
"bias"
:
LinearBias
(
l
.
bias
),
}
}
LinearParamsMultimer
=
lambda
l
:
{
"weights"
:
LinearWeightMultimer
(
l
.
weight
),
"bias"
:
LinearBiasMultimer
(
l
.
bias
),
}
LayerNormParams
=
lambda
l
:
{
LayerNormParams
=
lambda
l
:
{
"scale"
:
Param
(
l
.
weight
),
"scale"
:
Param
(
l
.
weight
),
"offset"
:
Param
(
l
.
bias
),
"offset"
:
Param
(
l
.
bias
),
...
@@ -178,31 +192,47 @@ def generate_translation_dict(model, version):
...
@@ -178,31 +192,47 @@ def generate_translation_dict(model, version):
"attention"
:
AttentionGatedParams
(
tri_att
.
mha
),
"attention"
:
AttentionGatedParams
(
tri_att
.
mha
),
}
}
TriMulOutParams
=
lambda
tri_mul
:
{
def
TriMulOutParams
(
tri_mul
,
outgoing
=
True
):
"layer_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
version
):
"left_projection"
:
LinearParams
(
tri_mul
.
linear_a_p
),
d
=
{
"right_projection"
:
LinearParams
(
tri_mul
.
linear_b_p
),
"left_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"left_gate"
:
LinearParams
(
tri_mul
.
linear_a_g
),
"projection"
:
LinearParams
(
tri_mul
.
linear_ab_p
),
"right_gate"
:
LinearParams
(
tri_mul
.
linear_b_g
),
"gate"
:
LinearParams
(
tri_mul
.
linear_ab_g
),
"center_layer_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
"center_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
"output_projection"
:
LinearParams
(
tri_mul
.
linear_z
),
}
"gating_linear"
:
LinearParams
(
tri_mul
.
linear_g
),
else
:
}
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
if
outgoing
:
left_projection
=
LinearParams
(
tri_mul
.
linear_a_p
)
right_projection
=
LinearParams
(
tri_mul
.
linear_b_p
)
left_gate
=
LinearParams
(
tri_mul
.
linear_a_g
)
right_gate
=
LinearParams
(
tri_mul
.
linear_b_g
)
else
:
left_projection
=
LinearParams
(
tri_mul
.
linear_b_p
)
right_projection
=
LinearParams
(
tri_mul
.
linear_a_p
)
left_gate
=
LinearParams
(
tri_mul
.
linear_b_g
)
right_gate
=
LinearParams
(
tri_mul
.
linear_a_g
)
d
=
{
"layer_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"left_projection"
:
left_projection
,
"right_projection"
:
right_projection
,
"left_gate"
:
left_gate
,
"right_gate"
:
right_gate
,
"center_layer_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
}
# see commit b88f8da on the Alphafold repo
d
.
update
({
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
"output_projection"
:
LinearParams
(
tri_mul
.
linear_z
),
# iterations of triangle multiplication, which is confusing and not
"gating_linear"
:
LinearParams
(
tri_mul
.
linear_g
),
# reproduced in our implementation.
})
TriMulInParams
=
lambda
tri_mul
:
{
"layer_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
return
d
"left_projection"
:
LinearParams
(
tri_mul
.
linear_b_p
),
"right_projection"
:
LinearParams
(
tri_mul
.
linear_a_p
),
TriMulInParams
=
partial
(
TriMulOutParams
,
outgoing
=
False
)
"left_gate"
:
LinearParams
(
tri_mul
.
linear_b_g
),
"right_gate"
:
LinearParams
(
tri_mul
.
linear_a_g
),
"center_layer_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
"output_projection"
:
LinearParams
(
tri_mul
.
linear_z
),
"gating_linear"
:
LinearParams
(
tri_mul
.
linear_g
),
}
PairTransitionParams
=
lambda
pt
:
{
PairTransitionParams
=
lambda
pt
:
{
"input_layer_norm"
:
LayerNormParams
(
pt
.
layer_norm
),
"input_layer_norm"
:
LayerNormParams
(
pt
.
layer_norm
),
...
@@ -236,8 +266,46 @@ def generate_translation_dict(model, version):
...
@@ -236,8 +266,46 @@ def generate_translation_dict(model, version):
IPAParams
=
lambda
ipa
:
{
IPAParams
=
lambda
ipa
:
{
"q_scalar"
:
LinearParams
(
ipa
.
linear_q
),
"q_scalar"
:
LinearParams
(
ipa
.
linear_q
),
"kv_scalar"
:
LinearParams
(
ipa
.
linear_kv
),
"kv_scalar"
:
LinearParams
(
ipa
.
linear_kv
),
"q_point_local"
:
LinearParams
(
ipa
.
linear_q_points
),
"q_point_local"
:
LinearParams
(
ipa
.
linear_q_points
.
linear
),
"kv_point_local"
:
LinearParams
(
ipa
.
linear_kv_points
),
"kv_point_local"
:
LinearParams
(
ipa
.
linear_kv_points
.
linear
),
"trainable_point_weights"
:
Param
(
param
=
ipa
.
head_weights
,
param_type
=
ParamType
.
Other
),
"attention_2d"
:
LinearParams
(
ipa
.
linear_b
),
"output_projection"
:
LinearParams
(
ipa
.
linear_out
),
}
PointProjectionParams
=
lambda
pp
:
{
"point_projection"
:
LinearParamsMultimer
(
pp
.
linear
,
),
}
IPAParamsMultimer
=
lambda
ipa
:
{
"q_scalar_projection"
:
{
"weights"
:
LinearWeightMultimer
(
ipa
.
linear_q
.
weight
,
),
},
"k_scalar_projection"
:
{
"weights"
:
LinearWeightMultimer
(
ipa
.
linear_k
.
weight
,
),
},
"v_scalar_projection"
:
{
"weights"
:
LinearWeightMultimer
(
ipa
.
linear_v
.
weight
,
),
},
"q_point_projection"
:
PointProjectionParams
(
ipa
.
linear_q_points
),
"k_point_projection"
:
PointProjectionParams
(
ipa
.
linear_k_points
),
"v_point_projection"
:
PointProjectionParams
(
ipa
.
linear_v_points
),
"trainable_point_weights"
:
Param
(
"trainable_point_weights"
:
Param
(
param
=
ipa
.
head_weights
,
param_type
=
ParamType
.
Other
param
=
ipa
.
head_weights
,
param_type
=
ParamType
.
Other
),
),
...
@@ -280,109 +348,183 @@ def generate_translation_dict(model, version):
...
@@ -280,109 +348,183 @@ def generate_translation_dict(model, version):
b
.
msa_att_row
b
.
msa_att_row
),
),
col_att_name
:
msa_col_att_params
,
col_att_name
:
msa_col_att_params
,
"msa_transition"
:
MSATransitionParams
(
b
.
core
.
msa_transition
),
"msa_transition"
:
MSATransitionParams
(
b
.
msa_transition
),
"outer_product_mean"
:
"outer_product_mean"
:
OuterProductMeanParams
(
b
.
core
.
outer_product_mean
),
OuterProductMeanParams
(
b
.
outer_product_mean
),
"triangle_multiplication_outgoing"
:
"triangle_multiplication_outgoing"
:
TriMulOutParams
(
b
.
core
.
tri_mul_out
),
TriMulOutParams
(
b
.
pair_stack
.
tri_mul_out
),
"triangle_multiplication_incoming"
:
"triangle_multiplication_incoming"
:
TriMulInParams
(
b
.
core
.
tri_mul_in
),
TriMulInParams
(
b
.
pair_stack
.
tri_mul_in
),
"triangle_attention_starting_node"
:
"triangle_attention_starting_node"
:
TriAttParams
(
b
.
core
.
tri_att_start
),
TriAttParams
(
b
.
pair_stack
.
tri_att_start
),
"triangle_attention_ending_node"
:
"triangle_attention_ending_node"
:
TriAttParams
(
b
.
core
.
tri_att_end
),
TriAttParams
(
b
.
pair_stack
.
tri_att_end
),
"pair_transition"
:
"pair_transition"
:
PairTransitionParams
(
b
.
core
.
pair_transition
),
PairTransitionParams
(
b
.
pair_stack
.
pair_transition
),
}
}
return
d
return
d
ExtraMSABlockParams
=
partial
(
EvoformerBlockParams
,
is_extra_msa
=
True
)
ExtraMSABlockParams
=
partial
(
EvoformerBlockParams
,
is_extra_msa
=
True
)
FoldIterationParams
=
lambda
sm
:
{
def
FoldIterationParams
(
sm
):
"invariant_point_attention"
:
IPAParams
(
sm
.
ipa
),
d
=
{
"attention_layer_norm"
:
LayerNormParams
(
sm
.
layer_norm_ipa
),
"invariant_point_attention"
:
"transition"
:
LinearParams
(
sm
.
transition
.
layers
[
0
].
linear_1
),
IPAParamsMultimer
(
sm
.
ipa
)
if
is_multimer
else
IPAParams
(
sm
.
ipa
),
"transition_1"
:
LinearParams
(
sm
.
transition
.
layers
[
0
].
linear_2
),
"attention_layer_norm"
:
LayerNormParams
(
sm
.
layer_norm_ipa
),
"transition_2"
:
LinearParams
(
sm
.
transition
.
layers
[
0
].
linear_3
),
"transition"
:
LinearParams
(
sm
.
transition
.
layers
[
0
].
linear_1
),
"transition_layer_norm"
:
LayerNormParams
(
sm
.
transition
.
layer_norm
),
"transition_1"
:
LinearParams
(
sm
.
transition
.
layers
[
0
].
linear_2
),
"affine_update"
:
LinearParams
(
sm
.
bb_update
.
linear
),
"transition_2"
:
LinearParams
(
sm
.
transition
.
layers
[
0
].
linear_3
),
"rigid_sidechain"
:
{
"transition_layer_norm"
:
LayerNormParams
(
sm
.
transition
.
layer_norm
),
"input_projection"
:
LinearParams
(
sm
.
angle_resnet
.
linear_in
),
"affine_update"
:
LinearParams
(
sm
.
bb_update
.
linear
),
"input_projection_1"
:
LinearParams
(
sm
.
angle_resnet
.
linear_initial
),
"rigid_sidechain"
:
{
"resblock1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
0
].
linear_1
),
"input_projection"
:
LinearParams
(
sm
.
angle_resnet
.
linear_in
),
"resblock2"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
0
].
linear_2
),
"input_projection_1"
:
"resblock1_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_1
),
LinearParams
(
sm
.
angle_resnet
.
linear_initial
),
"resblock2_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_2
),
"resblock1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
0
].
linear_1
),
"unnormalized_angles"
:
LinearParams
(
sm
.
angle_resnet
.
linear_out
),
"resblock2"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
0
].
linear_2
),
},
"resblock1_1"
:
}
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_1
),
"resblock2_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_2
),
"unnormalized_angles"
:
LinearParams
(
sm
.
angle_resnet
.
linear_out
),
},
}
if
(
is_multimer
):
d
.
pop
(
"affine_update"
)
d
[
"quat_rigid"
]
=
{
"rigid"
:
LinearParams
(
sm
.
bb_update
.
linear
)
}
return
d
############################
############################
# translations dict overflow
# translations dict overflow
############################
############################
ems_blocks
=
model
.
extra_msa_stack
.
blocks
ems_blocks
=
model
.
extra_msa_stack
.
blocks
ems_blocks_params
=
stacked
([
ExtraMSABlockParams
(
b
)
for
b
in
ems_blocks
])
ems_blocks_params
=
stacked
([
ExtraMSABlockParams
(
b
)
for
b
in
ems_blocks
])
evo_blocks
=
model
.
evoformer
.
blocks
evo_blocks
=
model
.
evoformer
.
blocks
evo_blocks_params
=
stacked
([
EvoformerBlockParams
(
b
)
for
b
in
evo_blocks
])
evo_blocks_params
=
stacked
([
EvoformerBlockParams
(
b
)
for
b
in
evo_blocks
])
translations
=
{
if
(
not
is_multimer
):
"evoformer"
:
{
translations
=
{
"preprocess_1d"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_m
),
"evoformer"
:
{
"preprocess_msa"
:
LinearParams
(
model
.
input_embedder
.
linear_msa_m
),
"preprocess_1d"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_m
),
"left_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_i
),
"preprocess_msa"
:
LinearParams
(
model
.
input_embedder
.
linear_msa_m
),
"right_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_j
),
"left_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_i
),
"prev_pos_linear"
:
LinearParams
(
model
.
recycling_embedder
.
linear
),
"right_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_j
),
"prev_msa_first_row_norm"
:
LayerNormParams
(
"prev_pos_linear"
:
LinearParams
(
model
.
recycling_embedder
.
linear
),
model
.
recycling_embedder
.
layer_norm_m
"prev_msa_first_row_norm"
:
LayerNormParams
(
),
model
.
recycling_embedder
.
layer_norm_m
"prev_pair_norm"
:
LayerNormParams
(
),
model
.
recycling_embedder
.
layer_norm_z
"prev_pair_norm"
:
LayerNormParams
(
),
model
.
recycling_embedder
.
layer_norm_z
"pair_activiations"
:
LinearParams
(
),
model
.
input_embedder
.
linear_relpos
"pair_activiations"
:
LinearParams
(
),
model
.
input_embedder
.
linear_relpos
"extra_msa_activations"
:
LinearParams
(
),
model
.
extra_msa_embedder
.
linear
"extra_msa_activations"
:
LinearParams
(
),
model
.
extra_msa_embedder
.
linear
"extra_msa_stack"
:
ems_blocks_params
,
),
"evoformer_iteration"
:
evo_blocks_params
,
"extra_msa_stack"
:
ems_blocks_params
,
"single_activations"
:
LinearParams
(
model
.
evoformer
.
linear
),
"evoformer_iteration"
:
evo_blocks_params
,
},
"single_activations"
:
LinearParams
(
model
.
evoformer
.
linear
),
"structure_module"
:
{
},
"single_layer_norm"
:
LayerNormParams
(
"structure_module"
:
{
model
.
structure_module
.
layer_norm_s
"single_layer_norm"
:
LayerNormParams
(
),
model
.
structure_module
.
layer_norm_s
"initial_projection"
:
LinearParams
(
),
model
.
structure_module
.
linear_in
"initial_projection"
:
LinearParams
(
),
model
.
structure_module
.
linear_in
"pair_layer_norm"
:
LayerNormParams
(
),
model
.
structure_module
.
layer_norm_z
"pair_layer_norm"
:
LayerNormParams
(
),
model
.
structure_module
.
layer_norm_z
"fold_iteration"
:
FoldIterationParams
(
model
.
structure_module
),
),
},
"fold_iteration"
:
FoldIterationParams
(
model
.
structure_module
),
"predicted_lddt_head"
:
{
},
"input_layer_norm"
:
LayerNormParams
(
"predicted_lddt_head"
:
{
model
.
aux_heads
.
plddt
.
layer_norm
"input_layer_norm"
:
LayerNormParams
(
),
model
.
aux_heads
.
plddt
.
layer_norm
"act_0"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_1
),
),
"act_1"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_2
),
"act_0"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_1
),
"logits"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_3
),
"act_1"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_2
),
},
"logits"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_3
),
"distogram_head"
:
{
},
"half_logits"
:
LinearParams
(
model
.
aux_heads
.
distogram
.
linear
),
"distogram_head"
:
{
},
"half_logits"
:
LinearParams
(
model
.
aux_heads
.
distogram
.
linear
),
"experimentally_resolved_head"
:
{
},
"logits"
:
LinearParams
(
"experimentally_resolved_head"
:
{
model
.
aux_heads
.
experimentally_resolved
.
linear
"logits"
:
LinearParams
(
),
model
.
aux_heads
.
experimentally_resolved
.
linear
},
),
"masked_msa_head"
:
{
},
"logits"
:
LinearParams
(
model
.
aux_heads
.
masked_msa
.
linear
),
"masked_msa_head"
:
{
},
"logits"
:
LinearParams
(
model
.
aux_heads
.
masked_msa
.
linear
),
}
},
}
else
:
translations
=
{
"evoformer"
:
{
"preprocess_1d"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_m
),
"preprocess_msa"
:
LinearParams
(
model
.
input_embedder
.
linear_msa_m
),
"left_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_i
),
"right_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_j
),
"prev_pos_linear"
:
LinearParams
(
model
.
recycling_embedder
.
linear
),
"prev_msa_first_row_norm"
:
LayerNormParams
(
model
.
recycling_embedder
.
layer_norm_m
),
"prev_pair_norm"
:
LayerNormParams
(
model
.
recycling_embedder
.
layer_norm_z
),
"~_relative_encoding"
:
{
"position_activations"
:
LinearParams
(
model
.
input_embedder
.
linear_relpos
),
},
"extra_msa_activations"
:
LinearParams
(
model
.
extra_msa_embedder
.
linear
),
"extra_msa_stack"
:
ems_blocks_params
,
"evoformer_iteration"
:
evo_blocks_params
,
"single_activations"
:
LinearParams
(
model
.
evoformer
.
linear
),
},
"structure_module"
:
{
"single_layer_norm"
:
LayerNormParams
(
model
.
structure_module
.
layer_norm_s
),
"initial_projection"
:
LinearParams
(
model
.
structure_module
.
linear_in
),
"pair_layer_norm"
:
LayerNormParams
(
model
.
structure_module
.
layer_norm_z
),
"fold_iteration"
:
FoldIterationParams
(
model
.
structure_module
),
},
"predicted_lddt_head"
:
{
"input_layer_norm"
:
LayerNormParams
(
model
.
aux_heads
.
plddt
.
layer_norm
),
"act_0"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_1
),
"act_1"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_2
),
"logits"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_3
),
},
"distogram_head"
:
{
"half_logits"
:
LinearParams
(
model
.
aux_heads
.
distogram
.
linear
),
},
"experimentally_resolved_head"
:
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
experimentally_resolved
.
linear
),
},
"masked_msa_head"
:
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
masked_msa
.
linear
),
},
}
no_templ
=
[
no_templ
=
[
"model_3"
,
"model_3"
,
...
@@ -394,48 +536,98 @@ def generate_translation_dict(model, version):
...
@@ -394,48 +536,98 @@ def generate_translation_dict(model, version):
]
]
if
version
not
in
no_templ
:
if
version
not
in
no_templ
:
tps_blocks
=
model
.
template_pair_stack
.
blocks
tps_blocks
=
model
.
template_
embedder
.
template_
pair_stack
.
blocks
tps_blocks_params
=
stacked
(
tps_blocks_params
=
stacked
(
[
TemplatePairBlockParams
(
b
)
for
b
in
tps_blocks
]
[
TemplatePairBlockParams
(
b
)
for
b
in
tps_blocks
]
)
)
template_param_dict
=
{
if
(
not
is_multimer
):
"template_embedding"
:
{
template_param_dict
=
{
"single_template_embedding"
:
{
"template_embedding"
:
{
"embedding2d"
:
LinearParams
(
"single_template_embedding"
:
{
model
.
template_pair_embedder
.
linear
"embedding2d"
:
LinearParams
(
),
model
.
template_embedder
.
template_pair_embedder
.
linear
"template_pair_stack"
:
{
),
"__layer_stack_no_state"
:
tps_blocks_params
,
"template_pair_stack"
:
{
"__layer_stack_no_state"
:
tps_blocks_params
,
},
"output_layer_norm"
:
LayerNormParams
(
model
.
template_embedder
.
template_pair_stack
.
layer_norm
),
},
},
"output_layer_norm"
:
LayerNormParams
(
"attention"
:
AttentionParams
(
model
.
template_embedder
.
template_pointwise_att
.
mha
),
model
.
template_pair_stack
.
layer_norm
},
"template_single_embedding"
:
LinearParams
(
model
.
template_embedder
.
template_angle_embedder
.
linear_1
),
"template_projection"
:
LinearParams
(
model
.
template_embedder
.
template_angle_embedder
.
linear_2
),
}
else
:
temp_embedder
=
model
.
template_embedder
template_param_dict
=
{
"template_embedding"
:
{
"single_template_embedding"
:
{
"query_embedding_norm"
:
LayerNormParams
(
temp_embedder
.
template_pair_embedder
.
query_embedding_layer_norm
),
"template_pair_embedding_0"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
dgram_linear
),
"template_pair_embedding_1"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
pseudo_beta_mask_linear
),
"template_pair_embedding_2"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
aatype_linear_1
),
"template_pair_embedding_3"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
aatype_linear_2
),
"template_pair_embedding_4"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
x_linear
),
"template_pair_embedding_5"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
y_linear
),
"template_pair_embedding_6"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
z_linear
),
"template_pair_embedding_7"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
backbone_mask_linear
),
"template_pair_embedding_8"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
query_embedding_linear
),
"template_embedding_iteration"
:
tps_blocks_params
,
"output_layer_norm"
:
LayerNormParams
(
model
.
template_embedder
.
template_pair_stack
.
layer_norm
),
},
"output_linear"
:
LinearParams
(
temp_embedder
.
linear_t
),
),
},
},
"attention"
:
AttentionParams
(
model
.
template_pointwise_att
.
mha
),
"template_projection"
:
LinearParams
(
},
temp_embedder
.
template_single_embedder
.
template_projector
,
"template_single_embedding"
:
LinearParams
(
),
model
.
template_angle_embedder
.
linear_1
"template_single_embedding"
:
LinearParams
(
),
temp_embedder
.
template_single_embedder
.
template_single_embedder
,
"template_projection"
:
LinearParams
(
),
model
.
template_angle_embedder
.
linear_2
}
),
}
translations
[
"evoformer"
].
update
(
template_param_dict
)
translations
[
"evoformer"
].
update
(
template_param_dict
)
if
"_ptm"
in
version
:
if
is_multimer
or
"_ptm"
in
version
:
translations
[
"predicted_aligned_error_head"
]
=
{
translations
[
"predicted_aligned_error_head"
]
=
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
}
}
return
translations
return
translations
def
import_jax_weights_
(
model
,
npz_path
,
version
=
"model_1"
):
def
import_jax_weights_
(
model
,
npz_path
,
version
=
"model_1"
):
data
=
np
.
load
(
npz_path
)
data
=
np
.
load
(
npz_path
)
translations
=
generate_translation_dict
(
model
,
version
,
is_multimer
=
(
"multimer"
in
version
))
translations
=
generate_translation_dict
(
model
,
version
)
# Flatten keys and insert missing key prefixes
# Flatten keys and insert missing key prefixes
flat
=
process_translation_dict
(
translations
)
flat
=
process_translation_dict
(
translations
)
...
...
openfold/utils/loss.py
View file @
56d5e39c
...
@@ -25,6 +25,8 @@ from typing import Dict, Optional, Tuple
...
@@ -25,6 +25,8 @@ from typing import Dict, Optional, Tuple
from
openfold.np
import
residue_constants
from
openfold.np
import
residue_constants
from
openfold.utils
import
feats
from
openfold.utils
import
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.geometry.vector
import
Vec3Array
,
euclidean_distance
from
openfold.utils.all_atom_multimer
import
get_rc_tensor
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
tree_map
,
tree_map
,
tensor_tree_map
,
tensor_tree_map
,
...
@@ -87,6 +89,7 @@ def compute_fape(
...
@@ -87,6 +89,7 @@ def compute_fape(
target_positions
:
torch
.
Tensor
,
target_positions
:
torch
.
Tensor
,
positions_mask
:
torch
.
Tensor
,
positions_mask
:
torch
.
Tensor
,
length_scale
:
float
,
length_scale
:
float
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
l1_clamp_distance
:
Optional
[
float
]
=
None
,
l1_clamp_distance
:
Optional
[
float
]
=
None
,
eps
=
1e-8
,
eps
=
1e-8
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -108,6 +111,9 @@ def compute_fape(
...
@@ -108,6 +111,9 @@ def compute_fape(
[*, N_pts] positions mask
[*, N_pts] positions mask
length_scale:
length_scale:
Length scale by which the loss is divided
Length scale by which the loss is divided
pair_mask:
[*, N_frames, N_pts] mask to use for
separating intra- from inter-chain losses.
l1_clamp_distance:
l1_clamp_distance:
Cutoff above which distance errors are disregarded
Cutoff above which distance errors are disregarded
eps:
eps:
...
@@ -134,21 +140,30 @@ def compute_fape(
...
@@ -134,21 +140,30 @@ def compute_fape(
normed_error
=
normed_error
*
frames_mask
[...,
None
]
normed_error
=
normed_error
*
frames_mask
[...,
None
]
normed_error
=
normed_error
*
positions_mask
[...,
None
,
:]
normed_error
=
normed_error
*
positions_mask
[...,
None
,
:]
# FP16-friendly averaging. Roughly equivalent to:
if
pair_mask
is
not
None
:
#
normed_error
=
normed_error
*
pair_mask
# norm_factor = (
normed_error
=
torch
.
sum
(
normed_error
,
dim
=
(
-
1
,
-
2
))
# torch.sum(frames_mask, dim=-1) *
# torch.sum(positions_mask, dim=-1)
mask
=
frames_mask
[...,
None
]
*
positions_mask
[...,
None
,
:]
*
pair_mask
# )
norm_factor
=
torch
.
sum
(
mask
,
dim
=
(
-
2
,
-
1
))
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
#
normed_error
=
normed_error
/
(
eps
+
norm_factor
)
# ("roughly" because eps is necessarily duplicated in the latter)
else
:
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
# FP16-friendly averaging. Roughly equivalent to:
normed_error
=
(
#
normed_error
/
(
eps
+
torch
.
sum
(
frames_mask
,
dim
=-
1
))[...,
None
]
# norm_factor = (
)
# torch.sum(frames_mask, dim=-1) *
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
# torch.sum(positions_mask, dim=-1)
normed_error
=
normed_error
/
(
eps
+
torch
.
sum
(
positions_mask
,
dim
=-
1
))
# )
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
#
# ("roughly" because eps is necessarily duplicated in the latter)
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
(
normed_error
/
(
eps
+
torch
.
sum
(
frames_mask
,
dim
=-
1
))[...,
None
]
)
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
normed_error
/
(
eps
+
torch
.
sum
(
positions_mask
,
dim
=-
1
))
return
normed_error
return
normed_error
...
@@ -157,6 +172,7 @@ def backbone_loss(
...
@@ -157,6 +172,7 @@ def backbone_loss(
backbone_rigid_tensor
:
torch
.
Tensor
,
backbone_rigid_tensor
:
torch
.
Tensor
,
backbone_rigid_mask
:
torch
.
Tensor
,
backbone_rigid_mask
:
torch
.
Tensor
,
traj
:
torch
.
Tensor
,
traj
:
torch
.
Tensor
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
use_clamped_fape
:
Optional
[
torch
.
Tensor
]
=
None
,
use_clamped_fape
:
Optional
[
torch
.
Tensor
]
=
None
,
clamp_distance
:
float
=
10.0
,
clamp_distance
:
float
=
10.0
,
loss_unit_distance
:
float
=
10.0
,
loss_unit_distance
:
float
=
10.0
,
...
@@ -184,6 +200,7 @@ def backbone_loss(
...
@@ -184,6 +200,7 @@ def backbone_loss(
pred_aff
.
get_trans
(),
pred_aff
.
get_trans
(),
gt_aff
[
None
].
get_trans
(),
gt_aff
[
None
].
get_trans
(),
backbone_rigid_mask
[
None
],
backbone_rigid_mask
[
None
],
pair_mask
=
pair_mask
,
l1_clamp_distance
=
clamp_distance
,
l1_clamp_distance
=
clamp_distance
,
length_scale
=
loss_unit_distance
,
length_scale
=
loss_unit_distance
,
eps
=
eps
,
eps
=
eps
,
...
@@ -196,6 +213,7 @@ def backbone_loss(
...
@@ -196,6 +213,7 @@ def backbone_loss(
pred_aff
.
get_trans
(),
pred_aff
.
get_trans
(),
gt_aff
[
None
].
get_trans
(),
gt_aff
[
None
].
get_trans
(),
backbone_rigid_mask
[
None
],
backbone_rigid_mask
[
None
],
pair_mask
=
pair_mask
,
l1_clamp_distance
=
None
,
l1_clamp_distance
=
None
,
length_scale
=
loss_unit_distance
,
length_scale
=
loss_unit_distance
,
eps
=
eps
,
eps
=
eps
,
...
@@ -253,6 +271,7 @@ def sidechain_loss(
...
@@ -253,6 +271,7 @@ def sidechain_loss(
sidechain_atom_pos
,
sidechain_atom_pos
,
renamed_atom14_gt_positions
,
renamed_atom14_gt_positions
,
renamed_atom14_gt_exists
,
renamed_atom14_gt_exists
,
pair_mask
=
None
,
l1_clamp_distance
=
clamp_distance
,
l1_clamp_distance
=
clamp_distance
,
length_scale
=
length_scale
,
length_scale
=
length_scale
,
eps
=
eps
,
eps
=
eps
,
...
@@ -266,10 +285,29 @@ def fape_loss(
...
@@ -266,10 +285,29 @@ def fape_loss(
batch
:
Dict
[
str
,
torch
.
Tensor
],
batch
:
Dict
[
str
,
torch
.
Tensor
],
config
:
ml_collections
.
ConfigDict
,
config
:
ml_collections
.
ConfigDict
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
bb_loss
=
backbone_loss
(
traj
=
out
[
"sm"
][
"frames"
],
traj
=
out
[
"sm"
][
"frames"
]
**
{
**
batch
,
**
config
.
backbone
},
asym_id
=
batch
.
get
(
"asym_id"
)
)
if
asym_id
is
not
None
:
intra_chain_mask
=
(
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]).
to
(
dtype
=
traj
.
dtype
)
intra_chain_bb_loss
=
backbone_loss
(
traj
=
traj
,
pair_mask
=
intra_chain_mask
,
**
{
**
batch
,
**
config
.
intra_chain_backbone
},
)
interface_bb_loss
=
backbone_loss
(
traj
=
traj
,
pair_mask
=
1.
-
intra_chain_mask
,
**
{
**
batch
,
**
config
.
interface_backbone
},
)
weighted_bb_loss
=
(
intra_chain_bb_loss
*
config
.
intra_chain_backbone
.
weight
+
interface_bb_loss
*
config
.
interface_backbone
.
weight
)
else
:
bb_loss
=
backbone_loss
(
traj
=
traj
,
**
{
**
batch
,
**
config
.
backbone
},
)
weighted_bb_loss
=
bb_loss
*
config
.
backbone
.
weight
sc_loss
=
sidechain_loss
(
sc_loss
=
sidechain_loss
(
out
[
"sm"
][
"sidechain_frames"
],
out
[
"sm"
][
"sidechain_frames"
],
...
@@ -277,7 +315,7 @@ def fape_loss(
...
@@ -277,7 +315,7 @@ def fape_loss(
**
{
**
batch
,
**
config
.
sidechain
},
**
{
**
batch
,
**
config
.
sidechain
},
)
)
loss
=
config
.
backbone
.
weight
*
bb_loss
+
config
.
sidechain
.
weight
*
sc_loss
loss
=
weight
ed_
bb_loss
+
config
.
sidechain
.
weight
*
sc_loss
# Average over the batch dimension
# Average over the batch dimension
loss
=
torch
.
mean
(
loss
)
loss
=
torch
.
mean
(
loss
)
...
@@ -627,6 +665,8 @@ def compute_predicted_aligned_error(
...
@@ -627,6 +665,8 @@ def compute_predicted_aligned_error(
def
compute_tm
(
def
compute_tm
(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
residue_weights
:
Optional
[
torch
.
Tensor
]
=
None
,
residue_weights
:
Optional
[
torch
.
Tensor
]
=
None
,
asym_id
:
Optional
[
torch
.
Tensor
]
=
None
,
interface
:
bool
=
False
,
max_bin
:
int
=
31
,
max_bin
:
int
=
31
,
no_bins
:
int
=
64
,
no_bins
:
int
=
64
,
eps
:
float
=
1e-8
,
eps
:
float
=
1e-8
,
...
@@ -649,15 +689,25 @@ def compute_tm(
...
@@ -649,15 +689,25 @@ def compute_tm(
tm_per_bin
=
1.0
/
(
1
+
(
bin_centers
**
2
)
/
(
d0
**
2
))
tm_per_bin
=
1.0
/
(
1
+
(
bin_centers
**
2
)
/
(
d0
**
2
))
predicted_tm_term
=
torch
.
sum
(
probs
*
tm_per_bin
,
dim
=-
1
)
predicted_tm_term
=
torch
.
sum
(
probs
*
tm_per_bin
,
dim
=-
1
)
normed_residue_mask
=
residue_weights
/
(
eps
+
residue_weights
.
sum
())
n
=
residue_weights
.
shape
[
-
1
]
pair_mask
=
residue_weights
.
new_ones
((
n
,
n
),
dtype
=
torch
.
int32
)
if
interface
:
pair_mask
*=
(
asym_id
[...,
None
]
!=
asym_id
[...,
None
,
:]).
to
(
dtype
=
pair_mask
.
dtype
)
predicted_tm_term
*=
pair_mask
pair_residue_weights
=
pair_mask
*
(
residue_weights
[...,
None
,
:]
*
residue_weights
[...,
:,
None
]
)
denom
=
eps
+
torch
.
sum
(
pair_residue_weights
,
dim
=-
1
,
keepdims
=
True
)
normed_residue_mask
=
pair_residue_weights
/
denom
per_alignment
=
torch
.
sum
(
predicted_tm_term
*
normed_residue_mask
,
dim
=-
1
)
per_alignment
=
torch
.
sum
(
predicted_tm_term
*
normed_residue_mask
,
dim
=-
1
)
weighted
=
per_alignment
*
residue_weights
weighted
=
per_alignment
*
residue_weights
argmax
=
(
weighted
==
torch
.
max
(
weighted
)).
nonzero
()[
0
]
argmax
=
(
weighted
==
torch
.
max
(
weighted
)).
nonzero
()[
0
]
return
per_alignment
[
tuple
(
argmax
)]
return
per_alignment
[
tuple
(
argmax
)]
def
tm_loss
(
def
tm_loss
(
logits
,
logits
,
final_affine_tensor
,
final_affine_tensor
,
...
@@ -709,7 +759,7 @@ def tm_loss(
...
@@ -709,7 +759,7 @@ def tm_loss(
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
)
)
# Average over the
loss
dimension
# Average over the
batch
dimension
loss
=
torch
.
mean
(
loss
)
loss
=
torch
.
mean
(
loss
)
return
loss
return
loss
...
@@ -879,6 +929,7 @@ def between_residue_clash_loss(
...
@@ -879,6 +929,7 @@ def between_residue_clash_loss(
atom14_atom_exists
:
torch
.
Tensor
,
atom14_atom_exists
:
torch
.
Tensor
,
atom14_atom_radius
:
torch
.
Tensor
,
atom14_atom_radius
:
torch
.
Tensor
,
residue_index
:
torch
.
Tensor
,
residue_index
:
torch
.
Tensor
,
asym_id
:
Optional
[
torch
.
Tensor
]
=
None
,
overlap_tolerance_soft
=
1.5
,
overlap_tolerance_soft
=
1.5
,
overlap_tolerance_hard
=
1.5
,
overlap_tolerance_hard
=
1.5
,
eps
=
1e-10
,
eps
=
1e-10
,
...
@@ -954,9 +1005,13 @@ def between_residue_clash_loss(
...
@@ -954,9 +1005,13 @@ def between_residue_clash_loss(
)
)
n_one_hot
=
n_one_hot
.
type
(
fp_type
)
n_one_hot
=
n_one_hot
.
type
(
fp_type
)
neighbour_mask
=
(
neighbour_mask
=
(
residue_index
[...,
:,
None
]
+
1
)
==
residue_index
[...,
None
,
:]
residue_index
[...,
:,
None
,
None
,
None
]
+
1
)
==
residue_index
[...,
None
,
:,
None
,
None
]
if
asym_id
is
not
None
:
neighbour_mask
=
neighbour_mask
&
(
asym_id
[...,
:,
None
]
==
asym_id
[...,
None
,
:])
neighbour_mask
=
neighbour_mask
[...,
None
,
None
]
c_n_bonds
=
(
c_n_bonds
=
(
neighbour_mask
neighbour_mask
*
c_one_hot
[...,
None
,
None
,
:,
None
]
*
c_one_hot
[...,
None
,
None
,
:,
None
]
...
@@ -998,26 +1053,29 @@ def between_residue_clash_loss(
...
@@ -998,26 +1053,29 @@ def between_residue_clash_loss(
# Compute the per atom loss sum.
# Compute the per atom loss sum.
# shape (N, 14)
# shape (N, 14)
per_atom_loss_sum
=
torch
.
sum
(
dists_to_low_error
,
dim
=
(
-
4
,
-
2
))
+
torch
.
sum
(
per_atom_loss_sum
=
torch
.
sum
(
dists_to_low_error
,
dim
=
(
-
4
,
-
2
))
+
torch
.
sum
(
dists_to_low_error
,
axis
=
(
-
3
,
-
1
)
dists_to_low_error
,
dim
=
(
-
3
,
-
1
)
)
)
# Compute the hard clash mask.
# Compute the hard clash mask.
# shape (N, N, 14, 14)
# shape (N, N, 14, 14)
clash_mask
=
dists_mask
*
(
clash_mask
=
dists_mask
*
(
dists
<
(
dists_lower_bound
-
overlap_tolerance_hard
)
dists
<
(
dists_lower_bound
-
overlap_tolerance_hard
)
)
)
per_atom_num_clash
=
torch
.
sum
(
clash_mask
,
dim
=
(
-
4
,
-
2
))
+
torch
.
sum
(
clash_mask
,
dim
=
(
-
3
,
-
1
))
# Compute the per atom clash.
# Compute the per atom clash.
# shape (N, 14)
# shape (N, 14)
per_atom_clash_mask
=
torch
.
maximum
(
per_atom_clash_mask
=
torch
.
maximum
(
torch
.
amax
(
clash_mask
,
axis
=
(
-
4
,
-
2
)),
torch
.
amax
(
clash_mask
,
dim
=
(
-
4
,
-
2
)),
torch
.
amax
(
clash_mask
,
axis
=
(
-
3
,
-
1
)),
torch
.
amax
(
clash_mask
,
dim
=
(
-
3
,
-
1
)),
)
)
return
{
return
{
"mean_loss"
:
mean_loss
,
# shape ()
"mean_loss"
:
mean_loss
,
# shape ()
"per_atom_loss_sum"
:
per_atom_loss_sum
,
# shape (N, 14)
"per_atom_loss_sum"
:
per_atom_loss_sum
,
# shape (N, 14)
"per_atom_clash_mask"
:
per_atom_clash_mask
,
# shape (N, 14)
"per_atom_clash_mask"
:
per_atom_clash_mask
,
# shape (N, 14)
"per_atom_num_clash"
:
per_atom_num_clash
# shape (N, 14)
}
}
...
@@ -1097,6 +1155,8 @@ def within_residue_violations(
...
@@ -1097,6 +1155,8 @@ def within_residue_violations(
(
dists
<
atom14_dists_lower_bound
)
|
(
dists
>
atom14_dists_upper_bound
)
(
dists
<
atom14_dists_lower_bound
)
|
(
dists
>
atom14_dists_upper_bound
)
)
)
per_atom_num_clash
=
torch
.
sum
(
violations
,
dim
=-
2
)
+
torch
.
sum
(
violations
,
dim
=-
1
)
# Compute the per atom violations.
# Compute the per atom violations.
per_atom_violations
=
torch
.
maximum
(
per_atom_violations
=
torch
.
maximum
(
torch
.
max
(
violations
,
dim
=-
2
)[
0
],
torch
.
max
(
violations
,
axis
=-
1
)[
0
]
torch
.
max
(
violations
,
dim
=-
2
)[
0
],
torch
.
max
(
violations
,
axis
=-
1
)[
0
]
...
@@ -1105,6 +1165,7 @@ def within_residue_violations(
...
@@ -1105,6 +1165,7 @@ def within_residue_violations(
return
{
return
{
"per_atom_loss_sum"
:
per_atom_loss_sum
,
"per_atom_loss_sum"
:
per_atom_loss_sum
,
"per_atom_violations"
:
per_atom_violations
,
"per_atom_violations"
:
per_atom_violations
,
"per_atom_num_clash"
:
per_atom_num_clash
}
}
...
@@ -1134,11 +1195,24 @@ def find_structural_violations(
...
@@ -1134,11 +1195,24 @@ def find_structural_violations(
residue_constants
.
van_der_waals_radius
[
name
[
0
]]
residue_constants
.
van_der_waals_radius
[
name
[
0
]]
for
name
in
residue_constants
.
atom_types
for
name
in
residue_constants
.
atom_types
]
]
atomtype_radius
=
atom14_pred_positions
.
new_tensor
(
atomtype_radius
)
atomtype_radius
=
atom14_pred_positions
.
new_tensor
(
atomtype_radius
)
atom14_atom_radius
=
(
batch
[
"atom14_atom_exists"
]
#TODO: Consolidate monomer/multimer modes
*
atomtype_radius
[
batch
[
"residx_atom14_to_atom37"
]]
asym_id
=
batch
.
get
(
"asym_id"
)
)
if
asym_id
is
not
None
:
residx_atom14_to_atom37
=
get_rc_tensor
(
residue_constants
.
RESTYPE_ATOM14_TO_ATOM37
,
batch
[
"aatype"
]
)
atom14_atom_radius
=
(
batch
[
"atom14_atom_exists"
]
*
atomtype_radius
[
residx_atom14_to_atom37
.
long
()]
)
else
:
atom14_atom_radius
=
(
batch
[
"atom14_atom_exists"
]
*
atomtype_radius
[
batch
[
"residx_atom14_to_atom37"
]]
)
# Compute the between residue clash loss.
# Compute the between residue clash loss.
between_residue_clashes
=
between_residue_clash_loss
(
between_residue_clashes
=
between_residue_clash_loss
(
...
@@ -1146,6 +1220,7 @@ def find_structural_violations(
...
@@ -1146,6 +1220,7 @@ def find_structural_violations(
atom14_atom_exists
=
batch
[
"atom14_atom_exists"
],
atom14_atom_exists
=
batch
[
"atom14_atom_exists"
],
atom14_atom_radius
=
atom14_atom_radius
,
atom14_atom_radius
=
atom14_atom_radius
,
residue_index
=
batch
[
"residue_index"
],
residue_index
=
batch
[
"residue_index"
],
asym_id
=
asym_id
,
overlap_tolerance_soft
=
clash_overlap_tolerance
,
overlap_tolerance_soft
=
clash_overlap_tolerance
,
overlap_tolerance_hard
=
clash_overlap_tolerance
,
overlap_tolerance_hard
=
clash_overlap_tolerance
,
)
)
...
@@ -1208,6 +1283,9 @@ def find_structural_violations(
...
@@ -1208,6 +1283,9 @@ def find_structural_violations(
"clashes_per_atom_clash_mask"
:
between_residue_clashes
[
"clashes_per_atom_clash_mask"
:
between_residue_clashes
[
"per_atom_clash_mask"
"per_atom_clash_mask"
],
# (N, 14)
],
# (N, 14)
"clashes_per_atom_num_clash"
:
between_residue_clashes
[
"per_atom_num_clash"
],
# (N, 14)
},
},
"within_residues"
:
{
"within_residues"
:
{
"per_atom_loss_sum"
:
residue_violations
[
"per_atom_loss_sum"
:
residue_violations
[
...
@@ -1216,6 +1294,9 @@ def find_structural_violations(
...
@@ -1216,6 +1294,9 @@ def find_structural_violations(
"per_atom_violations"
:
residue_violations
[
"per_atom_violations"
:
residue_violations
[
"per_atom_violations"
"per_atom_violations"
],
# (N, 14),
],
# (N, 14),
"per_atom_num_clash"
:
residue_violations
[
"per_atom_num_clash"
],
# (N, 14)
},
},
"total_per_residue_violations_mask"
:
per_residue_violations_mask
,
# (N)
"total_per_residue_violations_mask"
:
per_residue_violations_mask
,
# (N)
}
}
...
@@ -1337,15 +1418,21 @@ def compute_violation_metrics_np(
...
@@ -1337,15 +1418,21 @@ def compute_violation_metrics_np(
def
violation_loss
(
def
violation_loss
(
violations
:
Dict
[
str
,
torch
.
Tensor
],
violations
:
Dict
[
str
,
torch
.
Tensor
],
atom14_atom_exists
:
torch
.
Tensor
,
atom14_atom_exists
:
torch
.
Tensor
,
average_clashes
:
bool
=
False
,
eps
=
1e-6
,
eps
=
1e-6
,
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_atoms
=
torch
.
sum
(
atom14_atom_exists
)
num_atoms
=
torch
.
sum
(
atom14_atom_exists
)
l_clash
=
torch
.
sum
(
violations
[
"between_residues"
][
"clashes_per_atom_loss_sum"
]
per_atom_clash
=
(
violations
[
"between_residues"
][
"clashes_per_atom_loss_sum"
]
+
+
violations
[
"within_residues"
][
"per_atom_loss_sum"
]
violations
[
"within_residues"
][
"per_atom_loss_sum"
])
)
l_clash
=
l_clash
/
(
eps
+
num_atoms
)
if
average_clashes
:
num_clash
=
(
violations
[
"between_residues"
][
"clashes_per_atom_num_clash"
]
+
violations
[
"within_residues"
][
"per_atom_num_clash"
])
per_atom_clash
=
per_atom_clash
/
(
num_clash
+
eps
)
l_clash
=
torch
.
sum
(
per_atom_clash
)
/
(
eps
+
num_atoms
)
loss
=
(
loss
=
(
violations
[
"between_residues"
][
"bonds_c_n_loss_mean"
]
violations
[
"between_residues"
][
"bonds_c_n_loss_mean"
]
+
violations
[
"between_residues"
][
"angles_ca_c_n_loss_mean"
]
+
violations
[
"between_residues"
][
"angles_ca_c_n_loss_mean"
]
...
@@ -1491,7 +1578,7 @@ def experimentally_resolved_loss(
...
@@ -1491,7 +1578,7 @@ def experimentally_resolved_loss(
return
loss
return
loss
def
masked_msa_loss
(
logits
,
true_msa
,
bert_mask
,
eps
=
1e-8
,
**
kwargs
):
def
masked_msa_loss
(
logits
,
true_msa
,
bert_mask
,
num_classes
,
eps
=
1e-8
,
**
kwargs
):
"""
"""
Computes BERT-style masked MSA loss. Implements subsection 1.9.9.
Computes BERT-style masked MSA loss. Implements subsection 1.9.9.
...
@@ -1503,7 +1590,7 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
...
@@ -1503,7 +1590,7 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
Masked MSA loss
Masked MSA loss
"""
"""
errors
=
softmax_cross_entropy
(
errors
=
softmax_cross_entropy
(
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_msa
,
num_classes
=
23
)
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_msa
,
num_classes
=
num_classes
)
)
)
# FP16-friendly averaging. Equivalent to:
# FP16-friendly averaging. Equivalent to:
...
@@ -1524,6 +1611,64 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
...
@@ -1524,6 +1611,64 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
return
loss
return
loss
def
chain_center_of_mass_loss
(
all_atom_pred_pos
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
asym_id
:
torch
.
Tensor
,
clamp_distance
:
float
=
-
4.0
,
weight
:
float
=
0.05
,
eps
:
float
=
1e-10
)
->
torch
.
Tensor
:
"""
Computes chain centre-of-mass loss. Implements section 2.5, eqn 1 in the Multimer paper.
Args:
all_atom_pred_pos:
[*, N_pts, 37, 3] All-atom predicted atom positions
all_atom_positions:
[*, N_pts, 37, 3] Ground truth all-atom positions
all_atom_mask:
[*, N_pts, 37] All-atom positions mask
asym_id:
[*, N_pts] Chain asym IDs
clamp_distance:
Cutoff above which distance errors are disregarded
weight:
Weight for loss
eps:
Small value used to regularize denominators
Returns:
[*] loss tensor
"""
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
chains
,
_
=
asym_id
.
unique
(
return_counts
=
True
)
one_hot
=
torch
.
nn
.
functional
.
one_hot
(
asym_id
,
num_classes
=
chains
.
shape
[
0
]).
to
(
dtype
=
all_atom_mask
.
dtype
)
one_hot
=
one_hot
*
all_atom_mask
chain_pos_mask
=
one_hot
.
transpose
(
-
2
,
-
1
)
chain_exists
=
torch
.
any
(
chain_pos_mask
,
dim
=-
1
).
float
()
def
get_chain_center_of_mass
(
pos
):
center_sum
=
(
chain_pos_mask
[...,
None
]
*
pos
[...,
None
,
:,
:]).
sum
(
dim
=-
2
)
centers
=
center_sum
/
(
torch
.
sum
(
chain_pos_mask
,
dim
=-
1
,
keepdim
=
True
)
+
eps
)
return
Vec3Array
.
from_array
(
centers
)
pred_centers
=
get_chain_center_of_mass
(
all_atom_pred_pos
)
# [B, NC, 3]
true_centers
=
get_chain_center_of_mass
(
all_atom_positions
)
# [B, NC, 3]
pred_dists
=
euclidean_distance
(
pred_centers
[...,
None
,
:],
pred_centers
[...,
:,
None
],
epsilon
=
eps
)
true_dists
=
euclidean_distance
(
true_centers
[...,
None
,
:],
true_centers
[...,
:,
None
],
epsilon
=
eps
)
losses
=
torch
.
clamp
((
weight
*
(
pred_dists
-
true_dists
-
clamp_distance
)),
max
=
0
)
**
2
loss_mask
=
chain_exists
[...,
:,
None
]
*
chain_exists
[...,
None
,
:]
loss
=
masked_mean
(
loss_mask
,
losses
,
dim
=
(
-
1
,
-
2
))
return
loss
class
AlphaFoldLoss
(
nn
.
Module
):
class
AlphaFoldLoss
(
nn
.
Module
):
"""Aggregation of the various losses described in the supplement"""
"""Aggregation of the various losses described in the supplement"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
@@ -1576,7 +1721,7 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1576,7 +1721,7 @@ class AlphaFoldLoss(nn.Module):
),
),
"violation"
:
lambda
:
violation_loss
(
"violation"
:
lambda
:
violation_loss
(
out
[
"violation"
],
out
[
"violation"
],
**
batch
,
**
{
**
batch
,
**
self
.
config
.
violation
},
),
),
}
}
...
@@ -1586,6 +1731,12 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1586,6 +1731,12 @@ class AlphaFoldLoss(nn.Module):
**
{
**
batch
,
**
out
,
**
self
.
config
.
tm
},
**
{
**
batch
,
**
out
,
**
self
.
config
.
tm
},
)
)
if
(
self
.
config
.
chain_center_of_mass
.
enabled
):
loss_fns
[
"chain_center_of_mass"
]
=
lambda
:
chain_center_of_mass_loss
(
all_atom_pred_pos
=
out
[
"final_atom_positions"
],
**
{
**
batch
,
**
self
.
config
.
chain_center_of_mass
},
)
cum_loss
=
0.
cum_loss
=
0.
losses
=
{}
losses
=
{}
for
loss_name
,
loss_fn
in
loss_fns
.
items
():
for
loss_name
,
loss_fn
in
loss_fns
.
items
():
...
...
openfold/utils/rigid_utils.py
View file @
56d5e39c
...
@@ -978,6 +978,16 @@ class Rigid:
...
@@ -978,6 +978,16 @@ class Rigid:
"""
"""
return
self
.
_trans
.
device
return
self
.
_trans
.
device
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
"""
Returns the dtype of the Rigid tensors.
Returns:
The dtype of the Rigid tensors
"""
return
self
.
_rots
.
dtype
def
get_rots
(
self
)
->
Rotation
:
def
get_rots
(
self
)
->
Rotation
:
"""
"""
Getter for the rotation.
Getter for the rotation.
...
...
openfold/utils/script_utils.py
View file @
56d5e39c
...
@@ -219,7 +219,7 @@ def prep_output(out, batch, feature_dict, feature_processor, config_preset, mult
...
@@ -219,7 +219,7 @@ def prep_output(out, batch, feature_dict, feature_processor, config_preset, mult
features
=
batch
,
features
=
batch
,
result
=
out
,
result
=
out
,
b_factors
=
plddt_b_factors
,
b_factors
=
plddt_b_factors
,
chain_index
=
chain_index
,
remove_leading_feature_dimension
=
not
"multimer"
in
config_preset
,
remark
=
remark
,
remark
=
remark
,
parents
=
template_domain_names
,
parents
=
template_domain_names
,
parents_chain_index
=
template_chain_index
,
parents_chain_index
=
template_chain_index
,
...
...
run_pretrained_openfold.py
View file @
56d5e39c
...
@@ -44,6 +44,9 @@ if(
...
@@ -44,6 +44,9 @@ if(
torch
.
set_grad_enabled
(
False
)
torch
.
set_grad_enabled
(
False
)
from
openfold.config
import
model_config
from
openfold.config
import
model_config
from
openfold.data.tools
import
hhsearch
,
hmmsearch
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_preset_
from
openfold.data
import
templates
,
feature_pipeline
,
data_pipeline
from
openfold.data
import
templates
,
feature_pipeline
,
data_pipeline
from
openfold.np
import
residue_constants
,
protein
from
openfold.np
import
residue_constants
,
protein
import
openfold.np.relax.relax
as
relax
import
openfold.np.relax.relax
as
relax
...
@@ -61,13 +64,19 @@ from scripts.utils import add_data_args
...
@@ -61,13 +64,19 @@ from scripts.utils import add_data_args
TRACING_INTERVAL
=
50
TRACING_INTERVAL
=
50
def
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
):
def
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
,
is_multimer
):
for
tag
,
seq
in
zip
(
tags
,
seqs
):
for
tag
,
seq
in
zip
(
tags
,
seqs
):
tmp_fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
tmp_fasta_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"tmp_
{
os
.
getpid
()
}
.fasta"
)
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
fp
.
write
(
f
">
{
tag
}
\n
{
seq
}
"
)
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
tag
)
if
is_multimer
:
local_alignment_dir
=
alignment_dir
else
:
local_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
os
.
path
.
join
(
alignment_dir
,
tag
),
)
if
(
args
.
use_precomputed_alignments
is
None
and
not
os
.
path
.
isdir
(
local_alignment_dir
)):
if
(
args
.
use_precomputed_alignments
is
None
and
not
os
.
path
.
isdir
(
local_alignment_dir
)):
logger
.
info
(
f
"Generating alignments for
{
tag
}
..."
)
logger
.
info
(
f
"Generating alignments for
{
tag
}
..."
)
...
@@ -76,12 +85,11 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
...
@@ -76,12 +85,11 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniref30_database_path
=
args
.
uniref30_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
no_cpus
=
args
.
cpus
,
no_cpus
=
args
.
cpus
,
)
)
alignment_runner
.
run
(
alignment_runner
.
run
(
...
@@ -118,6 +126,14 @@ def generate_feature_dict(
...
@@ -118,6 +126,14 @@ def generate_feature_dict(
feature_dict
=
data_processor
.
process_fasta
(
feature_dict
=
data_processor
.
process_fasta
(
fasta_path
=
tmp_fasta_path
,
alignment_dir
=
local_alignment_dir
fasta_path
=
tmp_fasta_path
,
alignment_dir
=
local_alignment_dir
)
)
elif
"multimer"
in
args
.
config_preset
:
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
'
\n
'
.
join
([
f
">
{
tag
}
\n
{
seq
}
"
for
tag
,
seq
in
zip
(
tags
,
seqs
)])
)
feature_dict
=
data_processor
.
process_fasta
(
fasta_path
=
tmp_fasta_path
,
alignment_dir
=
alignment_dir
,
)
else
:
else
:
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
with
open
(
tmp_fasta_path
,
"w"
)
as
fp
:
fp
.
write
(
fp
.
write
(
...
@@ -137,7 +153,7 @@ def list_files_with_extensions(dir, extensions):
...
@@ -137,7 +153,7 @@ def list_files_with_extensions(dir, extensions):
def
main
(
args
):
def
main
(
args
):
# Create the output directory
# Create the output directory
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
config
=
model_config
(
args
.
config_preset
,
long_sequence_inference
=
args
.
long_sequence_inference
)
config
=
model_config
(
args
.
config_preset
,
long_sequence_inference
=
args
.
long_sequence_inference
)
...
@@ -148,19 +164,70 @@ def main(args):
...
@@ -148,19 +164,70 @@ def main(args):
"Tracing requires that fixed_size mode be enabled in the config"
"Tracing requires that fixed_size mode be enabled in the config"
)
)
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
is_multimer
=
"multimer"
in
args
.
config_preset
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
if
(
is_multimer
):
max_hits
=
config
.
data
.
predict
.
max_templates
,
if
(
not
args
.
use_precomputed_alignments
):
kalign_binary_path
=
args
.
kalign_binary_path
,
template_searcher
=
hmmsearch
.
Hmmsearch
(
release_dates_path
=
args
.
release_dates_path
,
binary_path
=
args
.
hmmsearch_binary_path
,
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
hmmbuild_binary_path
=
args
.
hmmbuild_binary_path
,
)
database_path
=
args
.
pdb_seqres_database_path
,
)
else
:
template_searcher
=
None
template_featurizer
=
templates
.
HmmsearchHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
max_hits
=
config
.
data
.
predict
.
max_templates
,
kalign_binary_path
=
args
.
kalign_binary_path
,
release_dates_path
=
args
.
release_dates_path
,
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
)
else
:
if
(
not
args
.
use_precomputed_alignments
):
template_searcher
=
hhsearch
.
HHSearch
(
binary_path
=
args
.
hhsearch_binary_path
,
databases
=
[
args
.
pdb70_database_path
],
)
else
:
template_searcher
=
None
template_featurizer
=
templates
.
HhsearchHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
max_hits
=
config
.
data
.
predict
.
max_templates
,
kalign_binary_path
=
args
.
kalign_binary_path
,
release_dates_path
=
args
.
release_dates_path
,
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
)
if
(
not
args
.
use_precomputed_alignments
):
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniref30_database_path
=
args
.
uniref30_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
uniprot_database_path
=
args
.
uniprot_database_path
,
template_searcher
=
template_searcher
,
use_small_bfd
=
(
args
.
bfd_database_path
is
None
),
no_cpus
=
args
.
cpus
,
)
else
:
alignment_runner
=
None
data_processor
=
data_pipeline
.
DataPipeline
(
data_processor
=
data_pipeline
.
DataPipeline
(
template_featurizer
=
template_featurizer
,
template_featurizer
=
template_featurizer
,
)
)
if
(
is_multimer
):
data_processor
=
data_pipeline
.
DataPipelineMultimer
(
monomer_data_pipeline
=
data_processor
,
)
output_dir_base
=
args
.
output_dir
output_dir_base
=
args
.
output_dir
random_seed
=
args
.
data_random_seed
random_seed
=
args
.
data_random_seed
if
random_seed
is
None
:
if
random_seed
is
None
:
...
@@ -181,10 +248,19 @@ def main(args):
...
@@ -181,10 +248,19 @@ def main(args):
seq_list
=
[]
seq_list
=
[]
for
fasta_file
in
list_files_with_extensions
(
args
.
fasta_dir
,
(
".fasta"
,
".fa"
)):
for
fasta_file
in
list_files_with_extensions
(
args
.
fasta_dir
,
(
".fasta"
,
".fa"
)):
# Gather input sequences
# Gather input sequences
with
open
(
os
.
path
.
join
(
args
.
fasta_dir
,
fasta_file
),
"r"
)
as
fp
:
fasta_path
=
os
.
path
.
join
(
args
.
fasta_dir
,
fasta_file
)
with
open
(
fasta_path
,
"r"
)
as
fp
:
data
=
fp
.
read
()
data
=
fp
.
read
()
tags
,
seqs
=
parse_fasta
(
data
)
tags
,
seqs
=
parse_fasta
(
data
)
if
((
not
is_multimer
)
and
len
(
tags
)
!=
1
):
print
(
f
"
{
fasta_path
}
contains more than one sequence but "
f
"multimer mode is not enabled. Skipping..."
)
continue
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag
=
'-'
.
join
(
tags
)
tag
=
'-'
.
join
(
tags
)
...
@@ -208,7 +284,7 @@ def main(args):
...
@@ -208,7 +284,7 @@ def main(args):
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
output_name
=
f
'
{
output_name
}
_
{
args
.
output_postfix
}
'
# Does nothing if the alignments have already been computed
# Does nothing if the alignments have already been computed
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
)
precompute_alignments
(
tags
,
seqs
,
alignment_dir
,
args
,
is_multimer
)
feature_dict
=
feature_dicts
.
get
(
tag
,
None
)
feature_dict
=
feature_dicts
.
get
(
tag
,
None
)
if
(
feature_dict
is
None
):
if
(
feature_dict
is
None
):
...
@@ -230,7 +306,7 @@ def main(args):
...
@@ -230,7 +306,7 @@ def main(args):
feature_dicts
[
tag
]
=
feature_dict
feature_dicts
[
tag
]
=
feature_dict
processed_feature_dict
=
feature_processor
.
process_features
(
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
mode
=
'predict'
,
feature_dict
,
mode
=
'predict'
,
is_multimer
=
is_multimer
)
)
processed_feature_dict
=
{
processed_feature_dict
=
{
...
@@ -238,8 +314,8 @@ def main(args):
...
@@ -238,8 +314,8 @@ def main(args):
for
k
,
v
in
processed_feature_dict
.
items
()
for
k
,
v
in
processed_feature_dict
.
items
()
}
}
if
(
args
.
trace_model
):
if
(
args
.
trace_model
):
if
(
rounded_seqlen
>
cur_tracing_interval
):
if
(
rounded_seqlen
>
cur_tracing_interval
):
logger
.
info
(
logger
.
info
(
f
"Tracing model at
{
rounded_seqlen
}
residues..."
f
"Tracing model at
{
rounded_seqlen
}
residues..."
)
)
...
...
scripts/__init__.py
0 → 100644
View file @
56d5e39c
scripts/data_dir_to_fasta.py
View file @
56d5e39c
import
argparse
import
argparse
import
logging
import
logging
import
os
import
os
import
string
from
collections
import
defaultdict
from
openfold.data
import
mmcif_parsing
from
openfold.data
import
mmcif_parsing
from
openfold.np
import
protein
,
residue_constants
from
openfold.np
import
protein
,
residue_constants
...
@@ -22,7 +23,7 @@ def main(args):
...
@@ -22,7 +23,7 @@ def main(args):
if
(
mmcif
.
mmcif_object
is
None
):
if
(
mmcif
.
mmcif_object
is
None
):
logging
.
warning
(
f
'Failed to parse
{
fname
}
...'
)
logging
.
warning
(
f
'Failed to parse
{
fname
}
...'
)
if
(
args
.
raise_errors
):
if
(
args
.
raise_errors
):
raise
list
(
mmcif
.
errors
.
values
())[
0
]
raise
Exception
(
list
(
mmcif
.
errors
.
values
())[
0
]
)
else
:
else
:
continue
continue
...
@@ -31,6 +32,25 @@ def main(args):
...
@@ -31,6 +32,25 @@ def main(args):
chain_id
=
'_'
.
join
([
basename
,
chain
])
chain_id
=
'_'
.
join
([
basename
,
chain
])
fasta
.
append
(
f
">
{
chain_id
}
"
)
fasta
.
append
(
f
">
{
chain_id
}
"
)
fasta
.
append
(
seq
)
fasta
.
append
(
seq
)
elif
(
ext
==
".pdb"
):
with
open
(
fpath
,
'r'
)
as
fp
:
pdb_str
=
fp
.
read
()
protein_object
=
protein
.
from_pdb_string
(
pdb_str
)
aatype
=
protein_object
.
aatype
chain_index
=
protein_object
.
chain_index
last_chain_index
=
chain_index
[
0
]
chain_dict
=
defaultdict
(
list
)
for
i
in
range
(
aatype
.
shape
[
0
]):
chain_dict
[
chain_index
[
i
]].
append
(
residue_constants
.
restypes_with_x
[
aatype
[
i
]])
chain_tags
=
string
.
ascii_uppercase
for
chain
,
seq
in
chain_dict
.
items
():
chain_id
=
'_'
.
join
([
basename
,
chain_tags
[
chain
]])
fasta
.
append
(
f
">
{
chain_id
}
"
)
fasta
.
append
(
''
.
join
(
seq
))
elif
(
ext
==
".core"
):
elif
(
ext
==
".core"
):
with
open
(
fpath
,
'r'
)
as
fp
:
with
open
(
fpath
,
'r'
)
as
fp
:
core_str
=
fp
.
read
()
core_str
=
fp
.
read
()
...
...
scripts/deepspeed_inference_test.py
0 → 100644
View file @
56d5e39c
import
copy
import
os
import
torch
import
deepspeed
local_rank
=
int
(
os
.
getenv
(
'LOCAL_RANK'
,
'0'
))
world_size
=
int
(
os
.
getenv
(
'WORLD_SIZE'
,
'1'
))
class
Model
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
ml
=
torch
.
nn
.
ModuleList
()
for
_
in
range
(
4000
):
self
.
ml
.
append
(
torch
.
nn
.
Linear
(
500
,
500
))
def
forward
(
self
,
batch
):
for
i
,
l
in
enumerate
(
self
.
ml
):
# print(f"{i}: {l.weight.device}")
batch
=
l
(
batch
)
return
batch
class
DummyDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
):
self
.
batch
=
torch
.
rand
(
500
,
500
)
def
__getitem__
(
self
,
idx
):
return
copy
.
deepcopy
(
self
.
batch
)
def
__len__
(
self
):
return
1000
dd
=
DummyDataset
()
dl
=
torch
.
utils
.
data
.
DataLoader
(
dd
)
example
=
next
(
iter
(
dl
)).
to
(
f
"cuda:
{
local_rank
}
"
)
model
=
Model
()
model
=
model
.
to
(
f
"cuda:
{
local_rank
}
"
)
model
=
deepspeed
.
init_inference
(
model
,
mp_size
=
world_size
,
checkpoint
=
None
,
replace_method
=
None
,
#replace_method="auto"
)
out
=
model
(
example
)
#if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
# print(out)
scripts/download_alphafold_dbs.sh
View file @
56d5e39c
...
@@ -56,10 +56,16 @@ bash "${SCRIPT_DIR}/download_pdb70.sh" "${DOWNLOAD_DIR}"
...
@@ -56,10 +56,16 @@ bash "${SCRIPT_DIR}/download_pdb70.sh" "${DOWNLOAD_DIR}"
echo
"Downloading PDB mmCIF files..."
echo
"Downloading PDB mmCIF files..."
bash
"
${
SCRIPT_DIR
}
/download_pdb_mmcif.sh"
"
${
DOWNLOAD_DIR
}
"
bash
"
${
SCRIPT_DIR
}
/download_pdb_mmcif.sh"
"
${
DOWNLOAD_DIR
}
"
echo
"Downloading Uni
clust
30..."
echo
"Downloading Uni
ref
30..."
bash
"
${
SCRIPT_DIR
}
/download_uni
clust
30.sh"
"
${
DOWNLOAD_DIR
}
"
bash
"
${
SCRIPT_DIR
}
/download_uni
ref
30.sh"
"
${
DOWNLOAD_DIR
}
"
echo
"Downloading Uniref90..."
echo
"Downloading Uniref90..."
bash
"
${
SCRIPT_DIR
}
/download_uniref90.sh"
"
${
DOWNLOAD_DIR
}
"
bash
"
${
SCRIPT_DIR
}
/download_uniref90.sh"
"
${
DOWNLOAD_DIR
}
"
echo
"Downloading UniProt..."
bash
"
${
SCRIPT_DIR
}
/download_uniprot.sh"
"
${
DOWNLOAD_DIR
}
"
echo
"Downloading PDB SeqRes..."
bash
"
${
SCRIPT_DIR
}
/download_pdb_seqres.sh"
"
${
DOWNLOAD_DIR
}
"
echo
"All data downloaded."
echo
"All data downloaded."
scripts/download_alphafold_params.sh
View file @
56d5e39c
...
@@ -31,7 +31,7 @@ fi
...
@@ -31,7 +31,7 @@ fi
DOWNLOAD_DIR
=
"
$1
"
DOWNLOAD_DIR
=
"
$1
"
ROOT_DIR
=
"
${
DOWNLOAD_DIR
}
/params"
ROOT_DIR
=
"
${
DOWNLOAD_DIR
}
/params"
SOURCE_URL
=
"https://storage.googleapis.com/alphafold/alphafold_params_2022-
01-19
.tar"
SOURCE_URL
=
"https://storage.googleapis.com/alphafold/alphafold_params_2022-
12-06
.tar"
BASENAME
=
$(
basename
"
${
SOURCE_URL
}
"
)
BASENAME
=
$(
basename
"
${
SOURCE_URL
}
"
)
mkdir
--parents
"
${
ROOT_DIR
}
"
mkdir
--parents
"
${
ROOT_DIR
}
"
...
...
scripts/download_mgnify.sh
View file @
56d5e39c
...
@@ -32,8 +32,8 @@ fi
...
@@ -32,8 +32,8 @@ fi
DOWNLOAD_DIR
=
"
$1
"
DOWNLOAD_DIR
=
"
$1
"
ROOT_DIR
=
"
${
DOWNLOAD_DIR
}
/mgnify"
ROOT_DIR
=
"
${
DOWNLOAD_DIR
}
/mgnify"
# Mirror of:
# Mirror of:
# ftp://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/20
18_12
/mgy_clusters.fa.gz
# ftp://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/20
22_05
/mgy_clusters.fa.gz
SOURCE_URL
=
"https://storage.googleapis.com/alphafold-databases/
casp14_versions
/mgy_clusters_20
18_12
.fa.gz"
SOURCE_URL
=
"https://storage.googleapis.com/alphafold-databases/
v2.3
/mgy_clusters_20
22_05
.fa.gz"
BASENAME
=
$(
basename
"
${
SOURCE_URL
}
"
)
BASENAME
=
$(
basename
"
${
SOURCE_URL
}
"
)
mkdir
--parents
"
${
ROOT_DIR
}
"
mkdir
--parents
"
${
ROOT_DIR
}
"
...
...
scripts/download_pdb_seqres.sh
0 → 100755
View file @
56d5e39c
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips the PDB SeqRes database for AlphaFold.
#
# Usage: bash download_pdb_seqres.sh /path/to/download/directory
set
-e
if
[[
$#
-eq
0
]]
;
then
echo
"Error: download directory must be provided as an input argument."
exit
1
fi
if
!
command
-v
aria2c &> /dev/null
;
then
echo
"Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit
1
fi
DOWNLOAD_DIR
=
"
$1
"
ROOT_DIR
=
"
${
DOWNLOAD_DIR
}
/pdb_seqres"
SOURCE_URL
=
"ftp://ftp.wwpdb.org/pub/pdb/derived_data/pdb_seqres.txt"
BASENAME
=
$(
basename
"
${
SOURCE_URL
}
"
)
mkdir
--parents
"
${
ROOT_DIR
}
"
aria2c
"
${
SOURCE_URL
}
"
--dir
=
"
${
ROOT_DIR
}
"
# Keep only protein sequences.
grep
--after-context
=
1
--no-group-separator
'>.* mol:protein'
"
${
ROOT_DIR
}
/pdb_seqres.txt"
>
"
${
ROOT_DIR
}
/pdb_seqres_filtered.txt"
mv
"
${
ROOT_DIR
}
/pdb_seqres_filtered.txt"
"
${
ROOT_DIR
}
/pdb_seqres.txt"
scripts/download_uniprot.sh
0 → 100755
View file @
56d5e39c
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads, unzips and merges the SwissProt and TrEMBL databases for
# AlphaFold-Multimer.
#
# Usage: bash download_uniprot.sh /path/to/download/directory
set
-e
if
[[
$#
-eq
0
]]
;
then
echo
"Error: download directory must be provided as an input argument."
exit
1
fi
if
!
command
-v
aria2c &> /dev/null
;
then
echo
"Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit
1
fi
DOWNLOAD_DIR
=
"
$1
"
ROOT_DIR
=
"
${
DOWNLOAD_DIR
}
/uniprot"
TREMBL_SOURCE_URL
=
"ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_trembl.fasta.gz"
TREMBL_BASENAME
=
$(
basename
"
${
TREMBL_SOURCE_URL
}
"
)
TREMBL_UNZIPPED_BASENAME
=
"
${
TREMBL_BASENAME
%.gz
}
"
SPROT_SOURCE_URL
=
"ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.fasta.gz"
SPROT_BASENAME
=
$(
basename
"
${
SPROT_SOURCE_URL
}
"
)
SPROT_UNZIPPED_BASENAME
=
"
${
SPROT_BASENAME
%.gz
}
"
mkdir
--parents
"
${
ROOT_DIR
}
"
aria2c
"
${
TREMBL_SOURCE_URL
}
"
--dir
=
"
${
ROOT_DIR
}
"
aria2c
"
${
SPROT_SOURCE_URL
}
"
--dir
=
"
${
ROOT_DIR
}
"
pushd
"
${
ROOT_DIR
}
"
gunzip
"
${
ROOT_DIR
}
/
${
TREMBL_BASENAME
}
"
gunzip
"
${
ROOT_DIR
}
/
${
SPROT_BASENAME
}
"
# Concatenate TrEMBL and SwissProt, rename to uniprot and clean up.
cat
"
${
ROOT_DIR
}
/
${
SPROT_UNZIPPED_BASENAME
}
"
>>
"
${
ROOT_DIR
}
/
${
TREMBL_UNZIPPED_BASENAME
}
"
mv
"
${
ROOT_DIR
}
/
${
TREMBL_UNZIPPED_BASENAME
}
"
"
${
ROOT_DIR
}
/uniprot.fasta"
rm
"
${
ROOT_DIR
}
/
${
SPROT_UNZIPPED_BASENAME
}
"
popd
scripts/download_uniref30.sh
View file @
56d5e39c
...
@@ -30,10 +30,15 @@ if ! command -v aria2c &> /dev/null ; then
...
@@ -30,10 +30,15 @@ if ! command -v aria2c &> /dev/null ; then
fi
fi
DOWNLOAD_DIR
=
"
$1
"
DOWNLOAD_DIR
=
"
$1
"
ROOT_DIR
=
"
${
DOWNLOAD_DIR
}
"
ROOT_DIR
=
"
${
DOWNLOAD_DIR
}
/uniref30"
SOURCE_URL
=
"http://wwwuser.gwdg.de/~compbiol/colabfold/uniref30_2103.tar.gz"
# Mirror of:
# https://wwwuser.gwdg.de/~compbiol/uniclust/2021_03/UniRef30_2021_03.tar.gz
SOURCE_URL
=
"https://storage.googleapis.com/alphafold-databases/v2.3/UniRef30_2021_03.tar.gz"
BASENAME
=
$(
basename
"
${
SOURCE_URL
}
"
)
BASENAME
=
$(
basename
"
${
SOURCE_URL
}
"
)
mkdir
--parents
"
${
ROOT_DIR
}
"
mkdir
--parents
"
${
ROOT_DIR
}
"
aria2c
"
${
SOURCE_URL
}
"
--dir
=
"
${
ROOT_DIR
}
"
-x
4
--check-certificate
=
false
aria2c
"
${
SOURCE_URL
}
"
--dir
=
"
${
ROOT_DIR
}
"
-x
4
--check-certificate
=
false
gunzip
"
${
ROOT_DIR
}
/
${
BASENAME
}
"
tar
--extract
--verbose
--file
=
"
${
ROOT_DIR
}
/
${
BASENAME
}
"
\
--directory
=
"
${
ROOT_DIR
}
"
rm
"
${
ROOT_DIR
}
/
${
BASENAME
}
"
scripts/flatten_roda.sh
View file @
56d5e39c
...
@@ -26,7 +26,12 @@ mkdir -p "${ALIGNMENT_DIR}"
...
@@ -26,7 +26,12 @@ mkdir -p "${ALIGNMENT_DIR}"
for
chain_dir
in
$(
ls
"
${
RODA_DIR
}
"
)
;
do
for
chain_dir
in
$(
ls
"
${
RODA_DIR
}
"
)
;
do
CHAIN_DIR_PATH
=
"
${
RODA_DIR
}
/
${
chain_dir
}
"
CHAIN_DIR_PATH
=
"
${
RODA_DIR
}
/
${
chain_dir
}
"
for
subdir
in
$(
ls
"
${
CHAIN_DIR_PATH
}
"
)
;
do
for
subdir
in
$(
ls
"
${
CHAIN_DIR_PATH
}
"
)
;
do
if
[[
$subdir
=
"pdb"
]]
||
[[
$subdir
=
"cif"
]]
;
then
if
[[
!
-d
"
$subdir
"
]]
;
then
echo
"
$subdir
is not directory"
continue
elif
[[
-z
$(
ls
"
${
subdir
}
"
)
]]
;
then
continue
elif
[[
$subdir
=
"pdb"
]]
||
[[
$subdir
=
"cif"
]]
;
then
mv
"
${
CHAIN_DIR_PATH
}
/
${
subdir
}
"
/
*
"
${
DATA_DIR
}
"
mv
"
${
CHAIN_DIR_PATH
}
/
${
subdir
}
"
/
*
"
${
DATA_DIR
}
"
else
else
CHAIN_ALIGNMENT_DIR
=
"
${
ALIGNMENT_DIR
}
/
${
chain_dir
}
"
CHAIN_ALIGNMENT_DIR
=
"
${
ALIGNMENT_DIR
}
/
${
chain_dir
}
"
...
...
scripts/generate_alphafold_feature_dict.py
View file @
56d5e39c
...
@@ -2,35 +2,62 @@ import argparse
...
@@ -2,35 +2,62 @@ import argparse
import
os
import
os
import
pickle
import
pickle
from
alphafold.data
import
pipeline
,
templates
from
alphafold.data
import
pipeline
,
pipeline_multimer
,
templates
from
alphafold.data.tools
import
hmmsearch
,
hhsearch
from
scripts.utils
import
add_data_args
from
scripts.utils
import
add_data_args
def
main
(
args
):
def
main
(
args
):
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
if
(
args
.
multimer
):
mmcif_dir
=
args
.
mmcif_dir
,
template_searcher
=
hmmsearch
.
Hmmsearch
(
max_template_date
=
args
.
max_template_date
,
binary_path
=
args
.
hmmsearch_binary_path
,
max_hits
=
20
,
hmmbuild_binary_path
=
args
.
hmmbuild_binary_path
,
kalign_binary_path
=
args
.
kalign_binary_path
,
database_path
=
args
.
pdb_seqres_database_path
,
release_dates_path
=
None
,
)
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
,
)
template_featurizer
=
templates
.
HmmsearchHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
max_hits
=
20
,
kalign_binary_path
=
args
.
kalign_binary_path
,
release_dates_path
=
args
.
release_dates_path
,
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
)
else
:
template_searcher
=
hhsearch
.
HHSearch
(
binary_path
=
args
.
hhsearch_binary_path
,
databases
=
[
args
.
pdb70_database_path
],
)
template_featurizer
=
templates
.
HhsearchHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
max_hits
=
20
,
kalign_binary_path
=
args
.
kalign_binary_path
,
release_dates_path
=
None
,
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
)
data_pipeline
=
pipeline
.
DataPipeline
(
data_pipeline
=
pipeline
.
DataPipeline
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
small_bfd_database_path
=
None
,
small_bfd_database_path
=
None
,
template_featurizer
=
template_featurizer
,
template_featurizer
=
template_featurizer
,
template_searcher
=
template_searcher
,
use_small_bfd
=
False
,
use_small_bfd
=
False
,
)
)
if
(
args
.
multimer
):
data_pipeline
=
pipeline_multimer
.
DataPipeline
(
monomer_data_pipeline
=
data_pipeline
,
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
uniprot_database_path
=
args
.
uniprot_database_path
)
feature_dict
=
data_pipeline
.
process
(
feature_dict
=
data_pipeline
.
process
(
input_fasta_path
=
args
.
fasta_path
,
input_fasta_path
=
args
.
fasta_path
,
msa_output_dir
=
args
.
output_dir
,
msa_output_dir
=
args
.
output_dir
,
...
@@ -44,6 +71,7 @@ if __name__ == "__main__":
...
@@ -44,6 +71,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"fasta_path"
,
type
=
str
)
parser
.
add_argument
(
"fasta_path"
,
type
=
str
)
parser
.
add_argument
(
"mmcif_dir"
,
type
=
str
)
parser
.
add_argument
(
"mmcif_dir"
,
type
=
str
)
parser
.
add_argument
(
"output_dir"
,
type
=
str
)
parser
.
add_argument
(
"output_dir"
,
type
=
str
)
parser
.
add_argument
(
"--multimer"
,
action
=
'store_true'
)
add_data_args
(
parser
)
add_data_args
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
scripts/generate_chain_data_cache.py
View file @
56d5e39c
...
@@ -4,10 +4,11 @@ import json
...
@@ -4,10 +4,11 @@ import json
import
logging
import
logging
from
multiprocessing
import
Pool
from
multiprocessing
import
Pool
import
os
import
os
import
string
import
sys
import
sys
sys
.
path
.
append
(
"."
)
# an innocent hack to get this to run from the top level
sys
.
path
.
append
(
"."
)
# an innocent hack to get this to run from the top level
from
collections
import
defaultdict
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
openfold.data.mmcif_parsing
import
parse
from
openfold.data.mmcif_parsing
import
parse
...
@@ -49,20 +50,27 @@ def parse_file(
...
@@ -49,20 +50,27 @@ def parse_file(
pdb_string
=
fp
.
read
()
pdb_string
=
fp
.
read
()
protein_object
=
protein
.
from_pdb_string
(
pdb_string
,
None
)
protein_object
=
protein
.
from_pdb_string
(
pdb_string
,
None
)
aatype
=
protein_object
.
aatype
chain_index
=
protein_object
.
chain_index
chain_dict
=
{}
chain_dict
=
defaultdict
(
list
)
chain_dict
[
"seq"
]
=
residue_constants
.
aatype_to_str_sequence
(
for
i
in
range
(
aatype
.
shape
[
0
]):
protein_object
.
aatype
,
chain_dict
[
chain_index
[
i
]].
append
(
residue_constants
.
restypes_with_x
[
aatype
[
i
]])
)
chain_dict
[
"resolution"
]
=
0.
if
(
chain_cluster_size_dict
is
not
None
):
cluster_size
=
chain_cluster_size_dict
.
get
(
full_name
.
upper
(),
-
1
)
chain_dict
[
"cluster_size"
]
=
cluster_size
out
=
{
file_id
:
chain_dict
}
out
=
{}
chain_tags
=
string
.
ascii_uppercase
for
chain
,
seq
in
chain_dict
.
items
():
full_name
=
"_"
.
join
([
file_id
,
chain_tags
[
chain
]])
out
[
full_name
]
=
{}
local_data
=
out
[
full_name
]
local_data
[
"resolution"
]
=
0.
local_data
[
"seq"
]
=
''
.
join
(
seq
)
if
(
chain_cluster_size_dict
is
not
None
):
cluster_size
=
chain_cluster_size_dict
.
get
(
full_name
.
upper
(),
-
1
)
local_data
[
"cluster_size"
]
=
cluster_size
return
out
return
out
...
...
scripts/precompute_alignments.py
View file @
56d5e39c
...
@@ -11,6 +11,7 @@ import tempfile
...
@@ -11,6 +11,7 @@ import tempfile
import
openfold.data.mmcif_parsing
as
mmcif_parsing
import
openfold.data.mmcif_parsing
as
mmcif_parsing
from
openfold.data.data_pipeline
import
AlignmentRunner
from
openfold.data.data_pipeline
import
AlignmentRunner
from
openfold.data.parsers
import
parse_fasta
from
openfold.data.parsers
import
parse_fasta
from
openfold.data.tools
import
hhsearch
,
hmmsearch
from
openfold.np
import
protein
,
residue_constants
from
openfold.np
import
protein
,
residue_constants
from
utils
import
add_data_args
from
utils
import
add_data_args
...
@@ -39,7 +40,8 @@ def run_seq_group_alignments(seq_groups, alignment_runner, args):
...
@@ -39,7 +40,8 @@ def run_seq_group_alignments(seq_groups, alignment_runner, args):
alignment_runner
.
run
(
alignment_runner
.
run
(
fasta_path
,
alignment_dir
fasta_path
,
alignment_dir
)
)
except
:
except
Exception
as
e
:
logging
.
warning
(
e
)
logging
.
warning
(
f
"Failed to run alignments for
{
first_name
}
. Skipping..."
)
logging
.
warning
(
f
"Failed to run alignments for
{
first_name
}
. Skipping..."
)
os
.
remove
(
fasta_path
)
os
.
remove
(
fasta_path
)
os
.
rmdir
(
alignment_dir
)
os
.
rmdir
(
alignment_dir
)
...
@@ -114,15 +116,30 @@ def parse_and_align(files, alignment_runner, args):
...
@@ -114,15 +116,30 @@ def parse_and_align(files, alignment_runner, args):
def
main
(
args
):
def
main
(
args
):
# Build the alignment tool runner
# Build the alignment tool runner
if
(
args
.
hmmsearch_binary_path
is
not
None
):
template_searcher
=
hmmsearch
.
Hmmsearch
(
binary_path
=
args
.
hmmsearch_binary_path
,
hmmbuild_binary_path
=
args
.
hmmbuild_binary_path
,
database_path
=
args
.
pdb_seqres_database_path
,
)
elif
(
args
.
hhsearch_binary_path
is
not
None
):
template_searcher
=
hhsearch
.
HHSearch
(
binary_path
=
args
.
hhsearch_binary_path
,
databases
=
[
args
.
pdb70_database_path
],
)
else
:
template_searcher
=
None
alignment_runner
=
AlignmentRunner
(
alignment_runner
=
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniref30_database_path
=
args
.
uniref30_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
uniprot_database_path
=
args
.
uniprot_database_path
,
template_searcher
=
template_searcher
,
use_small_bfd
=
args
.
bfd_database_path
is
None
,
use_small_bfd
=
args
.
bfd_database_path
is
None
,
no_cpus
=
args
.
cpus_per_task
,
no_cpus
=
args
.
cpus_per_task
,
)
)
...
...
scripts/utils.py
View file @
56d5e39c
...
@@ -14,9 +14,18 @@ def add_data_args(parser: argparse.ArgumentParser):
...
@@ -14,9 +14,18 @@ def add_data_args(parser: argparse.ArgumentParser):
parser
.
add_argument
(
parser
.
add_argument
(
'--pdb70_database_path'
,
type
=
str
,
default
=
None
,
'--pdb70_database_path'
,
type
=
str
,
default
=
None
,
)
)
parser
.
add_argument
(
'--pdb_seqres_database_path'
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
'--uniref30_database_path'
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
parser
.
add_argument
(
'--uniclust30_database_path'
,
type
=
str
,
default
=
None
,
'--uniclust30_database_path'
,
type
=
str
,
default
=
None
,
)
)
parser
.
add_argument
(
'--uniprot_database_path'
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
parser
.
add_argument
(
'--bfd_database_path'
,
type
=
str
,
default
=
None
,
'--bfd_database_path'
,
type
=
str
,
default
=
None
,
)
)
...
@@ -29,6 +38,12 @@ def add_data_args(parser: argparse.ArgumentParser):
...
@@ -29,6 +38,12 @@ def add_data_args(parser: argparse.ArgumentParser):
parser
.
add_argument
(
parser
.
add_argument
(
'--hhsearch_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhsearch'
'--hhsearch_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhsearch'
)
)
parser
.
add_argument
(
'--hmmsearch_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hmmsearch'
)
parser
.
add_argument
(
'--hmmbuild_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hmmbuild'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--kalign_binary_path'
,
type
=
str
,
default
=
'/usr/bin/kalign'
'--kalign_binary_path'
,
type
=
str
,
default
=
'/usr/bin/kalign'
)
)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment