Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
9aebc203
Commit
9aebc203
authored
Sep 28, 2023
by
Geoffrey Yu
Browse files
update the returned features when it's not training mode in multimer input pipeline
parent
68389359
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
15 deletions
+17
-15
openfold/data/input_pipeline_multimer.py
openfold/data/input_pipeline_multimer.py
+17
-15
No files found.
openfold/data/input_pipeline_multimer.py
View file @
9aebc203
...
...
@@ -113,14 +113,17 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
return
transforms
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
):
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
,
is_training
=
False
):
"""Based on the config, apply filters and transformations to the data."""
GROUNDTRUTH_FEATURES
=
[
'all_atom_mask'
,
'all_atom_positions'
,
'asym_id'
,
'sym_id'
,
'entity_id'
]
tensors
[
'aatype'
]
=
tensors
[
'aatype'
].
to
(
torch
.
long
)
gt_tensors
=
{
k
:
v
for
k
,
v
in
tensors
.
items
()
if
k
in
GROUNDTRUTH_FEATURES
}
gt_tensors
[
'aatype'
]
=
tensors
[
'aatype'
]
ensemble_seed
=
random
.
randint
(
0
,
torch
.
iinfo
(
torch
.
int32
).
max
)
tensors
[
'aatype'
]
=
tensors
[
'aatype'
].
to
(
torch
.
long
)
nonensembled
=
nonensembled_transform_fns
()
tensors
=
compose
(
nonensembled
)(
tensors
)
if
(
"no_recycling_iters"
in
tensors
):
num_recycling
=
int
(
tensors
[
"no_recycling_iters"
])
else
:
num_recycling
=
common_cfg
.
max_recycling_iters
def
wrap_ensemble_fn
(
data
,
i
):
"""Function to be mapped over the ensemble dimension."""
...
...
@@ -134,20 +137,19 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg):
d
[
"ensemble_index"
]
=
i
return
fn
(
d
)
nonensembled
=
nonensembled_transform_fns
()
gt_tensors
=
compose
(
grountruth_transforms_fns
())(
gt_tensors
)
tensors
=
compose
(
nonensembled
)(
tensors
)
if
(
"no_recycling_iters"
in
tensors
):
num_recycling
=
int
(
tensors
[
"no_recycling_iters"
])
else
:
num_recycling
=
common_cfg
.
max_recycling_iters
tensors
=
map_fn
(
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
torch
.
arange
(
num_recycling
+
1
)
)
return
tensors
,
gt_tensors
if
is_training
:
GROUNDTRUTH_FEATURES
=
[
'all_atom_mask'
,
'all_atom_positions'
,
'asym_id'
,
'sym_id'
,
'entity_id'
]
gt_tensors
=
{
k
:
v
for
k
,
v
in
tensors
.
items
()
if
k
in
GROUNDTRUTH_FEATURES
}
gt_tensors
[
'aatype'
]
=
tensors
[
'aatype'
]
gt_tensors
=
compose
(
grountruth_transforms_fns
())(
gt_tensors
)
return
tensors
,
gt_tensors
else
:
return
tensors
@
data_transforms
.
curry1
def
compose
(
x
,
fs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment