Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
0cf1541c
Commit
0cf1541c
authored
Oct 16, 2023
by
Christina Floristean
Browse files
Refactoring multimer data pipeline and permutation alignment.
parent
377f854c
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
660 additions
and
789 deletions
+660
-789
environment.yml
environment.yml
+2
-0
openfold/config.py
openfold/config.py
+5
-1
openfold/data/data_modules.py
openfold/data/data_modules.py
+371
-438
openfold/data/feature_pipeline.py
openfold/data/feature_pipeline.py
+5
-18
openfold/data/input_pipeline_multimer.py
openfold/data/input_pipeline_multimer.py
+25
-20
openfold/utils/loss.py
openfold/utils/loss.py
+207
-230
scripts/generate_mmcif_cache.py
scripts/generate_mmcif_cache.py
+36
-2
train_openfold.py
train_openfold.py
+9
-80
No files found.
environment.yml
View file @
0cf1541c
...
...
@@ -19,6 +19,8 @@ dependencies:
-
deepspeed==0.5.10
-
dm-tree==0.1.6
-
ml-collections==0.1.0
-
jax==0.3.25
-
pandas==2.0.2
-
numpy==1.21.2
-
PyYAML==5.4.1
-
requests==2.26.0
...
...
openfold/config.py
View file @
0cf1541c
...
...
@@ -749,6 +749,9 @@ multimer_config_update = mlc.ConfigDict({
"sym_id"
,
]
},
"supervised"
:
{
"clamp_prob"
:
1.
},
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model:
# c.model.input_embedder.num_msa = 508
# c.model.extra_msa.extra_msa_embedder.num_extra_msa = 2048
...
...
@@ -765,7 +768,8 @@ multimer_config_update = mlc.ConfigDict({
"max_extra_msa"
:
2048
,
"crop_size"
:
640
,
"spatial_crop_prob"
:
0.5
,
"interface_threshold"
:
10.
"interface_threshold"
:
10.
,
"clamp_prob"
:
1.
,
},
},
"model"
:
{
...
...
openfold/data/data_modules.py
View file @
0cf1541c
...
...
@@ -4,7 +4,7 @@ import json
import
logging
import
os
import
pickle
from
typing
import
Optional
,
Sequence
,
Any
from
typing
import
Optional
,
Sequence
,
Any
,
Union
import
ml_collections
as
mlc
import
pytorch_lightning
as
pl
...
...
@@ -18,43 +18,31 @@ from openfold.data import (
templates
,
)
from
openfold.utils.tensor_utils
import
dict_multimap
import
contextlib
import
tempfile
from
openfold.utils.tensor_utils
import
(
tensor_tree_map
,
)
import
random
logging
.
basicConfig
(
level
=
logging
.
INFO
)
@
contextlib
.
contextmanager
def
temp_fasta_file
(
sequence_str
):
"""function that create temparory fasta file used in multimer datapipeline"""
with
tempfile
.
NamedTemporaryFile
(
"w"
,
suffix
=
".fasta"
)
as
fasta_file
:
fasta_file
.
write
(
sequence_str
)
fasta_file
.
seek
(
0
)
yield
fasta_file
.
name
class
OpenFoldSingleDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
data_dir
:
str
,
alignment_dir
:
str
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
config
:
mlc
.
ConfigDict
,
chain_data_cache_path
:
Optional
[
str
]
=
None
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
max_template_hits
:
int
=
4
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
shuffle_top_k_prefiltered
:
Optional
[
int
]
=
None
,
treat_pdb_as_distillation
:
bool
=
True
,
filter_path
:
Optional
[
str
]
=
None
,
mode
:
str
=
"train"
,
alignment_index
:
Optional
[
Any
]
=
None
,
_output_raw
:
bool
=
False
,
_structure_index
:
Optional
[
Any
]
=
None
,
):
data_dir
:
str
,
alignment_dir
:
str
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
config
:
mlc
.
ConfigDict
,
chain_data_cache_path
:
Optional
[
str
]
=
None
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
max_template_hits
:
int
=
4
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
shuffle_top_k_prefiltered
:
Optional
[
int
]
=
None
,
treat_pdb_as_distillation
:
bool
=
True
,
filter_path
:
Optional
[
str
]
=
None
,
mode
:
str
=
"train"
,
alignment_index
:
Optional
[
Any
]
=
None
,
_output_raw
:
bool
=
False
,
_structure_index
:
Optional
[
Any
]
=
None
,
):
"""
Args:
data_dir:
...
...
@@ -116,21 +104,21 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self
.
supported_exts
=
[
".cif"
,
".core"
,
".pdb"
]
valid_modes
=
[
"train"
,
"eval"
,
"predict"
]
if
(
mode
not
in
valid_modes
)
:
if
mode
not
in
valid_modes
:
raise
ValueError
(
f
'mode must be one of
{
valid_modes
}
'
)
if
(
template_release_dates_cache_path
is
None
)
:
if
template_release_dates_cache_path
is
None
:
logging
.
warning
(
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache.py before running OpenFold"
)
if
(
alignment_index
is
not
None
)
:
if
alignment_index
is
not
None
:
self
.
_chain_ids
=
list
(
alignment_index
.
keys
())
else
:
self
.
_chain_ids
=
list
(
os
.
listdir
(
alignment_dir
))
if
(
filter_path
is
not
None
)
:
if
filter_path
is
not
None
:
with
open
(
filter_path
,
"r"
)
as
f
:
chains_to_include
=
set
([
l
.
strip
()
for
l
in
f
.
readlines
()])
...
...
@@ -160,7 +148,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
len
(
missing
),
missing_examples
,
chain_data_cache_path
)
self
.
_chain_id_to_idx_dict
=
{
chain
:
i
for
i
,
chain
in
enumerate
(
self
.
_chain_ids
)
}
...
...
@@ -182,7 +170,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
template_featurizer
=
template_featurizer
,
)
if
(
not
self
.
_output_raw
)
:
if
not
self
.
_output_raw
:
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
def
_parse_mmcif
(
self
,
path
,
file_id
,
chain_id
,
alignment_dir
,
alignment_index
):
...
...
@@ -195,7 +183,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
# Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage.
if
(
mmcif_object
.
mmcif_object
is
None
)
:
if
mmcif_object
.
mmcif_object
is
None
:
raise
list
(
mmcif_object
.
errors
.
values
())[
0
]
mmcif_object
=
mmcif_object
.
mmcif_object
...
...
@@ -220,47 +208,46 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_dir
=
os
.
path
.
join
(
self
.
alignment_dir
,
name
)
alignment_index
=
None
if
(
self
.
alignment_index
is
not
None
)
:
if
self
.
alignment_index
is
not
None
:
alignment_dir
=
self
.
alignment_dir
alignment_index
=
self
.
alignment_index
[
name
]
if
(
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
)
:
if
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
:
spl
=
name
.
rsplit
(
'_'
,
1
)
if
(
len
(
spl
)
==
2
)
:
if
len
(
spl
)
==
2
:
file_id
,
chain_id
=
spl
else
:
file_id
,
=
spl
chain_id
=
None
path
=
os
.
path
.
join
(
self
.
data_dir
,
file_id
)
structure_index_entry
=
None
if
(
self
.
_structure_index
is
not
None
):
if
self
.
_structure_index
is
not
None
:
structure_index_entry
=
self
.
_structure_index
[
name
]
assert
(
len
(
structure_index_entry
[
"files"
])
==
1
)
assert
(
len
(
structure_index_entry
[
"files"
])
==
1
)
filename
,
_
,
_
=
structure_index_entry
[
"files"
][
0
]
ext
=
os
.
path
.
splitext
(
filename
)[
1
]
else
:
ext
=
None
for
e
in
self
.
supported_exts
:
if
(
os
.
path
.
exists
(
path
+
e
)
)
:
if
os
.
path
.
exists
(
path
+
e
):
ext
=
e
break
if
(
ext
is
None
)
:
if
ext
is
None
:
raise
ValueError
(
"Invalid file type"
)
path
+=
ext
if
(
ext
==
".cif"
)
:
if
ext
==
".cif"
:
data
=
self
.
_parse_mmcif
(
path
,
file_id
,
chain_id
,
alignment_dir
,
alignment_index
,
)
elif
(
ext
==
".core"
)
:
elif
ext
==
".core"
:
data
=
self
.
data_pipeline
.
process_core
(
path
,
alignment_dir
,
alignment_index
,
)
elif
(
ext
==
".pdb"
)
:
elif
ext
==
".pdb"
:
structure_index
=
None
if
(
self
.
_structure_index
is
not
None
)
:
if
self
.
_structure_index
is
not
None
:
structure_index
=
self
.
_structure_index
[
name
]
data
=
self
.
data_pipeline
.
process_pdb
(
pdb_path
=
path
,
...
...
@@ -271,7 +258,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
_structure_index
=
structure_index
,
)
else
:
raise
ValueError
(
"Extension branch missing"
)
raise
ValueError
(
"Extension branch missing"
)
else
:
path
=
os
.
path
.
join
(
name
,
name
+
".fasta"
)
data
=
self
.
data_pipeline
.
process_fasta
(
...
...
@@ -280,11 +267,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_index
=
alignment_index
,
)
if
(
self
.
_output_raw
)
:
if
self
.
_output_raw
:
return
data
feats
=
self
.
feature_pipeline
.
process_features
(
data
,
self
.
mode
data
,
self
.
mode
)
feats
[
"batch_idx"
]
=
torch
.
tensor
(
...
...
@@ -295,30 +282,29 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
return
feats
def
__len__
(
self
):
return
len
(
self
.
_chain_ids
)
return
len
(
self
.
_chain_ids
)
class
OpenFoldSingleMultimerDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
data_dir
:
str
,
alignment_dir
:
str
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
config
:
mlc
.
ConfigDict
,
chain_data_cache_path
:
Optional
[
str
]
=
None
,
mmcif_data_cache_path
:
Optional
[
str
]
=
None
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
max_template_hits
:
int
=
4
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
shuffle_top_k_prefiltered
:
Optional
[
int
]
=
None
,
treat_pdb_as_distillation
:
bool
=
True
,
filter_path
:
Optional
[
str
]
=
None
,
mode
:
str
=
"train"
,
alignment_index
:
Optional
[
Any
]
=
None
,
_output_raw
:
bool
=
False
,
_structure_index
:
Optional
[
Any
]
=
None
,
):
data_dir
:
str
,
alignment_dir
:
str
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
config
:
mlc
.
ConfigDict
,
mmcif_data_cache_path
:
Optional
[
str
]
=
None
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
max_template_hits
:
int
=
4
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
shuffle_top_k_prefiltered
:
Optional
[
int
]
=
None
,
treat_pdb_as_distillation
:
bool
=
True
,
filter_path
:
Optional
[
str
]
=
None
,
mode
:
str
=
"train"
,
alignment_index
:
Optional
[
Any
]
=
None
,
_output_raw
:
bool
=
False
,
_structure_index
:
Optional
[
Any
]
=
None
,
):
"""
This class check each individual PDB ID and return its chain(s) features/ground truth
Args:
...
...
@@ -336,15 +322,10 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
Path to a directory containing template mmCIF files.
config:
A dataset config object. See openfold.config
chain_data_cache_path:
Path to cache of data_dir generated by
scripts/generate_chain_data_cache.py
mmcif_data_cache_path:
Path to cache of all mmcifs files generated by
scripts/generate_mmcif_cache.py It should be a json file which records
what PDB ID contains which chain(s)
kalign_binary_path:
Path to kalign binary.
max_template_hits:
...
...
@@ -369,17 +350,12 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
"""
super
(
OpenFoldSingleMultimerDataset
,
self
).
__init__
()
self
.
data_dir
=
data_dir
self
.
mmcif_data_cache_path
=
mmcif_data_cache_path
self
.
chain_data_cache
=
None
if
chain_data_cache_path
is
not
None
:
with
open
(
chain_data_cache_path
,
"r"
)
as
fp
:
self
.
chain_data_cache
=
json
.
load
(
fp
)
assert
isinstance
(
self
.
chain_data_cache
,
dict
)
self
.
mmcif_data_cache_path
=
mmcif_data_cache_path
if
self
.
mmcif_data_cache_path
is
not
None
:
with
open
(
self
.
mmcif_data_cache_path
,
"r"
)
as
infile
:
with
open
(
self
.
mmcif_data_cache_path
,
"r"
)
as
infile
:
self
.
mmcif_data_cache
=
json
.
load
(
infile
)
assert
isinstance
(
self
.
mmcif_data_cache
,
dict
)
assert
isinstance
(
self
.
mmcif_data_cache
,
dict
)
self
.
alignment_dir
=
alignment_dir
self
.
config
=
config
...
...
@@ -392,39 +368,36 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
self
.
supported_exts
=
[
".cif"
,
".core"
,
".pdb"
]
valid_modes
=
[
"train"
,
"eval"
,
"predict"
]
if
(
mode
not
in
valid_modes
)
:
if
mode
not
in
valid_modes
:
raise
ValueError
(
f
'mode must be one of
{
valid_modes
}
'
)
if
(
template_release_dates_cache_path
is
None
)
:
if
template_release_dates_cache_path
is
None
:
logging
.
warning
(
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache.py before running OpenFold"
)
if
(
alignment_index
is
not
None
):
self
.
_chain_ids
=
list
(
alignment_index
.
keys
())
if
self
.
mmcif_data_cache_path
is
not
None
:
self
.
_mmcifs
=
list
(
self
.
mmcif_data_cache
.
keys
())
elif
self
.
alignment_index
is
not
None
:
self
.
_mmcifs
=
[
i
.
split
(
"_"
)[
0
]
for
i
in
list
(
alignment_index
.
keys
())]
elif
self
.
alignment_dir
is
not
None
:
self
.
_mmcifs
=
[
i
.
split
(
"_"
)[
0
]
for
i
in
os
.
listdir
(
self
.
alignment_dir
)]
else
:
self
.
_chain_ids
=
list
(
os
.
listdir
(
alignment_dir
)
)
raise
ValueError
(
"You must provide at least one of the mmcif_data_cache or
alignment_dir
"
)
if
(
filter_path
is
not
None
)
:
if
filter_path
is
not
None
:
with
open
(
filter_path
,
"r"
)
as
f
:
chain
s_to_include
=
set
([
l
.
strip
()
for
l
in
f
.
readlines
()])
mmcif
s_to_include
=
set
([
l
.
strip
()
for
l
in
f
.
readlines
()])
self
.
_
chain_id
s
=
[
c
for
c
in
self
.
_
chain_id
s
if
c
in
chain
s_to_include
self
.
_
mmcif
s
=
[
m
for
m
in
self
.
_
mmcif
s
if
m
in
mmcif
s_to_include
]
if
self
.
mmcif_data_cache_path
is
not
None
:
self
.
_mmcifs
=
list
(
self
.
mmcif_data_cache
.
keys
())
elif
self
.
mmcif_data_cache_path
is
None
and
self
.
alignment_dir
is
not
None
:
self
.
_mmcifs
=
[
i
.
split
(
"_"
)[
0
]
for
i
in
os
.
listdir
(
self
.
alignment_dir
)]
else
:
raise
ValueError
(
"You must provide at least one of the mmcif_data_cache or alignment_dir"
)
self
.
_mmcif_id_to_idx_dict
=
{
mmcif
:
i
for
i
,
mmcif
in
enumerate
(
self
.
_mmcifs
)
}
# changed template_featurizer to hmmsearch for now just to run the test
mmcif
:
i
for
i
,
mmcif
in
enumerate
(
self
.
_mmcifs
)
}
template_featurizer
=
templates
.
HmmsearchHitFeaturizer
(
mmcif_dir
=
template_mmcif_dir
,
max_template_date
=
max_template_date
,
...
...
@@ -443,7 +416,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
)
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
def
_parse_mmcif
(
self
,
path
,
file_id
,
alignment_dir
,
alignment_index
):
def
_parse_mmcif
(
self
,
path
,
file_id
,
alignment_dir
,
alignment_index
):
with
open
(
path
,
'r'
)
as
f
:
mmcif_string
=
f
.
read
()
...
...
@@ -453,7 +426,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
# Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage.
if
(
mmcif_object
.
mmcif_object
is
None
)
:
if
mmcif_object
.
mmcif_object
is
None
:
raise
list
(
mmcif_object
.
errors
.
values
())[
0
]
mmcif_object
=
mmcif_object
.
mmcif_object
...
...
@@ -462,34 +435,34 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
mmcif
=
mmcif_object
,
alignment_dir
=
alignment_dir
,
alignment_index
=
alignment_index
)
)
return
data
def
mmcif_id_to_idx
(
self
,
chain
_id
):
return
self
.
_mmcif_id_to_idx_dict
[
chain
_id
]
def
mmcif_id_to_idx
(
self
,
mmcif
_id
):
return
self
.
_mmcif_id_to_idx_dict
[
mmcif
_id
]
def
idx_to_mmcif_id
(
self
,
idx
):
return
self
.
_mmcifs
[
idx
]
def
__getitem__
(
self
,
idx
):
mmcif_id
=
self
.
idx_to_mmcif_id
(
idx
)
alignment_index
=
None
if
(
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
)
:
if
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
:
path
=
os
.
path
.
join
(
self
.
data_dir
,
f
"
{
mmcif_id
}
"
)
ext
=
None
for
e
in
self
.
supported_exts
:
if
(
os
.
path
.
exists
(
path
+
e
)
)
:
if
os
.
path
.
exists
(
path
+
e
):
ext
=
e
break
if
(
ext
is
None
)
:
if
ext
is
None
:
raise
ValueError
(
"Invalid file type"
)
#TODO: Add pdb and core exts to data_pipeline for multimer
#
TODO: Add pdb and core exts to data_pipeline for multimer
path
+=
ext
if
(
ext
==
".cif"
)
:
if
ext
==
".cif"
:
data
=
self
.
_parse_mmcif
(
path
,
mmcif_id
,
self
.
alignment_dir
,
alignment_index
,
)
...
...
@@ -502,107 +475,52 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
alignment_dir
=
self
.
alignment_dir
)
if
(
self
.
_output_raw
)
:
if
self
.
_output_raw
:
return
data
# process all_chain_features
data
,
ground_truth
=
self
.
feature_pipeline
.
process_features
(
data
,
data
=
self
.
feature_pipeline
.
process_features
(
data
,
mode
=
self
.
mode
,
is_multimer
=
True
)
# if it's inference mode, only need all_chain_features
data
[
"batch_idx"
]
=
torch
.
tensor
(
[
idx
for
_
in
range
(
data
[
"aatype"
].
shape
[
-
1
])],
dtype
=
torch
.
int64
,
device
=
data
[
"aatype"
].
device
)
return
data
,
ground_truth
return
data
def
__len__
(
self
):
return
len
(
self
.
_
chain_id
s
)
return
len
(
self
.
_
mmcif
s
)
def
deterministic_train_filter
(
chain_data_cache_entry
:
Any
,
max_resolution
:
float
=
9.
,
max_single_aa_prop
:
float
=
0.8
,
)
->
bool
:
# Hard filters
resolution
=
chain_data_cache_entry
.
get
(
"resolution"
,
None
)
if
(
resolution
is
not
None
and
resolution
>
max_resolution
):
return
False
def
resolution_filter
(
resolution
:
int
,
max_resolution
:
float
)
->
bool
:
"""Check that the resolution is <= max_resolution permitted"""
return
resolution
is
not
None
and
resolution
<=
max_resolution
seq
=
chain_data_cache_entry
[
"seq"
]
counts
=
{}
for
aa
in
seq
:
counts
.
setdefault
(
aa
,
0
)
counts
[
aa
]
+=
1
largest_aa_count
=
max
(
counts
.
values
())
largest_single_aa_prop
=
largest_aa_count
/
len
(
seq
)
if
(
largest_single_aa_prop
>
max_single_aa_prop
):
return
False
return
True
def
deterministic_multimer_train_filter
(
mmcif_data_cache_entry
,
max_resolution
:
float
=
9.
,
max_single_aa_prop
:
float
=
0.8
,
minimum_number_of_residues
:
int
=
200
,
)
->
bool
:
"""
Implement multimer training filtering criteria described in
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf Supplementary section 7.1
"""
# First check resolution
resolution
=
mmcif_data_cache_entry
.
get
(
"resolution"
,
None
)
if
(
resolution
is
not
None
and
resolution
>
max_resolution
)
or
(
resolution
is
None
):
return
False
# Then check if any single amino acid accounts for more than 80% of the complex sequences
seqs
=
mmcif_data_cache_entry
[
"seqs"
]
def
aa_count_filter
(
seqs
:
list
,
max_single_aa_prop
:
float
)
->
bool
:
"""Check if any single amino acid accounts for more than max_single_aa_prop percent of the sequence(s)"""
counts
=
{}
for
aa
in
restypes
:
counts
[
aa
]
=
0
total_len
=
sum
([
len
(
i
)
for
i
in
seqs
])
if
total_len
<
minimum_number_of_residues
:
# check if the complex has less than 200 residues
return
False
for
seq
in
seqs
:
for
aa
in
seq
:
counts
.
setdefault
(
aa
,
0
)
if
aa
not
in
restypes
:
return
False
else
:
counts
[
aa
]
+=
1
total_len
=
sum
([
len
(
i
)
for
i
in
seqs
])
largest_aa_count
=
max
(
counts
.
values
())
largest_single_aa_prop
=
largest_aa_count
/
total_len
if
(
largest_single_aa_prop
>
max_single_aa_prop
):
return
False
return
True
def
get_stochastic_train_filter_prob
(
chain_data_cache_entry
:
Any
,
)
->
float
:
# Stochastic filters
probabilities
=
[]
cluster_size
=
chain_data_cache_entry
.
get
(
"cluster_size"
,
None
)
if
(
cluster_size
is
not
None
and
cluster_size
>
0
):
probabilities
.
append
(
1
/
cluster_size
)
chain_length
=
len
(
chain_data_cache_entry
[
"seq"
])
probabilities
.
append
((
1
/
512
)
*
(
max
(
min
(
chain_length
,
512
),
256
)))
return
largest_single_aa_prop
<=
max_single_aa_prop
# Risk of underflow here?
out
=
1
for
p
in
probabilities
:
out
*=
p
return
out
def
all_seq_len_filter
(
seqs
:
list
,
minimum_number_of_residues
:
int
)
->
bool
:
"""Check if the total combined sequence lengths are >= minimum_numer_of_residues"""
total_len
=
sum
([
len
(
i
)
for
i
in
seqs
])
return
total_len
>=
minimum_number_of_residues
class
OpenFoldDataset
(
torch
.
utils
.
data
.
Dataset
):
...
...
@@ -612,67 +530,104 @@ class OpenFoldDataset(torch.utils.data.Dataset):
length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
and filtered once at initialization.
"""
def
__init__
(
self
,
datasets
:
Sequence
[
OpenFoldSingleDataset
],
probabilities
:
Sequence
[
float
],
epoch_len
:
int
,
generator
:
torch
.
Generator
=
None
,
_roll_at_init
:
bool
=
True
,
):
datasets
:
Union
[
Sequence
[
OpenFoldSingleDataset
],
Sequence
[
OpenFoldSingleMultimerDataset
]],
probabilities
:
Sequence
[
float
],
epoch_len
:
int
,
generator
:
torch
.
Generator
=
None
,
_roll_at_init
:
bool
=
True
,
):
self
.
datasets
=
datasets
self
.
probabilities
=
probabilities
self
.
epoch_len
=
epoch_len
self
.
generator
=
generator
def
looped_shuffled_dataset_idx
(
dataset_len
):
while
True
:
# Uniformly shuffle each dataset's indices
weights
=
[
1.
for
_
in
range
(
dataset_len
)]
shuf
=
torch
.
multinomial
(
torch
.
tensor
(
weights
),
num_samples
=
dataset_len
,
replacement
=
False
,
generator
=
self
.
generator
,
)
for
idx
in
shuf
:
yield
idx
def
looped_samples
(
dataset_idx
):
max_cache_len
=
int
(
epoch_len
*
probabilities
[
dataset_idx
])
dataset
=
self
.
datasets
[
dataset_idx
]
idx_iter
=
looped_shuffled_dataset_idx
(
len
(
dataset
))
chain_data_cache
=
dataset
.
chain_data_cache
while
True
:
weights
=
[]
idx
=
[]
for
_
in
range
(
max_cache_len
):
candidate_idx
=
next
(
idx_iter
)
chain_id
=
dataset
.
idx_to_chain_id
(
candidate_idx
)
chain_data_cache_entry
=
chain_data_cache
[
chain_id
]
if
(
not
deterministic_train_filter
(
chain_data_cache_entry
)):
continue
p
=
get_stochastic_train_filter_prob
(
chain_data_cache_entry
,
)
weights
.
append
([
1.
-
p
,
p
])
idx
.
append
(
candidate_idx
)
samples
=
torch
.
multinomial
(
torch
.
tensor
(
weights
),
num_samples
=
1
,
generator
=
self
.
generator
,
self
.
_samples
=
[
self
.
looped_samples
(
i
)
for
i
in
range
(
len
(
self
.
datasets
))]
if
_roll_at_init
:
self
.
reroll
()
@
staticmethod
def
deterministic_train_filter
(
cache_entry
:
Any
,
max_resolution
:
float
=
9.
,
max_single_aa_prop
:
float
=
0.8
,
)
->
bool
:
# Hard filters
resolution
=
cache_entry
.
get
(
"resolution"
,
None
)
seqs
=
[
cache_entry
[
"seq"
]]
return
all
([
resolution_filter
(
resolution
=
resolution
,
max_resolution
=
max_resolution
),
aa_count_filter
(
seqs
=
seqs
,
max_single_aa_prop
=
max_single_aa_prop
)])
@
staticmethod
def
get_stochastic_train_filter_prob
(
cache_entry
:
Any
,
)
->
float
:
# Stochastic filters
probabilities
=
[]
cluster_size
=
cache_entry
.
get
(
"cluster_size"
,
None
)
if
cluster_size
is
not
None
and
cluster_size
>
0
:
probabilities
.
append
(
1
/
cluster_size
)
chain_length
=
len
(
cache_entry
[
"seq"
])
probabilities
.
append
((
1
/
512
)
*
(
max
(
min
(
chain_length
,
512
),
256
)))
# Risk of underflow here?
out
=
1
for
p
in
probabilities
:
out
*=
p
return
out
def
looped_shuffled_dataset_idx
(
self
,
dataset_len
):
while
True
:
# Uniformly shuffle each dataset's indices
weights
=
[
1.
for
_
in
range
(
dataset_len
)]
shuf
=
torch
.
multinomial
(
torch
.
tensor
(
weights
),
num_samples
=
dataset_len
,
replacement
=
False
,
generator
=
self
.
generator
,
)
for
idx
in
shuf
:
yield
idx
def
looped_samples
(
self
,
dataset_idx
):
max_cache_len
=
int
(
self
.
epoch_len
*
self
.
probabilities
[
dataset_idx
])
dataset
=
self
.
datasets
[
dataset_idx
]
idx_iter
=
self
.
looped_shuffled_dataset_idx
(
len
(
dataset
))
chain_data_cache
=
dataset
.
chain_data_cache
while
True
:
weights
=
[]
idx
=
[]
for
_
in
range
(
max_cache_len
):
candidate_idx
=
next
(
idx_iter
)
chain_id
=
dataset
.
idx_to_chain_id
(
candidate_idx
)
chain_data_cache_entry
=
chain_data_cache
[
chain_id
]
if
not
self
.
deterministic_train_filter
(
chain_data_cache_entry
):
continue
p
=
self
.
get_stochastic_train_filter_prob
(
chain_data_cache_entry
,
)
samples
=
samples
.
squeeze
()
weights
.
append
([
1.
-
p
,
p
])
idx
.
append
(
candidate_idx
)
cache
=
[
i
for
i
,
s
in
zip
(
idx
,
samples
)
if
s
]
samples
=
torch
.
multinomial
(
torch
.
tensor
(
weights
),
num_samples
=
1
,
generator
=
self
.
generator
,
)
samples
=
samples
.
squeeze
()
for
datapoint_idx
in
cache
:
yield
datapoint_idx
cache
=
[
i
for
i
,
s
in
zip
(
idx
,
samples
)
if
s
]
self
.
_samples
=
[
looped_samples
(
i
)
for
i
in
range
(
len
(
self
.
datasets
))]
if
(
_roll_at_init
):
self
.
reroll
()
for
datapoint_idx
in
cache
:
yield
datapoint_idx
def
__getitem__
(
self
,
idx
):
dataset_idx
,
datapoint_idx
=
self
.
datapoints
[
idx
]
...
...
@@ -695,71 +650,97 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self
.
datapoints
.
append
((
dataset_idx
,
datapoint_idx
))
class
OpenFoldMultimerDataset
(
torch
.
utils
.
data
.
Dataset
):
class
OpenFoldMultimerDataset
(
OpenFold
Dataset
):
"""
Create a torch Dataset object for multimer training and
add filtering steps described in AlphaFold Multimer's paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf Supplementary section 7.1
"""
def
__init__
(
self
,
datasets
:
Sequence
[
OpenFoldSingleDataset
],
probabilities
:
Sequence
[
float
],
epoch_len
:
int
,
generator
:
torch
.
Generator
=
None
,
_roll_at_init
:
bool
=
True
,
):
self
.
datasets
=
datasets
self
.
probabilities
=
probabilities
self
.
epoch_len
=
epoch_len
self
.
generator
=
generator
if
_roll_at_init
:
self
.
reroll
()
def
filter_samples
(
self
,
dataset_idx
):
dataset
=
self
.
datasets
[
dataset_idx
]
mmcif_data_cache
=
dataset
.
mmcif_data_cache
if
hasattr
(
dataset
,
"mmcif_data_cache"
)
else
None
selected_idx
=
[]
if
mmcif_data_cache
is
not
None
:
for
i
in
range
(
len
(
mmcif_data_cache
)):
mmcif_id
=
dataset
.
idx_to_mmcif_id
(
i
)
mmcif_data_cache_entry
=
mmcif_data_cache
[
mmcif_id
]
if
deterministic_multimer_train_filter
(
mmcif_data_cache_entry
,
max_resolution
=
9
):
selected_idx
.
append
(
i
)
logging
.
info
(
f
"Originally
{
len
(
mmcif_data_cache
)
}
mmcifs. After filtering:
{
len
(
selected_idx
)
}
"
)
else
:
selected_idx
=
list
(
range
(
len
(
dataset
.
_mmcif_id_to_idx_dict
)))
return
selected_idx
def
__getitem__
(
self
,
idx
):
dataset_idx
,
datapoint_idx
=
self
.
datapoints
[
idx
]
return
self
.
datasets
[
dataset_idx
][
datapoint_idx
]
def
__init__
(
self
,
datasets
:
Sequence
[
OpenFoldSingleMultimerDataset
],
probabilities
:
Sequence
[
float
],
epoch_len
:
int
,
generator
:
torch
.
Generator
=
None
,
_roll_at_init
:
bool
=
True
):
super
(
OpenFoldMultimerDataset
).
__init__
(
datasets
=
datasets
,
probabilities
=
probabilities
,
epoch_len
=
epoch_len
,
generator
=
generator
,
_roll_at_init
=
_roll_at_init
)
@
staticmethod
def
deterministic_train_filter
(
cache_entry
:
Any
,
max_resolution
:
float
=
9.
,
max_single_aa_prop
:
float
=
0.8
,
minimum_number_of_residues
:
int
=
200
,
)
->
bool
:
"""
Implement multimer training filtering criteria described in
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf Supplementary section 7.1
"""
resolution
=
cache_entry
.
get
(
"resolution"
,
None
)
seqs
=
cache_entry
[
"seqs"
]
return
all
([
resolution_filter
(
resolution
=
resolution
,
max_resolution
=
max_resolution
),
all_seq_len_filter
(
seqs
=
seqs
,
minimum_number_of_residues
=
minimum_number_of_residues
),
aa_count_filter
(
seqs
=
seqs
,
max_single_aa_prop
=
max_single_aa_prop
)])
@
staticmethod
def
get_stochastic_train_filter_prob
(
cache_entry
:
Any
,
)
->
float
:
# Stochastic filters
cluster_sizes
=
cache_entry
.
get
(
"cluster_sizes"
,
[])
chain_probs
=
[
1
/
c
for
c
in
cluster_sizes
if
c
>
0
]
if
chain_probs
:
return
sum
(
chain_probs
)
return
1.
def
looped_samples
(
self
,
dataset_idx
):
max_cache_len
=
int
(
self
.
epoch_len
*
self
.
probabilities
[
dataset_idx
])
dataset
=
self
.
datasets
[
dataset_idx
]
idx_iter
=
self
.
looped_shuffled_dataset_idx
(
len
(
dataset
))
mmcif_data_cache
=
dataset
.
mmcif_data_cache
while
True
:
weights
=
[]
idx
=
[]
for
_
in
range
(
max_cache_len
):
candidate_idx
=
next
(
idx_iter
)
mmcif_id
=
dataset
.
idx_to_mmcif_id
(
candidate_idx
)
mmcif_data_cache_entry
=
mmcif_data_cache
[
mmcif_id
]
if
not
self
.
deterministic_train_filter
(
mmcif_data_cache_entry
):
continue
p
=
self
.
get_stochastic_train_filter_prob
(
mmcif_data_cache_entry
,
)
weights
.
append
([
1.
-
p
,
p
])
idx
.
append
(
candidate_idx
)
def
__len__
(
self
):
return
self
.
epoch_len
samples
=
torch
.
multinomial
(
torch
.
tensor
(
weights
),
num_samples
=
1
,
generator
=
self
.
generator
,
)
samples
=
samples
.
squeeze
()
def
reroll
(
self
):
dataset_choices
=
torch
.
multinomial
(
torch
.
tensor
(
self
.
probabilities
),
num_samples
=
len
(
self
.
probabilities
),
replacement
=
True
,
generator
=
self
.
generator
,
)
cache
=
[
i
for
i
,
s
in
zip
(
idx
,
samples
)
if
s
]
self
.
datapoints
=
[]
for
dataset_idx
in
dataset_choices
:
selected_idx
=
self
.
filter_samples
(
dataset_idx
)
random
.
shuffle
(
selected_idx
)
if
len
(
selected_idx
)
<
self
.
epoch_len
:
self
.
epoch_len
=
len
(
selected_idx
)
logging
.
info
(
f
"self.epoch_len is
{
self
.
epoch_len
}
"
)
self
.
datapoints
+=
[(
dataset_idx
,
selected_idx
[
i
])
for
i
in
range
(
self
.
epoch_len
)
]
for
datapoint_idx
in
cache
:
yield
datapoint_idx
class
OpenFoldBatchCollator
:
def
__call__
(
self
,
prots
):
stack_fn
=
partial
(
torch
.
stack
,
dim
=
0
)
return
dict_multimap
(
stack_fn
,
prots
)
return
dict_multimap
(
stack_fn
,
prots
)
class
OpenFoldDataLoader
(
torch
.
utils
.
data
.
DataLoader
):
...
...
@@ -775,8 +756,8 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
stage_cfg
=
self
.
config
[
self
.
stage
]
max_iters
=
self
.
config
.
common
.
max_recycling_iters
if
(
stage_cfg
.
uniform_recycling
)
:
if
stage_cfg
.
uniform_recycling
:
recycling_probs
=
[
1.
/
(
max_iters
+
1
)
for
_
in
range
(
max_iters
+
1
)
]
...
...
@@ -785,15 +766,15 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
0.
for
_
in
range
(
max_iters
+
1
)
]
recycling_probs
[
-
1
]
=
1.
keyed_probs
.
append
(
(
"no_recycling_iters"
,
recycling_probs
)
)
keys
,
probs
=
zip
(
*
keyed_probs
)
max_len
=
max
([
len
(
p
)
for
p
in
probs
])
padding
=
[[
0.
]
*
(
max_len
-
len
(
p
))
for
p
in
probs
]
padding
=
[[
0.
]
*
(
max_len
-
len
(
p
))
for
p
in
probs
]
self
.
prop_keys
=
keys
self
.
prop_probs_tensor
=
torch
.
tensor
(
[
p
+
pad
for
p
,
pad
in
zip
(
probs
,
padding
)],
...
...
@@ -803,7 +784,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
def
_add_batch_properties
(
self
,
batch
):
samples
=
torch
.
multinomial
(
self
.
prop_probs_tensor
,
num_samples
=
1
,
# 1 per row
num_samples
=
1
,
# 1 per row
replacement
=
True
,
generator
=
self
.
generator
)
...
...
@@ -815,8 +796,8 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
for
i
,
key
in
enumerate
(
self
.
prop_keys
):
sample
=
int
(
samples
[
i
][
0
])
sample_tensor
=
torch
.
tensor
(
sample
,
device
=
aatype
.
device
,
sample
,
device
=
aatype
.
device
,
requires_grad
=
False
)
orig_shape
=
sample_tensor
.
shape
...
...
@@ -828,9 +809,9 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
)
batch
[
key
]
=
sample_tensor
if
(
key
==
"no_recycling_iters"
)
:
no_recycling
=
sample
if
key
==
"no_recycling_iters"
:
no_recycling
=
sample
resample_recycling
=
lambda
t
:
t
[...,
:
no_recycling
+
1
]
batch
=
tensor_tree_map
(
resample_recycling
,
batch
)
...
...
@@ -846,50 +827,33 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
return
_batch_prop_gen
(
it
)
class
OpenFoldMultimerDataLoader
(
torch
.
utils
.
data
.
DataLoader
):
def
__init__
(
self
,
*
args
,
config
,
stage
=
"train"
,
generator
=
None
,
**
kwargs
):
super
(
OpenFoldMultimerDataLoader
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
config
=
config
self
.
stage
=
stage
self
.
generator
=
generator
def
__iter__
(
self
):
it
=
super
().
__iter__
()
def
_batch_prop_gen
(
iterator
):
for
batch
in
iterator
:
yield
batch
return
_batch_prop_gen
(
it
)
class
OpenFoldDataModule
(
pl
.
LightningDataModule
):
def
__init__
(
self
,
config
:
mlc
.
ConfigDict
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_alignment_dir
:
Optional
[
str
]
=
None
,
train_chain_data_cache_path
:
Optional
[
str
]
=
None
,
distillation_data_dir
:
Optional
[
str
]
=
None
,
distillation_alignment_dir
:
Optional
[
str
]
=
None
,
distillation_chain_data_cache_path
:
Optional
[
str
]
=
None
,
val_data_dir
:
Optional
[
str
]
=
None
,
val_alignment_dir
:
Optional
[
str
]
=
None
,
predict_data_dir
:
Optional
[
str
]
=
None
,
predict_alignment_dir
:
Optional
[
str
]
=
None
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
train_filter_path
:
Optional
[
str
]
=
None
,
distillation_filter_path
:
Optional
[
str
]
=
None
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
batch_seed
:
Optional
[
int
]
=
None
,
train_epoch_len
:
int
=
50000
,
_distillation_structure_index_path
:
Optional
[
str
]
=
None
,
alignment_index_path
:
Optional
[
str
]
=
None
,
distillation_alignment_index_path
:
Optional
[
str
]
=
None
,
**
kwargs
):
config
:
mlc
.
ConfigDict
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_alignment_dir
:
Optional
[
str
]
=
None
,
train_chain_data_cache_path
:
Optional
[
str
]
=
None
,
distillation_data_dir
:
Optional
[
str
]
=
None
,
distillation_alignment_dir
:
Optional
[
str
]
=
None
,
distillation_chain_data_cache_path
:
Optional
[
str
]
=
None
,
val_data_dir
:
Optional
[
str
]
=
None
,
val_alignment_dir
:
Optional
[
str
]
=
None
,
predict_data_dir
:
Optional
[
str
]
=
None
,
predict_alignment_dir
:
Optional
[
str
]
=
None
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
train_filter_path
:
Optional
[
str
]
=
None
,
distillation_filter_path
:
Optional
[
str
]
=
None
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
batch_seed
:
Optional
[
int
]
=
None
,
train_epoch_len
:
int
=
50000
,
_distillation_structure_index_path
:
Optional
[
str
]
=
None
,
alignment_index_path
:
Optional
[
str
]
=
None
,
distillation_alignment_index_path
:
Optional
[
str
]
=
None
,
**
kwargs
):
super
(
OpenFoldDataModule
,
self
).
__init__
()
self
.
config
=
config
...
...
@@ -917,7 +881,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
batch_seed
=
batch_seed
self
.
train_epoch_len
=
train_epoch_len
if
(
self
.
train_data_dir
is
None
and
self
.
predict_data_dir
is
None
)
:
if
self
.
train_data_dir
is
None
and
self
.
predict_data_dir
is
None
:
raise
ValueError
(
'At least one of train_data_dir or predict_data_dir must be '
'specified'
...
...
@@ -925,65 +889,61 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
training_mode
=
self
.
train_data_dir
is
not
None
if
(
self
.
training_mode
and
train_alignment_dir
is
None
)
:
if
self
.
training_mode
and
train_alignment_dir
is
None
:
raise
ValueError
(
'In training mode, train_alignment_dir must be specified'
)
elif
(
not
self
.
training_mode
and
predict_alignment_dir
is
None
)
:
elif
not
self
.
training_mode
and
predict_alignment_dir
is
None
:
raise
ValueError
(
'In inference mode, predict_alignment_dir must be specified'
)
elif
(
val_data_dir
is
not
None
and
val_alignment_dir
is
None
)
:
)
elif
val_data_dir
is
not
None
and
val_alignment_dir
is
None
:
raise
ValueError
(
'If val_data_dir is specified, val_alignment_dir must '
'be specified as well'
)
)
# An ad-hoc measure for our particular filesystem restrictions
self
.
_distillation_structure_index
=
None
if
(
_distillation_structure_index_path
is
not
None
)
:
if
_distillation_structure_index_path
is
not
None
:
with
open
(
_distillation_structure_index_path
,
"r"
)
as
fp
:
self
.
_distillation_structure_index
=
json
.
load
(
fp
)
self
.
alignment_index
=
None
if
(
alignment_index_path
is
not
None
)
:
if
alignment_index_path
is
not
None
:
with
open
(
alignment_index_path
,
"r"
)
as
fp
:
self
.
alignment_index
=
json
.
load
(
fp
)
self
.
distillation_alignment_index
=
None
if
(
distillation_alignment_index_path
is
not
None
)
:
if
distillation_alignment_index_path
is
not
None
:
with
open
(
distillation_alignment_index_path
,
"r"
)
as
fp
:
self
.
distillation_alignment_index
=
json
.
load
(
fp
)
def
setup
(
self
):
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
max_template_date
=
self
.
max_template_date
,
config
=
self
.
config
,
kalign_binary_path
=
self
.
kalign_binary_path
,
template_release_dates_cache_path
=
self
.
template_release_dates_cache_path
,
obsolete_pdbs_file_path
=
self
.
obsolete_pdbs_file_path
,
)
if
(
self
.
training_mode
):
template_mmcif_dir
=
self
.
template_mmcif_dir
,
max_template_date
=
self
.
max_template_date
,
config
=
self
.
config
,
kalign_binary_path
=
self
.
kalign_binary_path
,
template_release_dates_cache_path
=
self
.
template_release_dates_cache_path
,
obsolete_pdbs_file_path
=
self
.
obsolete_pdbs_file_path
)
if
self
.
training_mode
:
train_dataset
=
dataset_gen
(
data_dir
=
self
.
train_data_dir
,
chain_data_cache_path
=
self
.
train_chain_data_cache_path
,
alignment_dir
=
self
.
train_alignment_dir
,
filter_path
=
self
.
train_filter_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
shuffle_top_k_prefiltered
=
self
.
config
.
train
.
shuffle_top_k_prefiltered
,
shuffle_top_k_prefiltered
=
self
.
config
.
train
.
shuffle_top_k_prefiltered
,
treat_pdb_as_distillation
=
False
,
mode
=
"train"
,
alignment_index
=
self
.
alignment_index
,
)
distillation_dataset
=
None
if
(
self
.
distillation_data_dir
is
not
None
)
:
if
self
.
distillation_data_dir
is
not
None
:
distillation_dataset
=
dataset_gen
(
data_dir
=
self
.
distillation_data_dir
,
chain_data_cache_path
=
self
.
distillation_chain_data_cache_path
,
...
...
@@ -997,8 +957,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
d_prob
=
self
.
config
.
train
.
distillation_prob
if
(
distillation_dataset
is
not
None
)
:
if
distillation_dataset
is
not
None
:
datasets
=
[
train_dataset
,
distillation_dataset
]
d_prob
=
self
.
config
.
train
.
distillation_prob
probabilities
=
[
1.
-
d_prob
,
d_prob
]
...
...
@@ -1007,10 +967,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
probabilities
=
[
1.
]
generator
=
None
if
(
self
.
batch_seed
is
not
None
)
:
if
self
.
batch_seed
is
not
None
:
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
self
.
batch_seed
+
1
)
self
.
train_dataset
=
OpenFoldDataset
(
datasets
=
datasets
,
probabilities
=
probabilities
,
...
...
@@ -1018,8 +978,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
generator
=
generator
,
_roll_at_init
=
False
,
)
if
(
self
.
val_data_dir
is
not
None
)
:
if
self
.
val_data_dir
is
not
None
:
self
.
eval_dataset
=
dataset_gen
(
data_dir
=
self
.
val_data_dir
,
alignment_dir
=
self
.
val_alignment_dir
,
...
...
@@ -1029,7 +989,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
else
:
self
.
eval_dataset
=
None
else
:
else
:
self
.
predict_dataset
=
dataset_gen
(
data_dir
=
self
.
predict_data_dir
,
alignment_dir
=
self
.
predict_alignment_dir
,
...
...
@@ -1040,18 +1000,17 @@ class OpenFoldDataModule(pl.LightningDataModule):
def
_gen_dataloader
(
self
,
stage
):
generator
=
None
if
(
self
.
batch_seed
is
not
None
)
:
if
self
.
batch_seed
is
not
None
:
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
self
.
batch_seed
)
dataset
=
None
if
(
stage
==
"train"
):
if
stage
==
"train"
:
dataset
=
self
.
train_dataset
# Filter the dataset, if necessary
dataset
.
reroll
()
elif
(
stage
==
"eval"
)
:
elif
stage
==
"eval"
:
dataset
=
self
.
eval_dataset
elif
(
stage
==
"predict"
)
:
elif
stage
==
"predict"
:
dataset
=
self
.
predict_dataset
else
:
raise
ValueError
(
"Invalid stage"
)
...
...
@@ -1071,15 +1030,15 @@ class OpenFoldDataModule(pl.LightningDataModule):
return
dl
def
train_dataloader
(
self
):
return
self
.
_gen_dataloader
(
"train"
)
return
self
.
_gen_dataloader
(
"train"
)
def
val_dataloader
(
self
):
if
(
self
.
eval_dataset
is
not
None
)
:
if
self
.
eval_dataset
is
not
None
:
return
self
.
_gen_dataloader
(
"eval"
)
return
None
def
predict_dataloader
(
self
):
return
self
.
_gen_dataloader
(
"predict"
)
return
self
.
_gen_dataloader
(
"predict"
)
class
OpenFoldMultimerDataModule
(
OpenFoldDataModule
):
...
...
@@ -1091,16 +1050,19 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
scripts/generate_mmcif_cache.py mmcif_data_cache_path should be
a file that record what chain(s) each mmcif file has
"""
def
__init__
(
self
,
config
:
mlc
.
ConfigDict
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
def
__init__
(
self
,
config
:
mlc
.
ConfigDict
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_mmcif_data_cache_path
:
Optional
[
str
]
=
None
,
val_mmcif_data_cache_path
:
Optional
[
str
]
=
None
,
train_mmcif_data_cache_path
:
Optional
[
str
]
=
None
,
val_mmcif_data_cache_path
:
Optional
[
str
]
=
None
,
**
kwargs
):
super
(
OpenFoldMultimerDataModule
,
self
).
__init__
(
config
,
template_mmcif_dir
,
max_template_date
,
train_data_dir
,
**
kwargs
)
super
(
OpenFoldMultimerDataModule
,
self
).
__init__
(
config
,
template_mmcif_dir
,
max_template_date
,
train_data_dir
,
**
kwargs
)
self
.
train_mmcif_data_cache_path
=
train_mmcif_data_cache_path
self
.
training_mode
=
self
.
train_data_dir
is
not
None
self
.
val_mmcif_data_cache_path
=
val_mmcif_data_cache_path
...
...
@@ -1108,32 +1070,28 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
def
setup
(
self
):
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleMultimerDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
max_template_date
=
self
.
max_template_date
,
config
=
self
.
config
,
kalign_binary_path
=
self
.
kalign_binary_path
,
template_release_dates_cache_path
=
self
.
template_release_dates_cache_path
,
obsolete_pdbs_file_path
=
self
.
obsolete_pdbs_file_path
,
)
if
(
self
.
training_mode
):
template_mmcif_dir
=
self
.
template_mmcif_dir
,
max_template_date
=
self
.
max_template_date
,
config
=
self
.
config
,
kalign_binary_path
=
self
.
kalign_binary_path
,
template_release_dates_cache_path
=
self
.
template_release_dates_cache_path
,
obsolete_pdbs_file_path
=
self
.
obsolete_pdbs_file_path
)
if
self
.
training_mode
:
train_dataset
=
dataset_gen
(
data_dir
=
self
.
train_data_dir
,
mmcif_data_cache_path
=
self
.
train_mmcif_data_cache_path
,
alignment_dir
=
self
.
train_alignment_dir
,
filter_path
=
self
.
train_filter_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
shuffle_top_k_prefiltered
=
self
.
config
.
train
.
shuffle_top_k_prefiltered
,
shuffle_top_k_prefiltered
=
self
.
config
.
train
.
shuffle_top_k_prefiltered
,
treat_pdb_as_distillation
=
False
,
mode
=
"train"
,
alignment_index
=
self
.
alignment_index
,
)
distillation_dataset
=
None
if
(
self
.
distillation_data_dir
is
not
None
)
:
if
self
.
distillation_data_dir
is
not
None
:
distillation_dataset
=
dataset_gen
(
data_dir
=
self
.
distillation_data_dir
,
alignment_dir
=
self
.
distillation_alignment_dir
,
...
...
@@ -1146,8 +1104,8 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
)
d_prob
=
self
.
config
.
train
.
distillation_prob
if
(
distillation_dataset
is
not
None
)
:
if
distillation_dataset
is
not
None
:
datasets
=
[
train_dataset
,
distillation_dataset
]
d_prob
=
self
.
config
.
train
.
distillation_prob
probabilities
=
[
1.
-
d_prob
,
d_prob
]
...
...
@@ -1156,10 +1114,10 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
probabilities
=
[
1.
]
generator
=
None
if
(
self
.
batch_seed
is
not
None
)
:
if
self
.
batch_seed
is
not
None
:
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
self
.
batch_seed
+
1
)
self
.
train_dataset
=
OpenFoldMultimerDataset
(
datasets
=
datasets
,
probabilities
=
probabilities
,
...
...
@@ -1167,8 +1125,8 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
generator
=
generator
,
_roll_at_init
=
True
,
)
if
(
self
.
val_data_dir
is
not
None
)
:
if
self
.
val_data_dir
is
not
None
:
self
.
eval_dataset
=
dataset_gen
(
data_dir
=
self
.
val_data_dir
,
alignment_dir
=
self
.
val_alignment_dir
,
...
...
@@ -1179,7 +1137,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
)
else
:
self
.
eval_dataset
=
None
else
:
else
:
self
.
predict_dataset
=
dataset_gen
(
data_dir
=
self
.
predict_data_dir
,
alignment_dir
=
self
.
predict_alignment_dir
,
...
...
@@ -1187,32 +1145,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
max_template_hits
=
self
.
config
.
predict
.
max_template_hits
,
mode
=
"predict"
,
)
def
_gen_dataloader
(
self
,
stage
):
generator
=
None
if
(
self
.
batch_seed
is
not
None
):
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
self
.
batch_seed
)
dataset
=
None
if
(
stage
==
"train"
):
dataset
=
self
.
train_dataset
# Filter the dataset, if necessary
dataset
.
reroll
()
elif
(
stage
==
"eval"
):
dataset
=
self
.
eval_dataset
elif
(
stage
==
"predict"
):
dataset
=
self
.
predict_dataset
else
:
raise
ValueError
(
"Invalid stage"
)
dl
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
1
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
)
return
dl
class
DummyDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
batch_path
):
...
...
openfold/data/feature_pipeline.py
View file @
0cf1541c
...
...
@@ -93,24 +93,11 @@ def np_example_to_features(
with
torch
.
no_grad
():
if
is_multimer
:
if
mode
==
'train'
:
features
,
gt_features
=
input_pipeline_multimer
.
process_tensors_from_config
(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
is_training
=
True
)
return
{
k
:
v
for
k
,
v
in
features
.
items
()},
gt_features
else
:
features
=
input_pipeline_multimer
.
process_tensors_from_config
(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
is_training
=
False
)
return
{
k
:
v
for
k
,
v
in
features
.
items
()}
features
=
input_pipeline_multimer
.
process_tensors_from_config
(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
)
else
:
features
=
input_pipeline
.
process_tensors_from_config
(
tensor_dict
,
...
...
openfold/data/input_pipeline_multimer.py
View file @
0cf1541c
...
...
@@ -21,16 +21,17 @@ from openfold.data import (
data_transforms_multimer
,
)
def
grountruth_transforms_fns
():
transforms
=
[
data_transforms
.
make_atom14_masks
,
data_transforms
.
make_atom14_positions
,
data_transforms
.
atom37_to_frames
,
data_transforms
.
atom37_to_torsion_angles
(
""
),
data_transforms
.
make_pseudo_beta
(
""
),
data_transforms
.
get_backbone_frames
,
data_transforms
.
get_chi_angles
,
]
return
transforms
def
groundtruth_transforms_fns
():
transforms
=
[
data_transforms
.
make_atom14_masks
,
data_transforms
.
make_atom14_positions
,
data_transforms
.
atom37_to_frames
,
data_transforms
.
atom37_to_torsion_angles
(
""
),
data_transforms
.
make_pseudo_beta
(
""
),
data_transforms
.
get_backbone_frames
,
data_transforms
.
get_chi_angles
]
return
transforms
def
nonensembled_transform_fns
():
"""Input pipeline data transformers that are not ensembled."""
...
...
@@ -112,20 +113,24 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
return
transforms
def
prepare_ground_truth_features
(
tensors
):
"""Prepare ground truth features that are only needed for loss calculation during training"""
GROUNDTRUTH_FEATURES
=
[
'all_atom_mask'
,
'all_atom_positions'
,
'asym_id'
,
'sym_id'
,
'entity_id'
]
gt_tensors
=
{
k
:
v
for
k
,
v
in
tensors
.
items
()
if
k
in
GROUNDTRUTH_FEATURES
}
gt_features
=
[
'all_atom_mask'
,
'all_atom_positions'
,
'asym_id'
,
'sym_id'
,
'entity_id'
]
gt_tensors
=
{
k
:
v
for
k
,
v
in
tensors
.
items
()
if
k
in
gt_features
}
gt_tensors
[
'aatype'
]
=
tensors
[
'aatype'
].
to
(
torch
.
long
)
gt_tensors
=
compose
(
grountruth_transforms_fns
())(
gt_tensors
)
gt_tensors
=
compose
(
groun
d
truth_transforms_fns
())(
gt_tensors
)
return
gt_tensors
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
,
is_training
=
False
):
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
):
"""Based on the config, apply filters and transformations to the data."""
if
is_training
:
gt_tensors
=
prepare_ground_truth_features
(
tensors
)
process_gt_feats
=
mode_cfg
.
supervised
gt_tensors
=
{}
if
process_gt_feats
:
gt_tensors
=
prepare_ground_truth_features
(
tensors
)
ensemble_seed
=
random
.
randint
(
0
,
torch
.
iinfo
(
torch
.
int32
).
max
)
tensors
[
'aatype'
]
=
tensors
[
'aatype'
].
to
(
torch
.
long
)
...
...
@@ -152,10 +157,10 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg,is_training=False)
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
torch
.
arange
(
num_recycling
+
1
)
)
if
is_training
:
return
tensors
,
gt_tensors
else
:
return
tensors
if
process_gt_feats
:
tensors
[
'gt_features'
]
=
gt_tensors
return
tensors
@
data_transforms
.
curry1
def
compose
(
x
,
fs
):
...
...
openfold/utils/loss.py
View file @
0cf1541c
...
...
@@ -13,35 +13,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
logging
import
ml_collections
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
torch.distributions.bernoulli
import
Bernoulli
from
typing
import
Dict
,
Optional
,
Tuple
from
openfold.np
import
residue_constants
from
openfold.utils
import
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.geometry.vector
import
Vec3Array
,
euclidean_distance
from
openfold.utils.all_atom_multimer
import
get_rc_tensor
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
masked_mean
,
permute_final_dims
,
batched_gather
,
)
import
random
from
openfold.np
import
residue_constants
as
rc
import
logging
import
procrustes
import
logging
import
procrustes
from
openfold.utils.tensor_utils
import
tensor_tree_map
import
gc
logger
=
logging
.
getLogger
(
__name__
)
def
softmax_cross_entropy
(
logits
,
labels
):
loss
=
-
1
*
torch
.
sum
(
labels
*
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
),
...
...
@@ -185,11 +180,10 @@ def backbone_loss(
eps
:
float
=
1e-4
,
**
kwargs
,
)
->
torch
.
Tensor
:
### need to check if the traj belongs to 4*4 matrix or a tensor_7
if
traj
.
shape
[
-
1
]
==
7
:
if
traj
.
shape
[
-
1
]
==
7
:
pred_aff
=
Rigid
.
from_tensor_7
(
traj
)
elif
traj
.
shape
[
-
1
]
==
4
:
elif
traj
.
shape
[
-
1
]
==
4
:
pred_aff
=
Rigid
.
from_tensor_4x4
(
traj
)
pred_aff
=
Rigid
(
...
...
@@ -256,10 +250,10 @@ def sidechain_loss(
**
kwargs
,
)
->
torch
.
Tensor
:
renamed_gt_frames
=
(
1.0
-
alt_naming_is_better
[...,
None
,
None
,
None
]
)
*
rigidgroups_gt_frames
+
alt_naming_is_better
[
...,
None
,
None
,
None
]
*
rigidgroups_alt_gt_frames
1.0
-
alt_naming_is_better
[...,
None
,
None
,
None
]
)
*
rigidgroups_gt_frames
+
alt_naming_is_better
[
...,
None
,
None
,
None
]
*
rigidgroups_alt_gt_frames
# Steamroll the inputs
sidechain_frames
=
sidechain_frames
[
-
1
]
...
...
@@ -297,7 +291,6 @@ def fape_loss(
batch
:
Dict
[
str
,
torch
.
Tensor
],
config
:
ml_collections
.
ConfigDict
,
)
->
torch
.
Tensor
:
traj
=
out
[
"sm"
][
"frames"
]
asym_id
=
batch
.
get
(
"asym_id"
)
if
asym_id
is
not
None
:
...
...
@@ -328,7 +321,7 @@ def fape_loss(
)
loss
=
weighted_bb_loss
+
config
.
sidechain
.
weight
*
sc_loss
# Average over the batch dimension
loss
=
torch
.
mean
(
loss
)
...
...
@@ -390,7 +383,7 @@ def supervised_chi_loss(
(
true_chi_shifted
-
pred_angles
)
**
2
,
dim
=-
1
)
sq_chi_error
=
torch
.
minimum
(
sq_chi_error
,
sq_chi_error_shifted
)
# The ol' switcheroo
sq_chi_error
=
sq_chi_error
.
permute
(
*
range
(
len
(
sq_chi_error
.
shape
))[
1
:
-
2
],
0
,
-
2
,
-
1
...
...
@@ -502,7 +495,7 @@ def lddt_ca(
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
return
lddt
(
all_atom_pred_pos
,
...
...
@@ -532,19 +525,19 @@ def lddt_loss(
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
score
=
lddt
(
all_atom_pred_pos
,
all_atom_positions
,
all_atom_mask
,
cutoff
=
cutoff
,
all_atom_pred_pos
,
all_atom_positions
,
all_atom_mask
,
cutoff
=
cutoff
,
eps
=
eps
)
# TODO: Remove after initial pipeline testing
score
=
torch
.
nan_to_num
(
score
,
nan
=
torch
.
nanmean
(
score
))
score
[
score
<
0
]
=
0
score
[
score
<
0
]
=
0
score
=
score
.
detach
()
bin_index
=
torch
.
floor
(
score
*
no_bins
).
long
()
...
...
@@ -586,7 +579,7 @@ def distogram_loss(
device
=
logits
.
device
,
)
boundaries
=
boundaries
**
2
dists
=
torch
.
sum
(
(
pseudo_beta
[...,
None
,
:]
-
pseudo_beta
[...,
None
,
:,
:])
**
2
,
dim
=-
1
,
...
...
@@ -707,12 +700,12 @@ def compute_tm(
n
=
residue_weights
.
shape
[
-
1
]
pair_mask
=
residue_weights
.
new_ones
((
n
,
n
),
dtype
=
torch
.
int32
)
if
interface
and
(
asym_id
is
not
None
):
if
len
(
asym_id
.
shape
)
>
1
:
assert
len
(
asym_id
.
shape
)
<=
2
if
len
(
asym_id
.
shape
)
>
1
:
assert
len
(
asym_id
.
shape
)
<=
2
batch_size
=
asym_id
.
shape
[
0
]
pair_mask
=
residue_weights
.
new_ones
((
batch_size
,
n
,
n
),
dtype
=
torch
.
int32
)
pair_mask
=
residue_weights
.
new_ones
((
batch_size
,
n
,
n
),
dtype
=
torch
.
int32
)
pair_mask
*=
(
asym_id
[...,
None
]
!=
asym_id
[...,
None
,
:]).
to
(
dtype
=
pair_mask
.
dtype
)
predicted_tm_term
*=
pair_mask
pair_residue_weights
=
pair_mask
*
(
...
...
@@ -727,6 +720,7 @@ def compute_tm(
argmax
=
(
weighted
==
torch
.
max
(
weighted
)).
nonzero
()[
0
]
return
per_alignment
[
tuple
(
argmax
)]
def
tm_loss
(
logits
,
final_affine_tensor
,
...
...
@@ -741,9 +735,9 @@ def tm_loss(
**
kwargs
,
):
# first check whether this is a tensor_7 or tensor_4*4
if
final_affine_tensor
.
shape
[
-
1
]
==
7
:
if
final_affine_tensor
.
shape
[
-
1
]
==
7
:
pred_affine
=
Rigid
.
from_tensor_7
(
final_affine_tensor
)
elif
final_affine_tensor
.
shape
[
-
1
]
==
4
:
elif
final_affine_tensor
.
shape
[
-
1
]
==
4
:
pred_affine
=
Rigid
.
from_tensor_4x4
(
final_affine_tensor
)
backbone_rigid
=
Rigid
.
from_tensor_4x4
(
backbone_rigid_tensor
)
...
...
@@ -844,19 +838,19 @@ def between_residue_bond_loss(
# The C-N bond to proline has slightly different length because of the ring.
next_is_proline
=
aatype
[...,
1
:]
==
residue_constants
.
resname_to_idx
[
"PRO"
]
gt_length
=
(
~
next_is_proline
)
*
residue_constants
.
between_res_bond_length_c_n
[
0
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_c_n
[
1
]
~
next_is_proline
)
*
residue_constants
.
between_res_bond_length_c_n
[
0
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_c_n
[
1
]
gt_stddev
=
(
~
next_is_proline
)
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
0
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
1
]
~
next_is_proline
)
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
0
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
1
]
c_n_bond_length_error
=
torch
.
sqrt
(
eps
+
(
c_n_bond_length
-
gt_length
)
**
2
)
c_n_loss_per_residue
=
torch
.
nn
.
functional
.
relu
(
...
...
@@ -1082,7 +1076,7 @@ def between_residue_clash_loss(
# Compute the hard clash mask.
# shape (N, N, 14, 14)
clash_mask
=
dists_mask
*
(
dists
<
(
dists_lower_bound
-
overlap_tolerance_hard
)
dists
<
(
dists_lower_bound
-
overlap_tolerance_hard
)
)
per_atom_num_clash
=
torch
.
sum
(
clash_mask
,
dim
=
(
-
4
,
-
2
))
+
torch
.
sum
(
clash_mask
,
dim
=
(
-
3
,
-
1
))
...
...
@@ -1098,7 +1092,7 @@ def between_residue_clash_loss(
"mean_loss"
:
mean_loss
,
# shape ()
"per_atom_loss_sum"
:
per_atom_loss_sum
,
# shape (N, 14)
"per_atom_clash_mask"
:
per_atom_clash_mask
,
# shape (N, 14)
"per_atom_num_clash"
:
per_atom_num_clash
# shape (N, 14)
"per_atom_num_clash"
:
per_atom_num_clash
# shape (N, 14)
}
...
...
@@ -1221,7 +1215,7 @@ def find_structural_violations(
atomtype_radius
=
atom14_pred_positions
.
new_tensor
(
atomtype_radius
)
#TODO: Consolidate monomer/multimer modes
#
TODO: Consolidate monomer/multimer modes
asym_id
=
batch
.
get
(
"asym_id"
)
if
asym_id
is
not
None
:
residx_atom14_to_atom37
=
get_rc_tensor
(
...
...
@@ -1372,8 +1366,8 @@ def extreme_ca_ca_distance_violations(
eps
+
torch
.
sum
((
this_ca_pos
-
next_ca_pos
)
**
2
,
dim
=-
1
)
)
violations
=
(
ca_ca_distance
-
residue_constants
.
ca_ca
)
>
max_angstrom_tolerance
ca_ca_distance
-
residue_constants
.
ca_ca
)
>
max_angstrom_tolerance
mask
=
this_ca_mask
*
next_ca_mask
*
has_no_gap_mask
mean
=
masked_mean
(
mask
,
violations
,
-
1
)
return
mean
...
...
@@ -1559,16 +1553,16 @@ def compute_renamed_ground_truth(
alt_naming_is_better
=
(
alt_per_res_lddt
<
per_res_lddt
).
type
(
fp_type
)
renamed_atom14_gt_positions
=
(
1.0
-
alt_naming_is_better
[...,
None
,
None
]
)
*
atom14_gt_positions
+
alt_naming_is_better
[
...,
None
,
None
]
*
atom14_alt_gt_positions
1.0
-
alt_naming_is_better
[...,
None
,
None
]
)
*
atom14_gt_positions
+
alt_naming_is_better
[
...,
None
,
None
]
*
atom14_alt_gt_positions
renamed_atom14_gt_mask
=
(
1.0
-
alt_naming_is_better
[...,
None
]
)
*
atom14_gt_exists
+
alt_naming_is_better
[...,
None
]
*
batch
[
"atom14_alt_gt_exists"
]
1.0
-
alt_naming_is_better
[...,
None
]
)
*
atom14_gt_exists
+
alt_naming_is_better
[...,
None
]
*
batch
[
"atom14_alt_gt_exists"
]
return
{
"alt_naming_is_better"
:
alt_naming_is_better
,
...
...
@@ -1591,13 +1585,13 @@ def experimentally_resolved_loss(
loss
=
torch
.
sum
(
errors
*
atom37_atom_exists
,
dim
=-
1
)
loss
=
loss
/
(
eps
+
torch
.
sum
(
atom37_atom_exists
,
dim
=
(
-
1
,
-
2
)).
unsqueeze
(
-
1
))
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
loss
*
(
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
)
loss
=
torch
.
mean
(
loss
)
return
loss
...
...
@@ -1701,20 +1695,17 @@ def compute_rmsd(
eps
:
float
=
1e-6
,
)
->
torch
.
Tensor
:
sq_diff
=
torch
.
square
(
true_atom_pos
-
pred_atom_pos
).
sum
(
dim
=-
1
,
keepdim
=
False
)
del
true_atom_pos
del
pred_atom_pos
gc
.
collect
()
if
atom_mask
is
not
None
:
sq_diff
=
torch
.
masked_select
(
sq_diff
,
atom_mask
.
to
(
sq_diff
.
device
))
msd
=
torch
.
mean
(
sq_diff
)
msd
=
torch
.
nan_to_num
(
msd
,
nan
=
1e8
)
return
torch
.
sqrt
(
msd
+
eps
)
# prevent sqrt 0
return
torch
.
sqrt
(
msd
+
eps
)
# prevent sqrt 0
def
kabsch_rotation
(
P
,
Q
):
"""
Use procrustes package to calculate best rotation that minimises
the RMSD betwee P and Q
Use procrustes package to calculate
the
best rotation that minimises
the RMSD betwee
n
P and Q
The optimal rotation matrix was calculated using
the rotational() function from procrustes package. Details can be found here:
...
...
@@ -1728,12 +1719,12 @@ def kabsch_rotation(P, Q):
A 3*3 rotation matrix
"""
assert
P
.
shape
==
torch
.
Size
([
Q
.
shape
[
0
],
Q
.
shape
[
1
]])
assert
P
.
shape
==
torch
.
Size
([
Q
.
shape
[
0
],
Q
.
shape
[
1
]])
rotation
=
procrustes
.
rotational
(
P
.
detach
().
cpu
().
float
().
numpy
(),
Q
.
detach
().
cpu
().
float
().
numpy
(),
translate
=
False
,
scale
=
False
)
Q
.
detach
().
cpu
().
float
().
numpy
(),
translate
=
False
,
scale
=
False
)
# Rotation.t doesn't mean transpose, t only means get the matrix out of the procruste object
rotation
=
torch
.
tensor
(
rotation
.
t
,
dtype
=
torch
.
float
)
assert
rotation
.
shape
==
torch
.
Size
([
3
,
3
])
rotation
=
torch
.
tensor
(
rotation
.
t
,
dtype
=
torch
.
float
)
assert
rotation
.
shape
==
torch
.
Size
([
3
,
3
])
return
rotation
.
to
(
device
=
P
.
device
,
dtype
=
P
.
dtype
)
...
...
@@ -1756,7 +1747,7 @@ def get_optimal_transform(
# sometimes using fake test inputs generates NaN in the predicted atom positions
# #
logging
.
warning
(
f
"src_atom has nan or inf"
)
src_atoms
=
torch
.
nan_to_num
(
src_atoms
,
nan
=
0.0
,
posinf
=
1.0
,
neginf
=
1.0
)
src_atoms
=
torch
.
nan_to_num
(
src_atoms
,
nan
=
0.0
,
posinf
=
1.0
,
neginf
=
1.0
)
if
mask
is
not
None
:
assert
mask
.
dtype
==
torch
.
bool
...
...
@@ -1767,21 +1758,15 @@ def get_optimal_transform(
else
:
src_atoms
=
src_atoms
[
mask
,
:]
tgt_atoms
=
tgt_atoms
[
mask
,
:]
src_center
=
src_atoms
.
mean
(
-
2
,
keepdim
=
True
,
dtype
=
src_atoms
.
dtype
)
tgt_center
=
tgt_atoms
.
mean
(
-
2
,
keepdim
=
True
,
dtype
=
src_atoms
.
dtype
)
r
=
kabsch_rotation
(
src_atoms
,
tgt_atoms
)
del
src_atoms
,
tgt_atoms
,
gc
.
collect
()
src_center
=
src_atoms
.
mean
(
-
2
,
keepdim
=
True
,
dtype
=
src_atoms
.
dtype
)
tgt_center
=
tgt_atoms
.
mean
(
-
2
,
keepdim
=
True
,
dtype
=
src_atoms
.
dtype
)
r
=
kabsch_rotation
(
src_atoms
,
tgt_atoms
)
x
=
tgt_center
-
src_center
@
r
del
tgt_center
,
src_center
,
mask
gc
.
collect
()
return
r
,
x
def
get_least_asym_entity_or_longest_length
(
batch
,
input_asym_id
):
def
get_least_asym_entity_or_longest_length
(
batch
,
input_asym_id
):
"""
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
one of the A as anchor
...
...
@@ -1805,7 +1790,7 @@ def get_least_asym_entity_or_longest_length(batch,input_asym_id):
for
entity_id
in
unique_entity_ids
:
asym_ids
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
entity_id
])
entity_asym_count
[
int
(
entity_id
)]
=
len
(
asym_ids
)
# Calculate entity length
entity_mask
=
(
batch
[
"entity_id"
]
==
entity_id
)
entity_length
[
int
(
entity_id
)]
=
entity_mask
.
sum
().
item
()
...
...
@@ -1821,13 +1806,14 @@ def get_least_asym_entity_or_longest_length(batch,input_asym_id):
# If still multiple entities, return a random one
if
len
(
least_asym_entities
)
>
1
:
least_asym_entities
=
random
.
choice
(
least_asym_entities
)
assert
len
(
least_asym_entities
)
==
1
assert
len
(
least_asym_entities
)
==
1
least_asym_entities
=
least_asym_entities
[
0
]
anchor_gt_asym_id
=
random
.
choice
(
entity_2_asym_list
[
least_asym_entities
])
anchor_pred_asym_ids
=
[
id
for
id
in
entity_2_asym_list
[
least_asym_entities
]
if
id
in
input_asym_id
]
return
anchor_gt_asym_id
,
anchor_pred_asym_ids
def
greedy_align
(
batch
,
per_asym_residue_index
,
...
...
@@ -1843,7 +1829,7 @@ def greedy_align(
"""
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
align
=
[]
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
for
cur_asym_id
in
unique_asym_ids
:
i
=
int
(
cur_asym_id
-
1
)
asym_mask
=
batch
[
"asym_id"
]
==
cur_asym_id
...
...
@@ -1857,13 +1843,13 @@ def greedy_align(
for
next_asym_id
in
cur_asym_list
:
j
=
int
(
next_asym_id
-
1
)
if
not
used
[
j
]:
# possible candidate
cropped_pos
=
torch
.
index_select
(
true_ca_poses
[
j
],
1
,
cur_residue_index
)
mask
=
torch
.
index_select
(
true_ca_masks
[
j
],
1
,
cur_residue_index
)
cropped_pos
=
torch
.
index_select
(
true_ca_poses
[
j
],
1
,
cur_residue_index
)
mask
=
torch
.
index_select
(
true_ca_masks
[
j
],
1
,
cur_residue_index
)
rmsd
=
compute_rmsd
(
torch
.
squeeze
(
cropped_pos
,
0
),
torch
.
squeeze
(
cur_pred_pos
,
0
),
torch
.
squeeze
(
cropped_pos
,
0
),
torch
.
squeeze
(
cur_pred_pos
,
0
),
(
cur_pred_mask
*
mask
).
bool
()
)
if
(
rmsd
is
not
None
)
and
(
rmsd
<
best_rmsd
)
:
if
rmsd
is
not
None
and
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
best_idx
=
j
...
...
@@ -1873,15 +1859,15 @@ def greedy_align(
return
align
def
pad_features
(
feature_tensor
,
nres_pad
,
pad_dim
):
def
pad_features
(
feature_tensor
,
nres_pad
,
pad_dim
):
"""Pad input feature tensor"""
pad_shape
=
list
(
feature_tensor
.
shape
)
pad_shape
[
pad_dim
]
=
nres_pad
padding_tensor
=
feature_tensor
.
new_zeros
(
pad_shape
,
device
=
feature_tensor
.
device
)
return
torch
.
concat
((
feature_tensor
,
padding_tensor
),
dim
=
pad_dim
)
padding_tensor
=
feature_tensor
.
new_zeros
(
pad_shape
,
device
=
feature_tensor
.
device
)
return
torch
.
concat
((
feature_tensor
,
padding_tensor
),
dim
=
pad_dim
)
def
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
):
def
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
):
"""
Merge ground truth labels according to the permutation results
...
...
@@ -1898,24 +1884,25 @@ def merge_labels(per_asym_residue_index,labels, align,original_nres):
label
=
labels
[
j
][
k
]
# to 1-based
cur_residue_index
=
per_asym_residue_index
[
i
+
1
]
if
len
(
v
.
shape
)
<=
1
or
"template"
in
k
or
"row_mask"
in
k
:
if
len
(
v
.
shape
)
<=
1
or
"template"
in
k
or
"row_mask"
in
k
:
continue
else
:
dimension_to_merge
=
1
cur_out
[
i
]
=
label
.
index_select
(
dimension_to_merge
,
cur_residue_index
)
cur_out
[
i
]
=
label
.
index_select
(
dimension_to_merge
,
cur_residue_index
)
cur_out
=
[
x
[
1
]
for
x
in
sorted
(
cur_out
.
items
())]
if
len
(
cur_out
)
>
0
:
if
len
(
cur_out
)
>
0
:
new_v
=
torch
.
concat
(
cur_out
,
dim
=
dimension_to_merge
)
# below check whether padding is needed
if
new_v
.
shape
[
dimension_to_merge
]
!=
original_nres
:
if
new_v
.
shape
[
dimension_to_merge
]
!=
original_nres
:
nres_pad
=
original_nres
-
new_v
.
shape
[
dimension_to_merge
]
new_v
=
pad_features
(
new_v
,
nres_pad
,
pad_dim
=
dimension_to_merge
)
new_v
=
pad_features
(
new_v
,
nres_pad
,
pad_dim
=
dimension_to_merge
)
outs
[
k
]
=
new_v
return
outs
class
AlphaFoldLoss
(
nn
.
Module
):
"""Aggregation of the various losses described in the supplement"""
def
__init__
(
self
,
config
):
super
(
AlphaFoldLoss
,
self
).
__init__
()
self
.
config
=
config
...
...
@@ -1974,13 +1961,13 @@ class AlphaFoldLoss(nn.Module):
),
}
if
(
self
.
config
.
tm
.
enabled
)
:
if
self
.
config
.
tm
.
enabled
:
loss_fns
[
"tm"
]
=
lambda
:
tm_loss
(
logits
=
out
[
"tm_logits"
],
**
{
**
batch
,
**
out
,
**
self
.
config
.
tm
},
)
if
(
self
.
config
.
chain_center_of_mass
.
enabled
)
:
if
self
.
config
.
chain_center_of_mass
.
enabled
:
loss_fns
[
"chain_center_of_mass"
]
=
lambda
:
chain_center_of_mass_loss
(
all_atom_pred_pos
=
out
[
"final_atom_positions"
],
**
{
**
batch
,
**
self
.
config
.
chain_center_of_mass
},
...
...
@@ -1991,11 +1978,11 @@ class AlphaFoldLoss(nn.Module):
for
loss_name
,
loss_fn
in
loss_fns
.
items
():
weight
=
self
.
config
[
loss_name
].
weight
loss
=
loss_fn
()
if
(
torch
.
isnan
(
loss
)
or
torch
.
isinf
(
loss
)
)
:
#for k,v in batch.items():
# if
(
torch.any(torch.isnan(v)) or torch.any(torch.isinf(v))
)
:
if
torch
.
isnan
(
loss
)
or
torch
.
isinf
(
loss
):
#
for k,v in batch.items():
# if
torch.any(torch.isnan(v)) or torch.any(torch.isinf(v)):
# logging.warning(f"{k}: is nan")
#logging.warning(f"{loss_name}: {loss}")
#
logging.warning(f"{loss_name}: {loss}")
logging
.
warning
(
f
"
{
loss_name
}
loss is NaN. Skipping..."
)
loss
=
loss
.
new_tensor
(
0.
,
requires_grad
=
True
)
cum_loss
=
cum_loss
+
weight
*
loss
...
...
@@ -2010,18 +1997,18 @@ class AlphaFoldLoss(nn.Module):
losses
[
"loss"
]
=
cum_loss
.
detach
().
clone
()
if
(
not
_return_breakdown
)
:
if
not
_return_breakdown
:
return
cum_loss
return
cum_loss
,
losses
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
):
if
(
not
_return_breakdown
)
:
cum_loss
=
self
.
loss
(
out
,
batch
,
_return_breakdown
)
return
cum_loss
else
:
cum_loss
,
losses
=
self
.
loss
(
out
,
batch
,
_return_breakdown
)
return
cum_loss
,
losses
if
not
_return_breakdown
:
cum_loss
=
self
.
loss
(
out
,
batch
,
_return_breakdown
)
return
cum_loss
else
:
cum_loss
,
losses
=
self
.
loss
(
out
,
batch
,
_return_breakdown
)
return
cum_loss
,
losses
class
AlphaFoldMultimerLoss
(
AlphaFoldLoss
):
...
...
@@ -2029,12 +2016,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Add multi-chain permutation on top of
AlphaFoldLoss
"""
def
__init__
(
self
,
config
):
super
(
AlphaFoldMultimerLoss
,
self
).
__init__
(
config
)
self
.
config
=
config
@
staticmethod
def
split_ground_truth_labels
(
batch
,
REQUIRED_FEATURES
,
split_dim
=
1
):
def
split_ground_truth_labels
(
gt_features
):
"""
Splits ground truth features according to chains
...
...
@@ -2042,26 +2030,26 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation
"""
unique_asym_ids
,
asym_id_counts
=
torch
.
unique
(
batch
[
"asym_id"
],
sorted
=
Fals
e
,
return_counts
=
True
)
unique_asym_ids
,
asym_id_counts
=
unique_asym_ids
.
tolist
(),
asym_id_counts
.
tolist
()
if
0
in
unique_asym_ids
:
pop_idx
=
unique_asym_ids
.
index
(
0
)
padding_asym_id
=
unique_asym_ids
.
pop
(
pop_idx
)
padding_asym_counts
=
asym_id_counts
.
pop
(
pop_idx
)
unique_asym_ids
.
append
(
padding_asym_id
)
asym_id_counts
.
append
(
padding_asym_counts
)
labels
=
list
(
map
(
dict
,
zip
(
*
[[(
k
,
v
)
for
v
in
torch
.
split
(
value
,
asym_id_counts
,
dim
=
split_dim
)]
for
k
,
value
in
batch
.
items
()
if
k
in
REQUIRED_FEATURES
])))
unique_asym_ids
,
asym_id_counts
=
torch
.
unique
(
gt_features
[
"asym_id"
],
sorted
=
Tru
e
,
return_counts
=
True
)
n_res
=
gt_features
[
"asym_id"
].
shape
[
-
1
]
def
split_dim
(
shape
):
return
next
(
iter
(
i
for
i
,
size
in
enumerate
(
shape
)
if
size
==
n_res
),
None
)
labels
=
list
(
map
(
dict
,
zip
(
*
[[(
k
,
v
)
for
v
in
torch
.
split
(
v_all
,
asym_id_counts
.
tolist
(),
dim
=
split_dim
(
v_all
.
shape
))]
for
k
,
v_all
in
gt_features
.
items
()
if
n_res
in
v_all
.
shape
])))
return
labels
@
staticmethod
def
get_per_asym_residue_index
(
features
):
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
features
[
"asym_id"
])
if
i
!=
0
]
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
features
[
"asym_id"
])
if
i
!=
0
]
per_asym_residue_index
=
{}
for
cur_asym_id
in
unique_asym_ids
:
asym_mask
=
(
features
[
"asym_id"
]
==
cur_asym_id
).
bool
()
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
torch
.
masked_select
(
features
[
"residue_index"
],
asym_mask
)
return
per_asym_residue_index
@
staticmethod
...
...
@@ -2083,10 +2071,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
return
entity_2_asym_list
@
staticmethod
def
calculate_input_mask
(
true_ca_masks
,
anchor_gt_idx
,
anchor_gt_residue
,
asym_mask
,
pred_ca_mask
):
def
calculate_input_mask
(
true_ca_masks
,
anchor_gt_idx
,
anchor_gt_residue
,
asym_mask
,
pred_ca_mask
):
"""
Calculate an input mask for downstream optimal transformation computation
...
...
@@ -2099,37 +2087,37 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Returns:
input_mask (Tensor): A boolean mask
"""
pred_ca_mask
=
torch
.
squeeze
(
pred_ca_mask
,
0
)
asym_mask
=
torch
.
squeeze
(
asym_mask
,
0
)
pred_ca_mask
=
torch
.
squeeze
(
pred_ca_mask
,
0
)
asym_mask
=
torch
.
squeeze
(
asym_mask
,
0
)
anchor_pred_mask
=
pred_ca_mask
[
asym_mask
]
anchor_true_mask
=
torch
.
index_select
(
true_ca_masks
[
anchor_gt_idx
],
1
,
anchor_gt_residue
)
anchor_true_mask
=
torch
.
index_select
(
true_ca_masks
[
anchor_gt_idx
],
1
,
anchor_gt_residue
)
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
return
input_mask
@
staticmethod
def
calculate_optimal_transform
(
true_ca_poses
,
anchor_gt_idx
,
anchor_gt_residue
,
true_ca_masks
,
pred_ca_mask
,
anchor_gt_idx
,
anchor_gt_residue
,
true_ca_masks
,
pred_ca_mask
,
asym_mask
,
pred_ca_pos
):
input_mask
=
AlphaFoldMultimerLoss
.
calculate_input_mask
(
true_ca_masks
,
anchor_gt_idx
,
anchor_gt_residue
,
anchor_gt_idx
,
anchor_gt_residue
,
asym_mask
,
pred_ca_mask
)
input_mask
=
torch
.
squeeze
(
input_mask
,
0
)
pred_ca_pos
=
torch
.
squeeze
(
pred_ca_pos
,
0
)
asym_mask
=
torch
.
squeeze
(
asym_mask
,
0
)
anchor_true_pos
=
torch
.
index_select
(
true_ca_poses
[
anchor_gt_idx
],
1
,
anchor_gt_residue
)
input_mask
=
torch
.
squeeze
(
input_mask
,
0
)
pred_ca_pos
=
torch
.
squeeze
(
pred_ca_pos
,
0
)
asym_mask
=
torch
.
squeeze
(
asym_mask
,
0
)
anchor_true_pos
=
torch
.
index_select
(
true_ca_poses
[
anchor_gt_idx
],
1
,
anchor_gt_residue
)
anchor_pred_pos
=
pred_ca_pos
[
asym_mask
]
r
,
x
=
get_optimal_transform
(
anchor_pred_pos
,
torch
.
squeeze
(
anchor_true_pos
,
0
),
anchor_pred_pos
,
torch
.
squeeze
(
anchor_true_pos
,
0
),
mask
=
input_mask
)
)
return
r
,
x
@
staticmethod
def
multi_chain_perm_align
(
out
,
batch
,
permutate_chains
=
False
):
def
multi_chain_perm_align
(
out
,
features
,
ground_truth
):
"""
A class method that first permutate chains in ground truth first
before calculating the loss.
...
...
@@ -2137,71 +2125,68 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
feature
,
ground_truth
=
batch
del
batch
if
permutate_chains
:
best_rmsd
=
float
(
'inf'
)
best_align
=
None
# First select anchors from predicted structures and ground truths
anchor_gt_asym
,
anchor_pred_asym_ids
=
get_least_asym_entity_or_longest_length
(
ground_truth
,
feature
[
'asym_id'
])
entity_2_asym_list
=
AlphaFoldMultimerLoss
.
get_entity_2_asym_list
(
ground_truth
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
assert
isinstance
(
labels
,
list
)
del
ground_truth
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
# Then calculate optimal transform by aligning anchors
ca_idx
=
rc
.
atom_order
[
"CA"
]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
pred_ca_mask
=
out
[
"final_atom_mask"
][...,
ca_idx
].
to
(
dtype
=
pred_ca_pos
.
dtype
)
# [bsz, nres]
true_ca_poses
=
[
l
[
"all_atom_positions"
][...,
ca_idx
,
:]
for
l
in
labels
]
# list([nres, 3])
true_ca_masks
=
[
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
]
# list([nres,])
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_index
(
feature
)
for
candidate_pred_anchor
in
anchor_pred_asym_ids
:
asym_mask
=
(
feature
[
"asym_id"
]
==
candidate_pred_anchor
).
bool
()
anchor_gt_residue
=
per_asym_residue_index
[
int
(
candidate_pred_anchor
)]
r
,
x
=
AlphaFoldMultimerLoss
.
calculate_optimal_transform
(
true_ca_poses
,
anchor_gt_idx
,
anchor_gt_residue
,
true_ca_masks
,
pred_ca_mask
,
asym_mask
,
pred_ca_pos
)
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
align
=
greedy_align
(
feature
,
per_asym_residue_index
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_mask
,
aligned_true_ca_poses
,
true_ca_masks
,
)
merged_labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
=
feature
[
'aatype'
].
shape
[
-
1
])
rmsd
=
compute_rmsd
(
true_atom_pos
=
merged_labels
[
'all_atom_positions'
][...,
ca_idx
,
:].
to
(
r
.
dtype
)
@
r
+
x
,
pred_atom_pos
=
pred_ca_pos
,
atom_mask
=
(
pred_ca_mask
*
merged_labels
[
'all_atom_mask'
][...,
ca_idx
].
long
()).
bool
())
if
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
best_align
=
align
del
r
,
x
del
true_ca_masks
,
aligned_true_ca_poses
del
pred_ca_pos
,
pred_ca_mask
gc
.
collect
()
else
:
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_index
(
feature
)
unique_asym_ids
=
set
(
torch
.
unique
(
features
[
'asym_id'
]).
tolist
())
unique_asym_ids
.
discard
(
0
)
# Remove padding asym_id
is_monomer
=
len
(
unique_asym_ids
)
==
1
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_index
(
features
)
if
is_monomer
:
best_align
=
list
(
enumerate
(
range
(
len
(
per_asym_residue_index
))))
return
best_align
,
per_asym_residue_index
best_rmsd
=
float
(
'inf'
)
best_align
=
None
# First select anchors from predicted structures and ground truths
anchor_gt_asym
,
anchor_pred_asym_ids
=
get_least_asym_entity_or_longest_length
(
ground_truth
,
features
[
'asym_id'
])
entity_2_asym_list
=
AlphaFoldMultimerLoss
.
get_entity_2_asym_list
(
ground_truth
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
)
assert
isinstance
(
labels
,
list
)
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
# Then calculate optimal transform by aligning anchors
ca_idx
=
rc
.
atom_order
[
"CA"
]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
pred_ca_mask
=
out
[
"final_atom_mask"
][...,
ca_idx
].
to
(
dtype
=
pred_ca_pos
.
dtype
)
# [bsz, nres]
true_ca_poses
=
[
l
[
"all_atom_positions"
][...,
ca_idx
,
:]
for
l
in
labels
]
# list([nres, 3])
true_ca_masks
=
[
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
]
# list([nres,])
for
candidate_pred_anchor
in
anchor_pred_asym_ids
:
asym_mask
=
(
features
[
"asym_id"
]
==
candidate_pred_anchor
).
bool
()
anchor_gt_residue
=
per_asym_residue_index
[
candidate_pred_anchor
.
item
()]
r
,
x
=
AlphaFoldMultimerLoss
.
calculate_optimal_transform
(
true_ca_poses
,
anchor_gt_idx
,
anchor_gt_residue
,
true_ca_masks
,
pred_ca_mask
,
asym_mask
,
pred_ca_pos
)
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
align
=
greedy_align
(
features
,
per_asym_residue_index
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_mask
,
aligned_true_ca_poses
,
true_ca_masks
,
)
merged_labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
=
features
[
'aatype'
].
shape
[
-
1
])
rmsd
=
compute_rmsd
(
true_atom_pos
=
merged_labels
[
'all_atom_positions'
][...,
ca_idx
,
:].
to
(
r
.
dtype
)
@
r
+
x
,
pred_atom_pos
=
pred_ca_pos
,
atom_mask
=
(
pred_ca_mask
*
merged_labels
[
'all_atom_mask'
][...,
ca_idx
].
long
()).
bool
())
if
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
best_align
=
align
return
best_align
,
per_asym_residue_index
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
,
permutate_chains
=
True
):
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
):
"""
Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation
...
...
@@ -2210,32 +2195,24 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure
"""
# first check if it is a monomer
features
,
ground_truth
=
batch
del
batch
is_monomer
=
len
(
torch
.
unique
(
features
[
'asym_id'
]))
==
1
or
torch
.
unique
(
features
[
'asym_id'
]).
tolist
()
==
[
0
,
1
]
if
not
is_monomer
:
permutate_chains
=
True
# Then permutate ground truth chains before calculating the loss
align
,
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
(
features
,
ground_truth
),
permutate_chains
=
permutate_chains
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
REQUIRED_FEATURES
=
[
i
for
i
in
ground_truth
.
keys
()])
ground_truth
=
batch
.
pop
(
'gt_features'
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
)
# Then permute ground truth chains before calculating the loss
align
,
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
=
out
,
features
=
batch
,
ground_truth
=
ground_truth
)
# reorder ground truth labels according to permutation results
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
=
features
[
'aatype'
].
shape
[
-
1
])
features
.
update
(
labels
)
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
=
batch
[
'aatype'
].
shape
[
-
1
])
batch
.
update
(
labels
)
if
(
not
_return_breakdown
)
:
cum_loss
=
self
.
loss
(
out
,
features
,
_return_breakdown
)
if
not
_return_breakdown
:
cum_loss
=
self
.
loss
(
out
,
batch
,
_return_breakdown
)
print
(
f
"cum_loss:
{
cum_loss
}
"
)
return
cum_loss
else
:
cum_loss
,
losses
=
self
.
loss
(
out
,
features
,
_return_breakdown
)
cum_loss
,
losses
=
self
.
loss
(
out
,
batch
,
_return_breakdown
)
print
(
f
"cum_loss:
{
cum_loss
}
losses:
{
losses
}
"
)
return
cum_loss
,
losses
\ No newline at end of file
return
cum_loss
,
losses
scripts/generate_mmcif_cache.py
View file @
0cf1541c
...
...
@@ -13,7 +13,7 @@ from tqdm import tqdm
from
openfold.data.mmcif_parsing
import
parse
def
parse_file
(
f
,
args
):
def
parse_file
(
f
,
args
,
chain_cluster_size_dict
=
None
):
with
open
(
os
.
path
.
join
(
args
.
mmcif_dir
,
f
),
"r"
)
as
fp
:
mmcif_string
=
fp
.
read
()
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
...
...
@@ -28,6 +28,18 @@ def parse_file(f, args):
local_data
[
"release_date"
]
=
mmcif
.
header
[
"release_date"
]
chain_ids
,
seqs
=
list
(
zip
(
*
mmcif
.
chain_to_seqres
.
items
()))
if
chain_cluster_size_dict
is
not
None
:
cluster_sizes
=
[]
for
chain_id
in
chain_ids
:
full_name
=
"_"
.
join
([
file_id
,
chain_id
])
cluster_size
=
chain_cluster_size_dict
.
get
(
full_name
.
upper
(),
-
1
)
cluster_sizes
.
append
(
cluster_size
)
local_data
[
"cluster_sizes"
]
=
cluster_sizes
local_data
[
"chain_ids"
]
=
chain_ids
local_data
[
"seqs"
]
=
seqs
local_data
[
"no_chains"
]
=
len
(
chain_ids
)
...
...
@@ -38,8 +50,21 @@ def parse_file(f, args):
def
main
(
args
):
chain_cluster_size_dict
=
None
if
args
.
cluster_file
is
not
None
:
chain_cluster_size_dict
=
{}
with
open
(
args
.
cluster_file
,
"r"
)
as
fp
:
clusters
=
[
l
.
strip
()
for
l
in
fp
.
readlines
()]
for
cluster
in
clusters
:
chain_ids
=
cluster
.
split
()
cluster_len
=
len
(
chain_ids
)
for
chain_id
in
chain_ids
:
chain_id
=
chain_id
.
upper
()
chain_cluster_size_dict
[
chain_id
]
=
cluster_len
files
=
[
f
for
f
in
os
.
listdir
(
args
.
mmcif_dir
)
if
".cif"
in
f
]
fn
=
partial
(
parse_file
,
args
=
args
)
fn
=
partial
(
parse_file
,
args
=
args
,
chain_cluster_size_dict
=
chain_cluster_size_dict
)
data
=
{}
with
Pool
(
processes
=
args
.
no_workers
)
as
p
:
with
tqdm
(
total
=
len
(
files
))
as
pbar
:
...
...
@@ -63,6 +88,15 @@ if __name__ == "__main__":
"--no_workers"
,
type
=
int
,
default
=
4
,
help
=
"Number of workers to use for parsing"
)
parser
.
add_argument
(
"--cluster_file"
,
type
=
str
,
default
=
None
,
help
=
(
"Path to a cluster file (e.g. PDB40), one cluster "
"({PROT1_ID}_{CHAIN_ID} {PROT2_ID}_{CHAIN_ID} ...) per line. "
"Chains not in this cluster file will NOT be filtered by cluster "
"size."
)
)
parser
.
add_argument
(
"--chunksize"
,
type
=
int
,
default
=
10
,
help
=
"How many files should be distributed to each worker at a time"
...
...
train_openfold.py
View file @
0cf1541c
import
argparse
import
logging
import
os
import
random
import
sys
import
time
import
numpy
as
np
import
pytorch_lightning
as
pl
from
pytorch_lightning.callbacks.lr_monitor
import
LearningRateMonitor
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.loggers
import
WandbLogger
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
,
DDPPlugin
from
pytorch_lightning.plugins.environments
import
SLURMEnvironment
import
torch
from
openfold.config
import
model_config
from
openfold.data.data_modules
import
(
OpenFoldDataModule
,
OpenFoldMultimerDataModule
,
DummyDataLoader
,
)
from
openfold.data.data_modules
import
OpenFoldDataModule
,
OpenFoldMultimerDataModule
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_preset_
from
openfold.np
import
residue_constants
...
...
@@ -53,7 +46,12 @@ class OpenFoldWrapper(pl.LightningModule):
super
(
OpenFoldWrapper
,
self
).
__init__
()
self
.
config
=
config
self
.
model
=
AlphaFold
(
config
)
self
.
loss
=
AlphaFoldLoss
(
config
.
loss
)
if
self
.
config
.
globals
.
is_multimer
:
self
.
loss
=
AlphaFoldMultimerLoss
(
config
.
loss
)
else
:
self
.
loss
=
AlphaFoldLoss
(
config
.
loss
)
self
.
ema
=
ExponentialMovingAverage
(
model
=
self
.
model
,
decay
=
config
.
ema
.
decay
)
...
...
@@ -256,72 +254,6 @@ class OpenFoldWrapper(pl.LightningModule):
)
class
OpenFoldMultimerWrapper
(
OpenFoldWrapper
):
def
__init__
(
self
,
config
):
super
(
OpenFoldMultimerWrapper
,
self
).
__init__
(
config
)
self
.
config
=
config
self
.
model
=
AlphaFold
(
config
)
self
.
loss
=
AlphaFoldMultimerLoss
(
config
.
loss
)
self
.
ema
=
ExponentialMovingAverage
(
model
=
self
.
model
,
decay
=
config
.
ema
.
decay
)
self
.
cached_weights
=
None
self
.
last_lr_step
=
-
1
def
forward
(
self
,
batch
):
return
self
.
model
(
batch
)
def
training_step
(
self
,
batch
,
batch_idx
):
features
,
gt_features
=
batch
# Log it
if
(
self
.
ema
.
device
!=
features
[
"aatype"
].
device
):
self
.
ema
.
to
(
features
[
"aatype"
].
device
)
# Run the model
outputs
=
self
(
features
)
# Remove the recycling dimension
features
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
features
)
# Compute loss
loss
,
loss_breakdown
=
self
.
loss
(
outputs
,
(
features
,
gt_features
),
_return_breakdown
=
True
)
# Log it
self
.
_log
(
loss_breakdown
,
features
,
outputs
)
return
loss
def
validation_step
(
self
,
batch
,
batch_idx
):
features
,
gt_features
=
batch
# At the start of validation, load the EMA weights
if
(
self
.
cached_weights
is
None
):
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param
=
lambda
t
:
t
.
detach
().
clone
()
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
# Run the model
outputs
=
self
(
features
)
# Compute loss and other metrics
features
[
"use_clamped_fape"
]
=
0.
_
,
loss_breakdown
=
self
.
loss
(
outputs
,
(
features
,
gt_features
),
_return_breakdown
=
True
)
self
.
_log
(
loss_breakdown
,
features
,
outputs
,
train
=
False
)
def
validation_epoch_end
(
self
,
_
):
# Restore the model weights to normal
self
.
model
.
load_state_dict
(
self
.
cached_weights
)
self
.
cached_weights
=
None
def
main
(
args
):
if
(
args
.
seed
is
not
None
):
seed_everything
(
args
.
seed
)
...
...
@@ -331,10 +263,8 @@ def main(args):
train
=
True
,
low_prec
=
(
str
(
args
.
precision
)
==
"16"
)
)
if
"multimer"
in
args
.
config_preset
:
model_module
=
OpenFoldMultimerWrapper
(
config
)
else
:
model_module
=
OpenFoldWrapper
(
config
)
model_module
=
OpenFoldWrapper
(
config
)
if
(
args
.
resume_from_ckpt
):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
last_global_step
=
get_global_step_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
...
...
@@ -359,7 +289,6 @@ def main(args):
if
(
args
.
script_modules
):
script_preset_
(
model_module
)
#data_module = DummyDataLoader("new_batch.pickle")
if
"multimer"
in
args
.
config_preset
:
data_module
=
OpenFoldMultimerDataModule
(
config
=
config
.
data
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment