Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
093603ee
"git@developer.sourcefind.cn:OpenDAS/openfold.git" did not exist on "4dd50722cabdea1c00373d073b2a67590b1f2507"
Commit
093603ee
authored
Sep 23, 2023
by
Geoffrey Yu
Browse files
update multimer data input pipeline
parent
f3c1af45
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
39 additions
and
34 deletions
+39
-34
openfold/data/data_modules.py
openfold/data/data_modules.py
+8
-7
openfold/data/feature_pipeline.py
openfold/data/feature_pipeline.py
+4
-3
openfold/data/input_pipeline_multimer.py
openfold/data/input_pipeline_multimer.py
+27
-24
No files found.
openfold/data/data_modules.py
View file @
093603ee
...
...
@@ -443,7 +443,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
)
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
def
_parse_mmcif
(
self
,
path
,
file_id
,
alignment_dir
,
alignment_index
):
def
_parse_mmcif
(
self
,
path
,
file_id
,
alignment_dir
,
alignment_index
):
with
open
(
path
,
'r'
)
as
f
:
mmcif_string
=
f
.
read
()
...
...
@@ -462,7 +462,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
mmcif
=
mmcif_object
,
alignment_dir
=
alignment_dir
,
alignment_index
=
alignment_index
)
)
return
data
...
...
@@ -471,10 +471,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def
idx_to_mmcif_id
(
self
,
idx
):
return
self
.
_mmcifs
[
idx
]
def
__getitem__
(
self
,
idx
):
mmcif_id
=
self
.
idx_to_mmcif_id
(
idx
)
alignment_index
=
None
if
(
self
.
mode
==
'train'
or
self
.
mode
==
'eval'
):
path
=
os
.
path
.
join
(
self
.
data_dir
,
f
"
{
mmcif_id
}
"
)
ext
=
None
...
...
@@ -503,19 +504,19 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
if
(
self
.
_output_raw
):
return
data
# process all_chain_features
data
=
self
.
feature_pipeline
.
process_features
(
data
,
data
,
ground_truth
=
self
.
feature_pipeline
.
process_features
(
data
,
mode
=
self
.
mode
,
is_multimer
=
True
)
# if it's inference mode, only need all_chain_features
data
[
"batch_idx"
]
=
torch
.
tensor
(
[
idx
for
_
in
range
(
data
[
"aatype"
].
shape
[
-
1
])],
dtype
=
torch
.
int64
,
device
=
data
[
"aatype"
].
device
)
return
data
return
data
,
ground_truth
def
__len__
(
self
):
return
len
(
self
.
_chain_ids
)
...
...
openfold/data/feature_pipeline.py
View file @
093603ee
...
...
@@ -81,7 +81,7 @@ def np_example_to_features(
seq_length
=
np_example
[
"seq_length"
]
num_res
=
int
(
seq_length
[
0
])
if
seq_length
.
ndim
!=
0
else
int
(
seq_length
)
cfg
,
feature_names
=
make_data_config
(
config
,
mode
=
mode
,
num_res
=
num_res
)
if
"deletion_matrix_int"
in
np_example
:
np_example
[
"deletion_matrix"
]
=
np_example
.
pop
(
"deletion_matrix_int"
...
...
@@ -90,6 +90,7 @@ def np_example_to_features(
tensor_dict
=
np_to_tensor_dict
(
np_example
=
np_example
,
features
=
feature_names
)
with
torch
.
no_grad
():
if
(
not
is_multimer
):
features
=
input_pipeline
.
process_tensors_from_config
(
...
...
@@ -98,7 +99,7 @@ def np_example_to_features(
cfg
[
mode
],
)
else
:
features
=
input_pipeline_multimer
.
process_tensors_from_config
(
features
,
gt_features
=
input_pipeline_multimer
.
process_tensors_from_config
(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
...
...
@@ -119,7 +120,7 @@ def np_example_to_features(
dtype
=
torch
.
float32
,
)
return
{
k
:
v
for
k
,
v
in
features
.
items
()}
return
{
k
:
v
for
k
,
v
in
features
.
items
()}
,
gt_features
class
FeaturePipeline
:
...
...
openfold/data/input_pipeline_multimer.py
View file @
093603ee
...
...
@@ -21,19 +21,11 @@ from openfold.data import (
data_transforms_multimer
,
)
def
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
):
"""Input pipeline data transformers that are not ensembled."""
transforms
=
[
data_transforms
.
cast_to_64bit_ints
,
data_transforms_multimer
.
make_msa_profile
,
data_transforms_multimer
.
create_target_feat
,
data_transforms
.
make_atom14_masks
,
]
if
mode_cfg
.
supervised
:
def
grountruth_transforms_fns
():
transforms
=
[]
transforms
.
extend
(
[
[
data_transforms
.
make_atom14_masks
,
data_transforms
.
make_atom14_positions
,
data_transforms
.
atom37_to_frames
,
data_transforms
.
atom37_to_torsion_angles
(
""
),
...
...
@@ -42,6 +34,16 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms
.
get_chi_angles
,
]
)
return
transforms
def
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
):
"""Input pipeline data transformers that are not ensembled."""
transforms
=
[
data_transforms
.
cast_to_64bit_ints
,
data_transforms_multimer
.
make_msa_profile
,
data_transforms_multimer
.
create_target_feat
,
data_transforms
.
make_atom14_masks
]
return
transforms
...
...
@@ -118,6 +120,11 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
):
"""Based on the config, apply filters and transformations to the data."""
GROUNDTRUTH_FEATURES
=
[
'all_atom_mask'
,
'all_atom_positions'
]
input_tensors
=
{
k
:
v
for
k
,
v
in
tensors
.
items
()
if
k
not
in
GROUNDTRUTH_FEATURES
}
gt_tensors
=
{
k
:
v
for
k
,
v
in
tensors
.
items
()
if
k
in
GROUNDTRUTH_FEATURES
}
gt_tensors
[
'aatype'
]
=
tensors
[
'aatype'
].
to
(
torch
.
long
)
del
tensors
ensemble_seed
=
random
.
randint
(
0
,
torch
.
iinfo
(
torch
.
int32
).
max
)
def
wrap_ensemble_fn
(
data
,
i
):
...
...
@@ -132,27 +139,23 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg):
d
[
"ensemble_index"
]
=
i
return
fn
(
d
)
no_templates
=
True
if
(
"template_aatype"
in
tensors
):
no_templates
=
tensors
[
"template_aatype"
].
shape
[
0
]
==
0
nonensembled
=
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
,
)
gt_tensors
=
compose
(
grountruth_transforms_fns
())(
gt_tensors
)
tensors
=
compose
(
nonensembled
)(
tensors
)
if
(
"no_recycling_iters"
in
tensors
):
num_recycling
=
int
(
tensors
[
"no_recycling_iters"
])
input_tensors
=
compose
(
nonensembled
)(
input_tensors
)
if
(
"no_recycling_iters"
in
input_tensors
):
num_recycling
=
int
(
input_tensors
[
"no_recycling_iters"
])
else
:
num_recycling
=
common_cfg
.
max_recycling_iters
tensors
=
map_fn
(
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
torch
.
arange
(
num_recycling
+
1
)
input_
tensors
=
map_fn
(
lambda
x
:
wrap_ensemble_fn
(
input_
tensors
,
x
),
torch
.
arange
(
num_recycling
+
1
)
)
return
tensors
return
input_tensors
,
gt_
tensors
@
data_transforms
.
curry1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment