Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
164f6777
"git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "5ee2f1729a02d1c041817beb8e5347bc56bc7112"
Unverified
Commit
164f6777
authored
Dec 02, 2022
by
LuGY
Committed by
GitHub
Dec 02, 2022
Browse files
modify data module for train (#116)
parent
6fbc402e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
152 additions
and
226 deletions
+152
-226
fastfold/data/data_modules.py
fastfold/data/data_modules.py
+149
-225
fastfold/utils/tensor_utils.py
fastfold/utils/tensor_utils.py
+3
-1
No files found.
fastfold/data/data_modules.py
View file @
164f6777
# Copyright 2022 HPC-AI Tech Inc.
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -12,19 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
from
functools
import
partial
import
json
import
logging
import
os
import
pickle
from
typing
import
Optional
,
Sequence
,
List
,
Any
import
ml_collections
as
mlc
import
numpy
as
np
import
pytorch_lightning
as
pl
import
torch
from
torch.utils.data
import
RandomSampler
from
fastfold.data
import
(
data_pipeline
,
...
...
@@ -217,7 +213,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
return
data
feats
=
self
.
feature_pipeline
.
process_features
(
data
,
self
.
mode
data
,
self
.
mode
)
return
feats
...
...
@@ -380,8 +376,10 @@ class OpenFoldBatchCollator:
prot
,
self
.
stage
)
processed_prots
.
append
(
features
)
# By this stack, the batch dimension is processed and added.
stack_fn
=
partial
(
torch
.
stack
,
dim
=
0
)
# I have modified some codes. Now if the bs=1, the shape will be [...] rather than [1, ...]
# If bs>1(not allowed), the shape would be still [2, ...]
return
dict_multimap
(
stack_fn
,
processed_prots
)
...
...
@@ -478,230 +476,156 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
return
_batch_prop_gen
(
it
)
class
OpenFoldDataModule
(
pl
.
LightningDataModule
):
def
__init__
(
self
,
config
:
mlc
.
ConfigDict
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_alignment_dir
:
Optional
[
str
]
=
None
,
train_chain_data_cache_path
:
Optional
[
str
]
=
None
,
distillation_data_dir
:
Optional
[
str
]
=
None
,
distillation_alignment_dir
:
Optional
[
str
]
=
None
,
distillation_chain_data_cache_path
:
Optional
[
str
]
=
None
,
val_data_dir
:
Optional
[
str
]
=
None
,
val_alignment_dir
:
Optional
[
str
]
=
None
,
predict_data_dir
:
Optional
[
str
]
=
None
,
predict_alignment_dir
:
Optional
[
str
]
=
None
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
train_mapping_path
:
Optional
[
str
]
=
None
,
distillation_mapping_path
:
Optional
[
str
]
=
None
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
batch_seed
:
Optional
[
int
]
=
None
,
train_epoch_len
:
int
=
50000
,
_alignment_index_path
:
Optional
[
str
]
=
None
,
**
kwargs
):
super
(
OpenFoldDataModule
,
self
).
__init__
()
self
.
config
=
config
self
.
template_mmcif_dir
=
template_mmcif_dir
self
.
max_template_date
=
max_template_date
self
.
train_data_dir
=
train_data_dir
self
.
train_alignment_dir
=
train_alignment_dir
self
.
train_chain_data_cache_path
=
train_chain_data_cache_path
self
.
distillation_data_dir
=
distillation_data_dir
self
.
distillation_alignment_dir
=
distillation_alignment_dir
self
.
distillation_chain_data_cache_path
=
(
distillation_chain_data_cache_path
)
self
.
val_data_dir
=
val_data_dir
self
.
val_alignment_dir
=
val_alignment_dir
self
.
predict_data_dir
=
predict_data_dir
self
.
predict_alignment_dir
=
predict_alignment_dir
self
.
kalign_binary_path
=
kalign_binary_path
self
.
train_mapping_path
=
train_mapping_path
self
.
distillation_mapping_path
=
distillation_mapping_path
self
.
template_release_dates_cache_path
=
(
template_release_dates_cache_path
)
self
.
obsolete_pdbs_file_path
=
obsolete_pdbs_file_path
self
.
batch_seed
=
batch_seed
self
.
train_epoch_len
=
train_epoch_len
if
(
self
.
train_data_dir
is
None
and
self
.
predict_data_dir
is
None
):
raise
ValueError
(
'At least one of train_data_dir or predict_data_dir must be '
'specified'
)
self
.
training_mode
=
self
.
train_data_dir
is
not
None
if
(
self
.
training_mode
and
train_alignment_dir
is
None
):
raise
ValueError
(
'In training mode, train_alignment_dir must be specified'
)
elif
(
not
self
.
training_mode
and
predict_alignment_dir
is
None
):
raise
ValueError
(
'In inference mode, predict_alignment_dir must be specified'
)
elif
(
val_data_dir
is
not
None
and
val_alignment_dir
is
None
):
raise
ValueError
(
'If val_data_dir is specified, val_alignment_dir must '
'be specified as well'
)
# An ad-hoc measure for our particular filesystem restrictions
self
.
_alignment_index
=
None
if
(
_alignment_index_path
is
not
None
):
with
open
(
_alignment_index_path
,
"r"
)
as
fp
:
self
.
_alignment_index
=
json
.
load
(
fp
)
def
setup
(
self
):
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
max_template_date
=
self
.
max_template_date
,
config
=
self
.
config
,
kalign_binary_path
=
self
.
kalign_binary_path
,
def
SetupTrainDataset
(
config
:
mlc
.
ConfigDict
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_alignment_dir
:
Optional
[
str
]
=
None
,
train_chain_data_cache_path
:
Optional
[
str
]
=
None
,
distillation_data_dir
:
Optional
[
str
]
=
None
,
distillation_alignment_dir
:
Optional
[
str
]
=
None
,
distillation_chain_data_cache_path
:
Optional
[
str
]
=
None
,
val_data_dir
:
Optional
[
str
]
=
None
,
val_alignment_dir
:
Optional
[
str
]
=
None
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
train_mapping_path
:
Optional
[
str
]
=
None
,
distillation_mapping_path
:
Optional
[
str
]
=
None
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
train_epoch_len
:
int
=
50000
,
_alignment_index_path
:
Optional
[
str
]
=
None
,
**
kwargs
,
):
if
(
train_data_dir
is
None
or
train_alignment_dir
is
None
):
raise
ValueError
(
'train_data_dir and train_alignment_dir must be specified'
)
elif
(
val_data_dir
is
not
None
and
val_alignment_dir
is
None
):
raise
ValueError
(
'If val_data_dir is specified, val_alignment_dir must '
'be specified as well'
)
_alignment_index
=
None
if
(
_alignment_index_path
is
not
None
):
with
open
(
_alignment_index_path
,
"r"
)
as
fp
:
_alignment_index
=
json
.
load
(
fp
)
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
template_mmcif_dir
=
template_mmcif_dir
,
max_template_date
=
max_template_date
,
config
=
config
,
kalign_binary_path
=
kalign_binary_path
,
template_release_dates_cache_path
=
self
.
template_release_dates_cache_path
,
template_release_dates_cache_path
,
obsolete_pdbs_file_path
=
self
.
obsolete_pdbs_file_path
,
obsolete_pdbs_file_path
,
)
if
(
self
.
training_mode
):
train_dataset
=
dataset_gen
(
data_dir
=
self
.
train_data_dir
,
alignment_dir
=
self
.
train_alignment_dir
,
mapping_path
=
self
.
train_mapping_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
shuffle_top_k_prefiltered
=
self
.
config
.
train
.
shuffle_top_k_prefiltered
,
treat_pdb_as_distillation
=
False
,
mode
=
"train"
,
_output_raw
=
True
,
_alignment_index
=
self
.
_alignment_index
,
)
distillation_dataset
=
None
if
(
self
.
distillation_data_dir
is
not
None
):
distillation_dataset
=
dataset_gen
(
data_dir
=
self
.
distillation_data_dir
,
alignment_dir
=
self
.
distillation_alignment_dir
,
mapping_path
=
self
.
distillation_mapping_path
,
max_template_hits
=
self
.
train
.
max_template_hits
,
treat_pdb_as_distillation
=
True
,
mode
=
"train"
,
_output_raw
=
True
,
)
train_dataset
=
dataset_gen
(
data_dir
=
train_data_dir
,
alignment_dir
=
train_alignment_dir
,
mapping_path
=
train_mapping_path
,
max_template_hits
=
config
.
train
.
max_template_hits
,
shuffle_top_k_prefiltered
=
config
.
train
.
shuffle_top_k_prefiltered
,
treat_pdb_as_distillation
=
False
,
mode
=
"train"
,
_output_raw
=
True
,
_alignment_index
=
_alignment_index
,
)
distillation_dataset
=
None
if
(
distillation_data_dir
is
not
None
):
distillation_dataset
=
dataset_gen
(
data_dir
=
distillation_data_dir
,
alignment_dir
=
distillation_alignment_dir
,
mapping_path
=
distillation_mapping_path
,
max_template_hits
=
config
.
train
.
max_template_hits
,
treat_pdb_as_distillation
=
True
,
mode
=
"train"
,
_output_raw
=
True
,
)
d_prob
=
self
.
config
.
train
.
distillation_prob
if
(
distillation_dataset
is
not
None
):
datasets
=
[
train_dataset
,
distillation_dataset
]
d_prob
=
self
.
config
.
train
.
distillation_prob
probabilities
=
[
1
-
d_prob
,
d_prob
]
chain_data_cache_paths
=
[
self
.
train_chain_data_cache_path
,
self
.
distillation_chain_data_cache_path
,
]
else
:
datasets
=
[
train_dataset
]
probabilities
=
[
1.
]
chain_data_cache_paths
=
[
self
.
train_chain_data_cache_path
,
]
self
.
train_dataset
=
OpenFoldDataset
(
datasets
=
datasets
,
probabilities
=
probabilities
,
epoch_len
=
self
.
train_epoch_len
,
chain_data_cache_paths
=
chain_data_cache_paths
,
_roll_at_init
=
False
,
)
d_prob
=
config
.
train
.
distillation_prob
if
(
self
.
val_data_dir
is
not
None
):
self
.
eval_dataset
=
dataset_gen
(
data_dir
=
self
.
val_data_dir
,
alignment_dir
=
self
.
val_alignment_dir
,
mapping_path
=
None
,
max_template_hits
=
self
.
config
.
eval
.
max_template_hits
,
mode
=
"eval"
,
_output_raw
=
True
,
)
else
:
self
.
eval_dataset
=
None
else
:
self
.
predict_dataset
=
dataset_gen
(
data_dir
=
self
.
predict_data_dir
,
alignment_dir
=
self
.
predict_alignment_dir
,
mapping_path
=
None
,
max_template_hits
=
self
.
config
.
predict
.
max_template_hits
,
mode
=
"predict"
,
)
def
_gen_dataloader
(
self
,
stage
):
generator
=
torch
.
Generator
()
if
(
self
.
batch_seed
is
not
None
):
generator
=
generator
.
manual_seed
(
self
.
batch_seed
)
dataset
=
None
if
(
stage
==
"train"
):
dataset
=
self
.
train_dataset
# Filter the dataset, if necessary
dataset
.
reroll
()
elif
(
stage
==
"eval"
):
dataset
=
self
.
eval_dataset
elif
(
stage
==
"predict"
):
dataset
=
self
.
predict_dataset
else
:
raise
ValueError
(
"Invalid stage"
)
batch_collator
=
OpenFoldBatchCollator
(
self
.
config
,
stage
)
dl
=
OpenFoldDataLoader
(
dataset
,
config
=
self
.
config
,
stage
=
stage
,
if
(
distillation_dataset
is
not
None
):
datasets
=
[
train_dataset
,
distillation_dataset
]
d_prob
=
config
.
train
.
distillation_prob
probabilities
=
[
1
-
d_prob
,
d_prob
]
chain_data_cache_paths
=
[
train_chain_data_cache_path
,
distillation_chain_data_cache_path
,
]
else
:
datasets
=
[
train_dataset
]
probabilities
=
[
1.
]
chain_data_cache_paths
=
[
train_chain_data_cache_path
,
]
train_dataset
=
OpenFoldDataset
(
datasets
=
datasets
,
probabilities
=
probabilities
,
epoch_len
=
train_epoch_len
,
chain_data_cache_paths
=
chain_data_cache_paths
,
_roll_at_init
=
False
,
)
if
(
val_data_dir
is
not
None
):
eval_dataset
=
dataset_gen
(
data_dir
=
val_data_dir
,
alignment_dir
=
val_alignment_dir
,
mapping_path
=
None
,
max_template_hits
=
config
.
eval
.
max_template_hits
,
mode
=
"eval"
,
_output_raw
=
True
,
)
else
:
eval_dataset
=
None
return
train_dataset
,
eval_dataset
def
TrainDataLoader
(
config
:
mlc
.
ConfigDict
,
train_dataset
:
torch
.
utils
.
data
.
Dataset
,
test_dataset
:
Optional
[
torch
.
utils
.
data
.
Dataset
]
=
None
,
batch_seed
:
Optional
[
int
]
=
None
,
):
if
not
config
.
data_module
.
data_loaders
.
batch_size
==
1
:
raise
ValueError
(
"Only support batch size equals to 1"
)
generator
=
torch
.
Generator
()
if
(
batch_seed
is
not
None
):
generator
=
generator
.
manual_seed
(
batch_seed
)
train_batch_collator
=
OpenFoldBatchCollator
(
config
,
"train"
)
train_dataset
.
reroll
()
train_dataloader
=
OpenFoldDataLoader
(
train_dataset
,
config
=
config
,
stage
=
"train"
,
generator
=
generator
,
batch_size
=
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
train_batch_collator
,
)
test_dataloader
=
None
if
test_dataset
is
not
None
:
test_batch_collator
=
OpenFoldBatchCollator
(
config
,
"test"
)
test_dataloader
=
OpenFoldDataLoader
(
train_dataset
,
config
=
config
,
stage
=
"test"
,
generator
=
generator
,
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
batch_collator
,
batch_size
=
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
test_
batch_collator
,
)
return
dl
def
train_dataloader
(
self
):
return
self
.
_gen_dataloader
(
"train"
)
def
val_dataloader
(
self
):
if
(
self
.
eval_dataset
is
not
None
):
return
self
.
_gen_dataloader
(
"eval"
)
return
None
def
predict_dataloader
(
self
):
return
self
.
_gen_dataloader
(
"predict"
)
class
DummyDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
batch_path
):
with
open
(
batch_path
,
"rb"
)
as
f
:
self
.
batch
=
pickle
.
load
(
f
)
def
__getitem__
(
self
,
idx
):
return
copy
.
deepcopy
(
self
.
batch
)
def
__len__
(
self
):
return
1000
class
DummyDataLoader
(
pl
.
LightningDataModule
):
def
__init__
(
self
,
batch_path
):
super
().
__init__
()
self
.
dataset
=
DummyDataset
(
batch_path
)
def
train_dataloader
(
self
):
return
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset
)
return
train_dataloader
,
test_dataloader
fastfold/utils/tensor_utils.py
View file @
164f6777
# Copyright 2022 HPC-AI Tech Inc
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
...
...
@@ -52,7 +53,8 @@ def dict_multimap(fn, dicts):
if
type
(
v
)
is
dict
:
new_dict
[
k
]
=
dict_multimap
(
fn
,
all_v
)
else
:
new_dict
[
k
]
=
fn
(
all_v
)
# when bs = 1, returns [...] rather than [1, ...]
new_dict
[
k
]
=
fn
(
all_v
)
if
len
(
all_v
)
>
1
else
all_v
[
0
]
return
new_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment