Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
9d4c9357
Commit
9d4c9357
authored
Oct 29, 2021
by
Gustaf Ahdritz
Browse files
Speed up template featurizer
parent
e69b2a11
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
131 additions
and
40 deletions
+131
-40
openfold/config.py
openfold/config.py
+2
-1
openfold/data/data_modules.py
openfold/data/data_modules.py
+28
-9
openfold/data/templates.py
openfold/data/templates.py
+74
-19
openfold/model/template.py
openfold/model/template.py
+1
-1
openfold/utils/loss.py
openfold/utils/loss.py
+25
-3
run_pretrained_openfold.py
run_pretrained_openfold.py
+1
-4
scripts/utils.py
scripts/utils.py
+0
-3
No files found.
openfold/config.py
View file @
9d4c9357
...
@@ -211,8 +211,9 @@ config = mlc.ConfigDict(
...
@@ -211,8 +211,9 @@ config = mlc.ConfigDict(
"subsample_templates"
:
True
,
"subsample_templates"
:
True
,
"masked_msa_replace_fraction"
:
0.15
,
"masked_msa_replace_fraction"
:
0.15
,
"max_msa_clusters"
:
128
,
"max_msa_clusters"
:
128
,
"max_template_hits"
:
20
,
"max_template_hits"
:
4
,
"max_templates"
:
4
,
"max_templates"
:
4
,
"shuffle_top_k_prefiltered"
:
20
,
"crop"
:
True
,
"crop"
:
True
,
"crop_size"
:
256
,
"crop_size"
:
256
,
"supervised"
:
True
,
"supervised"
:
True
,
...
...
openfold/data/data_modules.py
View file @
9d4c9357
...
@@ -32,9 +32,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -32,9 +32,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mapping_path
:
Optional
[
str
]
=
None
,
mapping_path
:
Optional
[
str
]
=
None
,
max_template_hits
:
int
=
4
,
max_template_hits
:
int
=
4
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
use_small_bfd
:
bool
=
True
,
shuffle_top_k_prefiltered
:
Optional
[
int
]
=
None
,
output_raw
:
bool
=
False
,
mode
:
str
=
"train"
,
mode
:
str
=
"train"
,
_output_raw
:
bool
=
False
,
):
):
"""
"""
Args:
Args:
...
@@ -48,21 +48,38 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -48,21 +48,38 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
files.
files.
template_mmcif_dir:
Path to a directory containing template mmCIF files.
config:
config:
A dataset config object. See openfold.config
A dataset config object. See openfold.config
kalign_binary_path:
Path to kalign binary.
mapping_path:
mapping_path:
A json file containing a mapping from consecutive numerical
A json file containing a mapping from consecutive numerical
ids to sample names (matching the directories in data_dir).
ids to sample names (matching the directories in data_dir).
Samples not in this mapping are ignored. Can be used to
Samples not in this mapping are ignored. Can be used to
implement the various training-time filters described in
implement the various training-time filters described in
the AlphaFold supplement
the AlphaFold supplement.
max_template_hits:
An upper bound on how many templates are considered. During
training, the templates ultimately used are subsampled
from this total quantity.
template_release_dates_cache_path:
Path to the output of scripts/generate_mmcif_cache.
shuffle_top_k_prefiltered:
Whether to uniformly shuffle the top k template hits before
parsing max_template_hits of them. Can be used to
approximate DeepMind's training-time template subsampling
scheme much more performantly.
mode:
"train", "val", or "predict"
"""
"""
super
(
OpenFoldSingleDataset
,
self
).
__init__
()
super
(
OpenFoldSingleDataset
,
self
).
__init__
()
self
.
data_dir
=
data_dir
self
.
data_dir
=
data_dir
self
.
alignment_dir
=
alignment_dir
self
.
alignment_dir
=
alignment_dir
self
.
config
=
config
self
.
config
=
config
self
.
output_raw
=
output_raw
self
.
mode
=
mode
self
.
mode
=
mode
self
.
_output_raw
=
_output_raw
valid_modes
=
[
"train"
,
"val"
,
"predict"
]
valid_modes
=
[
"train"
,
"val"
,
"predict"
]
if
(
mode
not
in
valid_modes
):
if
(
mode
not
in
valid_modes
):
...
@@ -90,13 +107,14 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -90,13 +107,14 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
kalign_binary_path
=
kalign_binary_path
,
kalign_binary_path
=
kalign_binary_path
,
release_dates_path
=
template_release_dates_cache_path
,
release_dates_path
=
template_release_dates_cache_path
,
obsolete_pdbs_path
=
None
,
obsolete_pdbs_path
=
None
,
_shuffle_top_k_prefiltered
=
shuffle_top_k_prefiltered
,
)
)
self
.
data_pipeline
=
data_pipeline
.
DataPipeline
(
self
.
data_pipeline
=
data_pipeline
.
DataPipeline
(
template_featurizer
=
template_featurizer
,
template_featurizer
=
template_featurizer
,
)
)
if
(
not
self
.
output_raw
):
if
(
not
self
.
_
output_raw
):
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
def
_parse_mmcif
(
self
,
path
,
file_id
,
chain_id
,
alignment_dir
):
def
_parse_mmcif
(
self
,
path
,
file_id
,
chain_id
,
alignment_dir
):
...
@@ -153,7 +171,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -153,7 +171,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_dir
=
alignment_dir
,
alignment_dir
=
alignment_dir
,
)
)
if
(
self
.
output_raw
):
if
(
self
.
_
output_raw
):
return
data
return
data
feats
=
self
.
feature_pipeline
.
process_features
(
feats
=
self
.
feature_pipeline
.
process_features
(
...
@@ -357,7 +375,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -357,7 +375,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
kalign_binary_path
=
self
.
kalign_binary_path
,
kalign_binary_path
=
self
.
kalign_binary_path
,
template_release_dates_cache_path
=
template_release_dates_cache_path
=
self
.
template_release_dates_cache_path
,
self
.
template_release_dates_cache_path
,
use_small_bfd
=
self
.
config
.
data_module
.
use_small_bfd
,
)
)
if
(
self
.
training_mode
):
if
(
self
.
training_mode
):
...
@@ -366,8 +383,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -366,8 +383,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
alignment_dir
=
self
.
train_alignment_dir
,
alignment_dir
=
self
.
train_alignment_dir
,
mapping_path
=
self
.
train_mapping_path
,
mapping_path
=
self
.
train_mapping_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
output_raw
=
True
,
shuffle_top_k_prefiltered
=
self
.
config
.
train
.
shuffle_top_k_prefiltered
,
mode
=
"train"
,
mode
=
"train"
,
_output_raw
=
True
,
)
)
if
(
self
.
distillation_data_dir
is
not
None
):
if
(
self
.
distillation_data_dir
is
not
None
):
...
@@ -376,8 +395,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -376,8 +395,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
alignment_dir
=
self
.
distillation_alignment_dir
,
alignment_dir
=
self
.
distillation_alignment_dir
,
mapping_path
=
self
.
distillation_mapping_path
,
mapping_path
=
self
.
distillation_mapping_path
,
max_template_hits
=
self
.
train
.
max_template_hits
,
max_template_hits
=
self
.
train
.
max_template_hits
,
output_raw
=
True
,
mode
=
"train"
,
mode
=
"train"
,
_output_raw
=
True
,
)
)
d_prob
=
self
.
config
.
train
.
distillation_prob
d_prob
=
self
.
config
.
train
.
distillation_prob
...
...
openfold/data/templates.py
View file @
9d4c9357
...
@@ -123,14 +123,15 @@ def _is_after_cutoff(
...
@@ -123,14 +123,15 @@ def _is_after_cutoff(
Returns:
Returns:
True if the template release date is after the cutoff, False otherwise.
True if the template release date is after the cutoff, False otherwise.
"""
"""
pdb_id_upper
=
pdb_id
.
upper
()
if
release_date_cutoff
is
None
:
if
release_date_cutoff
is
None
:
raise
ValueError
(
"The release_date_cutoff must not be None."
)
raise
ValueError
(
"The release_date_cutoff must not be None."
)
if
pdb_id
in
release_dates
:
if
pdb_id
_upper
in
release_dates
:
return
release_dates
[
pdb_id
]
>
release_date_cutoff
return
release_dates
[
pdb_id
_upper
]
>
release_date_cutoff
else
:
else
:
# Since this is just a quick prefilter to reduce the number of mmCIF files
# Since this is just a quick prefilter to reduce the number of mmCIF files
# we need to parse, we don't have to worry about returning True here.
# we need to parse, we don't have to worry about returning True here.
logging
.
info
(
logging
.
warning
(
"Template structure not in release dates dict: %s"
,
pdb_id
"Template structure not in release dates dict: %s"
,
pdb_id
)
)
return
False
return
False
...
@@ -183,7 +184,7 @@ def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
...
@@ -183,7 +184,7 @@ def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
data
=
json
.
load
(
fp
)
data
=
json
.
load
(
fp
)
return
{
return
{
pdb
:
to_date
(
v
)
pdb
.
upper
()
:
to_date
(
v
)
for
pdb
,
d
in
data
.
items
()
for
pdb
,
d
in
data
.
items
()
for
k
,
v
in
d
.
items
()
for
k
,
v
in
d
.
items
()
if
k
==
"release_date"
if
k
==
"release_date"
...
@@ -239,8 +240,9 @@ def _assess_hhsearch_hit(
...
@@ -239,8 +240,9 @@ def _assess_hhsearch_hit(
)
)
if
_is_after_cutoff
(
hit_pdb_code
,
release_dates
,
release_date_cutoff
):
if
_is_after_cutoff
(
hit_pdb_code
,
release_dates
,
release_date_cutoff
):
date
=
release_dates
[
hit_pdb_code
.
upper
()]
raise
DateError
(
raise
DateError
(
f
"Date (
{
release_dates
[
hit_pdb_code
]
}
) > max template date "
f
"Date (
{
date
}
) > max template date "
f
"(
{
release_date_cutoff
}
)."
f
"(
{
release_date_cutoff
}
)."
)
)
...
@@ -735,6 +737,12 @@ def _build_query_to_hit_index_mapping(
...
@@ -735,6 +737,12 @@ def _build_query_to_hit_index_mapping(
return
mapping
return
mapping
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
PrefilterResult
:
valid
:
bool
error
:
Optional
[
str
]
warning
:
Optional
[
str
]
@
dataclasses
.
dataclass
(
frozen
=
True
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
SingleHitResult
:
class
SingleHitResult
:
features
:
Optional
[
Mapping
[
str
,
Any
]]
features
:
Optional
[
Mapping
[
str
,
Any
]]
...
@@ -742,18 +750,15 @@ class SingleHitResult:
...
@@ -742,18 +750,15 @@ class SingleHitResult:
warning
:
Optional
[
str
]
warning
:
Optional
[
str
]
def
_pr
ocess_single
_hit
(
def
_pr
efilter
_hit
(
query_sequence
:
str
,
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
query_pdb_code
:
Optional
[
str
],
hit
:
parsers
.
TemplateHit
,
hit
:
parsers
.
TemplateHit
,
mmcif_dir
:
str
,
max_template_date
:
datetime
.
datetime
,
max_template_date
:
datetime
.
datetime
,
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
obsolete_pdbs
:
Mapping
[
str
,
str
],
obsolete_pdbs
:
Mapping
[
str
,
str
],
kalign_binary_path
:
str
,
strict_error_check
:
bool
=
False
,
strict_error_check
:
bool
=
False
,
)
->
SingleHitResult
:
):
"""Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit.
# Fail hard if we can't get the PDB ID and chain name from the hit.
hit_pdb_code
,
hit_chain_id
=
_get_pdb_id_and_chain
(
hit
)
hit_pdb_code
,
hit_chain_id
=
_get_pdb_id_and_chain
(
hit
)
...
@@ -761,7 +766,8 @@ def _process_single_hit(
...
@@ -761,7 +766,8 @@ def _process_single_hit(
if
hit_pdb_code
in
obsolete_pdbs
:
if
hit_pdb_code
in
obsolete_pdbs
:
hit_pdb_code
=
obsolete_pdbs
[
hit_pdb_code
]
hit_pdb_code
=
obsolete_pdbs
[
hit_pdb_code
]
# Pass hit_pdb_code since it might have changed due to the pdb being obsolete.
# Pass hit_pdb_code since it might have changed due to the pdb being
# obsolete.
try
:
try
:
_assess_hhsearch_hit
(
_assess_hhsearch_hit
(
hit
=
hit
,
hit
=
hit
,
...
@@ -772,15 +778,32 @@ def _process_single_hit(
...
@@ -772,15 +778,32 @@ def _process_single_hit(
release_date_cutoff
=
max_template_date
,
release_date_cutoff
=
max_template_date
,
)
)
except
PrefilterError
as
e
:
except
PrefilterError
as
e
:
msg
=
f
"hit
{
hit_pdb_code
}
_
{
hit_chain_id
}
did not pass prefilter:
{
str
(
e
)
}
"
hit_name
=
f
"
{
hit_pdb_code
}
_
{
hit_chain_id
}
"
msg
=
f
"hit
{
hit_name
}
did not pass prefilter:
{
str
(
e
)
}
"
logging
.
info
(
"%s: %s"
,
query_pdb_code
,
msg
)
logging
.
info
(
"%s: %s"
,
query_pdb_code
,
msg
)
if
strict_error_check
and
isinstance
(
if
strict_error_check
and
isinstance
(
e
,
(
DateError
,
PdbIdError
,
DuplicateError
)
e
,
(
DateError
,
PdbIdError
,
DuplicateError
)
):
):
# In strict mode we treat some prefilter cases as errors.
# In strict mode we treat some prefilter cases as errors.
return
SingleHitResult
(
features
=
Non
e
,
error
=
msg
,
warning
=
None
)
return
PrefilterResult
(
valid
=
Fals
e
,
error
=
msg
,
warning
=
None
)
return
SingleHitResult
(
features
=
None
,
error
=
None
,
warning
=
None
)
return
PrefilterResult
(
valid
=
False
,
error
=
None
,
warning
=
None
)
return
PrefilterResult
(
valid
=
True
,
error
=
None
,
warning
=
None
)
def
_process_single_hit
(
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
hit
:
parsers
.
TemplateHit
,
mmcif_dir
:
str
,
max_template_date
:
datetime
.
datetime
,
kalign_binary_path
:
str
,
strict_error_check
:
bool
=
False
,
)
->
SingleHitResult
:
"""Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit.
hit_pdb_code
,
hit_chain_id
=
_get_pdb_id_and_chain
(
hit
)
mapping
=
_build_query_to_hit_index_mapping
(
mapping
=
_build_query_to_hit_index_mapping
(
hit
.
query
,
hit
.
query
,
...
@@ -901,6 +924,7 @@ class TemplateHitFeaturizer:
...
@@ -901,6 +924,7 @@ class TemplateHitFeaturizer:
release_dates_path
:
Optional
[
str
],
release_dates_path
:
Optional
[
str
],
obsolete_pdbs_path
:
Optional
[
str
],
obsolete_pdbs_path
:
Optional
[
str
],
strict_error_check
:
bool
=
False
,
strict_error_check
:
bool
=
False
,
_shuffle_top_k_prefiltered
:
Optional
[
int
]
=
None
,
):
):
"""Initializes the Template Search.
"""Initializes the Template Search.
...
@@ -938,7 +962,7 @@ class TemplateHitFeaturizer:
...
@@ -938,7 +962,7 @@ class TemplateHitFeaturizer:
raise
ValueError
(
raise
ValueError
(
"max_template_date must be set and have format YYYY-MM-DD."
"max_template_date must be set and have format YYYY-MM-DD."
)
)
self
.
_
max_hits
=
max_hits
self
.
max_hits
=
max_hits
self
.
_kalign_binary_path
=
kalign_binary_path
self
.
_kalign_binary_path
=
kalign_binary_path
self
.
_strict_error_check
=
strict_error_check
self
.
_strict_error_check
=
strict_error_check
...
@@ -958,6 +982,8 @@ class TemplateHitFeaturizer:
...
@@ -958,6 +982,8 @@ class TemplateHitFeaturizer:
else
:
else
:
self
.
_obsolete_pdbs
=
{}
self
.
_obsolete_pdbs
=
{}
self
.
_shuffle_top_k_prefiltered
=
_shuffle_top_k_prefiltered
def
get_templates
(
def
get_templates
(
self
,
self
,
query_sequence
:
str
,
query_sequence
:
str
,
...
@@ -986,19 +1012,48 @@ class TemplateHitFeaturizer:
...
@@ -986,19 +1012,48 @@ class TemplateHitFeaturizer:
errors
=
[]
errors
=
[]
warnings
=
[]
warnings
=
[]
for
hit
in
sorted
(
hits
,
key
=
lambda
x
:
x
.
sum_probs
,
reverse
=
True
):
filtered
=
[]
for
hit
in
hits
:
prefilter_result
=
_prefilter_hit
(
query_sequence
=
query_sequence
,
query_pdb_code
=
query_pdb_code
,
hit
=
hit
,
max_template_date
=
template_cutoff_date
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
)
if
prefilter_result
.
error
:
errors
.
append
(
prefilter_result
.
error
)
if
prefilter_result
.
warning
:
warnings
.
append
(
prefilter_result
.
warning
)
if
prefilter_result
.
valid
:
filtered
.
append
(
hit
)
filtered
=
list
(
sorted
(
filtered
,
key
=
lambda
x
:
x
.
sum_probs
,
reverse
=
True
)
)
idx
=
list
(
range
(
len
(
filtered
)))
if
(
self
.
_shuffle_top_k_prefiltered
):
stk
=
self
.
_shuffle_top_k_prefiltered
idx
[:
stk
]
=
np
.
random
.
permutation
(
idx
[:
stk
])
for
i
in
idx
:
# We got all the templates we wanted, stop processing hits.
# We got all the templates we wanted, stop processing hits.
if
num_hits
>=
self
.
_
max_hits
:
if
num_hits
>=
self
.
max_hits
:
break
break
hit
=
filtered
[
i
]
result
=
_process_single_hit
(
result
=
_process_single_hit
(
query_sequence
=
query_sequence
,
query_sequence
=
query_sequence
,
query_pdb_code
=
query_pdb_code
,
query_pdb_code
=
query_pdb_code
,
hit
=
hit
,
hit
=
hit
,
mmcif_dir
=
self
.
_mmcif_dir
,
mmcif_dir
=
self
.
_mmcif_dir
,
max_template_date
=
template_cutoff_date
,
max_template_date
=
template_cutoff_date
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
strict_error_check
=
self
.
_strict_error_check
,
kalign_binary_path
=
self
.
_kalign_binary_path
,
kalign_binary_path
=
self
.
_kalign_binary_path
,
)
)
...
...
openfold/model/template.py
View file @
9d4c9357
...
@@ -259,7 +259,7 @@ class TemplatePairStack(nn.Module):
...
@@ -259,7 +259,7 @@ class TemplatePairStack(nn.Module):
self
.
blocks_per_ckpt
=
blocks_per_ckpt
self
.
blocks_per_ckpt
=
blocks_per_ckpt
self
.
blocks
=
nn
.
ModuleList
()
self
.
blocks
=
nn
.
ModuleList
()
for
i
in
range
(
no_blocks
):
for
_
in
range
(
no_blocks
):
block
=
TemplatePairStackBlock
(
block
=
TemplatePairStackBlock
(
c_t
=
c_t
,
c_t
=
c_t
,
c_hidden_tri_att
=
c_hidden_tri_att
,
c_hidden_tri_att
=
c_hidden_tri_att
,
...
...
openfold/utils/loss.py
View file @
9d4c9357
...
@@ -90,6 +90,7 @@ def compute_fape(
...
@@ -90,6 +90,7 @@ def compute_fape(
local_target_pos
=
target_frames
.
invert
()[...,
None
].
apply
(
local_target_pos
=
target_frames
.
invert
()[...,
None
].
apply
(
target_positions
[...,
None
,
:,
:],
target_positions
[...,
None
,
:,
:],
)
)
error_dist
=
torch
.
sqrt
(
error_dist
=
torch
.
sqrt
(
torch
.
sum
((
local_pred_pos
-
local_target_pos
)
**
2
,
dim
=-
1
)
+
eps
torch
.
sum
((
local_pred_pos
-
local_target_pos
)
**
2
,
dim
=-
1
)
+
eps
)
)
...
@@ -161,7 +162,9 @@ def backbone_loss(
...
@@ -161,7 +162,9 @@ def backbone_loss(
1
-
use_clamped_fape
1
-
use_clamped_fape
)
)
# Average over the batch dimension
fape_loss
=
torch
.
mean
(
fape_loss
)
fape_loss
=
torch
.
mean
(
fape_loss
)
return
fape_loss
return
fape_loss
...
@@ -231,7 +234,12 @@ def fape_loss(
...
@@ -231,7 +234,12 @@ def fape_loss(
**
{
**
batch
,
**
config
.
sidechain
},
**
{
**
batch
,
**
config
.
sidechain
},
)
)
return
config
.
backbone
.
weight
*
bb_loss
+
config
.
sidechain
.
weight
*
sc_loss
loss
=
config
.
backbone
.
weight
*
bb_loss
+
config
.
sidechain
.
weight
*
sc_loss
# Average over the batch dimension
loss
=
torch
.
mean
(
loss
)
return
loss
def
supervised_chi_loss
(
def
supervised_chi_loss
(
...
@@ -290,6 +298,9 @@ def supervised_chi_loss(
...
@@ -290,6 +298,9 @@ def supervised_chi_loss(
loss
=
loss
+
angle_norm_weight
*
angle_norm_loss
loss
=
loss
+
angle_norm_weight
*
angle_norm_loss
# Average over the batch dimension
loss
=
torch
.
mean
(
loss
)
return
loss
return
loss
...
@@ -388,6 +399,9 @@ def lddt_loss(
...
@@ -388,6 +399,9 @@ def lddt_loss(
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
)
)
# Average over the batch dimension
loss
=
torch
.
mean
(
loss
)
return
loss
return
loss
...
@@ -433,6 +447,9 @@ def distogram_loss(
...
@@ -433,6 +447,9 @@ def distogram_loss(
mean
=
mean
/
denom
[...,
None
]
mean
=
mean
/
denom
[...,
None
]
mean
=
torch
.
sum
(
mean
,
dim
=-
1
)
mean
=
torch
.
sum
(
mean
,
dim
=-
1
)
# Average over the batch dimensions
mean
=
torch
.
mean
(
mean
)
return
mean
return
mean
...
@@ -580,6 +597,9 @@ def tm_loss(
...
@@ -580,6 +597,9 @@ def tm_loss(
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
)
)
# Average over the loss dimension
loss
=
torch
.
mean
(
loss
)
return
loss
return
loss
...
@@ -1351,6 +1371,8 @@ def experimentally_resolved_loss(
...
@@ -1351,6 +1371,8 @@ def experimentally_resolved_loss(
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
)
)
loss
=
torch
.
mean
(
loss
)
return
loss
return
loss
...
@@ -1469,8 +1491,8 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1469,8 +1491,8 @@ class AlphaFoldLoss(nn.Module):
}
}
cum_loss
=
0
cum_loss
=
0
for
k
,
loss_fn
in
loss_fns
.
items
():
for
loss_name
,
loss_fn
in
loss_fns
.
items
():
weight
=
self
.
config
[
k
].
weight
weight
=
self
.
config
[
loss_name
].
weight
if
weight
:
if
weight
:
loss
=
loss_fn
()
loss
=
loss_fn
()
cum_loss
=
cum_loss
+
weight
*
loss
cum_loss
=
cum_loss
+
weight
*
loss
...
...
run_pretrained_openfold.py
View file @
9d4c9357
...
@@ -50,12 +50,10 @@ def main(args):
...
@@ -50,12 +50,10 @@ def main(args):
model
=
model
.
to
(
args
.
model_device
)
model
=
model
.
to
(
args
.
model_device
)
# FEATURE COLLECTION AND PROCESSING
# FEATURE COLLECTION AND PROCESSING
num_ensemble
=
1
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
max_template_date
=
args
.
max_template_date
,
max_hits
=
args
.
max_template
_hit
s
,
max_hits
=
config
.
data
.
predict
.
max_templates
,
kalign_binary_path
=
args
.
kalign_binary_path
,
kalign_binary_path
=
args
.
kalign_binary_path
,
release_dates_path
=
None
,
release_dates_path
=
None
,
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
...
@@ -85,7 +83,6 @@ def main(args):
...
@@ -85,7 +83,6 @@ def main(args):
random_seed
=
args
.
data_random_seed
random_seed
=
args
.
data_random_seed
if
random_seed
is
None
:
if
random_seed
is
None
:
random_seed
=
random
.
randrange
(
sys
.
maxsize
)
random_seed
=
random
.
randrange
(
sys
.
maxsize
)
config
.
data
.
predict
.
num_ensemble
=
num_ensemble
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
)
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
)
if
not
os
.
path
.
exists
(
output_dir_base
):
if
not
os
.
path
.
exists
(
output_dir_base
):
os
.
makedirs
(
output_dir_base
)
os
.
makedirs
(
output_dir_base
)
...
...
scripts/utils.py
View file @
9d4c9357
...
@@ -40,9 +40,6 @@ def add_data_args(parser: argparse.ArgumentParser):
...
@@ -40,9 +40,6 @@ def add_data_args(parser: argparse.ArgumentParser):
'--max_template_date'
,
type
=
str
,
'--max_template_date'
,
type
=
str
,
default
=
date
.
today
().
strftime
(
"%Y-%m-%d"
),
default
=
date
.
today
().
strftime
(
"%Y-%m-%d"
),
)
)
parser
.
add_argument
(
'--max_template_hits'
,
type
=
int
,
default
=
20
,
)
parser
.
add_argument
(
parser
.
add_argument
(
'--obsolete_pdbs_path'
,
type
=
str
,
default
=
None
'--obsolete_pdbs_path'
,
type
=
str
,
default
=
None
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment