Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
1abe6160
Commit
1abe6160
authored
Aug 29, 2022
by
Tim O'Donnell
Browse files
fix
parent
5e341f60
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
4 deletions
+8
-4
openfold/data/data_modules.py
openfold/data/data_modules.py
+8
-4
No files found.
openfold/data/data_modules.py
View file @
1abe6160
...
...
@@ -24,11 +24,11 @@ from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
class
OpenFoldSingleDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
data_dir
:
str
,
chain_data_cache_path
:
str
,
alignment_dir
:
str
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
config
:
mlc
.
ConfigDict
,
chain_data_cache_path
:
Optional
[
str
]
=
None
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
max_template_hits
:
int
=
4
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
...
...
@@ -82,6 +82,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
super
(
OpenFoldSingleDataset
,
self
).
__init__
()
self
.
data_dir
=
data_dir
self
.
chain_data_cache
=
None
if
chain_data_cache_path
is
not
None
:
with
open
(
chain_data_cache_path
,
"r"
)
as
fp
:
self
.
chain_data_cache
=
json
.
load
(
fp
)
assert
isinstance
(
self
.
chain_data_cache
,
dict
)
...
...
@@ -617,6 +619,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if
(
self
.
training_mode
):
train_dataset
=
dataset_gen
(
data_dir
=
self
.
train_data_dir
,
chain_data_cache_path
=
self
.
train_chain_data_cache_path
,
alignment_dir
=
self
.
train_alignment_dir
,
filter_path
=
self
.
train_filter_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
...
...
@@ -631,6 +634,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
if
(
self
.
distillation_data_dir
is
not
None
):
distillation_dataset
=
dataset_gen
(
data_dir
=
self
.
distillation_data_dir
,
chain_data_cache_path
=
self
.
distillation_chain_data_cache_path
,
alignment_dir
=
self
.
distillation_alignment_dir
,
filter_path
=
self
.
distillation_filter_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment