Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
60e9bd54
Commit
60e9bd54
authored
Aug 29, 2022
by
Tim O'Donnell
Browse files
Drop alignments that are missing structure data in training
parent
12caaa89
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
19 deletions
+39
-19
openfold/data/data_modules.py
openfold/data/data_modules.py
+39
-19
No files found.
openfold/data/data_modules.py
View file @
60e9bd54
...
...
@@ -24,6 +24,7 @@ from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
class
OpenFoldSingleDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
data_dir
:
str
,
chain_data_cache_path
:
str
,
alignment_dir
:
str
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
...
...
@@ -80,6 +81,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
"""
super
(
OpenFoldSingleDataset
,
self
).
__init__
()
self
.
data_dir
=
data_dir
with
open
(
chain_data_cache_path
,
"r"
)
as
fp
:
self
.
chain_data_cache
=
json
.
load
(
fp
)
assert
isinstance
(
self
.
chain_data_cache
,
dict
)
self
.
alignment_dir
=
alignment_dir
self
.
config
=
config
self
.
treat_pdb_as_distillation
=
treat_pdb_as_distillation
...
...
@@ -104,12 +110,35 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self
.
_chain_ids
=
list
(
alignment_index
.
keys
())
else
:
self
.
_chain_ids
=
list
(
os
.
listdir
(
alignment_dir
))
if
(
filter_path
is
not
None
):
with
open
(
filter_path
,
"r"
)
as
f
:
chains_to_include
=
set
([
l
.
strip
()
for
l
in
f
.
readlines
()])
self
.
_chain_ids
=
[
c
for
c
in
self
.
_chain_ids
if
c
in
chains_to_include
]
self
.
_chain_ids
=
[
c
for
c
in
self
.
_chain_ids
if
c
in
chains_to_include
]
# Filter to include only chains where we have structure data
# (i.e. entries in chain_data_cache)
original_chain_ids
=
self
.
_chain_ids
self
.
_chain_ids
=
[
c
for
c
in
self
.
_chain_ids
if
c
in
self
.
chain_data_cache
]
if
len
(
self
.
_chain_ids
)
<
len
(
original_chain_ids
):
missing
=
[
c
for
c
in
original_chain_ids
if
c
not
in
self
.
chain_data_cache
]
max_to_print
=
10
missing_examples
=
", "
.
join
(
missing
[:
max_to_print
])
if
len
(
missing
)
>
max_to_print
:
missing_examples
+=
", ..."
logging
.
warning
(
"Ignoring %d alignment entries (%s) that have no corresponding "
"entries in chain_data_cache (%s)."
,
len
(
missing
),
missing_examples
,
chain_data_cache_path
)
self
.
_chain_id_to_idx_dict
=
{
chain
:
i
for
i
,
chain
in
enumerate
(
self
.
_chain_ids
)
...
...
@@ -234,7 +263,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data
,
self
.
mode
)
feats
[
"batch_idx"
]
=
torch
.
tensor
([
idx
for
_
in
range
(
feats
[
"aatype"
].
shape
[
-
1
])],
dtype
=
torch
.
int64
,
device
=
feats
[
"aatype"
].
device
)
feats
[
"batch_idx"
]
=
torch
.
tensor
(
[
idx
for
_
in
range
(
feats
[
"aatype"
].
shape
[
-
1
])],
dtype
=
torch
.
int64
,
device
=
feats
[
"aatype"
].
device
)
return
feats
...
...
@@ -297,7 +329,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
datasets
:
Sequence
[
OpenFoldSingleDataset
],
probabilities
:
Sequence
[
int
],
epoch_len
:
int
,
chain_data_cache_paths
:
List
[
str
],
generator
:
torch
.
Generator
=
None
,
_roll_at_init
:
bool
=
True
,
):
...
...
@@ -305,11 +336,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self
.
probabilities
=
probabilities
self
.
epoch_len
=
epoch_len
self
.
generator
=
generator
self
.
chain_data_caches
=
[]
for
path
in
chain_data_cache_paths
:
with
open
(
path
,
"r"
)
as
fp
:
self
.
chain_data_caches
.
append
(
json
.
load
(
fp
))
def
looped_shuffled_dataset_idx
(
dataset_len
):
while
True
:
...
...
@@ -328,7 +354,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
max_cache_len
=
int
(
epoch_len
*
probabilities
[
dataset_idx
])
dataset
=
self
.
datasets
[
dataset_idx
]
idx_iter
=
looped_shuffled_dataset_idx
(
len
(
dataset
))
chain_data_cache
=
self
.
chain_data_cache
s
[
dataset_idx
]
chain_data_cache
=
dataset
.
chain_data_cache
while
True
:
weights
=
[]
idx
=
[]
...
...
@@ -591,6 +617,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if
(
self
.
training_mode
):
train_dataset
=
dataset_gen
(
data_dir
=
self
.
train_data_dir
,
chain_data_cache_path
=
self
.
train_data_cache_path
,
alignment_dir
=
self
.
train_alignment_dir
,
filter_path
=
self
.
train_filter_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
...
...
@@ -605,6 +632,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if
(
self
.
distillation_data_dir
is
not
None
):
distillation_dataset
=
dataset_gen
(
data_dir
=
self
.
distillation_data_dir
,
chain_data_cache_path
=
self
.
distillation_chain_data_cache_path
,
alignment_dir
=
self
.
distillation_alignment_dir
,
filter_path
=
self
.
distillation_filter_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
...
...
@@ -620,16 +648,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets
=
[
train_dataset
,
distillation_dataset
]
d_prob
=
self
.
config
.
train
.
distillation_prob
probabilities
=
[
1.
-
d_prob
,
d_prob
]
chain_data_cache_paths
=
[
self
.
train_chain_data_cache_path
,
self
.
distillation_chain_data_cache_path
,
]
else
:
datasets
=
[
train_dataset
]
probabilities
=
[
1.
]
chain_data_cache_paths
=
[
self
.
train_chain_data_cache_path
,
]
probabilities
=
[
1.
]
generator
=
None
if
(
self
.
batch_seed
is
not
None
):
...
...
@@ -640,7 +661,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets
=
datasets
,
probabilities
=
probabilities
,
epoch_len
=
self
.
train_epoch_len
,
chain_data_cache_paths
=
chain_data_cache_paths
,
generator
=
generator
,
_roll_at_init
=
False
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment