Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
cec5a426
Commit
cec5a426
authored
Sep 03, 2023
by
Geoffrey Yu
Browse files
added test3 to test if permutated tensors end up as expected
parent
c5f16efc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
17 deletions
+27
-17
openfold/utils/loss.py
openfold/utils/loss.py
+13
-15
tests/test_permutation.py
tests/test_permutation.py
+14
-2
No files found.
openfold/utils/loss.py
View file @
cec5a426
...
@@ -2095,6 +2095,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2095,6 +2095,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
torch
.
masked_select
(
batch
[
"residue_index"
],
asym_mask
)
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
torch
.
masked_select
(
batch
[
"residue_index"
],
asym_mask
)
if
permutate_chains
:
if
permutate_chains
:
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
print
(
f
"anchor_gt_asym:
{
anchor_gt_asym
}
anchor_pred_asym:
{
anchor_pred_asym
}
"
)
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
...
@@ -2154,22 +2155,19 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2154,22 +2155,19 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
out: the output of model.forward()
out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure
batch: a pair of input features and its corresponding ground truth structure
"""
"""
_is_monomer
=
len
(
torch
.
unique
(
features
[
'asym_id'
]))
==
1
or
torch
.
unique
(
features
[
'asym_id'
]).
tolist
()
==
[
0
,
1
]
# first determin which dimension in the tensor to split into individual ground truth labels
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
features
)
# Then permutate ground truth chains before calculating the loss
align
,
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
features
,
dim_dict
=
dim_dict
,
permutate_chains
=
permutate_chains
)
if
not
_is_monomer
:
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
features
,
dim_dict
=
dim_dict
,
# first determin which dimension in the tensor to split into individual ground truth labels
REQUIRED_FEATURES
=
[
i
for
i
in
features
.
keys
()
if
i
in
dim_dict
])
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
features
)
# reorder ground truth labels according to permutation results
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
# Then permutate ground truth chains before calculating the loss
original_nres
=
features
[
'aatype'
].
shape
[
-
1
])
align
,
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
features
,
dim_dict
=
dim_dict
,
features
.
update
(
labels
)
permutate_chains
=
True
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
features
,
dim_dict
=
dim_dict
,
REQUIRED_FEATURES
=
[
i
for
i
in
features
.
keys
()
if
i
in
dim_dict
])
# reorder ground truth labels according to permutation results
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
=
features
[
'aatype'
].
shape
[
-
1
])
features
.
update
(
labels
)
if
(
not
_return_breakdown
):
if
(
not
_return_breakdown
):
cum_loss
=
self
.
loss
(
out
,
features
,
_return_breakdown
)
cum_loss
=
self
.
loss
(
out
,
features
,
_return_breakdown
)
...
...
tests/test_permutation.py
View file @
cec5a426
...
@@ -151,7 +151,19 @@ class TestPermutation(unittest.TestCase):
...
@@ -151,7 +151,19 @@ class TestPermutation(unittest.TestCase):
tensor_to_cuda
=
lambda
t
:
t
.
to
(
'cuda'
)
tensor_to_cuda
=
lambda
t
:
t
.
to
(
'cuda'
)
batch
=
tensor_tree_map
(
tensor_to_cuda
,
batch
)
batch
=
tensor_tree_map
(
tensor_to_cuda
,
batch
)
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
batch
)
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
batch
)
aligns
,
_
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
aligns
,
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
batch
,
batch
,
dim_dict
,
dim_dict
,
permutate_chains
=
True
)
permutate_chains
=
True
)
\ No newline at end of file
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
batch
,
dim_dict
=
dim_dict
,
REQUIRED_FEATURES
=
[
i
for
i
in
batch
.
keys
()
if
i
in
dim_dict
])
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
aligns
,
original_nres
=
batch
[
'aatype'
].
shape
[
-
1
])
self
.
assertTrue
(
torch
.
equal
(
labels
[
'residue_index'
],
batch
[
'residue_index'
]))
expected_permutated_gt_pos
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
expected_permutated_gt_pos
=
pad_features
(
expected_permutated_gt_pos
,
nres_pad
,
pad_dim
=
1
)
self
.
assertTrue
(
torch
.
equal
(
labels
[
'all_atom_positions'
],
expected_permutated_gt_pos
))
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment