Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
164f6777
"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "54c4e0761a5b9e5d102c1c0dface786cf489fdd7"
Unverified
Commit
164f6777
authored
Dec 02, 2022
by
LuGY
Committed by
GitHub
Dec 02, 2022
Browse files
modify data module for train (#116)
parent
6fbc402e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
152 additions
and
226 deletions
+152
-226
fastfold/data/data_modules.py
fastfold/data/data_modules.py
+149
-225
fastfold/utils/tensor_utils.py
fastfold/utils/tensor_utils.py
+3
-1
No files found.
fastfold/data/data_modules.py
View file @
164f6777
# Copyright 2022 HPC-AI Tech Inc.
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 AlQuraishi Laboratory
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -12,19 +13,14 @@
...
@@ -12,19 +13,14 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
copy
from
functools
import
partial
from
functools
import
partial
import
json
import
json
import
logging
import
logging
import
os
import
os
import
pickle
from
typing
import
Optional
,
Sequence
,
List
,
Any
from
typing
import
Optional
,
Sequence
,
List
,
Any
import
ml_collections
as
mlc
import
ml_collections
as
mlc
import
numpy
as
np
import
pytorch_lightning
as
pl
import
torch
import
torch
from
torch.utils.data
import
RandomSampler
from
fastfold.data
import
(
from
fastfold.data
import
(
data_pipeline
,
data_pipeline
,
...
@@ -380,8 +376,10 @@ class OpenFoldBatchCollator:
...
@@ -380,8 +376,10 @@ class OpenFoldBatchCollator:
prot
,
self
.
stage
prot
,
self
.
stage
)
)
processed_prots
.
append
(
features
)
processed_prots
.
append
(
features
)
# By this stack, the batch dimension is processed and added.
stack_fn
=
partial
(
torch
.
stack
,
dim
=
0
)
stack_fn
=
partial
(
torch
.
stack
,
dim
=
0
)
# I have modified some codes. Now if the bs=1, the shape will be [...] rather than [1, ...]
# If bs>1(not allowed), the shape would be still [2, ...]
return
dict_multimap
(
stack_fn
,
processed_prots
)
return
dict_multimap
(
stack_fn
,
processed_prots
)
...
@@ -478,8 +476,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
...
@@ -478,8 +476,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
return
_batch_prop_gen
(
it
)
return
_batch_prop_gen
(
it
)
class
OpenFoldDataModule
(
pl
.
LightningDataModule
):
def
SetupTrainDataset
(
def
__init__
(
self
,
config
:
mlc
.
ConfigDict
,
config
:
mlc
.
ConfigDict
,
template_mmcif_dir
:
str
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
max_template_date
:
str
,
...
@@ -491,60 +488,19 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -491,60 +488,19 @@ class OpenFoldDataModule(pl.LightningDataModule):
distillation_chain_data_cache_path
:
Optional
[
str
]
=
None
,
distillation_chain_data_cache_path
:
Optional
[
str
]
=
None
,
val_data_dir
:
Optional
[
str
]
=
None
,
val_data_dir
:
Optional
[
str
]
=
None
,
val_alignment_dir
:
Optional
[
str
]
=
None
,
val_alignment_dir
:
Optional
[
str
]
=
None
,
predict_data_dir
:
Optional
[
str
]
=
None
,
predict_alignment_dir
:
Optional
[
str
]
=
None
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
train_mapping_path
:
Optional
[
str
]
=
None
,
train_mapping_path
:
Optional
[
str
]
=
None
,
distillation_mapping_path
:
Optional
[
str
]
=
None
,
distillation_mapping_path
:
Optional
[
str
]
=
None
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
obsolete_pdbs_file_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
batch_seed
:
Optional
[
int
]
=
None
,
train_epoch_len
:
int
=
50000
,
train_epoch_len
:
int
=
50000
,
_alignment_index_path
:
Optional
[
str
]
=
None
,
_alignment_index_path
:
Optional
[
str
]
=
None
,
**
kwargs
**
kwargs
,
):
):
super
(
OpenFoldDataModule
,
self
).
__init__
()
self
.
config
=
config
if
(
train_data_dir
is
None
or
train_alignment_dir
is
None
):
self
.
template_mmcif_dir
=
template_mmcif_dir
self
.
max_template_date
=
max_template_date
self
.
train_data_dir
=
train_data_dir
self
.
train_alignment_dir
=
train_alignment_dir
self
.
train_chain_data_cache_path
=
train_chain_data_cache_path
self
.
distillation_data_dir
=
distillation_data_dir
self
.
distillation_alignment_dir
=
distillation_alignment_dir
self
.
distillation_chain_data_cache_path
=
(
distillation_chain_data_cache_path
)
self
.
val_data_dir
=
val_data_dir
self
.
val_alignment_dir
=
val_alignment_dir
self
.
predict_data_dir
=
predict_data_dir
self
.
predict_alignment_dir
=
predict_alignment_dir
self
.
kalign_binary_path
=
kalign_binary_path
self
.
train_mapping_path
=
train_mapping_path
self
.
distillation_mapping_path
=
distillation_mapping_path
self
.
template_release_dates_cache_path
=
(
template_release_dates_cache_path
)
self
.
obsolete_pdbs_file_path
=
obsolete_pdbs_file_path
self
.
batch_seed
=
batch_seed
self
.
train_epoch_len
=
train_epoch_len
if
(
self
.
train_data_dir
is
None
and
self
.
predict_data_dir
is
None
):
raise
ValueError
(
raise
ValueError
(
'At least one of train_data_dir or predict_data_dir must be '
'train_data_dir and train_alignment_dir must be specified'
'specified'
)
self
.
training_mode
=
self
.
train_data_dir
is
not
None
if
(
self
.
training_mode
and
train_alignment_dir
is
None
):
raise
ValueError
(
'In training mode, train_alignment_dir must be specified'
)
elif
(
not
self
.
training_mode
and
predict_alignment_dir
is
None
):
raise
ValueError
(
'In inference mode, predict_alignment_dir must be specified'
)
)
elif
(
val_data_dir
is
not
None
and
val_alignment_dir
is
None
):
elif
(
val_data_dir
is
not
None
and
val_alignment_dir
is
None
):
raise
ValueError
(
raise
ValueError
(
...
@@ -552,156 +508,124 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -552,156 +508,124 @@ class OpenFoldDataModule(pl.LightningDataModule):
'be specified as well'
'be specified as well'
)
)
# An ad-hoc measure for our particular filesystem restrictions
_alignment_index
=
None
self
.
_alignment_index
=
None
if
(
_alignment_index_path
is
not
None
):
if
(
_alignment_index_path
is
not
None
):
with
open
(
_alignment_index_path
,
"r"
)
as
fp
:
with
open
(
_alignment_index_path
,
"r"
)
as
fp
:
self
.
_alignment_index
=
json
.
load
(
fp
)
_alignment_index
=
json
.
load
(
fp
)
def
setup
(
self
):
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
template_mmcif_dir
=
template_mmcif_dir
,
max_template_date
=
self
.
max_template_date
,
max_template_date
=
max_template_date
,
config
=
self
.
config
,
config
=
config
,
kalign_binary_path
=
self
.
kalign_binary_path
,
kalign_binary_path
=
kalign_binary_path
,
template_release_dates_cache_path
=
template_release_dates_cache_path
=
self
.
template_release_dates_cache_path
,
template_release_dates_cache_path
,
obsolete_pdbs_file_path
=
obsolete_pdbs_file_path
=
self
.
obsolete_pdbs_file_path
,
obsolete_pdbs_file_path
,
)
)
if
(
self
.
training_mode
):
train_dataset
=
dataset_gen
(
train_dataset
=
dataset_gen
(
data_dir
=
self
.
train_data_dir
,
data_dir
=
train_data_dir
,
alignment_dir
=
self
.
train_alignment_dir
,
alignment_dir
=
train_alignment_dir
,
mapping_path
=
self
.
train_mapping_path
,
mapping_path
=
train_mapping_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
max_template_hits
=
config
.
train
.
max_template_hits
,
shuffle_top_k_prefiltered
=
shuffle_top_k_prefiltered
=
self
.
config
.
train
.
shuffle_top_k_prefiltered
,
config
.
train
.
shuffle_top_k_prefiltered
,
treat_pdb_as_distillation
=
False
,
treat_pdb_as_distillation
=
False
,
mode
=
"train"
,
mode
=
"train"
,
_output_raw
=
True
,
_output_raw
=
True
,
_alignment_index
=
self
.
_alignment_index
,
_alignment_index
=
_alignment_index
,
)
)
distillation_dataset
=
None
distillation_dataset
=
None
if
(
self
.
distillation_data_dir
is
not
None
):
if
(
distillation_data_dir
is
not
None
):
distillation_dataset
=
dataset_gen
(
distillation_dataset
=
dataset_gen
(
data_dir
=
self
.
distillation_data_dir
,
data_dir
=
distillation_data_dir
,
alignment_dir
=
self
.
distillation_alignment_dir
,
alignment_dir
=
distillation_alignment_dir
,
mapping_path
=
self
.
distillation_mapping_path
,
mapping_path
=
distillation_mapping_path
,
max_template_hits
=
self
.
train
.
max_template_hits
,
max_template_hits
=
config
.
train
.
max_template_hits
,
treat_pdb_as_distillation
=
True
,
treat_pdb_as_distillation
=
True
,
mode
=
"train"
,
mode
=
"train"
,
_output_raw
=
True
,
_output_raw
=
True
,
)
)
d_prob
=
self
.
config
.
train
.
distillation_prob
d_prob
=
config
.
train
.
distillation_prob
if
(
distillation_dataset
is
not
None
):
if
(
distillation_dataset
is
not
None
):
datasets
=
[
train_dataset
,
distillation_dataset
]
datasets
=
[
train_dataset
,
distillation_dataset
]
d_prob
=
self
.
config
.
train
.
distillation_prob
d_prob
=
config
.
train
.
distillation_prob
probabilities
=
[
1
-
d_prob
,
d_prob
]
probabilities
=
[
1
-
d_prob
,
d_prob
]
chain_data_cache_paths
=
[
chain_data_cache_paths
=
[
self
.
train_chain_data_cache_path
,
train_chain_data_cache_path
,
self
.
distillation_chain_data_cache_path
,
distillation_chain_data_cache_path
,
]
]
else
:
else
:
datasets
=
[
train_dataset
]
datasets
=
[
train_dataset
]
probabilities
=
[
1.
]
probabilities
=
[
1.
]
chain_data_cache_paths
=
[
chain_data_cache_paths
=
[
self
.
train_chain_data_cache_path
,
train_chain_data_cache_path
,
]
]
self
.
train_dataset
=
OpenFoldDataset
(
train_dataset
=
OpenFoldDataset
(
datasets
=
datasets
,
datasets
=
datasets
,
probabilities
=
probabilities
,
probabilities
=
probabilities
,
epoch_len
=
self
.
train_epoch_len
,
epoch_len
=
train_epoch_len
,
chain_data_cache_paths
=
chain_data_cache_paths
,
chain_data_cache_paths
=
chain_data_cache_paths
,
_roll_at_init
=
False
,
_roll_at_init
=
False
,
)
)
if
(
self
.
val_data_dir
is
not
None
):
if
(
val_data_dir
is
not
None
):
self
.
eval_dataset
=
dataset_gen
(
eval_dataset
=
dataset_gen
(
data_dir
=
self
.
val_data_dir
,
data_dir
=
val_data_dir
,
alignment_dir
=
self
.
val_alignment_dir
,
alignment_dir
=
val_alignment_dir
,
mapping_path
=
None
,
mapping_path
=
None
,
max_template_hits
=
self
.
config
.
eval
.
max_template_hits
,
max_template_hits
=
config
.
eval
.
max_template_hits
,
mode
=
"eval"
,
mode
=
"eval"
,
_output_raw
=
True
,
_output_raw
=
True
,
)
)
else
:
else
:
self
.
eval_dataset
=
None
eval_dataset
=
None
else
:
self
.
predict_dataset
=
dataset_gen
(
data_dir
=
self
.
predict_data_dir
,
alignment_dir
=
self
.
predict_alignment_dir
,
mapping_path
=
None
,
max_template_hits
=
self
.
config
.
predict
.
max_template_hits
,
mode
=
"predict"
,
)
def
_gen_dataloader
(
self
,
stage
):
return
train_dataset
,
eval_dataset
generator
=
torch
.
Generator
()
if
(
self
.
batch_seed
is
not
None
):
generator
=
generator
.
manual_seed
(
self
.
batch_seed
)
dataset
=
None
if
(
stage
==
"train"
):
dataset
=
self
.
train_dataset
# Filter the dataset, if necessary
dataset
.
reroll
()
elif
(
stage
==
"eval"
):
dataset
=
self
.
eval_dataset
elif
(
stage
==
"predict"
):
dataset
=
self
.
predict_dataset
else
:
raise
ValueError
(
"Invalid stage"
)
batch_collator
=
OpenFoldBatchCollator
(
self
.
config
,
stage
)
dl
=
OpenFoldDataLoader
(
dataset
,
config
=
self
.
config
,
stage
=
stage
,
generator
=
generator
,
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
batch_collator
,
)
return
dl
def
train_dataloader
(
self
):
return
self
.
_gen_dataloader
(
"train"
)
def
val_dataloader
(
self
):
if
(
self
.
eval_dataset
is
not
None
):
return
self
.
_gen_dataloader
(
"eval"
)
return
None
def
predict_dataloader
(
self
):
return
self
.
_gen_dataloader
(
"predict"
)
class
DummyDataset
(
torch
.
utils
.
data
.
Dataset
):
def
TrainDataLoader
(
def
__init__
(
self
,
batch_path
):
config
:
mlc
.
ConfigDict
,
with
open
(
batch_path
,
"rb"
)
as
f
:
train_dataset
:
torch
.
utils
.
data
.
Dataset
,
self
.
batch
=
pickle
.
load
(
f
)
test_dataset
:
Optional
[
torch
.
utils
.
data
.
Dataset
]
=
None
,
batch_seed
:
Optional
[
int
]
=
None
,
def
__getitem__
(
self
,
idx
):
):
return
copy
.
deepcopy
(
self
.
batch
)
def
__len__
(
self
):
return
1000
if
not
config
.
data_module
.
data_loaders
.
batch_size
==
1
:
raise
ValueError
(
"Only support batch size equals to 1"
)
class
DummyDataLoader
(
pl
.
LightningDataModule
):
generator
=
torch
.
Generator
()
def
__init__
(
self
,
batch_path
):
if
(
batch_seed
is
not
None
):
super
().
__init__
()
generator
=
generator
.
manual_seed
(
batch_seed
)
self
.
dataset
=
DummyDataset
(
batch_path
)
train_batch_collator
=
OpenFoldBatchCollator
(
config
,
"train"
)
train_dataset
.
reroll
()
train_dataloader
=
OpenFoldDataLoader
(
train_dataset
,
config
=
config
,
stage
=
"train"
,
generator
=
generator
,
batch_size
=
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
train_batch_collator
,
)
test_dataloader
=
None
if
test_dataset
is
not
None
:
test_batch_collator
=
OpenFoldBatchCollator
(
config
,
"test"
)
test_dataloader
=
OpenFoldDataLoader
(
train_dataset
,
config
=
config
,
stage
=
"test"
,
generator
=
generator
,
batch_size
=
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
test_batch_collator
,
)
def
train_dataloader
(
self
):
return
train_dataloader
,
test_dataloader
return
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset
)
fastfold/utils/tensor_utils.py
View file @
164f6777
# Copyright 2022 HPC-AI Tech Inc
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
# Copyright 2021 DeepMind Technologies Limited
#
#
...
@@ -52,7 +53,8 @@ def dict_multimap(fn, dicts):
...
@@ -52,7 +53,8 @@ def dict_multimap(fn, dicts):
if
type
(
v
)
is
dict
:
if
type
(
v
)
is
dict
:
new_dict
[
k
]
=
dict_multimap
(
fn
,
all_v
)
new_dict
[
k
]
=
dict_multimap
(
fn
,
all_v
)
else
:
else
:
new_dict
[
k
]
=
fn
(
all_v
)
# when bs = 1, returns [...] rather than [1, ...]
new_dict
[
k
]
=
fn
(
all_v
)
if
len
(
all_v
)
>
1
else
all_v
[
0
]
return
new_dict
return
new_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment