Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
e482a269
Commit
e482a269
authored
Jul 20, 2023
by
Geoffrey Yu
Browse files
updated data_module
parent
7f2a3267
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
51 deletions
+21
-51
openfold/data/data_modules.py
openfold/data/data_modules.py
+21
-51
No files found.
openfold/data/data_modules.py
View file @
e482a269
...
...
@@ -24,8 +24,7 @@ import tempfile
from
openfold.utils.tensor_utils
import
(
tensor_tree_map
,
)
import
logging
logger
=
logging
.
getLogger
(
__name__
)
@
contextlib
.
contextmanager
def
temp_fasta_file
(
sequence_str
):
...
...
@@ -471,8 +470,8 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def
__getitem__
(
self
,
idx
):
mmcif_id
=
self
.
idx_to_mmcif_id
(
idx
)
print
(
f
"mmcif_id is :
{
mmcif_id
}
"
)
chains
=
self
.
mmcif_data_cache
[
mmcif_id
][
'chain_ids'
]
print
(
f
"mmcif_id is :
{
mmcif_id
}
idx:
{
idx
}
and has
{
len
(
chains
)
}
chains"
)
seqs
=
self
.
mmcif_data_cache
[
mmcif_id
][
'seqs'
]
fasta_str
=
""
for
c
,
s
in
zip
(
chains
,
seqs
):
...
...
@@ -779,7 +778,7 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
selected_idx
=
self
.
filter_samples
(
dataset_idx
)
if
len
(
selected_idx
)
<
self
.
epoch_len
:
self
.
epoch_len
=
len
(
selected_idx
)
print
(
f
"self.epoch_len is
{
self
.
epoch_len
}
"
)
logging
.
info
(
f
"self.epoch_len is
{
self
.
epoch_len
}
"
)
self
.
datapoints
+=
[(
dataset_idx
,
selected_idx
[
i
])
for
i
in
range
(
self
.
epoch_len
)
]
class
OpenFoldBatchCollator
:
...
...
@@ -874,51 +873,25 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
return
_batch_prop_gen
(
it
)
class
OpenFoldMultimerDataLoader
(
OpenFold
DataLoader
):
class
OpenFoldMultimerDataLoader
(
torch
.
utils
.
data
.
DataLoader
):
def
__init__
(
self
,
*
args
,
config
,
stage
=
"train"
,
generator
=
None
,
**
kwargs
):
super
(
OpenFoldMultimerDataLoader
,
self
).
__init__
(
*
args
,
config
=
config
,
stage
=
stage
,
generator
=
generator
,
**
kwargs
)
super
(
OpenFoldMultimerDataLoader
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
config
=
config
self
.
stage
=
stage
def
_add_batch_properties
(
self
,
batch
):
samples
=
torch
.
multinomial
(
self
.
prop_probs_tensor
,
num_samples
=
1
,
# 1 per row
replacement
=
True
,
generator
=
self
.
generator
)
def
process_samples
(
batch
,
samples
):
aatype
=
batch
[
"aatype"
]
batch_dims
=
aatype
.
shape
[:
-
2
]
recycling_dim
=
aatype
.
shape
[
-
1
]
no_recycling
=
recycling_dim
for
i
,
key
in
enumerate
(
self
.
prop_keys
):
sample
=
int
(
samples
[
i
][
0
])
sample_tensor
=
torch
.
tensor
(
sample
,
device
=
aatype
.
device
,
requires_grad
=
False
)
orig_shape
=
sample_tensor
.
shape
sample_tensor
=
sample_tensor
.
view
(
(
1
,)
*
len
(
batch_dims
)
+
sample_tensor
.
shape
+
(
1
,)
)
sample_tensor
=
sample_tensor
.
expand
(
batch_dims
+
orig_shape
+
(
recycling_dim
,)
)
batch
[
key
]
=
sample_tensor
if
(
generator
is
None
):
generator
=
torch
.
Generator
()
self
.
generator
=
generator
print
(
'initialised a multimer dataloader'
)
def
__iter__
(
self
):
it
=
super
().
__iter__
()
if
(
key
==
"no_recycling_iters"
):
no_recycling
=
sample
resample_recycling
=
lambda
t
:
t
[...,
:
no_recycling
+
1
]
batch
=
tensor_tree_map
(
resample_recycling
,
batch
)
def
_batch_prop_gen
(
iterator
):
for
batch
in
iterator
:
yield
batch
return
batch
all_chain_features
,
ground_truth
=
batch
all_chain_features
=
process_samples
(
all_chain_features
,
samples
)
ground_truth
=
[
process_samples
(
i
,
samples
)
for
i
in
ground_truth
]
return
(
all_chain_features
,
ground_truth
)
return
_batch_prop_gen
(
it
)
class
OpenFoldDataModule
(
pl
.
LightningDataModule
):
...
...
@@ -1259,15 +1232,12 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
raise
ValueError
(
"Invalid stage"
)
dl
=
OpenFoldMultimer
DataLoader
(
dl
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
config
=
self
.
config
,
stage
=
stage
,
generator
=
generator
,
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
batch_size
=
1
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
)
print
(
f
"generated training dataloader"
)
return
dl
class
DummyDataset
(
torch
.
utils
.
data
.
Dataset
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment