Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
a3c2ae51
"vscode:/vscode.git/clone" did not exist on "04e50aba33109d9f60c609a0f100ebb413ffabad"
Commit
a3c2ae51
authored
Oct 10, 2021
by
Gustaf Ahdritz
Browse files
Add TM calculations, move testing wrappers
parent
9eda0b43
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
122 additions
and
83 deletions
+122
-83
openfold/features/data_transforms.py
openfold/features/data_transforms.py
+8
-0
openfold/model/heads.py
openfold/model/heads.py
+12
-2
openfold/utils/feats.py
openfold/utils/feats.py
+0
-77
openfold/utils/loss.py
openfold/utils/loss.py
+96
-1
tests/test_loss.py
tests/test_loss.py
+2
-1
tests/test_model.py
tests/test_model.py
+2
-1
tests/test_structure_module.py
tests/test_structure_module.py
+2
-1
No files found.
openfold/features/data_transforms.py
View file @
a3c2ae51
...
@@ -7,6 +7,7 @@ from operator import add
...
@@ -7,6 +7,7 @@ from operator import add
from
openfold.config
import
NUM_RES
,
NUM_EXTRA_SEQ
,
NUM_TEMPLATES
,
NUM_MSA_SEQ
from
openfold.config
import
NUM_RES
,
NUM_EXTRA_SEQ
,
NUM_TEMPLATES
,
NUM_MSA_SEQ
from
openfold.np
import
residue_constants
from
openfold.np
import
residue_constants
from
openfold.utils.tensor_utils
import
tree_map
,
tensor_tree_map
MSA_FEATURE_NAMES
=
[
MSA_FEATURE_NAMES
=
[
'msa'
,
'deletion_matrix'
,
'msa_mask'
,
'msa_row_mask'
,
'bert_mask'
,
'true_msa'
'msa'
,
'deletion_matrix'
,
'msa_mask'
,
'msa_row_mask'
,
'bert_mask'
,
'true_msa'
...
@@ -535,3 +536,10 @@ def make_atom14_masks(protein):
...
@@ -535,3 +536,10 @@ def make_atom14_masks(protein):
protein
[
'atom37_atom_exists'
]
=
residx_atom37_mask
protein
[
'atom37_atom_exists'
]
=
residx_atom37_mask
return
protein
return
protein
def
make_atom14_masks_np
(
batch
):
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
out
=
make_atom14_masks
(
batch
)
out
=
tensor_tree_map
(
lambda
t
:
np
.
array
(
t
),
out
)
return
out
openfold/model/heads.py
View file @
a3c2ae51
...
@@ -17,7 +17,11 @@ import torch
...
@@ -17,7 +17,11 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
from
openfold.model.primitives
import
Linear
from
openfold.utils.loss
import
compute_plddt
from
openfold.utils.loss
import
(
compute_plddt
,
compute_tm
,
compute_predicted_aligned_error
,
)
class
AuxiliaryHeads
(
nn
.
Module
):
class
AuxiliaryHeads
(
nn
.
Module
):
...
@@ -71,6 +75,12 @@ class AuxiliaryHeads(nn.Module):
...
@@ -71,6 +75,12 @@ class AuxiliaryHeads(nn.Module):
if
(
self
.
config
.
tm
.
enabled
):
if
(
self
.
config
.
tm
.
enabled
):
tm_logits
=
self
.
tm
(
outputs
[
"pair"
])
tm_logits
=
self
.
tm
(
outputs
[
"pair"
])
aux_out
[
"tm_logits"
]
=
tm_logits
aux_out
[
"tm_logits"
]
=
tm_logits
aux_out
[
"predicted_tm_score"
]
=
compute_tm
(
tm_logits
,
**
self
.
config
.
tm
)
aux_out
.
update
(
compute_predicted_aligned_error
(
tm_logits
,
**
self
.
config
.
tm
,
))
return
aux_out
return
aux_out
...
...
openfold/utils/feats.py
View file @
a3c2ae51
...
@@ -75,83 +75,6 @@ def get_chi_atom_indices():
...
@@ -75,83 +75,6 @@ def get_chi_atom_indices():
return
chi_atom_indices
return
chi_atom_indices
def
compute_residx
(
batch
):
out
=
{}
float_type
=
batch
[
"seq_mask"
].
dtype
aatype
=
batch
[
"aatype"
]
restype_atom14_to_atom37
=
[]
# mapping (restype, atom37) --> atom14
restype_atom37_to_atom14
=
[]
# mapping (restype, atom37) --> atom14
restype_atom14_mask
=
[]
for
rt
in
rc
.
restypes
:
atom_names
=
rc
.
restype_name_to_atom14_names
[
rc
.
restype_1to3
[
rt
]
]
restype_atom14_to_atom37
.
append
([
(
rc
.
atom_order
[
name
]
if
name
else
0
)
for
name
in
atom_names
])
atom_name_to_idx14
=
{
name
:
i
for
i
,
name
in
enumerate
(
atom_names
)}
restype_atom37_to_atom14
.
append
([
(
atom_name_to_idx14
[
name
]
if
name
in
atom_name_to_idx14
else
0
)
for
name
in
rc
.
atom_types
])
restype_atom14_mask
.
append
(
[(
1.
if
name
else
0.
)
for
name
in
atom_names
]
)
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37
.
append
([
0
]
*
14
)
restype_atom37_to_atom14
.
append
([
0
]
*
37
)
restype_atom14_mask
.
append
([
0.
]
*
14
)
restype_atom14_to_atom37
=
aatype
.
new_tensor
(
restype_atom14_to_atom37
)
restype_atom37_to_atom14
=
aatype
.
new_tensor
(
restype_atom37_to_atom14
)
restype_atom14_mask
=
batch
[
"seq_mask"
].
new_tensor
(
restype_atom14_mask
)
residx_atom14_to_atom37
=
restype_atom14_to_atom37
[
aatype
]
residx_atom14_mask
=
restype_atom14_mask
[
aatype
]
out
[
"residx_atom14_to_atom37"
]
=
residx_atom14_to_atom37
out
[
"atom14_atom_exists"
]
=
residx_atom14_mask
# create the gather indices for mapping back
residx_atom37_to_atom14
=
restype_atom37_to_atom14
[
aatype
]
out
[
"residx_atom37_to_atom14"
]
=
residx_atom37_to_atom14
# create the corresponding mask
restype_atom37_mask
=
torch
.
zeros
([
21
,
37
],
dtype
=
float_type
)
for
restype
,
restype_letter
in
enumerate
(
rc
.
restypes
):
restype_name
=
rc
.
restype_1to3
[
restype_letter
]
atom_names
=
rc
.
residue_atoms
[
restype_name
]
for
atom_name
in
atom_names
:
atom_type
=
rc
.
atom_order
[
atom_name
]
restype_atom37_mask
[
restype
,
atom_type
]
=
1
residx_atom37_mask
=
restype_atom37_mask
[
aatype
]
out
[
"atom37_atom_exists"
]
=
residx_atom37_mask
return
out
def
compute_residx_np
(
batch
):
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
out
=
compute_residx
(
batch
)
out
=
tensor_tree_map
(
lambda
t
:
np
.
array
(
t
),
out
)
return
out
def
atom14_to_atom37
(
atom14
,
batch
):
def
atom14_to_atom37
(
atom14
,
batch
):
atom37_data
=
batched_gather
(
atom37_data
=
batched_gather
(
atom14
,
atom14
,
...
...
openfold/utils/loss.py
View file @
a3c2ae51
...
@@ -18,7 +18,7 @@ import ml_collections
...
@@ -18,7 +18,7 @@ import ml_collections
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Dict
,
Optional
from
typing
import
Dict
,
Optional
,
Tuple
from
openfold.np
import
residue_constants
from
openfold.np
import
residue_constants
from
openfold.utils
import
feats
from
openfold.utils
import
feats
...
@@ -28,6 +28,7 @@ from openfold.utils.tensor_utils import (
...
@@ -28,6 +28,7 @@ from openfold.utils.tensor_utils import (
tensor_tree_map
,
tensor_tree_map
,
masked_mean
,
masked_mean
,
permute_final_dims
,
permute_final_dims
,
batched_gather
,
)
)
...
@@ -450,6 +451,100 @@ def distogram_loss(
...
@@ -450,6 +451,100 @@ def distogram_loss(
return
mean
return
mean
def
_calculate_bin_centers
(
boundaries
:
torch
.
Tensor
):
step
=
boundaries
[
1
]
-
boundaries
[
0
]
bin_centers
=
breaks
+
step
/
2
bin_centers
=
torch
.
cat
([
bin_centers
,
[
bin_centers
[
-
1
]
+
step
]],
dim
=
0
)
return
bin_centers
def
_calculate_expected_aligned_error
(
alignment_confidence_breaks
:
torch
.
Tensor
,
aligned_distance_error_probs
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
bin_centers
=
_calculate_bin_centers
(
alignment_confidence_breaks
)
return
(
torch
.
sum
(
aligned_distance_error_probs
*
bin_centers
,
dim
=-
1
),
bin_centers
[
-
1
]
)
def
compute_predicted_aligned_error
(
logits
:
torch
.
Tensor
,
max_bin
:
int
=
31
,
no_bins
:
int
=
64
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Computes aligned confidence metrics from logits.
Args:
logits: [*, num_res, num_res, num_bins] the logits output from
PredictedAlignedErrorHead.
max_bin: Maximum bin value
no_bins: Number of bins
Returns:
aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted
aligned error probabilities over bins for each residue pair.
predicted_aligned_error: [*, num_res, num_res] the expected aligned distance
error for each pair of residues.
max_predicted_aligned_error: [*] the maximum predicted error possible.
"""
boundaries
=
torch
.
linspace
(
0
,
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
)
aligned_confidence_probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
predicted_aligned_error
,
max_predicted_aligned_error
=
(
_calculate_expected_aligned_error
(
alignment_confidence_breaks
=
boundaries
,
aligned_distance_error_probs
=
aligned_confidence_probs
)
)
return
{
"aligned_confidence_probs"
:
aligned_confidence_probs
,
"predicted_aligned_error"
:
predicted_aligned_error
,
"max_predicted_aligned_error"
:
max_predicted_aligned_error
,
}
def
compute_tm
(
logits
:
torch
.
Tensor
,
residue_weights
:
Optional
[
torch
.
Tensor
]
=
None
,
max_bin
:
int
=
31
,
no_bins
:
int
=
64
,
)
->
torch
.
Tensor
:
if
(
residue_weights
is
None
):
residue_weights
=
np
.
ones
(
logits
.
shape
[
-
2
])
boundaries
=
torch
.
linspace
(
0
,
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
)
bin_centers
=
_calculate_bin_centers
(
boundaries
)
torch
.
sum
(
residue_weights
)
clipped_n
=
max
(
n
,
19
)
d0
=
1.24
*
(
clipped_n
-
15
)
**
(
1.
/
3
)
-
1.8
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
tm_per_bin
=
1.
/
(
1
+
(
bin_centers
**
2
)
/
(
d0
**
2
))
predicted_tm_term
=
torch
.
sum
(
probs
*
tm_per_bin
,
dim
=-
1
)
normed_residue_mask
=
residue_weights
/
(
eps
+
residue_weights
.
sum
())
per_alignment
=
torch
.
sum
(
predicted_tm_term
*
normed_residue_mask
,
dim
=-
1
)
weighted
=
per_alignment
*
residue_weights
argmax
=
(
weighted
==
torch
.
max
(
weighted
)).
nonzero
()[
0
]
return
per_alignment
[
tuple
(
argmax
)]
def
tm_loss
(
def
tm_loss
(
logits
,
logits
,
final_affine_tensor
,
final_affine_tensor
,
...
...
tests/test_loss.py
View file @
a3c2ae51
...
@@ -19,6 +19,7 @@ import numpy as np
...
@@ -19,6 +19,7 @@ import numpy as np
import
unittest
import
unittest
import
ml_collections
as
mlc
import
ml_collections
as
mlc
from
openfold.features.data_transforms
import
make_atom14_masks
from
openfold.utils.affine_utils
import
T
,
affine_vector_to_4x4
from
openfold.utils.affine_utils
import
T
,
affine_vector_to_4x4
import
openfold.utils.feats
as
feats
import
openfold.utils.feats
as
feats
from
openfold.utils.loss
import
(
from
openfold.utils.loss
import
(
...
@@ -310,7 +311,7 @@ class TestLoss(unittest.TestCase):
...
@@ -310,7 +311,7 @@ class TestLoss(unittest.TestCase):
def
_build_extra_feats_np
():
def
_build_extra_feats_np
():
b
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
b
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
b
.
update
(
feats
.
build_ambiguity_feats
(
b
))
b
.
update
(
feats
.
build_ambiguity_feats
(
b
))
b
.
update
(
feats
.
compute_residx
(
b
))
b
.
update
(
make_atom14_masks
(
b
))
return
tensor_tree_map
(
lambda
t
:
np
.
array
(
t
),
b
)
return
tensor_tree_map
(
lambda
t
:
np
.
array
(
t
),
b
)
batch
=
_build_extra_feats_np
()
batch
=
_build_extra_feats_np
()
...
...
tests/test_model.py
View file @
a3c2ae51
...
@@ -18,6 +18,7 @@ import torch.nn as nn
...
@@ -18,6 +18,7 @@ import torch.nn as nn
import
numpy
as
np
import
numpy
as
np
import
unittest
import
unittest
from
openfold.config
import
model_config
from
openfold.config
import
model_config
from
openfold.features.data_transforms
import
make_atom14_masks
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
import
openfold.utils.feats
as
feats
import
openfold.utils.feats
as
feats
from
openfold.utils.tensor_utils
import
tree_map
,
tensor_tree_map
from
openfold.utils.tensor_utils
import
tree_map
,
tensor_tree_map
...
@@ -73,7 +74,7 @@ class TestModel(unittest.TestCase):
...
@@ -73,7 +74,7 @@ class TestModel(unittest.TestCase):
batch
[
"seq_mask"
]
=
torch
.
randint
(
batch
[
"seq_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
batch_size
,
n_res
)
low
=
0
,
high
=
2
,
size
=
(
batch_size
,
n_res
)
).
float
()
).
float
()
batch
.
update
(
feats
.
compute_residx
(
batch
))
batch
.
update
(
make_atom14_masks
(
batch
))
add_recycling_dims
=
lambda
t
:
(
add_recycling_dims
=
lambda
t
:
(
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
no_cycles
)
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
no_cycles
)
...
...
tests/test_structure_module.py
View file @
a3c2ae51
...
@@ -16,6 +16,7 @@ import torch
...
@@ -16,6 +16,7 @@ import torch
import
numpy
as
np
import
numpy
as
np
import
unittest
import
unittest
from
openfold.features.data_transforms
import
make_atom14_masks_np
from
openfold.np.residue_constants
import
(
from
openfold.np.residue_constants
import
(
restype_rigid_group_default_frame
,
restype_rigid_group_default_frame
,
restype_atom14_to_rigid_group
,
restype_atom14_to_rigid_group
,
...
@@ -157,7 +158,7 @@ class TestStructureModule(unittest.TestCase):
...
@@ -157,7 +158,7 @@ class TestStructureModule(unittest.TestCase):
axis
=
0
axis
=
0
)
)
batch
.
update
(
feats
.
compute_residx
_np
(
batch
))
batch
.
update
(
make_atom14_masks
_np
(
batch
))
params
=
compare_utils
.
fetch_alphafold_module_weights
(
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/structure_module"
"alphafold/alphafold_iteration/structure_module"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment