Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
d35816e3
Commit
d35816e3
authored
Jul 09, 2023
by
Geoffrey Yu
Browse files
add OpenFoldMultimerDataModule
parent
4d9a4bc2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
140 additions
and
38 deletions
+140
-38
openfold/data/data_modules.py
openfold/data/data_modules.py
+140
-38
No files found.
openfold/data/data_modules.py
View file @
d35816e3
...
...
@@ -19,7 +19,16 @@ from openfold.data import (
templates
,
)
from
openfold.utils.tensor_utils
import
tensor_tree_map
,
dict_multimap
import
contextlib
import
tempfile
@
contextlib
.
contextmanager
def
temp_fasta_file
(
sequence_str
):
"""function that create temparory fasta file used in multimer datapipeline"""
with
tempfile
.
NamedTemporaryFile
(
"w"
,
suffix
=
".fasta"
)
as
fasta_file
:
fasta_file
.
write
(
sequence_str
)
fasta_file
.
seek
(
0
)
yield
fasta_file
.
name
class
OpenFoldSingleDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
...
...
@@ -278,7 +287,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
dtype
=
torch
.
int64
,
device
=
feats
[
"aatype"
].
device
)
return
feats
,
data
return
feats
def
__len__
(
self
):
return
len
(
self
.
_chain_ids
)
...
...
@@ -399,31 +408,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
c
for
c
in
self
.
_chain_ids
if
c
in
chains_to_include
]
if
self
.
chain_data_cache
is
not
None
:
# Filter to include only chains where we have structure data
# (entries in chain_data_cache)
original_chain_ids
=
self
.
_chain_ids
self
.
_chain_ids
=
[
c
for
c
in
self
.
_chain_ids
if
c
in
self
.
chain_data_cache
]
if
len
(
self
.
_chain_ids
)
<
len
(
original_chain_ids
):
missing
=
[
c
for
c
in
original_chain_ids
if
c
not
in
self
.
chain_data_cache
]
max_to_print
=
10
missing_examples
=
", "
.
join
(
missing
[:
max_to_print
])
if
len
(
missing
)
>
max_to_print
:
missing_examples
+=
", ..."
logging
.
warning
(
"Removing %d alignment entries (%s) with no corresponding "
"entries in chain_data_cache (%s)."
,
len
(
missing
),
missing_examples
,
chain_data_cache_path
)
if
self
.
mmcif_data_cache
is
not
None
:
self
.
_chain_id_to_idx_dict
=
{
chain
:
i
for
i
,
chain
in
enumerate
(
self
.
_chain_ids
)
self
.
_mmcifs
=
list
(
self
.
mmcif_data_cache
.
keys
())
self
.
_mmcif_id_to_idx_dict
=
{
mmcif
:
i
for
i
,
mmcif
in
enumerate
(
self
.
_mmcifs
)
}
# changed template_featurizer to hmmsearch for now just to run the test
...
...
@@ -440,6 +429,9 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
self
.
data_pipeline
=
data_pipeline
.
DataPipeline
(
template_featurizer
=
template_featurizer
,
)
self
.
multimer_data_pipeline
=
data_pipeline
.
DataPipelineMultimer
(
monomer_data_pipeline
=
self
.
data_pipeline
)
if
(
not
self
.
_output_raw
):
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
...
...
@@ -468,14 +460,23 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
return
data
def
chain
_id_to_idx
(
self
,
chain_id
):
return
self
.
_
chain
_id_to_idx_dict
[
chain_id
]
def
mmcif
_id_to_idx
(
self
,
chain_id
):
return
self
.
_
mmcif
_id_to_idx_dict
[
chain_id
]
def
idx_to_
chain
_id
(
self
,
idx
):
return
self
.
_
chain_id
s
[
idx
]
def
idx_to_
mmcif
_id
(
self
,
idx
):
return
self
.
_
mmcif
s
[
idx
]
def
__getitem__
(
self
,
idx
):
name
=
self
.
idx_to_chain_id
(
idx
)
mmcif_id
=
self
.
idx_to_mmcif_id
(
idx
)
chains
=
self
.
mmcif_data_cache
[
mmcif_id
][
'chain_ids'
]
seqs
=
self
.
mmcif_data_cache
[
mmcif_id
][
'seqs'
]
fasta_str
=
""
for
c
,
s
in
zip
(
chains
,
seqs
):
fasta_str
+
f
">
{
mmcif_id
}
_
{
c
}
\n
{
s
}
"
print
(
fasta_str
)
import
sys
sys
.
exit
()
alignment_dir
=
os
.
path
.
join
(
self
.
alignment_dir
,
name
)
alignment_index
=
None
...
...
@@ -642,6 +643,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
idx
=
[]
for
_
in
range
(
max_cache_len
):
candidate_idx
=
next
(
idx_iter
)
## TO DO: add filtering cretieria for multimer
chain_id
=
dataset
.
idx_to_chain_id
(
candidate_idx
)
chain_data_cache_entry
=
chain_data_cache
[
chain_id
]
if
(
not
deterministic_train_filter
(
chain_data_cache_entry
)):
...
...
@@ -690,12 +692,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
samples
=
self
.
_samples
[
dataset_idx
]
datapoint_idx
=
next
(
samples
)
self
.
datapoints
.
append
((
dataset_idx
,
datapoint_idx
))
class
OpenFoldMultimerDataset
(
torch
.
utils
.
data
.
Dataset
):
"""
Implement the filtering criteria used in AlphaFold Multimer training
"""
print
(
f
"datapoints is
{
self
.
datapoints
}
"
)
class
OpenFoldBatchCollator
:
def
__call__
(
self
,
prots
):
...
...
@@ -799,6 +796,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
train_data_dir
:
Optional
[
str
]
=
None
,
train_alignment_dir
:
Optional
[
str
]
=
None
,
train_chain_data_cache_path
:
Optional
[
str
]
=
None
,
train_mmcif_data_cache_path
:
Optional
[
str
]
=
None
,
distillation_data_dir
:
Optional
[
str
]
=
None
,
distillation_alignment_dir
:
Optional
[
str
]
=
None
,
distillation_chain_data_cache_path
:
Optional
[
str
]
=
None
,
...
...
@@ -826,6 +824,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
train_data_dir
=
train_data_dir
self
.
train_alignment_dir
=
train_alignment_dir
self
.
train_chain_data_cache_path
=
train_chain_data_cache_path
self
.
train_mmcif_data_cache_path
=
train_mmcif_data_cache_path
self
.
distillation_data_dir
=
distillation_data_dir
self
.
distillation_alignment_dir
=
distillation_alignment_dir
self
.
distillation_chain_data_cache_path
=
(
...
...
@@ -1008,6 +1007,109 @@ class OpenFoldDataModule(pl.LightningDataModule):
def
predict_dataloader
(
self
):
return
self
.
_gen_dataloader
(
"predict"
)
class
OpenFoldMultimerDataModule
(
OpenFoldDataModule
):
"""
Create a datamodule specifically for multimer training
Compared to OpenFoldDataModule, OpenFoldMultimerDataModule
requires mmcif_data_cache_path which is the product of
scripts/generate_mmcif_cache.py mmcif_data_cache_path should be
a file that record what chain(s) each mmcif file has
"""
def
__init__
(
self
,
config
:
mlc
.
ConfigDict
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_mmcif_data_cache_path
:
Optional
[
str
]
=
None
,
**
kwargs
):
super
(
OpenFoldMultimerDataModule
,
self
).
__init__
(
config
,
template_mmcif_dir
,
max_template_date
,
train_data_dir
,
**
kwargs
)
self
.
train_mmcif_data_cache_path
=
train_mmcif_data_cache_path
self
.
training_mode
=
self
.
train_data_dir
is
not
None
def
setup
(
self
):
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleMultimerDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
max_template_date
=
self
.
max_template_date
,
config
=
self
.
config
,
kalign_binary_path
=
self
.
kalign_binary_path
,
template_release_dates_cache_path
=
self
.
template_release_dates_cache_path
,
obsolete_pdbs_file_path
=
self
.
obsolete_pdbs_file_path
,
)
if
(
self
.
training_mode
):
train_dataset
=
dataset_gen
(
data_dir
=
self
.
train_data_dir
,
mmcif_data_cache_path
=
self
.
train_mmcif_data_cache_path
,
alignment_dir
=
self
.
train_alignment_dir
,
filter_path
=
self
.
train_filter_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
shuffle_top_k_prefiltered
=
self
.
config
.
train
.
shuffle_top_k_prefiltered
,
treat_pdb_as_distillation
=
False
,
mode
=
"train"
,
alignment_index
=
self
.
alignment_index
,
)
distillation_dataset
=
None
if
(
self
.
distillation_data_dir
is
not
None
):
distillation_dataset
=
dataset_gen
(
data_dir
=
self
.
distillation_data_dir
,
chain_data_cache_path
=
self
.
distillation_chain_data_cache_path
,
alignment_dir
=
self
.
distillation_alignment_dir
,
filter_path
=
self
.
distillation_filter_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
treat_pdb_as_distillation
=
True
,
mode
=
"train"
,
alignment_index
=
self
.
distillation_alignment_index
,
_structure_index
=
self
.
_distillation_structure_index
,
)
d_prob
=
self
.
config
.
train
.
distillation_prob
if
(
distillation_dataset
is
not
None
):
datasets
=
[
train_dataset
,
distillation_dataset
]
d_prob
=
self
.
config
.
train
.
distillation_prob
probabilities
=
[
1.
-
d_prob
,
d_prob
]
else
:
datasets
=
[
train_dataset
]
probabilities
=
[
1.
]
generator
=
None
if
(
self
.
batch_seed
is
not
None
):
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
self
.
batch_seed
+
1
)
self
.
train_dataset
=
OpenFoldDataset
(
datasets
=
datasets
,
probabilities
=
probabilities
,
epoch_len
=
self
.
train_epoch_len
,
generator
=
generator
,
_roll_at_init
=
False
,
)
if
(
self
.
val_data_dir
is
not
None
):
self
.
eval_dataset
=
dataset_gen
(
data_dir
=
self
.
val_data_dir
,
alignment_dir
=
self
.
val_alignment_dir
,
filter_path
=
None
,
max_template_hits
=
self
.
config
.
eval
.
max_template_hits
,
mode
=
"eval"
,
)
else
:
self
.
eval_dataset
=
None
else
:
self
.
predict_dataset
=
dataset_gen
(
data_dir
=
self
.
predict_data_dir
,
alignment_dir
=
self
.
predict_alignment_dir
,
filter_path
=
None
,
max_template_hits
=
self
.
config
.
predict
.
max_template_hits
,
mode
=
"predict"
,
)
class
DummyDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
batch_path
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment