Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
9dd9cea4
Unverified
Commit
9dd9cea4
authored
Aug 30, 2022
by
Gustaf Ahdritz
Committed by
GitHub
Aug 30, 2022
Browse files
Merge pull request #210 from timodonnell/remove-chains-missing-data
Drop chains that are missing (structure) data in training
parents
12caaa89
f6d02cd9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
19 deletions
+46
-19
openfold/data/data_modules.py
openfold/data/data_modules.py
+46
-19
No files found.
openfold/data/data_modules.py
View file @
9dd9cea4
...
@@ -28,6 +28,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -28,6 +28,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
template_mmcif_dir
:
str
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
max_template_date
:
str
,
config
:
mlc
.
ConfigDict
,
config
:
mlc
.
ConfigDict
,
chain_data_cache_path
:
Optional
[
str
]
=
None
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
max_template_hits
:
int
=
4
,
max_template_hits
:
int
=
4
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
...
@@ -56,6 +57,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -56,6 +57,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
Path to a directory containing template mmCIF files.
Path to a directory containing template mmCIF files.
config:
config:
A dataset config object. See openfold.config
A dataset config object. See openfold.config
chain_data_cache_path:
Path to cache of data_dir generated by
scripts/generate_chain_data_cache.py
kalign_binary_path:
kalign_binary_path:
Path to kalign binary.
Path to kalign binary.
max_template_hits:
max_template_hits:
...
@@ -80,6 +84,13 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -80,6 +84,13 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
"""
"""
super
(
OpenFoldSingleDataset
,
self
).
__init__
()
super
(
OpenFoldSingleDataset
,
self
).
__init__
()
self
.
data_dir
=
data_dir
self
.
data_dir
=
data_dir
self
.
chain_data_cache
=
None
if
chain_data_cache_path
is
not
None
:
with
open
(
chain_data_cache_path
,
"r"
)
as
fp
:
self
.
chain_data_cache
=
json
.
load
(
fp
)
assert
isinstance
(
self
.
chain_data_cache
,
dict
)
self
.
alignment_dir
=
alignment_dir
self
.
alignment_dir
=
alignment_dir
self
.
config
=
config
self
.
config
=
config
self
.
treat_pdb_as_distillation
=
treat_pdb_as_distillation
self
.
treat_pdb_as_distillation
=
treat_pdb_as_distillation
...
@@ -109,7 +120,32 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -109,7 +120,32 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
with
open
(
filter_path
,
"r"
)
as
f
:
with
open
(
filter_path
,
"r"
)
as
f
:
chains_to_include
=
set
([
l
.
strip
()
for
l
in
f
.
readlines
()])
chains_to_include
=
set
([
l
.
strip
()
for
l
in
f
.
readlines
()])
self
.
_chain_ids
=
[
c
for
c
in
self
.
_chain_ids
if
c
in
chains_to_include
]
self
.
_chain_ids
=
[
c
for
c
in
self
.
_chain_ids
if
c
in
chains_to_include
]
if
self
.
chain_data_cache
is
not
None
:
# Filter to include only chains where we have structure data
# (entries in chain_data_cache)
original_chain_ids
=
self
.
_chain_ids
self
.
_chain_ids
=
[
c
for
c
in
self
.
_chain_ids
if
c
in
self
.
chain_data_cache
]
if
len
(
self
.
_chain_ids
)
<
len
(
original_chain_ids
):
missing
=
[
c
for
c
in
original_chain_ids
if
c
not
in
self
.
chain_data_cache
]
max_to_print
=
10
missing_examples
=
", "
.
join
(
missing
[:
max_to_print
])
if
len
(
missing
)
>
max_to_print
:
missing_examples
+=
", ..."
logging
.
warning
(
"Removing %d alignment entries (%s) with no corresponding "
"entries in chain_data_cache (%s)."
,
len
(
missing
),
missing_examples
,
chain_data_cache_path
)
self
.
_chain_id_to_idx_dict
=
{
self
.
_chain_id_to_idx_dict
=
{
chain
:
i
for
i
,
chain
in
enumerate
(
self
.
_chain_ids
)
chain
:
i
for
i
,
chain
in
enumerate
(
self
.
_chain_ids
)
...
@@ -234,7 +270,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -234,7 +270,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data
,
self
.
mode
data
,
self
.
mode
)
)
feats
[
"batch_idx"
]
=
torch
.
tensor
([
idx
for
_
in
range
(
feats
[
"aatype"
].
shape
[
-
1
])],
dtype
=
torch
.
int64
,
device
=
feats
[
"aatype"
].
device
)
feats
[
"batch_idx"
]
=
torch
.
tensor
(
[
idx
for
_
in
range
(
feats
[
"aatype"
].
shape
[
-
1
])],
dtype
=
torch
.
int64
,
device
=
feats
[
"aatype"
].
device
)
return
feats
return
feats
...
@@ -297,7 +336,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
...
@@ -297,7 +336,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
datasets
:
Sequence
[
OpenFoldSingleDataset
],
datasets
:
Sequence
[
OpenFoldSingleDataset
],
probabilities
:
Sequence
[
int
],
probabilities
:
Sequence
[
int
],
epoch_len
:
int
,
epoch_len
:
int
,
chain_data_cache_paths
:
List
[
str
],
generator
:
torch
.
Generator
=
None
,
generator
:
torch
.
Generator
=
None
,
_roll_at_init
:
bool
=
True
,
_roll_at_init
:
bool
=
True
,
):
):
...
@@ -306,11 +344,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
...
@@ -306,11 +344,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self
.
epoch_len
=
epoch_len
self
.
epoch_len
=
epoch_len
self
.
generator
=
generator
self
.
generator
=
generator
self
.
chain_data_caches
=
[]
for
path
in
chain_data_cache_paths
:
with
open
(
path
,
"r"
)
as
fp
:
self
.
chain_data_caches
.
append
(
json
.
load
(
fp
))
def
looped_shuffled_dataset_idx
(
dataset_len
):
def
looped_shuffled_dataset_idx
(
dataset_len
):
while
True
:
while
True
:
# Uniformly shuffle each dataset's indices
# Uniformly shuffle each dataset's indices
...
@@ -328,7 +361,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
...
@@ -328,7 +361,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
max_cache_len
=
int
(
epoch_len
*
probabilities
[
dataset_idx
])
max_cache_len
=
int
(
epoch_len
*
probabilities
[
dataset_idx
])
dataset
=
self
.
datasets
[
dataset_idx
]
dataset
=
self
.
datasets
[
dataset_idx
]
idx_iter
=
looped_shuffled_dataset_idx
(
len
(
dataset
))
idx_iter
=
looped_shuffled_dataset_idx
(
len
(
dataset
))
chain_data_cache
=
self
.
chain_data_cache
s
[
dataset_idx
]
chain_data_cache
=
dataset
.
chain_data_cache
while
True
:
while
True
:
weights
=
[]
weights
=
[]
idx
=
[]
idx
=
[]
...
@@ -591,6 +624,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -591,6 +624,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if
(
self
.
training_mode
):
if
(
self
.
training_mode
):
train_dataset
=
dataset_gen
(
train_dataset
=
dataset_gen
(
data_dir
=
self
.
train_data_dir
,
data_dir
=
self
.
train_data_dir
,
chain_data_cache_path
=
self
.
train_chain_data_cache_path
,
alignment_dir
=
self
.
train_alignment_dir
,
alignment_dir
=
self
.
train_alignment_dir
,
filter_path
=
self
.
train_filter_path
,
filter_path
=
self
.
train_filter_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
...
@@ -605,6 +639,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -605,6 +639,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if
(
self
.
distillation_data_dir
is
not
None
):
if
(
self
.
distillation_data_dir
is
not
None
):
distillation_dataset
=
dataset_gen
(
distillation_dataset
=
dataset_gen
(
data_dir
=
self
.
distillation_data_dir
,
data_dir
=
self
.
distillation_data_dir
,
chain_data_cache_path
=
self
.
distillation_chain_data_cache_path
,
alignment_dir
=
self
.
distillation_alignment_dir
,
alignment_dir
=
self
.
distillation_alignment_dir
,
filter_path
=
self
.
distillation_filter_path
,
filter_path
=
self
.
distillation_filter_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
...
@@ -620,16 +655,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -620,16 +655,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets
=
[
train_dataset
,
distillation_dataset
]
datasets
=
[
train_dataset
,
distillation_dataset
]
d_prob
=
self
.
config
.
train
.
distillation_prob
d_prob
=
self
.
config
.
train
.
distillation_prob
probabilities
=
[
1.
-
d_prob
,
d_prob
]
probabilities
=
[
1.
-
d_prob
,
d_prob
]
chain_data_cache_paths
=
[
self
.
train_chain_data_cache_path
,
self
.
distillation_chain_data_cache_path
,
]
else
:
else
:
datasets
=
[
train_dataset
]
datasets
=
[
train_dataset
]
probabilities
=
[
1.
]
probabilities
=
[
1.
]
chain_data_cache_paths
=
[
self
.
train_chain_data_cache_path
,
]
generator
=
None
generator
=
None
if
(
self
.
batch_seed
is
not
None
):
if
(
self
.
batch_seed
is
not
None
):
...
@@ -640,7 +668,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -640,7 +668,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets
=
datasets
,
datasets
=
datasets
,
probabilities
=
probabilities
,
probabilities
=
probabilities
,
epoch_len
=
self
.
train_epoch_len
,
epoch_len
=
self
.
train_epoch_len
,
chain_data_cache_paths
=
chain_data_cache_paths
,
generator
=
generator
,
generator
=
generator
,
_roll_at_init
=
False
,
_roll_at_init
=
False
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment