Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
fe01bb0c
"git@developer.sourcefind.cn:OpenDAS/openfold.git" did not exist on "c46e97ca756ee4e549ee72c6aab84451b073eb62"
Commit
fe01bb0c
authored
Sep 21, 2023
by
Geoffrey Yu
Browse files
fixed the index error. Now working on updating greedy_align
parent
2184eff0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
16 deletions
+28
-16
openfold/utils/loss.py
openfold/utils/loss.py
+28
-16
No files found.
openfold/utils/loss.py
View file @
fe01bb0c
...
@@ -1700,9 +1700,6 @@ def compute_rmsd(
...
@@ -1700,9 +1700,6 @@ def compute_rmsd(
atom_mask
:
torch
.
Tensor
=
None
,
atom_mask
:
torch
.
Tensor
=
None
,
eps
:
float
=
1e-6
,
eps
:
float
=
1e-6
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# shape check
true_atom_pos
=
true_atom_pos
pred_atom_pos
=
pred_atom_pos
sq_diff
=
torch
.
square
(
true_atom_pos
-
pred_atom_pos
).
sum
(
dim
=-
1
,
keepdim
=
False
)
sq_diff
=
torch
.
square
(
true_atom_pos
-
pred_atom_pos
).
sum
(
dim
=-
1
,
keepdim
=
False
)
del
true_atom_pos
del
true_atom_pos
del
pred_atom_pos
del
pred_atom_pos
...
@@ -1860,19 +1857,28 @@ def greedy_align(
...
@@ -1860,19 +1857,28 @@ def greedy_align(
for
next_asym_id
in
cur_asym_list
:
for
next_asym_id
in
cur_asym_list
:
j
=
int
(
next_asym_id
-
1
)
j
=
int
(
next_asym_id
-
1
)
if
not
used
[
j
]:
# possible candidate
if
not
used
[
j
]:
# possible candidate
cropped_pos
=
torch
.
index_select
(
true_ca_poses
[
j
],
1
,
cur_residue_index
)
cropped_pos
=
true_ca_poses
[
j
]
mask
=
torch
.
index_select
(
true_ca_masks
[
j
],
1
,
cur_residue_index
)
cropped_pos
=
torch
.
squeeze
(
cropped_pos
,
0
)
rmsd
=
compute_rmsd
(
if
not
cropped_pos
.
shape
==
cur_pred_pos
.
shape
:
torch
.
squeeze
(
cropped_pos
,
0
),
torch
.
squeeze
(
cur_pred_pos
,
0
),
# this means selected candidte is not the correct one. Skip
(
cur_pred_mask
*
mask
).
bool
()
used
[
j
]
=
True
)
else
:
if
(
rmsd
is
not
None
)
and
(
rmsd
<
best_rmsd
):
mask
=
true_ca_masks
[
j
]
best_rmsd
=
rmsd
mask
=
torch
.
squeeze
(
mask
,
0
)
best_idx
=
j
print
(
f
"cropped_pos shape:
{
cropped_pos
.
shape
}
cur_pred_pos shape:
{
cur_pred_pos
.
shape
}
"
)
print
(
f
"mask shape:
{
mask
.
shape
}
and cur_pred_mask shape:
{
cur_pred_mask
.
shape
}
"
)
rmsd
=
compute_rmsd
(
torch
.
squeeze
(
cropped_pos
,
0
),
torch
.
squeeze
(
cur_pred_pos
,
0
),
(
cur_pred_mask
*
mask
).
bool
()
)
if
(
rmsd
is
not
None
)
and
(
rmsd
<
best_rmsd
):
best_rmsd
=
rmsd
best_idx
=
j
assert
best_idx
is
not
None
assert
best_idx
is
not
None
used
[
best_idx
]
=
True
used
[
best_idx
]
=
True
align
.
append
((
i
,
best_idx
))
align
.
append
((
i
,
best_idx
))
return
align
return
align
...
@@ -2065,7 +2071,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2065,7 +2071,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
return
labels
return
labels
@
staticmethod
@
staticmethod
def
multi_chain_perm_align
(
out
,
batch
,
dim_dict
,
permutate_chains
=
Tru
e
):
def
multi_chain_perm_align
(
out
,
batch
,
dim_dict
,
permutate_chains
=
Fals
e
):
"""
"""
A class method that first permutate chains in ground truth first
A class method that first permutate chains in ground truth first
before calculating the loss.
before calculating the loss.
...
@@ -2084,15 +2090,20 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2084,15 +2090,20 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
l
[
"all_atom_positions"
][...,
ca_idx
,
:]
for
l
in
labels
l
[
"all_atom_positions"
][...,
ca_idx
,
:]
for
l
in
labels
]
# list([nres, 3])
]
# list([nres, 3])
true_ca_masks
=
[
true_ca_masks
=
[
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
]
# list([nres,])
]
# list([nres,])
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
per_asym_residue_index
=
{}
per_asym_residue_index
=
{}
for
cur_asym_id
in
unique_asym_ids
:
for
cur_asym_id
in
unique_asym_ids
:
asym_mask
=
(
batch
[
"asym_id"
]
==
cur_asym_id
).
bool
()
asym_mask
=
(
batch
[
"asym_id"
]
==
cur_asym_id
).
bool
()
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
torch
.
masked_select
(
batch
[
"residue_index"
],
asym_mask
)
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
torch
.
masked_select
(
batch
[
"residue_index"
],
asym_mask
)
if
permutate_chains
:
if
permutate_chains
:
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
print
(
f
"anchor_gt_asym:
{
anchor_gt_asym
}
anchor_pred_asym:
{
anchor_pred_asym
}
"
)
print
(
f
"anchor_gt_asym:
{
anchor_gt_asym
}
anchor_pred_asym:
{
anchor_pred_asym
}
"
)
...
@@ -2105,11 +2116,12 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2105,11 +2116,12 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
asym_mask
=
(
batch
[
"asym_id"
]
==
anchor_pred_asym
).
bool
()
asym_mask
=
(
batch
[
"asym_id"
]
==
anchor_pred_asym
).
bool
()
anchor_residue_idx
=
per_asym_residue_index
[
int
(
anchor_pred_asym
)]
anchor_true_pos
=
torch
.
index_select
(
true_ca_poses
[
anchor_gt_idx
],
1
,
anchor_residue_idx
)
anchor_true_pos
=
true_ca_poses
[
anchor_gt_idx
]
anchor_pred_pos
=
pred_ca_pos
[
0
][
asym_mask
[
0
]]
anchor_pred_pos
=
pred_ca_pos
[
0
][
asym_mask
[
0
]]
anchor_true_mask
=
torch
.
index_select
(
true_ca_masks
[
anchor_gt_idx
]
,
1
,
anchor_residue_idx
)
anchor_true_mask
=
true_ca_masks
[
anchor_gt_idx
]
anchor_pred_mask
=
pred_ca_mask
[
0
][
asym_mask
[
0
]]
anchor_pred_mask
=
pred_ca_mask
[
0
][
asym_mask
[
0
]]
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment