Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
a3e8ebbc
Commit
a3e8ebbc
authored
Feb 10, 2022
by
Gustaf Ahdritz
Browse files
Add missing function
parent
b9faee76
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
139 additions
and
60 deletions
+139
-60
openfold/data/data_modules.py
openfold/data/data_modules.py
+49
-29
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+83
-31
openfold/np/residue_constants.py
openfold/np/residue_constants.py
+7
-0
No files found.
openfold/data/data_modules.py
View file @
a3e8ebbc
...
...
@@ -37,6 +37,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mapping_path
:
Optional
[
str
]
=
None
,
mode
:
str
=
"train"
,
_output_raw
:
bool
=
False
,
_alignment_index
:
Optional
[
Any
]
=
None
):
"""
Args:
...
...
@@ -83,6 +84,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self
.
treat_pdb_as_distillation
=
treat_pdb_as_distillation
self
.
mode
=
mode
self
.
_output_raw
=
_output_raw
self
.
_alignment_index
=
_alignment_index
valid_modes
=
[
"train"
,
"eval"
,
"predict"
]
if
(
mode
not
in
valid_modes
):
...
...
@@ -94,7 +96,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
"scripts/generate_mmcif_cache.py before running OpenFold"
)
if
(
mapping_path
is
None
):
if
(
_alignment_index
is
not
None
):
self
.
_chain_ids
=
list
(
_alignment_index
.
keys
())
elif
(
mapping_path
is
None
):
self
.
_chain_ids
=
list
(
os
.
listdir
(
alignment_dir
))
else
:
with
open
(
mapping_path
,
"r"
)
as
f
:
...
...
@@ -121,7 +125,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if
(
not
self
.
_output_raw
):
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
def
_parse_mmcif
(
self
,
path
,
file_id
,
chain_id
,
alignment_dir
):
def
_parse_mmcif
(
self
,
path
,
file_id
,
chain_id
,
alignment_dir
,
_alignment_index
):
with
open
(
path
,
'r'
)
as
f
:
mmcif_string
=
f
.
read
()
...
...
@@ -140,6 +144,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mmcif
=
mmcif_object
,
alignment_dir
=
alignment_dir
,
chain_id
=
chain_id
,
_alignment_index
=
_alignment_index
)
return
data
...
...
@@ -154,6 +159,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
name
=
self
.
idx_to_chain_id
(
idx
)
alignment_dir
=
os
.
path
.
join
(
self
.
alignment_dir
,
name
)
_alignment_index
=
None
if
(
self
.
_alignment_index
is
not
None
):
alignment_dir
=
self
.
alignment_dir
_alignment_index
=
self
.
_alignment_index
[
name
]
if
(
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
):
spl
=
name
.
rsplit
(
'_'
,
1
)
if
(
len
(
spl
)
==
2
):
...
...
@@ -165,11 +175,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
path
=
os
.
path
.
join
(
self
.
data_dir
,
file_id
)
if
(
os
.
path
.
exists
(
path
+
".cif"
)):
data
=
self
.
_parse_mmcif
(
path
+
".cif"
,
file_id
,
chain_id
,
alignment_dir
,
path
+
".cif"
,
file_id
,
chain_id
,
alignment_dir
,
_alignment_index
,
)
elif
(
os
.
path
.
exists
(
path
+
".core"
)):
data
=
self
.
data_pipeline
.
process_core
(
path
+
".core"
,
alignment_dir
,
path
+
".core"
,
alignment_dir
,
_alignment_index
,
)
elif
(
os
.
path
.
exists
(
path
+
".pdb"
)):
data
=
self
.
data_pipeline
.
process_pdb
(
...
...
@@ -177,6 +187,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_dir
=
alignment_dir
,
is_distillation
=
self
.
treat_pdb_as_distillation
,
chain_id
=
chain_id
,
_alignment_index
=
_alignment_index
,
)
else
:
raise
ValueError
(
"Invalid file type"
)
...
...
@@ -185,6 +196,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data
=
self
.
data_pipeline
.
process_fasta
(
fasta_path
=
path
,
alignment_dir
=
alignment_dir
,
_alignment_index
=
_alignment_index
,
)
if
(
self
.
_output_raw
):
...
...
@@ -201,16 +213,16 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
def
deterministic_train_filter
(
chain
_data_cache_entry
:
Any
,
prot
_data_cache_entry
:
Any
,
max_resolution
:
float
=
9.
,
max_single_aa_prop
:
float
=
0.8
,
)
->
bool
:
# Hard filters
resolution
=
chain
_data_cache_entry
.
get
(
"resolution"
,
None
)
resolution
=
prot
_data_cache_entry
.
get
(
"resolution"
,
None
)
if
(
resolution
is
not
None
and
resolution
>
max_resolution
):
return
False
seq
=
chain
_data_cache_entry
[
"seq"
]
seq
=
prot
_data_cache_entry
[
"seq"
]
counts
=
{}
for
aa
in
seq
:
counts
.
setdefault
(
aa
,
0
)
...
...
@@ -224,16 +236,16 @@ def deterministic_train_filter(
def
get_stochastic_train_filter_prob
(
chain
_data_cache_entry
:
Any
,
prot
_data_cache_entry
:
Any
,
)
->
List
[
float
]:
# Stochastic filters
probabilities
=
[]
cluster_size
=
chain
_data_cache_entry
.
get
(
"cluster_size"
,
None
)
cluster_size
=
prot
_data_cache_entry
.
get
(
"cluster_size"
,
None
)
if
(
cluster_size
is
not
None
and
cluster_size
>
0
):
probabilities
.
append
(
1
/
cluster_size
)
chain_length
=
len
(
chain
_data_cache_entry
[
"seq"
])
chain_length
=
len
(
prot
_data_cache_entry
[
"seq"
])
probabilities
.
append
((
1
/
512
)
*
(
max
(
min
(
chain_length
,
512
),
256
)))
# Risk of underflow here?
...
...
@@ -255,7 +267,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
datasets
:
Sequence
[
OpenFoldSingleDataset
],
probabilities
:
Sequence
[
int
],
epoch_len
:
int
,
chain
_data_cache_paths
:
List
[
str
],
prot
_data_cache_paths
:
List
[
str
],
generator
:
torch
.
Generator
=
None
,
_roll_at_init
:
bool
=
True
,
):
...
...
@@ -264,10 +276,10 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self
.
epoch_len
=
epoch_len
self
.
generator
=
generator
self
.
chain
_data_caches
=
[]
for
path
in
chain
_data_cache_paths
:
self
.
prot
_data_caches
=
[]
for
path
in
prot
_data_cache_paths
:
with
open
(
path
,
"r"
)
as
fp
:
self
.
chain
_data_caches
.
append
(
json
.
load
(
fp
))
self
.
prot
_data_caches
.
append
(
json
.
load
(
fp
))
def
looped_shuffled_dataset_idx
(
dataset_len
):
while
True
:
...
...
@@ -286,19 +298,19 @@ class OpenFoldDataset(torch.utils.data.Dataset):
max_cache_len
=
int
(
epoch_len
*
probabilities
[
dataset_idx
])
dataset
=
self
.
datasets
[
dataset_idx
]
idx_iter
=
looped_shuffled_dataset_idx
(
len
(
dataset
))
chain
_data_cache
=
self
.
chain
_data_caches
[
dataset_idx
]
prot
_data_cache
=
self
.
prot
_data_caches
[
dataset_idx
]
while
True
:
weights
=
[]
idx
=
[]
for
_
in
range
(
max_cache_len
):
candidate_idx
=
next
(
idx_iter
)
chain_id
=
dataset
.
idx_to_chain_id
(
candidate_idx
)
chain
_data_cache_entry
=
chain
_data_cache
[
chain_id
]
if
(
not
deterministic_train_filter
(
chain
_data_cache_entry
)):
prot
_data_cache_entry
=
prot
_data_cache
[
chain_id
]
if
(
not
deterministic_train_filter
(
prot
_data_cache_entry
)):
continue
p
=
get_stochastic_train_filter_prob
(
chain
_data_cache_entry
,
prot
_data_cache_entry
,
)
weights
.
append
([
1.
-
p
,
p
])
idx
.
append
(
candidate_idx
)
...
...
@@ -459,10 +471,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_date
:
str
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_alignment_dir
:
Optional
[
str
]
=
None
,
train_
chain
_data_cache_path
:
Optional
[
str
]
=
None
,
train_
prot
_data_cache_path
:
Optional
[
str
]
=
None
,
distillation_data_dir
:
Optional
[
str
]
=
None
,
distillation_alignment_dir
:
Optional
[
str
]
=
None
,
distillation_
chain
_data_cache_path
:
Optional
[
str
]
=
None
,
distillation_
prot
_data_cache_path
:
Optional
[
str
]
=
None
,
val_data_dir
:
Optional
[
str
]
=
None
,
val_alignment_dir
:
Optional
[
str
]
=
None
,
predict_data_dir
:
Optional
[
str
]
=
None
,
...
...
@@ -474,6 +486,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
batch_seed
:
Optional
[
int
]
=
None
,
train_epoch_len
:
int
=
50000
,
_alignment_index_path
:
Optional
[
str
]
=
None
,
**
kwargs
):
super
(
OpenFoldDataModule
,
self
).
__init__
()
...
...
@@ -483,11 +496,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
max_template_date
=
max_template_date
self
.
train_data_dir
=
train_data_dir
self
.
train_alignment_dir
=
train_alignment_dir
self
.
train_
chain
_data_cache_path
=
train_
chain
_data_cache_path
self
.
train_
prot
_data_cache_path
=
train_
prot
_data_cache_path
self
.
distillation_data_dir
=
distillation_data_dir
self
.
distillation_alignment_dir
=
distillation_alignment_dir
self
.
distillation_
chain
_data_cache_path
=
(
distillation_
chain
_data_cache_path
self
.
distillation_
prot
_data_cache_path
=
(
distillation_
prot
_data_cache_path
)
self
.
val_data_dir
=
val_data_dir
self
.
val_alignment_dir
=
val_alignment_dir
...
...
@@ -525,6 +538,12 @@ class OpenFoldDataModule(pl.LightningDataModule):
'be specified as well'
)
# An ad-hoc measure for our particular filesystem restrictions
self
.
_alignment_index
=
None
if
(
_alignment_index_path
is
not
None
):
with
open
(
_alignment_index_path
,
"r"
)
as
fp
:
self
.
_alignment_index
=
json
.
load
(
fp
)
def
setup
(
self
):
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
...
...
@@ -549,6 +568,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
treat_pdb_as_distillation
=
False
,
mode
=
"train"
,
_output_raw
=
True
,
_alignment_index
=
self
.
_alignment_index
,
)
distillation_dataset
=
None
...
...
@@ -569,22 +589,22 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets
=
[
train_dataset
,
distillation_dataset
]
d_prob
=
self
.
config
.
train
.
distillation_prob
probabilities
=
[
1
-
d_prob
,
d_prob
]
chain
_data_cache_paths
=
[
self
.
train_
chain
_data_cache_path
,
self
.
distillation_
chain
_data_cache_path
,
prot
_data_cache_paths
=
[
self
.
train_
prot
_data_cache_path
,
self
.
distillation_
prot
_data_cache_path
,
]
else
:
datasets
=
[
train_dataset
]
probabilities
=
[
1.
]
chain
_data_cache_paths
=
[
self
.
train_
chain
_data_cache_path
,
prot
_data_cache_paths
=
[
self
.
train_
prot
_data_cache_path
,
]
self
.
train_dataset
=
OpenFoldDataset
(
datasets
=
datasets
,
probabilities
=
probabilities
,
epoch_len
=
self
.
train_epoch_len
,
chain
_data_cache_paths
=
chain
_data_cache_paths
,
prot
_data_cache_paths
=
prot
_data_cache_paths
,
_roll_at_init
=
False
,
)
...
...
openfold/data/data_pipeline.py
View file @
a3e8ebbc
...
...
@@ -422,8 +422,38 @@ class DataPipeline:
def
_parse_msa_data
(
self
,
alignment_dir
:
str
,
_alignment_index
:
Optional
[
Any
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
msa_data
=
{}
if
(
_alignment_index
is
not
None
):
fp
=
open
(
os
.
path
.
join
(
alignment_dir
,
_alignment_index
[
"db"
]),
"rb"
)
def
read_msa
(
start
,
size
):
fp
.
seek
(
start
)
msa
=
fp
.
read
(
size
).
decode
(
"utf-8"
)
return
msa
for
(
name
,
start
,
size
)
in
_alignment_index
[
"files"
]:
ext
=
os
.
path
.
splitext
(
name
)[
-
1
]
if
(
ext
==
".a3m"
):
msa
,
deletion_matrix
=
parsers
.
parse_a3m
(
read_msa
(
start
,
size
)
)
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
elif
(
ext
==
".sto"
):
msa
,
deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
read_msa
(
start
,
size
)
)
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
else
:
continue
msa_data
[
name
]
=
data
fp
.
close
()
else
:
for
f
in
os
.
listdir
(
alignment_dir
):
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
ext
=
os
.
path
.
splitext
(
f
)[
-
1
]
...
...
@@ -448,8 +478,25 @@ class DataPipeline:
def
_parse_template_hits
(
self
,
alignment_dir
:
str
,
_alignment_index
:
Optional
[
Any
]
=
None
)
->
Mapping
[
str
,
Any
]:
all_hits
=
{}
if
(
_alignment_index
is
not
None
):
fp
=
open
(
os
.
path
.
join
(
alignment_dir
,
_alignment_index
[
"db"
]),
'rb'
)
def
read_template
(
start
,
size
):
fp
.
seek
(
start
)
return
fp
.
read
(
size
).
decode
(
"utf-8"
)
for
(
name
,
start
,
size
)
in
_alignment_index
[
"files"
]:
ext
=
os
.
path
.
splitext
(
name
)[
-
1
]
if
(
ext
==
".hhr"
):
hits
=
parsers
.
parse_hhr
(
read_template
(
start
,
size
))
all_hits
[
name
]
=
hits
fp
.
close
()
else
:
for
f
in
os
.
listdir
(
alignment_dir
):
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
ext
=
os
.
path
.
splitext
(
f
)[
-
1
]
...
...
@@ -465,8 +512,9 @@ class DataPipeline:
self
,
alignment_dir
:
str
,
input_sequence
:
Optional
[
str
]
=
None
,
_alignment_index
:
Optional
[
str
]
=
None
)
->
Mapping
[
str
,
Any
]:
msa_data
=
self
.
_parse_msa_data
(
alignment_dir
)
msa_data
=
self
.
_parse_msa_data
(
alignment_dir
,
_alignment_index
)
if
(
len
(
msa_data
)
==
0
):
if
(
input_sequence
is
None
):
...
...
@@ -496,6 +544,7 @@ class DataPipeline:
self
,
fasta_path
:
str
,
alignment_dir
:
str
,
_alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
"""Assembles features for a single sequence in a FASTA file"""
with
open
(
fasta_path
)
as
f
:
...
...
@@ -509,7 +558,7 @@ class DataPipeline:
input_description
=
input_descs
[
0
]
num_res
=
len
(
input_sequence
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
template_features
=
make_template_features
(
input_sequence
,
hits
,
...
...
@@ -522,7 +571,7 @@ class DataPipeline:
num_res
=
num_res
,
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
,
_alignment_index
)
return
{
**
sequence_features
,
...
...
@@ -535,6 +584,7 @@ class DataPipeline:
mmcif
:
mmcif_parsing
.
MmcifObject
,
# parsing is expensive, so no path
alignment_dir
:
str
,
chain_id
:
Optional
[
str
]
=
None
,
_alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
"""
Assembles features for a specific chain in an mmCIF object.
...
...
@@ -552,7 +602,7 @@ class DataPipeline:
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
template_features
=
make_template_features
(
input_sequence
,
hits
,
...
...
@@ -560,7 +610,7 @@ class DataPipeline:
query_release_date
=
to_date
(
mmcif
.
header
[
"release_date"
])
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
,
_alignment_index
)
return
{
**
mmcif_feats
,
**
template_features
,
**
msa_features
}
...
...
@@ -570,6 +620,7 @@ class DataPipeline:
alignment_dir
:
str
,
is_distillation
:
bool
=
True
,
chain_id
:
Optional
[
str
]
=
None
,
_alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
"""
Assembles features for a protein in a PDB file.
...
...
@@ -586,14 +637,14 @@ class DataPipeline:
is_distillation
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
template_features
=
make_template_features
(
input_sequence
,
hits
,
self
.
template_featurizer
,
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
,
_alignment_index
)
return
{
**
pdb_feats
,
**
template_features
,
**
msa_features
}
...
...
@@ -601,6 +652,7 @@ class DataPipeline:
self
,
core_path
:
str
,
alignment_dir
:
str
,
_alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
"""
Assembles features for a protein in a ProteinNet .core file.
...
...
@@ -613,7 +665,7 @@ class DataPipeline:
description
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
core_path
))[
0
].
upper
()
core_feats
=
make_protein_features
(
protein_object
,
description
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
template_features
=
make_template_features
(
input_sequence
,
hits
,
...
...
openfold/np/residue_constants.py
View file @
a3e8ebbc
...
...
@@ -1301,3 +1301,10 @@ def _make_atom14_ambiguity_feats():
_make_atom14_ambiguity_feats
()
def
aatype_to_str_sequence
(
aatype
):
return
''
.
join
([
residue_constants
.
restypes_with_x
[
aatype
[
i
]]
for
i
in
range
(
len
(
aatype
))
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment