Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
60e9bd54
Commit
60e9bd54
authored
Aug 29, 2022
by
Tim O'Donnell
Browse files
Drop alignments that are missing structure data in training
parent
12caaa89
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
19 deletions
+39
-19
openfold/data/data_modules.py
openfold/data/data_modules.py
+39
-19
No files found.
openfold/data/data_modules.py
View file @
60e9bd54
...
@@ -24,6 +24,7 @@ from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
...
@@ -24,6 +24,7 @@ from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
class
OpenFoldSingleDataset
(
torch
.
utils
.
data
.
Dataset
):
class
OpenFoldSingleDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
def
__init__
(
self
,
data_dir
:
str
,
data_dir
:
str
,
chain_data_cache_path
:
str
,
alignment_dir
:
str
,
alignment_dir
:
str
,
template_mmcif_dir
:
str
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
max_template_date
:
str
,
...
@@ -80,6 +81,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -80,6 +81,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
"""
"""
super
(
OpenFoldSingleDataset
,
self
).
__init__
()
super
(
OpenFoldSingleDataset
,
self
).
__init__
()
self
.
data_dir
=
data_dir
self
.
data_dir
=
data_dir
with
open
(
chain_data_cache_path
,
"r"
)
as
fp
:
self
.
chain_data_cache
=
json
.
load
(
fp
)
assert
isinstance
(
self
.
chain_data_cache
,
dict
)
self
.
alignment_dir
=
alignment_dir
self
.
alignment_dir
=
alignment_dir
self
.
config
=
config
self
.
config
=
config
self
.
treat_pdb_as_distillation
=
treat_pdb_as_distillation
self
.
treat_pdb_as_distillation
=
treat_pdb_as_distillation
...
@@ -109,7 +115,30 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -109,7 +115,30 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
with
open
(
filter_path
,
"r"
)
as
f
:
with
open
(
filter_path
,
"r"
)
as
f
:
chains_to_include
=
set
([
l
.
strip
()
for
l
in
f
.
readlines
()])
chains_to_include
=
set
([
l
.
strip
()
for
l
in
f
.
readlines
()])
self
.
_chain_ids
=
[
c
for
c
in
self
.
_chain_ids
if
c
in
chains_to_include
]
self
.
_chain_ids
=
[
c
for
c
in
self
.
_chain_ids
if
c
in
chains_to_include
]
# Filter to include only chains where we have structure data
# (i.e. entries in chain_data_cache)
original_chain_ids
=
self
.
_chain_ids
self
.
_chain_ids
=
[
c
for
c
in
self
.
_chain_ids
if
c
in
self
.
chain_data_cache
]
if
len
(
self
.
_chain_ids
)
<
len
(
original_chain_ids
):
missing
=
[
c
for
c
in
original_chain_ids
if
c
not
in
self
.
chain_data_cache
]
max_to_print
=
10
missing_examples
=
", "
.
join
(
missing
[:
max_to_print
])
if
len
(
missing
)
>
max_to_print
:
missing_examples
+=
", ..."
logging
.
warning
(
"Ignoring %d alignment entries (%s) that have no corresponding "
"entries in chain_data_cache (%s)."
,
len
(
missing
),
missing_examples
,
chain_data_cache_path
)
self
.
_chain_id_to_idx_dict
=
{
self
.
_chain_id_to_idx_dict
=
{
chain
:
i
for
i
,
chain
in
enumerate
(
self
.
_chain_ids
)
chain
:
i
for
i
,
chain
in
enumerate
(
self
.
_chain_ids
)
...
@@ -234,7 +263,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -234,7 +263,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data
,
self
.
mode
data
,
self
.
mode
)
)
feats
[
"batch_idx"
]
=
torch
.
tensor
([
idx
for
_
in
range
(
feats
[
"aatype"
].
shape
[
-
1
])],
dtype
=
torch
.
int64
,
device
=
feats
[
"aatype"
].
device
)
feats
[
"batch_idx"
]
=
torch
.
tensor
(
[
idx
for
_
in
range
(
feats
[
"aatype"
].
shape
[
-
1
])],
dtype
=
torch
.
int64
,
device
=
feats
[
"aatype"
].
device
)
return
feats
return
feats
...
@@ -297,7 +329,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
...
@@ -297,7 +329,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
datasets
:
Sequence
[
OpenFoldSingleDataset
],
datasets
:
Sequence
[
OpenFoldSingleDataset
],
probabilities
:
Sequence
[
int
],
probabilities
:
Sequence
[
int
],
epoch_len
:
int
,
epoch_len
:
int
,
chain_data_cache_paths
:
List
[
str
],
generator
:
torch
.
Generator
=
None
,
generator
:
torch
.
Generator
=
None
,
_roll_at_init
:
bool
=
True
,
_roll_at_init
:
bool
=
True
,
):
):
...
@@ -306,11 +337,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
...
@@ -306,11 +337,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self
.
epoch_len
=
epoch_len
self
.
epoch_len
=
epoch_len
self
.
generator
=
generator
self
.
generator
=
generator
self
.
chain_data_caches
=
[]
for
path
in
chain_data_cache_paths
:
with
open
(
path
,
"r"
)
as
fp
:
self
.
chain_data_caches
.
append
(
json
.
load
(
fp
))
def
looped_shuffled_dataset_idx
(
dataset_len
):
def
looped_shuffled_dataset_idx
(
dataset_len
):
while
True
:
while
True
:
# Uniformly shuffle each dataset's indices
# Uniformly shuffle each dataset's indices
...
@@ -328,7 +354,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
...
@@ -328,7 +354,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
max_cache_len
=
int
(
epoch_len
*
probabilities
[
dataset_idx
])
max_cache_len
=
int
(
epoch_len
*
probabilities
[
dataset_idx
])
dataset
=
self
.
datasets
[
dataset_idx
]
dataset
=
self
.
datasets
[
dataset_idx
]
idx_iter
=
looped_shuffled_dataset_idx
(
len
(
dataset
))
idx_iter
=
looped_shuffled_dataset_idx
(
len
(
dataset
))
chain_data_cache
=
self
.
chain_data_cache
s
[
dataset_idx
]
chain_data_cache
=
dataset
.
chain_data_cache
while
True
:
while
True
:
weights
=
[]
weights
=
[]
idx
=
[]
idx
=
[]
...
@@ -591,6 +617,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -591,6 +617,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if
(
self
.
training_mode
):
if
(
self
.
training_mode
):
train_dataset
=
dataset_gen
(
train_dataset
=
dataset_gen
(
data_dir
=
self
.
train_data_dir
,
data_dir
=
self
.
train_data_dir
,
chain_data_cache_path
=
self
.
train_data_cache_path
,
alignment_dir
=
self
.
train_alignment_dir
,
alignment_dir
=
self
.
train_alignment_dir
,
filter_path
=
self
.
train_filter_path
,
filter_path
=
self
.
train_filter_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
...
@@ -605,6 +632,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -605,6 +632,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if
(
self
.
distillation_data_dir
is
not
None
):
if
(
self
.
distillation_data_dir
is
not
None
):
distillation_dataset
=
dataset_gen
(
distillation_dataset
=
dataset_gen
(
data_dir
=
self
.
distillation_data_dir
,
data_dir
=
self
.
distillation_data_dir
,
chain_data_cache_path
=
self
.
distillation_chain_data_cache_path
,
alignment_dir
=
self
.
distillation_alignment_dir
,
alignment_dir
=
self
.
distillation_alignment_dir
,
filter_path
=
self
.
distillation_filter_path
,
filter_path
=
self
.
distillation_filter_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
...
@@ -620,16 +648,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -620,16 +648,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets
=
[
train_dataset
,
distillation_dataset
]
datasets
=
[
train_dataset
,
distillation_dataset
]
d_prob
=
self
.
config
.
train
.
distillation_prob
d_prob
=
self
.
config
.
train
.
distillation_prob
probabilities
=
[
1.
-
d_prob
,
d_prob
]
probabilities
=
[
1.
-
d_prob
,
d_prob
]
chain_data_cache_paths
=
[
self
.
train_chain_data_cache_path
,
self
.
distillation_chain_data_cache_path
,
]
else
:
else
:
datasets
=
[
train_dataset
]
datasets
=
[
train_dataset
]
probabilities
=
[
1.
]
probabilities
=
[
1.
]
chain_data_cache_paths
=
[
self
.
train_chain_data_cache_path
,
]
generator
=
None
generator
=
None
if
(
self
.
batch_seed
is
not
None
):
if
(
self
.
batch_seed
is
not
None
):
...
@@ -640,7 +661,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -640,7 +661,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets
=
datasets
,
datasets
=
datasets
,
probabilities
=
probabilities
,
probabilities
=
probabilities
,
epoch_len
=
self
.
train_epoch_len
,
epoch_len
=
self
.
train_epoch_len
,
chain_data_cache_paths
=
chain_data_cache_paths
,
generator
=
generator
,
generator
=
generator
,
_roll_at_init
=
False
,
_roll_at_init
=
False
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment