Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
56632d44
Commit
56632d44
authored
Dec 02, 2021
by
Gustaf Ahdritz
Browse files
Add non-distillation PDB option to data module
parent
f707a9ea
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
2 deletions
+11
-2
openfold/data/data_modules.py
openfold/data/data_modules.py
+11
-2
No files found.
openfold/data/data_modules.py
View file @
56632d44
...
...
@@ -33,6 +33,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
max_template_hits
:
int
=
4
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
shuffle_top_k_prefiltered
:
Optional
[
int
]
=
None
,
treat_pdb_as_distillation
:
bool
=
True
,
mode
:
str
=
"train"
,
_output_raw
:
bool
=
False
,
):
...
...
@@ -71,6 +72,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
parsing max_template_hits of them. Can be used to
approximate DeepMind's training-time template subsampling
scheme much more performantly.
treat_pdb_as_distillation:
Whether to assume that .pdb files in the data_dir are from
the self-distillation set (and should be subjected to
special distillation set preprocessing steps).
mode:
"train", "val", or "predict"
"""
...
...
@@ -78,6 +83,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self
.
data_dir
=
data_dir
self
.
alignment_dir
=
alignment_dir
self
.
config
=
config
self
.
treat_pdb_as_distillation
=
treat_pdb_as_distillation
self
.
mode
=
mode
self
.
_output_raw
=
_output_raw
...
...
@@ -162,10 +168,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
path
+
".core"
,
alignment_dir
)
else
:
# Try to search for a distillation PDB file instead
data
=
self
.
data_pipeline
.
process_pdb
(
pdb_path
=
path
+
".pdb"
,
alignment_dir
=
alignment_dir
alignment_dir
=
alignment_dir
,
is_distillation
=
self
.
treat_pdb_as_distillation
,
chain_id
=
chain_id
,
)
else
:
path
=
os
.
path
.
join
(
name
,
name
+
".fasta"
)
...
...
@@ -387,6 +394,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
shuffle_top_k_prefiltered
=
self
.
config
.
train
.
shuffle_top_k_prefiltered
,
treat_pdb_as_distillation
=
False
,
mode
=
"train"
,
_output_raw
=
True
,
)
...
...
@@ -397,6 +405,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
alignment_dir
=
self
.
distillation_alignment_dir
,
mapping_path
=
self
.
distillation_mapping_path
,
max_template_hits
=
self
.
train
.
max_template_hits
,
treat_pdb_as_distillation
=
True
,
mode
=
"train"
,
_output_raw
=
True
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment