Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
164f6777
Unverified
Commit
164f6777
authored
Dec 02, 2022
by
LuGY
Committed by
GitHub
Dec 02, 2022
Browse files
modify data module for train (#116)
parent
6fbc402e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
152 additions
and
226 deletions
+152
-226
fastfold/data/data_modules.py
fastfold/data/data_modules.py
+149
-225
fastfold/utils/tensor_utils.py
fastfold/utils/tensor_utils.py
+3
-1
No files found.
fastfold/data/data_modules.py
View file @
164f6777
# Copyright 2022 HPC-AI Tech Inc.
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 AlQuraishi Laboratory
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -12,19 +13,14 @@
...
@@ -12,19 +13,14 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
copy
from
functools
import
partial
from
functools
import
partial
import
json
import
json
import
logging
import
logging
import
os
import
os
import
pickle
from
typing
import
Optional
,
Sequence
,
List
,
Any
from
typing
import
Optional
,
Sequence
,
List
,
Any
import
ml_collections
as
mlc
import
ml_collections
as
mlc
import
numpy
as
np
import
pytorch_lightning
as
pl
import
torch
import
torch
from
torch.utils.data
import
RandomSampler
from
fastfold.data
import
(
from
fastfold.data
import
(
data_pipeline
,
data_pipeline
,
...
@@ -217,7 +213,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
...
@@ -217,7 +213,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
return
data
return
data
feats
=
self
.
feature_pipeline
.
process_features
(
feats
=
self
.
feature_pipeline
.
process_features
(
data
,
self
.
mode
data
,
self
.
mode
)
)
return
feats
return
feats
...
@@ -380,8 +376,10 @@ class OpenFoldBatchCollator:
...
@@ -380,8 +376,10 @@ class OpenFoldBatchCollator:
prot
,
self
.
stage
prot
,
self
.
stage
)
)
processed_prots
.
append
(
features
)
processed_prots
.
append
(
features
)
# By this stack, the batch dimension is processed and added.
stack_fn
=
partial
(
torch
.
stack
,
dim
=
0
)
stack_fn
=
partial
(
torch
.
stack
,
dim
=
0
)
# I have modified some codes. Now if the bs=1, the shape will be [...] rather than [1, ...]
# If bs>1(not allowed), the shape would be still [2, ...]
return
dict_multimap
(
stack_fn
,
processed_prots
)
return
dict_multimap
(
stack_fn
,
processed_prots
)
...
@@ -478,230 +476,156 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
...
@@ -478,230 +476,156 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
return
_batch_prop_gen
(
it
)
return
_batch_prop_gen
(
it
)
class
OpenFoldDataModule
(
pl
.
LightningDataModule
):
def
SetupTrainDataset
(
def
__init__
(
self
,
config
:
mlc
.
ConfigDict
,
config
:
mlc
.
ConfigDict
,
template_mmcif_dir
:
str
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
max_template_date
:
str
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_alignment_dir
:
Optional
[
str
]
=
None
,
train_alignment_dir
:
Optional
[
str
]
=
None
,
train_chain_data_cache_path
:
Optional
[
str
]
=
None
,
train_chain_data_cache_path
:
Optional
[
str
]
=
None
,
distillation_data_dir
:
Optional
[
str
]
=
None
,
distillation_data_dir
:
Optional
[
str
]
=
None
,
distillation_alignment_dir
:
Optional
[
str
]
=
None
,
distillation_alignment_dir
:
Optional
[
str
]
=
None
,
distillation_chain_data_cache_path
:
Optional
[
str
]
=
None
,
distillation_chain_data_cache_path
:
Optional
[
str
]
=
None
,
val_data_dir
:
Optional
[
str
]
=
None
,
val_data_dir
:
Optional
[
str
]
=
None
,
val_alignment_dir
:
Optional
[
str
]
=
None
,
val_alignment_dir
:
Optional
[
str
]
=
None
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
predict_data_dir
:
Optional
[
str
]
=
None
,
train_mapping_path
:
Optional
[
str
]
=
None
,
predict_alignment_dir
:
Optional
[
str
]
=
None
,
distillation_mapping_path
:
Optional
[
str
]
=
None
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
train_mapping_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
distillation_mapping_path
:
Optional
[
str
]
=
None
,
train_epoch_len
:
int
=
50000
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
_alignment_index_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
**
kwargs
,
batch_seed
:
Optional
[
int
]
=
None
,
):
train_epoch_len
:
int
=
50000
,
_alignment_index_path
:
Optional
[
str
]
=
None
,
if
(
train_data_dir
is
None
or
train_alignment_dir
is
None
):
**
kwargs
raise
ValueError
(
):
'train_data_dir and train_alignment_dir must be specified'
super
(
OpenFoldDataModule
,
self
).
__init__
()
)
elif
(
val_data_dir
is
not
None
and
val_alignment_dir
is
None
):
self
.
config
=
config
raise
ValueError
(
self
.
template_mmcif_dir
=
template_mmcif_dir
'If val_data_dir is specified, val_alignment_dir must '
self
.
max_template_date
=
max_template_date
'be specified as well'
self
.
train_data_dir
=
train_data_dir
)
self
.
train_alignment_dir
=
train_alignment_dir
self
.
train_chain_data_cache_path
=
train_chain_data_cache_path
_alignment_index
=
None
self
.
distillation_data_dir
=
distillation_data_dir
if
(
_alignment_index_path
is
not
None
):
self
.
distillation_alignment_dir
=
distillation_alignment_dir
with
open
(
_alignment_index_path
,
"r"
)
as
fp
:
self
.
distillation_chain_data_cache_path
=
(
_alignment_index
=
json
.
load
(
fp
)
distillation_chain_data_cache_path
)
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
self
.
val_data_dir
=
val_data_dir
template_mmcif_dir
=
template_mmcif_dir
,
self
.
val_alignment_dir
=
val_alignment_dir
max_template_date
=
max_template_date
,
self
.
predict_data_dir
=
predict_data_dir
config
=
config
,
self
.
predict_alignment_dir
=
predict_alignment_dir
kalign_binary_path
=
kalign_binary_path
,
self
.
kalign_binary_path
=
kalign_binary_path
self
.
train_mapping_path
=
train_mapping_path
self
.
distillation_mapping_path
=
distillation_mapping_path
self
.
template_release_dates_cache_path
=
(
template_release_dates_cache_path
)
self
.
obsolete_pdbs_file_path
=
obsolete_pdbs_file_path
self
.
batch_seed
=
batch_seed
self
.
train_epoch_len
=
train_epoch_len
if
(
self
.
train_data_dir
is
None
and
self
.
predict_data_dir
is
None
):
raise
ValueError
(
'At least one of train_data_dir or predict_data_dir must be '
'specified'
)
self
.
training_mode
=
self
.
train_data_dir
is
not
None
if
(
self
.
training_mode
and
train_alignment_dir
is
None
):
raise
ValueError
(
'In training mode, train_alignment_dir must be specified'
)
elif
(
not
self
.
training_mode
and
predict_alignment_dir
is
None
):
raise
ValueError
(
'In inference mode, predict_alignment_dir must be specified'
)
elif
(
val_data_dir
is
not
None
and
val_alignment_dir
is
None
):
raise
ValueError
(
'If val_data_dir is specified, val_alignment_dir must '
'be specified as well'
)
# An ad-hoc measure for our particular filesystem restrictions
self
.
_alignment_index
=
None
if
(
_alignment_index_path
is
not
None
):
with
open
(
_alignment_index_path
,
"r"
)
as
fp
:
self
.
_alignment_index
=
json
.
load
(
fp
)
def
setup
(
self
):
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
max_template_date
=
self
.
max_template_date
,
config
=
self
.
config
,
kalign_binary_path
=
self
.
kalign_binary_path
,
template_release_dates_cache_path
=
template_release_dates_cache_path
=
self
.
template_release_dates_cache_path
,
template_release_dates_cache_path
,
obsolete_pdbs_file_path
=
obsolete_pdbs_file_path
=
self
.
obsolete_pdbs_file_path
,
obsolete_pdbs_file_path
,
)
)
if
(
self
.
training_mode
):
train_dataset
=
dataset_gen
(
train_dataset
=
dataset_gen
(
data_dir
=
train_data_dir
,
data_dir
=
self
.
train_data_dir
,
alignment_dir
=
train_alignment_dir
,
alignment_dir
=
self
.
train_alignment_dir
,
mapping_path
=
train_mapping_path
,
mapping_path
=
self
.
train_mapping_path
,
max_template_hits
=
config
.
train
.
max_template_hits
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
shuffle_top_k_prefiltered
=
shuffle_top_k_prefiltered
=
config
.
train
.
shuffle_top_k_prefiltered
,
self
.
config
.
train
.
shuffle_top_k_prefiltered
,
treat_pdb_as_distillation
=
False
,
treat_pdb_as_distillation
=
False
,
mode
=
"train"
,
mode
=
"train"
,
_output_raw
=
True
,
_output_raw
=
True
,
_alignment_index
=
_alignment_index
,
_alignment_index
=
self
.
_alignment_index
,
)
)
distillation_dataset
=
None
distillation_dataset
=
None
if
(
distillation_data_dir
is
not
None
):
if
(
self
.
distillation_data_dir
is
not
None
):
distillation_dataset
=
dataset_gen
(
distillation_dataset
=
dataset_gen
(
data_dir
=
distillation_data_dir
,
data_dir
=
self
.
distillation_data_dir
,
alignment_dir
=
distillation_alignment_dir
,
alignment_dir
=
self
.
distillation_alignment_dir
,
mapping_path
=
distillation_mapping_path
,
mapping_path
=
self
.
distillation_mapping_path
,
max_template_hits
=
config
.
train
.
max_template_hits
,
max_template_hits
=
self
.
train
.
max_template_hits
,
treat_pdb_as_distillation
=
True
,
treat_pdb_as_distillation
=
True
,
mode
=
"train"
,
mode
=
"train"
,
_output_raw
=
True
,
_output_raw
=
True
,
)
)
d_prob
=
self
.
config
.
train
.
distillation_prob
d_prob
=
config
.
train
.
distillation_prob
if
(
distillation_dataset
is
not
None
):
datasets
=
[
train_dataset
,
distillation_dataset
]
d_prob
=
self
.
config
.
train
.
distillation_prob
probabilities
=
[
1
-
d_prob
,
d_prob
]
chain_data_cache_paths
=
[
self
.
train_chain_data_cache_path
,
self
.
distillation_chain_data_cache_path
,
]
else
:
datasets
=
[
train_dataset
]
probabilities
=
[
1.
]
chain_data_cache_paths
=
[
self
.
train_chain_data_cache_path
,
]
self
.
train_dataset
=
OpenFoldDataset
(
datasets
=
datasets
,
probabilities
=
probabilities
,
epoch_len
=
self
.
train_epoch_len
,
chain_data_cache_paths
=
chain_data_cache_paths
,
_roll_at_init
=
False
,
)
if
(
self
.
val_data_dir
is
not
None
):
if
(
distillation_dataset
is
not
None
):
self
.
eval_dataset
=
dataset_gen
(
datasets
=
[
train_dataset
,
distillation_dataset
]
data_dir
=
self
.
val_data_dir
,
d_prob
=
config
.
train
.
distillation_prob
alignment_dir
=
self
.
val_alignment_dir
,
probabilities
=
[
1
-
d_prob
,
d_prob
]
mapping_path
=
None
,
chain_data_cache_paths
=
[
max_template_hits
=
self
.
config
.
eval
.
max_template_hits
,
train_chain_data_cache_path
,
mode
=
"eval"
,
distillation_chain_data_cache_path
,
_output_raw
=
True
,
]
)
else
:
else
:
datasets
=
[
train_dataset
]
self
.
eval_dataset
=
None
probabilities
=
[
1.
]
else
:
chain_data_cache_paths
=
[
self
.
predict_dataset
=
dataset_gen
(
train_chain_data_cache_path
,
data_dir
=
self
.
predict_data_dir
,
]
alignment_dir
=
self
.
predict_alignment_dir
,
mapping_path
=
None
,
train_dataset
=
OpenFoldDataset
(
max_template_hits
=
self
.
config
.
predict
.
max_template_hits
,
datasets
=
datasets
,
mode
=
"predict"
,
probabilities
=
probabilities
,
)
epoch_len
=
train_epoch_len
,
chain_data_cache_paths
=
chain_data_cache_paths
,
def
_gen_dataloader
(
self
,
stage
):
_roll_at_init
=
False
,
generator
=
torch
.
Generator
()
)
if
(
self
.
batch_seed
is
not
None
):
generator
=
generator
.
manual_seed
(
self
.
batch_seed
)
if
(
val_data_dir
is
not
None
):
eval_dataset
=
dataset_gen
(
dataset
=
None
data_dir
=
val_data_dir
,
if
(
stage
==
"train"
):
alignment_dir
=
val_alignment_dir
,
dataset
=
self
.
train_dataset
mapping_path
=
None
,
max_template_hits
=
config
.
eval
.
max_template_hits
,
# Filter the dataset, if necessary
mode
=
"eval"
,
dataset
.
reroll
()
_output_raw
=
True
,
elif
(
stage
==
"eval"
):
)
dataset
=
self
.
eval_dataset
else
:
elif
(
stage
==
"predict"
):
eval_dataset
=
None
dataset
=
self
.
predict_dataset
else
:
return
train_dataset
,
eval_dataset
raise
ValueError
(
"Invalid stage"
)
batch_collator
=
OpenFoldBatchCollator
(
self
.
config
,
stage
)
def
TrainDataLoader
(
config
:
mlc
.
ConfigDict
,
dl
=
OpenFoldDataLoader
(
train_dataset
:
torch
.
utils
.
data
.
Dataset
,
dataset
,
test_dataset
:
Optional
[
torch
.
utils
.
data
.
Dataset
]
=
None
,
config
=
self
.
config
,
batch_seed
:
Optional
[
int
]
=
None
,
stage
=
stage
,
):
if
not
config
.
data_module
.
data_loaders
.
batch_size
==
1
:
raise
ValueError
(
"Only support batch size equals to 1"
)
generator
=
torch
.
Generator
()
if
(
batch_seed
is
not
None
):
generator
=
generator
.
manual_seed
(
batch_seed
)
train_batch_collator
=
OpenFoldBatchCollator
(
config
,
"train"
)
train_dataset
.
reroll
()
train_dataloader
=
OpenFoldDataLoader
(
train_dataset
,
config
=
config
,
stage
=
"train"
,
generator
=
generator
,
batch_size
=
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
train_batch_collator
,
)
test_dataloader
=
None
if
test_dataset
is
not
None
:
test_batch_collator
=
OpenFoldBatchCollator
(
config
,
"test"
)
test_dataloader
=
OpenFoldDataLoader
(
train_dataset
,
config
=
config
,
stage
=
"test"
,
generator
=
generator
,
generator
=
generator
,
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
batch_size
=
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
num_workers
=
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
batch_collator
,
collate_fn
=
test_
batch_collator
,
)
)
return
dl
return
train_dataloader
,
test_dataloader
def
train_dataloader
(
self
):
return
self
.
_gen_dataloader
(
"train"
)
def
val_dataloader
(
self
):
if
(
self
.
eval_dataset
is
not
None
):
return
self
.
_gen_dataloader
(
"eval"
)
return
None
def
predict_dataloader
(
self
):
return
self
.
_gen_dataloader
(
"predict"
)
class
DummyDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
batch_path
):
with
open
(
batch_path
,
"rb"
)
as
f
:
self
.
batch
=
pickle
.
load
(
f
)
def
__getitem__
(
self
,
idx
):
return
copy
.
deepcopy
(
self
.
batch
)
def
__len__
(
self
):
return
1000
class
DummyDataLoader
(
pl
.
LightningDataModule
):
def
__init__
(
self
,
batch_path
):
super
().
__init__
()
self
.
dataset
=
DummyDataset
(
batch_path
)
def
train_dataloader
(
self
):
return
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset
)
fastfold/utils/tensor_utils.py
View file @
164f6777
# Copyright 2022 HPC-AI Tech Inc
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
# Copyright 2021 DeepMind Technologies Limited
#
#
...
@@ -52,7 +53,8 @@ def dict_multimap(fn, dicts):
...
@@ -52,7 +53,8 @@ def dict_multimap(fn, dicts):
if
type
(
v
)
is
dict
:
if
type
(
v
)
is
dict
:
new_dict
[
k
]
=
dict_multimap
(
fn
,
all_v
)
new_dict
[
k
]
=
dict_multimap
(
fn
,
all_v
)
else
:
else
:
new_dict
[
k
]
=
fn
(
all_v
)
# when bs = 1, returns [...] rather than [1, ...]
new_dict
[
k
]
=
fn
(
all_v
)
if
len
(
all_v
)
>
1
else
all_v
[
0
]
return
new_dict
return
new_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment