Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
6ce8cfe3
"...git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "aa753ec0e941ddb117654810b7e6c16f2efec2f9"
Commit
6ce8cfe3
authored
Jan 18, 2022
by
Gustaf Ahdritz
Browse files
Fixes
parent
1df4991d
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
382 additions
and
236 deletions
+382
-236
openfold/data/data_modules.py
openfold/data/data_modules.py
+208
-74
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+76
-27
openfold/model/evoformer.py
openfold/model/evoformer.py
+2
-1
openfold/model/primitives.py
openfold/model/primitives.py
+2
-2
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+1
-2
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+14
-8
openfold/utils/loss.py
openfold/utils/loss.py
+0
-1
scripts/generate_mmcif_cache.py
scripts/generate_mmcif_cache.py
+0
-66
scripts/precompute_alignments.py
scripts/precompute_alignments.py
+29
-28
scripts/utils.py
scripts/utils.py
+5
-5
tests/test_evoformer.py
tests/test_evoformer.py
+9
-6
tests/test_model.py
tests/test_model.py
+1
-0
tests/test_msa.py
tests/test_msa.py
+7
-10
tests/test_triangular_attention.py
tests/test_triangular_attention.py
+3
-3
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+2
-2
train_openfold.py
train_openfold.py
+23
-1
No files found.
openfold/data/data_modules.py
View file @
6ce8cfe3
...
@@ -4,7 +4,7 @@ import json
...
@@ -4,7 +4,7 @@ import json
import
logging
import
logging
import
os
import
os
import
pickle
import
pickle
from
typing
import
Optional
,
Sequence
from
typing
import
Optional
,
Sequence
,
List
,
Any
import
ml_collections
as
mlc
import
ml_collections
as
mlc
import
numpy
as
np
import
numpy
as
np
...
@@ -29,14 +29,15 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -29,14 +29,15 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
max_template_date
:
str
,
max_template_date
:
str
,
config
:
mlc
.
ConfigDict
,
config
:
mlc
.
ConfigDict
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
mapping_path
:
Optional
[
str
]
=
None
,
max_template_hits
:
int
=
4
,
max_template_hits
:
int
=
4
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
shuffle_top_k_prefiltered
:
Optional
[
int
]
=
None
,
shuffle_top_k_prefiltered
:
Optional
[
int
]
=
None
,
treat_pdb_as_distillation
:
bool
=
True
,
treat_pdb_as_distillation
:
bool
=
True
,
mapping_path
:
Optional
[
str
]
=
None
,
mode
:
str
=
"train"
,
mode
:
str
=
"train"
,
_output_raw
:
bool
=
False
,
_output_raw
:
bool
=
False
,
_alignment_index
:
Optional
[
Any
]
=
None
):
):
"""
"""
Args:
Args:
...
@@ -56,12 +57,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -56,12 +57,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
A dataset config object. See openfold.config
A dataset config object. See openfold.config
kalign_binary_path:
kalign_binary_path:
Path to kalign binary.
Path to kalign binary.
mapping_path:
A json file containing a mapping from consecutive numerical
ids to sample names (matching the directories in data_dir).
Samples not in this mapping are ignored. Can be used to
implement the various training-time filters described in
the AlphaFold supplement.
max_template_hits:
max_template_hits:
An upper bound on how many templates are considered. During
An upper bound on how many templates are considered. During
training, the templates ultimately used are subsampled
training, the templates ultimately used are subsampled
...
@@ -89,26 +84,28 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -89,26 +84,28 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self
.
treat_pdb_as_distillation
=
treat_pdb_as_distillation
self
.
treat_pdb_as_distillation
=
treat_pdb_as_distillation
self
.
mode
=
mode
self
.
mode
=
mode
self
.
_output_raw
=
_output_raw
self
.
_output_raw
=
_output_raw
self
.
_alignment_index
=
_alignment_index
valid_modes
=
[
"train"
,
"eval"
,
"predict"
]
valid_modes
=
[
"train"
,
"eval"
,
"predict"
]
if
(
mode
not
in
valid_modes
):
if
(
mode
not
in
valid_modes
):
raise
ValueError
(
f
'mode must be one of
{
valid_modes
}
'
)
raise
ValueError
(
f
'mode must be one of
{
valid_modes
}
'
)
if
(
mapping_path
is
None
):
self
.
mapping
=
{
str
(
i
):
os
.
path
.
splitext
(
name
)[
0
]
for
i
,
name
in
enumerate
(
os
.
listdir
(
alignment_dir
))
}
else
:
with
open
(
mapping_path
,
'r'
)
as
fp
:
self
.
mapping
=
json
.
load
(
fp
)
if
(
template_release_dates_cache_path
is
None
):
if
(
template_release_dates_cache_path
is
None
):
logging
.
warning
(
logging
.
warning
(
"Template release dates cache does not exist. Remember to run "
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_cache.py before running OpenFold"
"scripts/generate_mmcif_cache.py before running OpenFold"
)
)
if
(
mapping_path
is
None
):
self
.
_chain_ids
=
list
(
os
.
listdir
(
alignment_dir
))
else
:
with
open
(
mapping_path
,
"r"
)
as
f
:
self
.
_chain_ids
=
[
l
.
strip
()
for
l
in
f
.
readlines
()]
self
.
_chain_id_to_idx_dict
=
{
chain
:
i
for
i
,
chain
in
enumerate
(
self
.
_chain_ids
)
}
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
mmcif_dir
=
template_mmcif_dir
,
mmcif_dir
=
template_mmcif_dir
,
max_template_date
=
max_template_date
,
max_template_date
=
max_template_date
,
...
@@ -126,7 +123,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -126,7 +123,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if
(
not
self
.
_output_raw
):
if
(
not
self
.
_output_raw
):
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
def
_parse_mmcif
(
self
,
path
,
file_id
,
chain_id
,
alignment_dir
):
def
_parse_mmcif
(
self
,
path
,
file_id
,
chain_id
,
alignment_dir
,
_alignment_index
):
with
open
(
path
,
'r'
)
as
f
:
with
open
(
path
,
'r'
)
as
f
:
mmcif_string
=
f
.
read
()
mmcif_string
=
f
.
read
()
...
@@ -145,14 +142,25 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -145,14 +142,25 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mmcif
=
mmcif_object
,
mmcif
=
mmcif_object
,
alignment_dir
=
alignment_dir
,
alignment_dir
=
alignment_dir
,
chain_id
=
chain_id
,
chain_id
=
chain_id
,
_alignment_index
=
_alignment_index
)
)
return
data
return
data
def
chain_id_to_idx
(
self
,
chain_id
):
return
self
.
_chain_id_to_idx_dict
[
chain_id
]
def
idx_to_chain_id
(
self
,
idx
):
return
self
.
_chain_ids
[
idx
]
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
name
=
self
.
mapping
[
str
(
idx
)
]
name
=
self
.
idx_to_chain_id
(
idx
)
alignment_dir
=
os
.
path
.
join
(
self
.
alignment_dir
,
name
)
alignment_dir
=
os
.
path
.
join
(
self
.
alignment_dir
,
name
)
_alignment_index
=
None
if
(
self
.
_alignment_index
is
not
None
):
_alignment_index
=
self
.
_alignment_index
[
name
]
if
(
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
):
if
(
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
):
spl
=
name
.
rsplit
(
'_'
,
1
)
spl
=
name
.
rsplit
(
'_'
,
1
)
if
(
len
(
spl
)
==
2
):
if
(
len
(
spl
)
==
2
):
...
@@ -164,11 +172,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -164,11 +172,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
path
=
os
.
path
.
join
(
self
.
data_dir
,
file_id
)
path
=
os
.
path
.
join
(
self
.
data_dir
,
file_id
)
if
(
os
.
path
.
exists
(
path
+
".cif"
)):
if
(
os
.
path
.
exists
(
path
+
".cif"
)):
data
=
self
.
_parse_mmcif
(
data
=
self
.
_parse_mmcif
(
path
+
".cif"
,
file_id
,
chain_id
,
alignment_dir
path
+
".cif"
,
file_id
,
chain_id
,
alignment_dir
,
_alignment_index
,
)
)
elif
(
os
.
path
.
exists
(
path
+
".core"
)):
elif
(
os
.
path
.
exists
(
path
+
".core"
)):
data
=
self
.
data_pipeline
.
process_core
(
data
=
self
.
data_pipeline
.
process_core
(
path
+
".core"
,
alignment_dir
path
+
".core"
,
alignment_dir
,
_alignment_index
,
)
)
elif
(
os
.
path
.
exists
(
path
+
".pdb"
)):
elif
(
os
.
path
.
exists
(
path
+
".pdb"
)):
data
=
self
.
data_pipeline
.
process_pdb
(
data
=
self
.
data_pipeline
.
process_pdb
(
...
@@ -176,6 +184,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -176,6 +184,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_dir
=
alignment_dir
,
alignment_dir
=
alignment_dir
,
is_distillation
=
self
.
treat_pdb_as_distillation
,
is_distillation
=
self
.
treat_pdb_as_distillation
,
chain_id
=
chain_id
,
chain_id
=
chain_id
,
_alignment_index
=
_alignment_index
,
)
)
else
:
else
:
raise
ValueError
(
"Invalid file type"
)
raise
ValueError
(
"Invalid file type"
)
...
@@ -184,6 +193,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -184,6 +193,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data
=
self
.
data_pipeline
.
process_fasta
(
data
=
self
.
data_pipeline
.
process_fasta
(
fasta_path
=
path
,
fasta_path
=
path
,
alignment_dir
=
alignment_dir
,
alignment_dir
=
alignment_dir
,
_alignment_index
=
_alignment_index
,
)
)
if
(
self
.
_output_raw
):
if
(
self
.
_output_raw
):
...
@@ -196,56 +206,130 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -196,56 +206,130 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
return
feats
return
feats
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
mapping
.
keys
())
return
len
(
self
.
_chain_ids
)
def
train_filter
(
prot_data_cache_entry
:
Any
,
generator
:
torch
.
Generator
,
max_resolution
:
float
=
9.
,
max_single_aa_prop
:
float
=
0.8
,
)
->
bool
:
# Hard filters
resolution
=
prot_data_cache_entry
.
get
(
"resolution"
,
None
)
if
(
resolution
is
not
None
and
resolution
>
max_resolution
):
return
False
seq
=
prot_data_cache_entry
[
"seq"
]
counts
=
{}
for
aa
in
seq
:
counts
.
setdefault
(
aa
,
0
)
counts
[
aa
]
+=
1
largest_aa_count
=
max
(
counts
.
values
())
largest_single_aa_prop
=
largest_aa_count
/
len
(
seq
)
if
(
largest_single_aa_prop
>
max_single_aa_prop
):
return
False
# Stochastic filters
probabilities
=
[]
cluster_size
=
prot_data_cache_entry
.
get
(
"cluster_size"
,
None
)
if
(
cluster_size
is
not
None
and
cluster_size
>
0
):
probabilities
.
append
(
1
/
cluster_size
)
chain_length
=
len
(
prot_data_cache_entry
[
"seq"
])
probabilities
.
append
((
1
/
512
)
*
(
max
(
min
(
chain_length
,
512
),
256
)))
weights
=
[[
1
-
p
,
p
]
for
p
in
probabilities
]
results
=
torch
.
multinomial
(
torch
.
tensor
(
weights
),
num_samples
=
1
,
generator
=
generator
,
)
def
looped_sequence
(
sequence
):
return
torch
.
all
(
results
)
while
True
:
for
x
in
sequence
:
yield
x
class
OpenFoldDataset
(
torch
.
utils
.
data
.
Iterable
Dataset
):
class
OpenFoldDataset
(
torch
.
utils
.
data
.
Dataset
):
"""
"""
The Dataset is written to accommodate the requirement that proteins are
Implements the stochastic filters applied during AlphaFold's training.
sampled from the distillation set with some probability p
Because samples are selected from constituent datasets randomly, the
and from the PDB set with probability (1 - p). Proteins are sampled
length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
from both sets without replacement, and as soon as either set is
and filtered once at initialization.
emptied, it is refilled. The Dataset therefore has an arbitrary length.
Nevertheless, for compatibility with various PyTorch Lightning
functionalities, it is possible to specify an epoch length. This length
has no effect on the output of the Dataset.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
datasets
:
Sequence
[
OpenFoldSingleDataset
],
datasets
:
Sequence
[
OpenFoldSingleDataset
],
probabilities
:
Sequence
[
int
],
probabilities
:
Sequence
[
int
],
epoch_len
:
int
,
epoch_len
:
int
,
prot_data_cache_paths
:
List
[
str
],
filter_fn
:
Optional
[
Any
]
=
train_filter
,
generator
:
torch
.
Generator
=
None
,
_roll_at_init
:
bool
=
True
,
):
):
self
.
datasets
=
datasets
self
.
datasets
=
datasets
self
.
samplers
=
[
self
.
probabilities
=
probabilities
looped_sequence
(
RandomSampler
(
d
))
for
d
in
datasets
]
self
.
epoch_len
=
epoch_len
self
.
epoch_len
=
epoch_len
self
.
generator
=
generator
self
.
filter_fn
=
filter_fn
self
.
distr
=
torch
.
distributions
.
categorical
.
Categorical
(
def
looped_shuffled_dataset_idx
(
dataset_len
):
probs
=
torch
.
tensor
(
probabilities
),
while
True
:
# Uniformly shuffle each dataset's indices
weights
=
[
1.
for
_
in
range
(
dataset_len
)]
shuf
=
torch
.
multinomial
(
torch
.
tensor
(
weights
),
num_samples
=
dataset_len
,
replacement
=
False
,
generator
=
self
.
generator
,
)
)
for
idx
in
shuf
:
yield
idx
def
__iter__
(
self
):
self
.
shuffled_idx_iters
=
[]
return
self
for
d
in
datasets
:
self
.
shuffled_idx_iters
.
append
(
looped_shuffled_dataset_idx
(
len
(
d
))
)
self
.
prot_data_caches
=
[]
for
path
in
prot_data_cache_paths
:
with
open
(
path
,
"r"
)
as
fp
:
self
.
prot_data_caches
.
append
(
json
.
load
(
fp
))
def
__next__
(
self
):
if
(
_roll_at_init
):
dataset_idx
=
self
.
distr
.
sample
()
self
.
reroll
()
sampler
=
self
.
samplers
[
dataset_idx
]
element_idx
=
next
(
sampler
)
def
__getitem__
(
self
,
idx
):
return
self
.
datasets
[
dataset_idx
][
element_idx
]
dataset_idx
,
datapoint_idx
=
self
.
datapoints
[
idx
]
return
self
.
datasets
[
dataset_idx
][
datapoint_idx
]
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
epoch_len
return
self
.
epoch_len
def
reroll
(
self
):
dataset_choices
=
torch
.
multinomial
(
torch
.
tensor
(
self
.
probabilities
),
num_samples
=
self
.
epoch_len
,
replacement
=
True
,
generator
=
self
.
generator
,
)
self
.
datapoints
=
[]
for
dataset_idx
in
dataset_choices
:
dataset
=
self
.
datasets
[
dataset_idx
]
idx_iter
=
self
.
shuffled_idx_iters
[
dataset_idx
]
prot_data_cache
=
self
.
prot_data_caches
[
dataset_idx
]
datapoint_idx
=
None
while
datapoint_idx
is
None
:
candidate_idx
=
next
(
idx_iter
)
chain_id
=
dataset
.
idx_to_chain_id
(
candidate_idx
)
if
(
self
.
filter_fn
(
prot_data_cache
[
chain_id
],
self
.
generator
)):
datapoint_idx
=
candidate_idx
self
.
datapoints
.
append
((
dataset_idx
,
datapoint_idx
))
class
OpenFoldBatchCollator
:
class
OpenFoldBatchCollator
:
def
__init__
(
self
,
config
,
generator
,
stage
=
"train"
):
def
__init__
(
self
,
config
,
stage
=
"train"
):
self
.
stage
=
stage
self
.
stage
=
stage
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
...
@@ -283,18 +367,17 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
...
@@ -283,18 +367,17 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
keyed_probs
.
append
(
keyed_probs
.
append
(
(
"use_clamped_fape"
,
[
1
-
clamp_prob
,
clamp_prob
])
(
"use_clamped_fape"
,
[
1
-
clamp_prob
,
clamp_prob
])
)
)
if
(
self
.
config
.
supervised
.
uniform_recycling
):
if
(
stage_cfg
.
uniform_recycling
):
recycling_probs
=
[
recycling_probs
=
[
1.
/
(
max_iters
+
1
)
for
_
in
range
(
max_iters
+
1
)
1.
/
(
max_iters
+
1
)
for
_
in
range
(
max_iters
+
1
)
]
]
keyed_probs
.
append
(
(
"no_recycling_iters"
,
recycling_probs
)
)
else
:
else
:
recycling_probs
=
[
recycling_probs
=
[
0.
for
_
in
range
(
max_iters
+
1
)
0.
for
_
in
range
(
max_iters
+
1
)
]
]
recycling_probs
[
-
1
]
=
1.
recycling_probs
[
-
1
]
=
1.
keyed_probs
.
append
(
keyed_probs
.
append
(
(
"no_recycling_iters"
,
recycling_probs
)
(
"no_recycling_iters"
,
recycling_probs
)
)
)
...
@@ -362,8 +445,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -362,8 +445,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_date
:
str
,
max_template_date
:
str
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_alignment_dir
:
Optional
[
str
]
=
None
,
train_alignment_dir
:
Optional
[
str
]
=
None
,
train_filter_fn
:
Optional
[
Any
]
=
train_filter
,
train_prot_data_cache_path
:
Optional
[
str
]
=
None
,
distillation_data_dir
:
Optional
[
str
]
=
None
,
distillation_data_dir
:
Optional
[
str
]
=
None
,
distillation_alignment_dir
:
Optional
[
str
]
=
None
,
distillation_alignment_dir
:
Optional
[
str
]
=
None
,
distillation_prot_data_cache_path
:
Optional
[
str
]
=
None
,
val_data_dir
:
Optional
[
str
]
=
None
,
val_data_dir
:
Optional
[
str
]
=
None
,
val_alignment_dir
:
Optional
[
str
]
=
None
,
val_alignment_dir
:
Optional
[
str
]
=
None
,
predict_data_dir
:
Optional
[
str
]
=
None
,
predict_data_dir
:
Optional
[
str
]
=
None
,
...
@@ -374,6 +460,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -374,6 +460,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
batch_seed
:
Optional
[
int
]
=
None
,
batch_seed
:
Optional
[
int
]
=
None
,
train_epoch_len
:
int
=
50000
,
_alignment_index_path
:
Optional
[
str
]
=
None
,
**
kwargs
**
kwargs
):
):
super
(
OpenFoldDataModule
,
self
).
__init__
()
super
(
OpenFoldDataModule
,
self
).
__init__
()
...
@@ -383,8 +471,13 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -383,8 +471,13 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
max_template_date
=
max_template_date
self
.
max_template_date
=
max_template_date
self
.
train_data_dir
=
train_data_dir
self
.
train_data_dir
=
train_data_dir
self
.
train_alignment_dir
=
train_alignment_dir
self
.
train_alignment_dir
=
train_alignment_dir
self
.
train_filter_fn
=
train_filter_fn
self
.
train_prot_data_cache_path
=
train_prot_data_cache_path
self
.
distillation_data_dir
=
distillation_data_dir
self
.
distillation_data_dir
=
distillation_data_dir
self
.
distillation_alignment_dir
=
distillation_alignment_dir
self
.
distillation_alignment_dir
=
distillation_alignment_dir
self
.
distillation_prot_data_cache_path
=
(
distillation_prot_data_cache_path
)
self
.
val_data_dir
=
val_data_dir
self
.
val_data_dir
=
val_data_dir
self
.
val_alignment_dir
=
val_alignment_dir
self
.
val_alignment_dir
=
val_alignment_dir
self
.
predict_data_dir
=
predict_data_dir
self
.
predict_data_dir
=
predict_data_dir
...
@@ -397,6 +490,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -397,6 +490,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
)
self
.
obsolete_pdbs_file_path
=
obsolete_pdbs_file_path
self
.
obsolete_pdbs_file_path
=
obsolete_pdbs_file_path
self
.
batch_seed
=
batch_seed
self
.
batch_seed
=
batch_seed
self
.
train_epoch_len
=
train_epoch_len
if
(
self
.
train_data_dir
is
None
and
self
.
predict_data_dir
is
None
):
if
(
self
.
train_data_dir
is
None
and
self
.
predict_data_dir
is
None
):
raise
ValueError
(
raise
ValueError
(
...
@@ -406,11 +500,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -406,11 +500,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
training_mode
=
self
.
train_data_dir
is
not
None
self
.
training_mode
=
self
.
train_data_dir
is
not
None
if
(
self
.
training_mode
and
self
.
train_alignment_dir
is
None
):
if
(
self
.
training_mode
and
train_alignment_dir
is
None
):
raise
ValueError
(
raise
ValueError
(
'In training mode, train_alignment_dir must be specified'
'In training mode, train_alignment_dir must be specified'
)
)
elif
(
not
self
.
training_mode
and
self
.
predict_ali
n
gment_dir
is
None
):
elif
(
not
self
.
training_mode
and
predict_alig
n
ment_dir
is
None
):
raise
ValueError
(
raise
ValueError
(
'In inference mode, predict_alignment_dir must be specified'
'In inference mode, predict_alignment_dir must be specified'
)
)
...
@@ -420,10 +514,28 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -420,10 +514,28 @@ class OpenFoldDataModule(pl.LightningDataModule):
'be specified as well'
'be specified as well'
)
)
def
setup
(
self
,
stage
:
Optional
[
str
]
=
None
):
cache_missing
=
(
if
(
stage
is
None
):
train_filter_fn
and
stage
=
"train"
(
train_prot_data_cache_path
is
None
or
(
distillation_data_dir
is
not
None
and
distillation_prot_data_cache_path
is
None
)
)
)
if
(
cache_missing
):
raise
ValueError
(
"If train_filter_fn is given, so must the protein data caches"
)
# An ad-hoc measure for our particular filesystem restrictions
self
.
_alignment_index
=
None
if
(
_alignment_index_path
is
not
None
):
with
open
(
_alignment_index_path
,
"r"
)
as
fp
:
self
.
_alignment_index
=
json
.
load
(
fp
)
def
setup
(
self
):
# Most of the arguments are the same for the three datasets
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
...
@@ -434,10 +546,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -434,10 +546,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
template_release_dates_cache_path
,
self
.
template_release_dates_cache_path
,
obsolete_pdbs_file_path
=
obsolete_pdbs_file_path
=
self
.
obsolete_pdbs_file_path
,
self
.
obsolete_pdbs_file_path
,
_alignment_index
=
self
.
_alignment_index
,
)
)
if
(
self
.
training_mode
):
if
(
self
.
training_mode
):
self
.
train_dataset
=
dataset_gen
(
train_dataset
=
dataset_gen
(
data_dir
=
self
.
train_data_dir
,
data_dir
=
self
.
train_data_dir
,
alignment_dir
=
self
.
train_alignment_dir
,
alignment_dir
=
self
.
train_alignment_dir
,
mapping_path
=
self
.
train_mapping_path
,
mapping_path
=
self
.
train_mapping_path
,
...
@@ -449,6 +562,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -449,6 +562,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
_output_raw
=
True
,
_output_raw
=
True
,
)
)
distillation_dataset
=
None
if
(
self
.
distillation_data_dir
is
not
None
):
if
(
self
.
distillation_data_dir
is
not
None
):
distillation_dataset
=
dataset_gen
(
distillation_dataset
=
dataset_gen
(
data_dir
=
self
.
distillation_data_dir
,
data_dir
=
self
.
distillation_data_dir
,
...
@@ -461,12 +575,29 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -461,12 +575,29 @@ class OpenFoldDataModule(pl.LightningDataModule):
)
)
d_prob
=
self
.
config
.
train
.
distillation_prob
d_prob
=
self
.
config
.
train
.
distillation_prob
if
(
distillation_dataset
is
not
None
):
datasets
=
[
train_dataset
,
distillation_dataset
]
d_prob
=
self
.
config
.
train
.
distillation_prob
probabilities
=
[
1
-
d_prob
,
d_prob
]
prot_data_cache_paths
=
[
self
.
train_prot_data_cache_path
,
self
.
distillation_prot_data_cache_path
,
]
else
:
datasets
=
[
train_dataset
]
probabilities
=
[
1.
]
prot_data_cache_paths
=
[
self
.
train_prot_data_cache_path
,
]
self
.
train_dataset
=
OpenFoldDataset
(
self
.
train_dataset
=
OpenFoldDataset
(
datasets
=
[
self
.
train_dataset
,
distillation_dataset
],
datasets
=
datasets
,
probabilities
=
[
1
-
d_prob
,
d_prob
],
probabilities
=
probabilities
,
epoch_len
=
(
epoch_len
=
self
.
train_epoch_len
,
self
.
train_dataset
.
len
()
+
distillation_dataset
.
len
()
prot_data_cache_paths
=
prot_data_cache_paths
,
),
filter_fn
=
self
.
train_filter_fn
,
_roll_at_init
=
False
,
)
)
if
(
self
.
val_data_dir
is
not
None
):
if
(
self
.
val_data_dir
is
not
None
):
...
@@ -497,6 +628,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -497,6 +628,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
dataset
=
None
dataset
=
None
if
(
stage
==
"train"
):
if
(
stage
==
"train"
):
dataset
=
self
.
train_dataset
dataset
=
self
.
train_dataset
# Filter the dataset, if necessary
dataset
.
reroll
()
elif
(
stage
==
"eval"
):
elif
(
stage
==
"eval"
):
dataset
=
self
.
eval_dataset
dataset
=
self
.
eval_dataset
elif
(
stage
==
"predict"
):
elif
(
stage
==
"predict"
):
...
...
openfold/data/data_pipeline.py
View file @
6ce8cfe3
...
@@ -422,8 +422,37 @@ class DataPipeline:
...
@@ -422,8 +422,37 @@ class DataPipeline:
def
_parse_msa_data
(
def
_parse_msa_data
(
self
,
self
,
alignment_dir
:
str
,
alignment_dir
:
str
,
_alignment_index
:
Optional
[
Any
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
msa_data
=
{}
msa_data
=
{}
if
(
_alignment_index
is
not
None
):
fp
=
open
(
_alignment_index
[
"db"
],
"rb"
)
def
read_msa
(
start
,
size
):
fp
.
seek
(
start
)
msa
=
fp
.
read
(
size
).
encode
(
"utf-8"
)
return
msa
for
(
name
,
start
,
size
)
in
_alignment_index
[
"files"
]:
ext
=
os
.
path
.
splitext
(
name
)[
-
1
]
if
(
ext
==
".a3m"
):
msa
,
deletion_matrix
=
parsers
.
parse_a3m
(
read_msa
(
start
,
size
)
)
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
elif
(
ext
==
".sto"
):
msa
,
deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
read_msa
(
start
,
size
)
)
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
else
:
continue
msa_data
[
f
]
=
data
fp
.
close
()
else
:
for
f
in
os
.
listdir
(
alignment_dir
):
for
f
in
os
.
listdir
(
alignment_dir
):
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
ext
=
os
.
path
.
splitext
(
f
)[
-
1
]
ext
=
os
.
path
.
splitext
(
f
)[
-
1
]
...
@@ -448,8 +477,23 @@ class DataPipeline:
...
@@ -448,8 +477,23 @@ class DataPipeline:
def
_parse_template_hits
(
def
_parse_template_hits
(
self
,
self
,
alignment_dir
:
str
,
alignment_dir
:
str
,
_alignment_index
:
Optional
[
Any
]
=
None
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
all_hits
=
{}
all_hits
=
{}
if
(
_alignment_index
is
not
None
):
fp
=
open
(
_alignment_index
[
"db"
],
'rb'
)
def
read_template
(
start
,
size
):
fp
.
seek
(
start
)
return
fp
.
read
(
size
).
encode
(
"utf-8"
)
for
(
name
,
start
,
size
)
in
_alignment_index
[
"files"
]:
ext
=
os
.
path
.
splitext
(
name
)[
-
1
]
if
(
ext
==
".hhr"
):
hits
=
parsers
.
parse_hhr
(
read_template
(
start
,
size
))
all_hits
[
f
]
=
hits
else
:
for
f
in
os
.
listdir
(
alignment_dir
):
for
f
in
os
.
listdir
(
alignment_dir
):
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
ext
=
os
.
path
.
splitext
(
f
)[
-
1
]
ext
=
os
.
path
.
splitext
(
f
)[
-
1
]
...
@@ -465,8 +509,9 @@ class DataPipeline:
...
@@ -465,8 +509,9 @@ class DataPipeline:
self
,
self
,
alignment_dir
:
str
,
alignment_dir
:
str
,
input_sequence
:
Optional
[
str
]
=
None
,
input_sequence
:
Optional
[
str
]
=
None
,
_alignment_index
:
Optional
[
str
]
=
None
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
msa_data
=
self
.
_parse_msa_data
(
alignment_dir
)
msa_data
=
self
.
_parse_msa_data
(
alignment_dir
,
_alignment_index
)
if
(
len
(
msa_data
)
==
0
):
if
(
len
(
msa_data
)
==
0
):
if
(
input_sequence
is
None
):
if
(
input_sequence
is
None
):
...
@@ -496,6 +541,7 @@ class DataPipeline:
...
@@ -496,6 +541,7 @@ class DataPipeline:
self
,
self
,
fasta_path
:
str
,
fasta_path
:
str
,
alignment_dir
:
str
,
alignment_dir
:
str
,
_alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
)
->
FeatureDict
:
"""Assembles features for a single sequence in a FASTA file"""
"""Assembles features for a single sequence in a FASTA file"""
with
open
(
fasta_path
)
as
f
:
with
open
(
fasta_path
)
as
f
:
...
@@ -509,7 +555,7 @@ class DataPipeline:
...
@@ -509,7 +555,7 @@ class DataPipeline:
input_description
=
input_descs
[
0
]
input_description
=
input_descs
[
0
]
num_res
=
len
(
input_sequence
)
num_res
=
len
(
input_sequence
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
template_features
=
make_template_features
(
template_features
=
make_template_features
(
input_sequence
,
input_sequence
,
hits
,
hits
,
...
@@ -535,6 +581,7 @@ class DataPipeline:
...
@@ -535,6 +581,7 @@ class DataPipeline:
mmcif
:
mmcif_parsing
.
MmcifObject
,
# parsing is expensive, so no path
mmcif
:
mmcif_parsing
.
MmcifObject
,
# parsing is expensive, so no path
alignment_dir
:
str
,
alignment_dir
:
str
,
chain_id
:
Optional
[
str
]
=
None
,
chain_id
:
Optional
[
str
]
=
None
,
_alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
)
->
FeatureDict
:
"""
"""
Assembles features for a specific chain in an mmCIF object.
Assembles features for a specific chain in an mmCIF object.
...
@@ -552,7 +599,7 @@ class DataPipeline:
...
@@ -552,7 +599,7 @@ class DataPipeline:
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
template_features
=
make_template_features
(
template_features
=
make_template_features
(
input_sequence
,
input_sequence
,
hits
,
hits
,
...
@@ -570,6 +617,7 @@ class DataPipeline:
...
@@ -570,6 +617,7 @@ class DataPipeline:
alignment_dir
:
str
,
alignment_dir
:
str
,
is_distillation
:
bool
=
True
,
is_distillation
:
bool
=
True
,
chain_id
:
Optional
[
str
]
=
None
,
chain_id
:
Optional
[
str
]
=
None
,
_alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
)
->
FeatureDict
:
"""
"""
Assembles features for a protein in a PDB file.
Assembles features for a protein in a PDB file.
...
@@ -586,7 +634,7 @@ class DataPipeline:
...
@@ -586,7 +634,7 @@ class DataPipeline:
is_distillation
is_distillation
)
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
template_features
=
make_template_features
(
template_features
=
make_template_features
(
input_sequence
,
input_sequence
,
hits
,
hits
,
...
@@ -601,6 +649,7 @@ class DataPipeline:
...
@@ -601,6 +649,7 @@ class DataPipeline:
self
,
self
,
core_path
:
str
,
core_path
:
str
,
alignment_dir
:
str
,
alignment_dir
:
str
,
_alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
)
->
FeatureDict
:
"""
"""
Assembles features for a protein in a ProteinNet .core file.
Assembles features for a protein in a ProteinNet .core file.
...
@@ -613,7 +662,7 @@ class DataPipeline:
...
@@ -613,7 +662,7 @@ class DataPipeline:
description
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
core_path
))[
0
].
upper
()
description
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
core_path
))[
0
].
upper
()
core_feats
=
make_protein_features
(
protein_object
,
description
)
core_feats
=
make_protein_features
(
protein_object
,
description
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
template_features
=
make_template_features
(
template_features
=
make_template_features
(
input_sequence
,
input_sequence
,
hits
,
hits
,
...
...
openfold/model/evoformer.py
View file @
6ce8cfe3
...
@@ -360,7 +360,8 @@ class ExtraMSABlock(nn.Module):
...
@@ -360,7 +360,8 @@ class ExtraMSABlock(nn.Module):
mask
=
msa_mask
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
_chunk_logits
=
_chunk_logits
,
_chunk_logits
=
_chunk_logits
,
_checkpoint_chunks
=
self
.
ckpt
,
_checkpoint_chunks
=
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
)
)
)
)
...
...
openfold/model/primitives.py
View file @
6ce8cfe3
...
@@ -188,7 +188,7 @@ class LayerNorm(nn.Module):
...
@@ -188,7 +188,7 @@ class LayerNorm(nn.Module):
self
.
bias
.
to
(
dtype
=
d
),
self
.
bias
.
to
(
dtype
=
d
),
self
.
eps
self
.
eps
)
)
el
if
(
d
==
torch
.
bfloat16
)
:
el
se
:
out
=
nn
.
functional
.
layer_norm
(
out
=
nn
.
functional
.
layer_norm
(
x
,
x
,
self
.
c_in
,
self
.
c_in
,
...
@@ -209,7 +209,7 @@ def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
...
@@ -209,7 +209,7 @@ def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
if
(
d
is
torch
.
bfloat16
and
not
deepspeed
.
utils
.
is_initialized
()):
if
(
d
is
torch
.
bfloat16
and
not
deepspeed
.
utils
.
is_initialized
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
s
=
torch
.
nn
.
functional
.
softmax
(
t
,
dim
=
dim
)
s
=
torch
.
nn
.
functional
.
softmax
(
t
,
dim
=
dim
)
el
if
(
d
==
torch
.
bfloat16
)
:
el
se
:
s
=
torch
.
nn
.
functional
.
softmax
(
t
,
dim
=
dim
)
s
=
torch
.
nn
.
functional
.
softmax
(
t
,
dim
=
dim
)
return
s
return
s
...
...
openfold/model/triangular_attention.py
View file @
6ce8cfe3
...
@@ -65,8 +65,7 @@ class TriangleAttention(nn.Module):
...
@@ -65,8 +65,7 @@ class TriangleAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
mha_inputs
=
{
mha_inputs
=
{
"q_x"
:
x
,
"q_x"
:
x
,
"k_x"
:
x
,
"kv_x"
:
x
,
"v_x"
:
x
,
"biases"
:
biases
,
"biases"
:
biases
,
}
}
return
chunk_layer
(
return
chunk_layer
(
...
...
openfold/utils/import_weights.py
View file @
6ce8cfe3
...
@@ -282,13 +282,19 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -282,13 +282,19 @@ def import_jax_weights_(model, npz_path, version="model_1"):
b
.
msa_att_row
b
.
msa_att_row
),
),
col_att_name
:
msa_col_att_params
,
col_att_name
:
msa_col_att_params
,
"msa_transition"
:
MSATransitionParams
(
b
.
msa_transition
),
"msa_transition"
:
MSATransitionParams
(
b
.
core
.
msa_transition
),
"outer_product_mean"
:
OuterProductMeanParams
(
b
.
outer_product_mean
),
"outer_product_mean"
:
"triangle_multiplication_outgoing"
:
TriMulOutParams
(
b
.
tri_mul_out
),
OuterProductMeanParams
(
b
.
core
.
outer_product_mean
),
"triangle_multiplication_incoming"
:
TriMulInParams
(
b
.
tri_mul_in
),
"triangle_multiplication_outgoing"
:
"triangle_attention_starting_node"
:
TriAttParams
(
b
.
tri_att_start
),
TriMulOutParams
(
b
.
core
.
tri_mul_out
),
"triangle_attention_ending_node"
:
TriAttParams
(
b
.
tri_att_end
),
"triangle_multiplication_incoming"
:
"pair_transition"
:
PairTransitionParams
(
b
.
pair_transition
),
TriMulInParams
(
b
.
core
.
tri_mul_in
),
"triangle_attention_starting_node"
:
TriAttParams
(
b
.
core
.
tri_att_start
),
"triangle_attention_ending_node"
:
TriAttParams
(
b
.
core
.
tri_att_end
),
"pair_transition"
:
PairTransitionParams
(
b
.
core
.
pair_transition
),
}
}
return
d
return
d
...
@@ -323,7 +329,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -323,7 +329,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
[
TemplatePairBlockParams
(
b
)
for
b
in
tps_blocks
]
[
TemplatePairBlockParams
(
b
)
for
b
in
tps_blocks
]
)
)
ems_blocks
=
model
.
extra_msa_stack
.
stack
.
blocks
ems_blocks
=
model
.
extra_msa_stack
.
blocks
ems_blocks_params
=
stacked
([
ExtraMSABlockParams
(
b
)
for
b
in
ems_blocks
])
ems_blocks_params
=
stacked
([
ExtraMSABlockParams
(
b
)
for
b
in
ems_blocks
])
evo_blocks
=
model
.
evoformer
.
blocks
evo_blocks
=
model
.
evoformer
.
blocks
...
...
openfold/utils/loss.py
View file @
6ce8cfe3
...
@@ -1520,7 +1520,6 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1520,7 +1520,6 @@ class AlphaFoldLoss(nn.Module):
weight
=
self
.
config
[
loss_name
].
weight
weight
=
self
.
config
[
loss_name
].
weight
if
weight
:
if
weight
:
loss
=
loss_fn
()
loss
=
loss_fn
()
if
(
torch
.
isnan
(
loss
)
or
torch
.
isinf
(
loss
)):
if
(
torch
.
isnan
(
loss
)
or
torch
.
isinf
(
loss
)):
logging
.
warning
(
f
"
{
loss_name
}
loss is NaN. Skipping..."
)
logging
.
warning
(
f
"
{
loss_name
}
loss is NaN. Skipping..."
)
loss
=
loss
.
new_tensor
(
0.
,
requires_grad
=
True
)
loss
=
loss
.
new_tensor
(
0.
,
requires_grad
=
True
)
...
...
scripts/generate_mmcif_cache.py
deleted
100644 → 0
View file @
1df4991d
import
argparse
from
functools
import
partial
import
logging
from
multiprocessing
import
Pool
import
os
import
sys
import
json
sys
.
path
.
append
(
"."
)
# an innocent hack to get this to run from the top level
from
tqdm
import
tqdm
from
openfold.data.mmcif_parsing
import
parse
def
parse_file
(
f
,
args
):
with
open
(
os
.
path
.
join
(
args
.
mmcif_dir
,
f
),
"r"
)
as
fp
:
mmcif_string
=
fp
.
read
()
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
mmcif
=
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
if
mmcif
.
mmcif_object
is
None
:
logging
.
info
(
f
"Could not parse
{
f
}
. Skipping..."
)
return
{}
else
:
mmcif
=
mmcif
.
mmcif_object
local_data
=
{}
local_data
[
"release_date"
]
=
mmcif
.
header
[
"release_date"
]
local_data
[
"no_chains"
]
=
len
(
list
(
mmcif
.
structure
.
get_chains
()))
return
{
file_id
:
local_data
}
def
main
(
args
):
files
=
[
f
for
f
in
os
.
listdir
(
args
.
mmcif_dir
)
if
".cif"
in
f
]
fn
=
partial
(
parse_file
,
args
=
args
)
data
=
{}
with
Pool
(
processes
=
args
.
no_workers
)
as
p
:
with
tqdm
(
total
=
len
(
files
))
as
pbar
:
for
d
in
p
.
imap_unordered
(
fn
,
files
,
chunksize
=
args
.
chunksize
):
data
.
update
(
d
)
pbar
.
update
()
with
open
(
args
.
output_path
,
"w"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
data
,
indent
=
4
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"mmcif_dir"
,
type
=
str
,
help
=
"Directory containing mmCIF files"
)
parser
.
add_argument
(
"output_path"
,
type
=
str
,
help
=
"Path for .json output"
)
parser
.
add_argument
(
"--no_workers"
,
type
=
int
,
default
=
4
,
help
=
"Number of workers to use for parsing"
)
parser
.
add_argument
(
"--chunksize"
,
type
=
int
,
default
=
10
,
help
=
"How many files should be distributed to each worker at a time"
)
args
=
parser
.
parse_args
()
main
(
args
)
scripts/precompute_alignments.py
View file @
6ce8cfe3
...
@@ -25,11 +25,12 @@ def run_seq_group_alignments(seq_groups, alignment_runner, args):
...
@@ -25,11 +25,12 @@ def run_seq_group_alignments(seq_groups, alignment_runner, args):
first_name
=
names
[
0
]
first_name
=
names
[
0
]
alignment_dir
=
os
.
path
.
join
(
args
.
output_dir
,
first_name
)
alignment_dir
=
os
.
path
.
join
(
args
.
output_dir
,
first_name
)
try
:
os
.
makedirs
(
alignment_dir
,
exist_ok
=
True
)
os
.
makedirs
(
alignment_dir
)
# try:
except
Exception
as
e
:
# os.makedirs(alignment_dir)
logging
.
warning
(
f
"Failed to create directory for
{
first_name
}
with exception
{
e
}
..."
)
# except Exception as e:
continue
# logging.warning(f"Failed to create directory for {first_name} with exception {e}...")
# continue
fd
,
fasta_path
=
tempfile
.
mkstemp
(
suffix
=
".fasta"
)
fd
,
fasta_path
=
tempfile
.
mkstemp
(
suffix
=
".fasta"
)
with
os
.
fdopen
(
fd
,
'w'
)
as
fp
:
with
os
.
fdopen
(
fd
,
'w'
)
as
fp
:
...
@@ -48,14 +49,14 @@ def run_seq_group_alignments(seq_groups, alignment_runner, args):
...
@@ -48,14 +49,14 @@ def run_seq_group_alignments(seq_groups, alignment_runner, args):
os
.
remove
(
fasta_path
)
os
.
remove
(
fasta_path
)
for
name
in
names
[
1
:]:
for
name
in
names
[
1
:]:
if
(
name
in
dirs
):
#
if(name in dirs):
logging
.
warning
(
#
logging.warning(
f
'
{
name
}
has already been processed. Skipping...'
#
f'{name} has already been processed. Skipping...'
)
#
)
continue
#
continue
cp_dir
=
os
.
path
.
join
(
args
.
output_dir
,
name
)
cp_dir
=
os
.
path
.
join
(
args
.
output_dir
,
name
)
os
.
makedirs
(
cp_dir
)
os
.
makedirs
(
cp_dir
,
exist_ok
=
True
)
for
f
in
os
.
listdir
(
alignment_dir
):
for
f
in
os
.
listdir
(
alignment_dir
):
copyfile
(
os
.
path
.
join
(
alignment_dir
,
f
),
os
.
path
.
join
(
cp_dir
,
f
))
copyfile
(
os
.
path
.
join
(
alignment_dir
,
f
),
os
.
path
.
join
(
cp_dir
,
f
))
...
@@ -136,23 +137,23 @@ def main(args):
...
@@ -136,23 +137,23 @@ def main(args):
else
:
else
:
cache
=
None
cache
=
None
if
(
cache
is
not
None
and
args
.
filter
):
dirs
=
[]
dirs
=
set
(
os
.
listdir
(
args
.
output_di
r
)
)
#if(cache is not None and args.filte
r)
:
#
dirs = set(os.listdir(args.output_dir))
def
prot_is_done
(
f
):
#
def prot_is_done(f):
prot_id
=
os
.
path
.
splitext
(
f
)[
0
]
#
prot_id = os.path.splitext(f)[0]
if
(
prot_id
in
cache
):
#
if(prot_id in cache):
chain_ids
=
cache
[
prot_id
][
"chain_ids"
]
#
chain_ids = cache[prot_id]["chain_ids"]
for
c
in
chain_ids
:
#
for c in chain_ids:
full_name
=
prot_id
+
"_"
+
c
#
full_name = prot_id + "_" + c
if
(
not
full_name
in
dirs
):
#
if(not full_name in dirs):
return
False
#
return False
else
:
#
else:
return
False
#
return False
return
True
#
return True
files
=
[
f
for
f
in
files
if
not
prot_is_done
(
f
)]
#
files = [f for f in files if not prot_is_done(f)]
def
split_up_arglist
(
arglist
):
def
split_up_arglist
(
arglist
):
# Split up the survivors
# Split up the survivors
...
...
scripts/utils.py
View file @
6ce8cfe3
...
@@ -4,19 +4,19 @@ from datetime import date
...
@@ -4,19 +4,19 @@ from datetime import date
def
add_data_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_data_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
parser
.
add_argument
(
'uniref90_database_path'
,
type
=
str
,
'
--
uniref90_database_path'
,
type
=
str
,
default
=
None
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'mgnify_database_path'
,
type
=
str
,
'
--
mgnify_database_path'
,
type
=
str
,
default
=
None
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'pdb70_database_path'
,
type
=
str
,
'
--
pdb70_database_path'
,
type
=
str
,
default
=
None
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'template_mmcif_dir'
,
type
=
str
,
'
--
template_mmcif_dir'
,
type
=
str
,
default
=
None
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'uniclust30_database_path'
,
type
=
str
,
'
--
uniclust30_database_path'
,
type
=
str
,
default
=
None
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--bfd_database_path'
,
type
=
str
,
default
=
None
,
'--bfd_database_path'
,
type
=
str
,
default
=
None
,
...
...
tests/test_evoformer.py
View file @
6ce8cfe3
...
@@ -135,8 +135,8 @@ class TestEvoformerStack(unittest.TestCase):
...
@@ -135,8 +135,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa
=
out_repro_msa
.
cpu
()
out_repro_msa
=
out_repro_msa
.
cpu
()
out_repro_pair
=
out_repro_pair
.
cpu
()
out_repro_pair
=
out_repro_pair
.
cpu
()
assert
torch
.
max
(
torch
.
abs
(
out_repro_msa
-
out_gt_msa
)
<
consts
.
eps
)
assert
(
torch
.
max
(
torch
.
abs
(
out_repro_msa
-
out_gt_msa
)
)
<
consts
.
eps
)
assert
torch
.
max
(
torch
.
abs
(
out_repro_pair
-
out_gt_pair
)
<
consts
.
eps
)
assert
(
torch
.
max
(
torch
.
abs
(
out_repro_pair
-
out_gt_pair
)
)
<
consts
.
eps
)
class
TestExtraMSAStack
(
unittest
.
TestCase
):
class
TestExtraMSAStack
(
unittest
.
TestCase
):
...
@@ -172,7 +172,7 @@ class TestExtraMSAStack(unittest.TestCase):
...
@@ -172,7 +172,7 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n
,
transition_n
,
msa_dropout
,
msa_dropout
,
pair_stack_dropout
,
pair_stack_dropout
,
blocks_per_
ckpt
=
Non
e
,
ckpt
=
Fals
e
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
).
eval
()
).
eval
()
...
@@ -257,16 +257,19 @@ class TestMSATransition(unittest.TestCase):
...
@@ -257,16 +257,19 @@ class TestMSATransition(unittest.TestCase):
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
model
.
evoformer
.
blocks
[
0
].
core
.
msa_transition
(
.
msa_transition
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
)
)
.
cpu
()
.
cpu
()
)
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
print
(
out_gt
)
print
(
out_repro
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tests/test_model.py
View file @
6ce8cfe3
...
@@ -130,5 +130,6 @@ class TestModel(unittest.TestCase):
...
@@ -130,5 +130,6 @@ class TestModel(unittest.TestCase):
out_repro
=
out_repro
[
"sm"
][
"positions"
][
-
1
]
out_repro
=
out_repro
[
"sm"
][
"positions"
][
-
1
]
out_repro
=
out_repro
.
squeeze
(
0
)
out_repro
=
out_repro
.
squeeze
(
0
)
print
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
)))
print
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)))
print
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)))
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
1e-3
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
1e-3
)
tests/test_msa.py
View file @
6ce8cfe3
...
@@ -88,15 +88,13 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
...
@@ -88,15 +88,13 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
model
.
evoformer
.
blocks
[
0
].
msa_att_row
(
.
msa_att_row
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
torch
.
as_tensor
(
msa_act
).
cuda
(),
z
=
torch
.
as_tensor
(
pair_act
).
cuda
(),
z
=
torch
.
as_tensor
(
pair_act
).
cuda
(),
chunk_size
=
4
,
chunk_size
=
4
,
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
)
)
.
cpu
()
).
cpu
()
)
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
...
@@ -153,14 +151,14 @@ class TestMSAColumnAttention(unittest.TestCase):
...
@@ -153,14 +151,14 @@ class TestMSAColumnAttention(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
model
.
evoformer
.
blocks
[
0
].
msa_att_col
(
.
msa_att_col
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
torch
.
as_tensor
(
msa_act
).
cuda
(),
chunk_size
=
4
,
chunk_size
=
4
,
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
)
)
.
cpu
()
).
cpu
()
)
print
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
)))
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
...
@@ -218,8 +216,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
...
@@ -218,8 +216,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
out_repro
=
(
model
.
extra_msa_stack
.
stack
.
blocks
[
0
]
model
.
extra_msa_stack
.
blocks
[
0
].
msa_att_col
(
.
msa_att_col
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
chunk_size
=
4
,
chunk_size
=
4
,
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
...
...
tests/test_triangular_attention.py
View file @
6ce8cfe3
...
@@ -85,9 +85,9 @@ class TestTriangularAttention(unittest.TestCase):
...
@@ -85,9 +85,9 @@ class TestTriangularAttention(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
module
=
(
model
.
evoformer
.
blocks
[
0
].
tri_att_start
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_start
if
starting
if
starting
else
model
.
evoformer
.
blocks
[
0
].
tri_att_end
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_end
)
)
out_repro
=
module
(
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
...
@@ -95,7 +95,7 @@ class TestTriangularAttention(unittest.TestCase):
...
@@ -95,7 +95,7 @@ class TestTriangularAttention(unittest.TestCase):
chunk_size
=
None
,
chunk_size
=
None
,
).
cpu
()
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
)
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
)
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tri_att_end_compare
(
self
):
def
test_tri_att_end_compare
(
self
):
...
...
tests/test_triangular_multiplicative_update.py
View file @
6ce8cfe3
...
@@ -87,9 +87,9 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -87,9 +87,9 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
module
=
(
model
.
evoformer
.
blocks
[
0
].
tri_mul_in
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_in
if
incoming
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
tri_mul_out
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_out
)
)
out_repro
=
module
(
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
...
...
train_openfold.py
View file @
6ce8cfe3
...
@@ -67,6 +67,8 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -67,6 +67,8 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss
# Compute loss
loss
=
self
.
loss
(
outputs
,
batch
)
loss
=
self
.
loss
(
outputs
,
batch
)
self
.
log
(
"loss"
,
loss
)
return
{
"loss"
:
loss
}
return
{
"loss"
:
loss
}
def
validation_step
(
self
,
batch
,
batch_idx
):
def
validation_step
(
self
,
batch
,
batch_idx
):
...
@@ -79,6 +81,7 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -79,6 +81,7 @@ class OpenFoldWrapper(pl.LightningModule):
outputs
=
self
(
batch
)
outputs
=
self
(
batch
)
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
loss
=
self
.
loss
(
outputs
,
batch
)
loss
=
self
.
loss
(
outputs
,
batch
)
self
.
log
(
"val_loss"
,
loss
)
return
{
"val_loss"
:
loss
}
return
{
"val_loss"
:
loss
}
def
validation_epoch_end
(
self
,
_
):
def
validation_epoch_end
(
self
,
_
):
...
@@ -316,6 +319,15 @@ if __name__ == "__main__":
...
@@ -316,6 +319,15 @@ if __name__ == "__main__":
"--script_modules"
,
type
=
bool_type
,
default
=
False
,
"--script_modules"
,
type
=
bool_type
,
default
=
False
,
help
=
"Whether to TorchScript eligible components of them model"
help
=
"Whether to TorchScript eligible components of them model"
)
)
parser
.
add_argument
(
"--train_prot_data_cache_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--distillation_prot_data_cache_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--train_epoch_len"
,
type
=
int
,
default
=
10000
,
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
# Disable the initial validation pass
# Disable the initial validation pass
...
@@ -324,7 +336,14 @@ if __name__ == "__main__":
...
@@ -324,7 +336,14 @@ if __name__ == "__main__":
)
)
# Remove some buggy/redundant arguments introduced by the Trainer
# Remove some buggy/redundant arguments introduced by the Trainer
remove_arguments
(
parser
,
[
"--accelerator"
,
"--resume_from_checkpoint"
])
remove_arguments
(
parser
,
[
"--accelerator"
,
"--resume_from_checkpoint"
,
"--reload_dataloaders_every_epoch"
]
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -333,4 +352,7 @@ if __name__ == "__main__":
...
@@ -333,4 +352,7 @@ if __name__ == "__main__":
(
args
.
num_nodes
is
not
None
and
args
.
num_nodes
>
1
))):
(
args
.
num_nodes
is
not
None
and
args
.
num_nodes
>
1
))):
raise
ValueError
(
"For distributed training, --seed must be specified"
)
raise
ValueError
(
"For distributed training, --seed must be specified"
)
# This re-applies the training-time filters at the beginning of every epoch
args
.
reload_dataloaders_every_epoch
=
True
main
(
args
)
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment