Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
cec5a426
"container/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "3ba2b7e94b5ed7e92d971958c15806656559d4b1"
Commit
cec5a426
authored
Sep 03, 2023
by
Geoffrey Yu
Browse files
added test3 to test if permutated tensors end up as expected
parent
c5f16efc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
17 deletions
+27
-17
openfold/utils/loss.py
openfold/utils/loss.py
+13
-15
tests/test_permutation.py
tests/test_permutation.py
+14
-2
No files found.
openfold/utils/loss.py
View file @
cec5a426
...
@@ -2095,6 +2095,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2095,6 +2095,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
torch
.
masked_select
(
batch
[
"residue_index"
],
asym_mask
)
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
torch
.
masked_select
(
batch
[
"residue_index"
],
asym_mask
)
if
permutate_chains
:
if
permutate_chains
:
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
print
(
f
"anchor_gt_asym:
{
anchor_gt_asym
}
anchor_pred_asym:
{
anchor_pred_asym
}
"
)
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
...
@@ -2154,22 +2155,19 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2154,22 +2155,19 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
out: the output of model.forward()
out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure
batch: a pair of input features and its corresponding ground truth structure
"""
"""
_is_monomer
=
len
(
torch
.
unique
(
features
[
'asym_id'
]))
==
1
or
torch
.
unique
(
features
[
'asym_id'
]).
tolist
()
==
[
0
,
1
]
# first determin which dimension in the tensor to split into individual ground truth labels
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
features
)
# Then permutate ground truth chains before calculating the loss
align
,
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
features
,
dim_dict
=
dim_dict
,
permutate_chains
=
permutate_chains
)
if
not
_is_monomer
:
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
features
,
dim_dict
=
dim_dict
,
# first determin which dimension in the tensor to split into individual ground truth labels
REQUIRED_FEATURES
=
[
i
for
i
in
features
.
keys
()
if
i
in
dim_dict
])
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
features
)
# reorder ground truth labels according to permutation results
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
# Then permutate ground truth chains before calculating the loss
original_nres
=
features
[
'aatype'
].
shape
[
-
1
])
align
,
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
features
,
dim_dict
=
dim_dict
,
features
.
update
(
labels
)
permutate_chains
=
True
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
features
,
dim_dict
=
dim_dict
,
REQUIRED_FEATURES
=
[
i
for
i
in
features
.
keys
()
if
i
in
dim_dict
])
# reorder ground truth labels according to permutation results
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
=
features
[
'aatype'
].
shape
[
-
1
])
features
.
update
(
labels
)
if
(
not
_return_breakdown
):
if
(
not
_return_breakdown
):
cum_loss
=
self
.
loss
(
out
,
features
,
_return_breakdown
)
cum_loss
=
self
.
loss
(
out
,
features
,
_return_breakdown
)
...
...
tests/test_permutation.py
View file @
cec5a426
...
@@ -151,7 +151,19 @@ class TestPermutation(unittest.TestCase):
...
@@ -151,7 +151,19 @@ class TestPermutation(unittest.TestCase):
tensor_to_cuda
=
lambda
t
:
t
.
to
(
'cuda'
)
tensor_to_cuda
=
lambda
t
:
t
.
to
(
'cuda'
)
batch
=
tensor_tree_map
(
tensor_to_cuda
,
batch
)
batch
=
tensor_tree_map
(
tensor_to_cuda
,
batch
)
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
batch
)
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
batch
)
aligns
,
_
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
aligns
,
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
batch
,
batch
,
dim_dict
,
dim_dict
,
permutate_chains
=
True
)
permutate_chains
=
True
)
\ No newline at end of file
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
batch
,
dim_dict
=
dim_dict
,
REQUIRED_FEATURES
=
[
i
for
i
in
batch
.
keys
()
if
i
in
dim_dict
])
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
aligns
,
original_nres
=
batch
[
'aatype'
].
shape
[
-
1
])
self
.
assertTrue
(
torch
.
equal
(
labels
[
'residue_index'
],
batch
[
'residue_index'
]))
expected_permutated_gt_pos
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
expected_permutated_gt_pos
=
pad_features
(
expected_permutated_gt_pos
,
nres_pad
,
pad_dim
=
1
)
self
.
assertTrue
(
torch
.
equal
(
labels
[
'all_atom_positions'
],
expected_permutated_gt_pos
))
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment