Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
566ca1a3
"vscode:/vscode.git/clone" did not exist on "c3b847901099bf5c3dd174a3c8ec994b73426833"
Commit
566ca1a3
authored
Jul 18, 2023
by
Geoffrey Yu
Browse files
added openfold multimer dataloader class and overwrite batch processing
parent
dbc0b085
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
83 additions
and
3 deletions
+83
-3
openfold/data/data_modules.py
openfold/data/data_modules.py
+83
-3
No files found.
openfold/data/data_modules.py
View file @
566ca1a3
...
...
@@ -24,6 +24,9 @@ import tempfile
from
openfold.utils.tensor_utils
import
(
tensor_tree_map
,
)
import
logging
logger
=
logging
.
getLogger
(
__name__
)
@
contextlib
.
contextmanager
def
temp_fasta_file
(
sequence_str
):
"""function that create temparory fasta file used in multimer datapipeline"""
...
...
@@ -468,6 +471,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def
__getitem__
(
self
,
idx
):
mmcif_id
=
self
.
idx_to_mmcif_id
(
idx
)
print
(
f
"mmcif_id is :
{
mmcif_id
}
"
)
chains
=
self
.
mmcif_data_cache
[
mmcif_id
][
'chain_ids'
]
seqs
=
self
.
mmcif_data_cache
[
mmcif_id
][
'seqs'
]
fasta_str
=
""
...
...
@@ -599,7 +603,10 @@ def deterministic_multimer_train_filter(
for
seq
in
seqs
:
for
aa
in
seq
:
counts
[
aa
]
+=
1
if
aa
not
in
restypes
:
return
False
else
:
counts
[
aa
]
+=
1
largest_aa_count
=
max
(
counts
.
values
())
largest_single_aa_prop
=
largest_aa_count
/
total_len
if
(
largest_single_aa_prop
>
max_single_aa_prop
):
...
...
@@ -867,6 +874,52 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
return
_batch_prop_gen
(
it
)
class
OpenFoldMultimerDataLoader
(
OpenFoldDataLoader
):
def
__init__
(
self
,
*
args
,
config
,
stage
=
"train"
,
generator
=
None
,
**
kwargs
):
super
(
OpenFoldMultimerDataLoader
,
self
).
__init__
(
*
args
,
config
=
config
,
stage
=
stage
,
generator
=
generator
,
**
kwargs
)
def
_add_batch_properties
(
self
,
batch
):
samples
=
torch
.
multinomial
(
self
.
prop_probs_tensor
,
num_samples
=
1
,
# 1 per row
replacement
=
True
,
generator
=
self
.
generator
)
def
process_samples
(
batch
,
samples
):
aatype
=
batch
[
"aatype"
]
batch_dims
=
aatype
.
shape
[:
-
2
]
recycling_dim
=
aatype
.
shape
[
-
1
]
no_recycling
=
recycling_dim
for
i
,
key
in
enumerate
(
self
.
prop_keys
):
sample
=
int
(
samples
[
i
][
0
])
sample_tensor
=
torch
.
tensor
(
sample
,
device
=
aatype
.
device
,
requires_grad
=
False
)
orig_shape
=
sample_tensor
.
shape
sample_tensor
=
sample_tensor
.
view
(
(
1
,)
*
len
(
batch_dims
)
+
sample_tensor
.
shape
+
(
1
,)
)
sample_tensor
=
sample_tensor
.
expand
(
batch_dims
+
orig_shape
+
(
recycling_dim
,)
)
batch
[
key
]
=
sample_tensor
if
(
key
==
"no_recycling_iters"
):
no_recycling
=
sample
resample_recycling
=
lambda
t
:
t
[...,
:
no_recycling
+
1
]
batch
=
tensor_tree_map
(
resample_recycling
,
batch
)
return
batch
all_chain_features
,
ground_truth
=
batch
all_chain_features
=
process_samples
(
all_chain_features
,
samples
)
ground_truth
=
[
process_samples
(
i
,
samples
)
for
i
in
ground_truth
]
return
(
all_chain_features
,
ground_truth
)
class
OpenFoldDataModule
(
pl
.
LightningDataModule
):
def
__init__
(
self
,
...
...
@@ -1123,7 +1176,6 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
train_dataset
=
dataset_gen
(
data_dir
=
self
.
train_data_dir
,
mmcif_data_cache_path
=
self
.
train_mmcif_data_cache_path
,
chain_data_cache_path
=
self
.
train_chain_data_cache_path
,
alignment_dir
=
self
.
train_alignment_dir
,
filter_path
=
self
.
train_filter_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
...
...
@@ -1138,7 +1190,6 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
if
(
self
.
distillation_data_dir
is
not
None
):
distillation_dataset
=
dataset_gen
(
data_dir
=
self
.
distillation_data_dir
,
chain_data_cache_path
=
self
.
distillation_chain_data_cache_path
,
alignment_dir
=
self
.
distillation_alignment_dir
,
filter_path
=
self
.
distillation_filter_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
...
...
@@ -1189,6 +1240,35 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
max_template_hits
=
self
.
config
.
predict
.
max_template_hits
,
mode
=
"predict"
,
)
def
_gen_dataloader
(
self
,
stage
):
generator
=
torch
.
Generator
()
if
(
self
.
batch_seed
is
not
None
):
generator
=
generator
.
manual_seed
(
self
.
batch_seed
)
dataset
=
None
if
(
stage
==
"train"
):
dataset
=
self
.
train_dataset
# Filter the dataset, if necessary
dataset
.
reroll
()
elif
(
stage
==
"eval"
):
dataset
=
self
.
eval_dataset
elif
(
stage
==
"predict"
):
dataset
=
self
.
predict_dataset
else
:
raise
ValueError
(
"Invalid stage"
)
dl
=
OpenFoldMultimerDataLoader
(
dataset
,
config
=
self
.
config
,
stage
=
stage
,
generator
=
generator
,
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
)
return
dl
class
DummyDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
batch_path
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment