Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
9dd9cea4
Unverified
Commit
9dd9cea4
authored
Aug 30, 2022
by
Gustaf Ahdritz
Committed by
GitHub
Aug 30, 2022
Browse files
Merge pull request #210 from timodonnell/remove-chains-missing-data
Drop chains that are missing (structure) data in training
parents
12caaa89
f6d02cd9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
19 deletions
+46
-19
openfold/data/data_modules.py
openfold/data/data_modules.py
+46
-19
No files found.
openfold/data/data_modules.py
View file @
9dd9cea4
...
...
@@ -28,6 +28,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
template_mmcif_dir
:
str
,
max_template_date
:
str
,
config
:
mlc
.
ConfigDict
,
chain_data_cache_path
:
Optional
[
str
]
=
None
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
max_template_hits
:
int
=
4
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
...
...
@@ -56,6 +57,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
Path to a directory containing template mmCIF files.
config:
A dataset config object. See openfold.config
chain_data_cache_path:
Path to cache of data_dir generated by
scripts/generate_chain_data_cache.py
kalign_binary_path:
Path to kalign binary.
max_template_hits:
...
...
@@ -80,6 +84,13 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
"""
super
(
OpenFoldSingleDataset
,
self
).
__init__
()
self
.
data_dir
=
data_dir
self
.
chain_data_cache
=
None
if
chain_data_cache_path
is
not
None
:
with
open
(
chain_data_cache_path
,
"r"
)
as
fp
:
self
.
chain_data_cache
=
json
.
load
(
fp
)
assert
isinstance
(
self
.
chain_data_cache
,
dict
)
self
.
alignment_dir
=
alignment_dir
self
.
config
=
config
self
.
treat_pdb_as_distillation
=
treat_pdb_as_distillation
...
...
@@ -109,7 +120,32 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
with
open
(
filter_path
,
"r"
)
as
f
:
chains_to_include
=
set
([
l
.
strip
()
for
l
in
f
.
readlines
()])
self
.
_chain_ids
=
[
c
for
c
in
self
.
_chain_ids
if
c
in
chains_to_include
]
self
.
_chain_ids
=
[
c
for
c
in
self
.
_chain_ids
if
c
in
chains_to_include
]
if
self
.
chain_data_cache
is
not
None
:
# Filter to include only chains where we have structure data
# (entries in chain_data_cache)
original_chain_ids
=
self
.
_chain_ids
self
.
_chain_ids
=
[
c
for
c
in
self
.
_chain_ids
if
c
in
self
.
chain_data_cache
]
if
len
(
self
.
_chain_ids
)
<
len
(
original_chain_ids
):
missing
=
[
c
for
c
in
original_chain_ids
if
c
not
in
self
.
chain_data_cache
]
max_to_print
=
10
missing_examples
=
", "
.
join
(
missing
[:
max_to_print
])
if
len
(
missing
)
>
max_to_print
:
missing_examples
+=
", ..."
logging
.
warning
(
"Removing %d alignment entries (%s) with no corresponding "
"entries in chain_data_cache (%s)."
,
len
(
missing
),
missing_examples
,
chain_data_cache_path
)
self
.
_chain_id_to_idx_dict
=
{
chain
:
i
for
i
,
chain
in
enumerate
(
self
.
_chain_ids
)
...
...
@@ -234,7 +270,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data
,
self
.
mode
)
feats
[
"batch_idx"
]
=
torch
.
tensor
([
idx
for
_
in
range
(
feats
[
"aatype"
].
shape
[
-
1
])],
dtype
=
torch
.
int64
,
device
=
feats
[
"aatype"
].
device
)
feats
[
"batch_idx"
]
=
torch
.
tensor
(
[
idx
for
_
in
range
(
feats
[
"aatype"
].
shape
[
-
1
])],
dtype
=
torch
.
int64
,
device
=
feats
[
"aatype"
].
device
)
return
feats
...
...
@@ -297,7 +336,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
datasets
:
Sequence
[
OpenFoldSingleDataset
],
probabilities
:
Sequence
[
int
],
epoch_len
:
int
,
chain_data_cache_paths
:
List
[
str
],
generator
:
torch
.
Generator
=
None
,
_roll_at_init
:
bool
=
True
,
):
...
...
@@ -306,11 +344,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self
.
epoch_len
=
epoch_len
self
.
generator
=
generator
self
.
chain_data_caches
=
[]
for
path
in
chain_data_cache_paths
:
with
open
(
path
,
"r"
)
as
fp
:
self
.
chain_data_caches
.
append
(
json
.
load
(
fp
))
def
looped_shuffled_dataset_idx
(
dataset_len
):
while
True
:
# Uniformly shuffle each dataset's indices
...
...
@@ -328,7 +361,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
max_cache_len
=
int
(
epoch_len
*
probabilities
[
dataset_idx
])
dataset
=
self
.
datasets
[
dataset_idx
]
idx_iter
=
looped_shuffled_dataset_idx
(
len
(
dataset
))
chain_data_cache
=
self
.
chain_data_cache
s
[
dataset_idx
]
chain_data_cache
=
dataset
.
chain_data_cache
while
True
:
weights
=
[]
idx
=
[]
...
...
@@ -591,6 +624,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if
(
self
.
training_mode
):
train_dataset
=
dataset_gen
(
data_dir
=
self
.
train_data_dir
,
chain_data_cache_path
=
self
.
train_chain_data_cache_path
,
alignment_dir
=
self
.
train_alignment_dir
,
filter_path
=
self
.
train_filter_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
...
...
@@ -605,6 +639,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if
(
self
.
distillation_data_dir
is
not
None
):
distillation_dataset
=
dataset_gen
(
data_dir
=
self
.
distillation_data_dir
,
chain_data_cache_path
=
self
.
distillation_chain_data_cache_path
,
alignment_dir
=
self
.
distillation_alignment_dir
,
filter_path
=
self
.
distillation_filter_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
...
...
@@ -620,16 +655,9 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets
=
[
train_dataset
,
distillation_dataset
]
d_prob
=
self
.
config
.
train
.
distillation_prob
probabilities
=
[
1.
-
d_prob
,
d_prob
]
chain_data_cache_paths
=
[
self
.
train_chain_data_cache_path
,
self
.
distillation_chain_data_cache_path
,
]
else
:
datasets
=
[
train_dataset
]
probabilities
=
[
1.
]
chain_data_cache_paths
=
[
self
.
train_chain_data_cache_path
,
]
generator
=
None
if
(
self
.
batch_seed
is
not
None
):
...
...
@@ -640,7 +668,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets
=
datasets
,
probabilities
=
probabilities
,
epoch_len
=
self
.
train_epoch_len
,
chain_data_cache_paths
=
chain_data_cache_paths
,
generator
=
generator
,
_roll_at_init
=
False
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment