Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
66a60d58
Commit
66a60d58
authored
Jul 13, 2023
by
Geoffrey Yu
Browse files
start modifying mderge_label function to make it compatible with dataloader inputs
parent
e9794a62
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
23 deletions
+22
-23
openfold/utils/loss.py
openfold/utils/loss.py
+22
-23
No files found.
openfold/utils/loss.py
View file @
66a60d58
...
@@ -1610,6 +1610,7 @@ def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs
...
@@ -1610,6 +1610,7 @@ def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs
Returns:
Returns:
Masked MSA loss
Masked MSA loss
"""
"""
print
(
f
"logits shape:
{
logits
.
shape
}
true_msa shape:
{
true_msa
.
shape
}
"
)
errors
=
softmax_cross_entropy
(
errors
=
softmax_cross_entropy
(
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_msa
,
num_classes
=
num_classes
)
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_msa
,
num_classes
=
num_classes
)
)
)
...
@@ -1749,7 +1750,7 @@ def get_optimal_transform(
...
@@ -1749,7 +1750,7 @@ def get_optimal_transform(
src_atoms
=
torch
.
zeros
((
1
,
3
),
device
=
src_atoms
.
device
).
float
()
src_atoms
=
torch
.
zeros
((
1
,
3
),
device
=
src_atoms
.
device
).
float
()
tgt_atoms
=
src_atoms
tgt_atoms
=
src_atoms
else
:
else
:
src_atoms
=
src_atoms
[
mask
,
:]
src_atoms
=
src_atoms
.
to
(
'cuda:0'
)
[
mask
,
:]
tgt_atoms
=
tgt_atoms
.
to
(
'cuda:0'
)[
mask
,
:]
tgt_atoms
=
tgt_atoms
.
to
(
'cuda:0'
)[
mask
,
:]
src_center
=
src_atoms
.
mean
(
-
2
,
keepdim
=
True
)
src_center
=
src_atoms
.
mean
(
-
2
,
keepdim
=
True
)
tgt_center
=
tgt_atoms
.
mean
(
-
2
,
keepdim
=
True
)
tgt_center
=
tgt_atoms
.
mean
(
-
2
,
keepdim
=
True
)
...
@@ -1857,7 +1858,6 @@ def greedy_align(
...
@@ -1857,7 +1858,6 @@ def greedy_align(
best_idx
=
None
best_idx
=
None
cur_asym_list
=
entity_2_asym_list
[
int
(
cur_entity_ids
)]
cur_asym_list
=
entity_2_asym_list
[
int
(
cur_entity_ids
)]
cur_residue_index
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
cur_residue_index
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
cur_pred_pos
=
pred_ca_pos
[
asym_mask
]
cur_pred_pos
=
pred_ca_pos
[
asym_mask
]
cur_pred_mask
=
pred_ca_mask
[
asym_mask
]
cur_pred_mask
=
pred_ca_mask
[
asym_mask
]
for
next_asym_id
in
cur_asym_list
:
for
next_asym_id
in
cur_asym_list
:
...
@@ -1890,27 +1890,30 @@ def merge_labels(batch, per_asym_residue_index, labels, align):
...
@@ -1890,27 +1890,30 @@ def merge_labels(batch, per_asym_residue_index, labels, align):
num_res
=
batch
[
"msa_mask"
].
shape
[
-
1
]
num_res
=
batch
[
"msa_mask"
].
shape
[
-
1
]
outs
=
{}
outs
=
{}
for
k
,
v
in
labels
[
0
].
items
():
for
k
,
v
in
labels
[
0
].
items
():
if
k
in
[
"resolution"
,
]:
continue
cur_out
=
{}
cur_out
=
{}
for
i
,
j
in
align
:
for
i
,
j
in
align
:
label
=
labels
[
j
][
k
]
label
=
labels
[
j
][
k
]
# to 1-based
# to 1-based
cur_residue_index
=
per_asym_residue_index
[
i
+
1
]
cur_residue_index
=
per_asym_residue_index
[
i
+
1
]
cur_out
[
i
]
=
label
[
cur_residue_index
]
if
len
(
v
.
shape
)
==
0
or
"template"
in
k
:
continue
else
:
cur_out
[
i
]
=
label
[
cur_residue_index
]
cur_out
=
[
x
[
1
]
for
x
in
sorted
(
cur_out
.
items
())]
cur_out
=
[
x
[
1
]
for
x
in
sorted
(
cur_out
.
items
())]
new_v
=
torch
.
concat
(
cur_out
,
dim
=
0
)
if
len
(
cur_out
)
>
0
:
merged_nres
=
new_v
.
shape
[
0
]
new_v
=
torch
.
concat
(
cur_out
,
dim
=
0
)
assert
(
merged_nres
=
new_v
.
shape
[
0
]
merged_nres
<=
num_res
assert
(
),
f
"bad merged num res:
{
merged_nres
}
>
{
num_res
}
. something is wrong."
merged_nres
<=
num_res
if
merged_nres
<
num_res
:
# must pad
),
f
"bad merged num res:
{
merged_nres
}
>
{
num_res
}
. something is wrong."
pad_dim
=
new_v
.
shape
[
1
:]
if
merged_nres
<
num_res
:
# must pad
pad_v
=
new_v
.
new_zeros
((
num_res
-
merged_nres
,
*
pad_dim
))
pad_dim
=
new_v
.
shape
[
1
:]
new_v
=
torch
.
concat
((
new_v
,
pad_v
),
dim
=
0
)
pad_v
=
new_v
.
new_zeros
((
num_res
-
merged_nres
,
*
pad_dim
))
outs
[
k
]
=
new_v
new_v
=
torch
.
concat
((
new_v
,
pad_v
),
dim
=
0
)
outs
[
k
]
=
new_v
print
(
f
"finished merging"
)
for
k
,
v
in
outs
.
items
():
print
(
f
"
{
k
}
:
{
v
.
shape
}
"
)
return
outs
return
outs
...
@@ -2050,7 +2053,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2050,7 +2053,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
true_ca_masks
=
[
true_ca_masks
=
[
l
[
"all_atom_mask"
][...,
ca_idx
].
float
()
for
l
in
labels
l
[
"all_atom_mask"
][...,
ca_idx
].
float
()
for
l
in
labels
]
# list([nres,])
]
# list([nres,])
unique_asym_ids
=
torch
.
unique
(
batch
[
"asym_id"
])
unique_asym_ids
=
torch
.
unique
(
batch
[
"asym_id"
])
per_asym_residue_index
=
{}
per_asym_residue_index
=
{}
...
@@ -2059,7 +2061,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2059,7 +2061,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
batch
[
"residue_index"
][
asym_mask
]
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
batch
[
"residue_index"
][
asym_mask
]
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
print
(
f
"anchor_gt_asym is :
{
anchor_gt_asym
}
and anchor_pred_asym is
{
anchor_pred_asym
}
"
)
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
...
@@ -2100,15 +2101,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2100,15 +2101,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del
aligned_true_ca_poses
del
aligned_true_ca_poses
del
r
,
x
del
r
,
x
gc
.
collect
()
gc
.
collect
()
print
(
f
"finished multi-chain permutation and final align is
{
align
}
"
)
merged_labels
=
merge_labels
(
merged_labels
=
merge_labels
(
batch
,
batch
,
per_asym_residue_index
,
per_asym_residue_index
,
labels
,
labels
,
align
,
align
,
)
)
print
(
f
"finished multi-chain permutation and final align is
{
align
}
"
)
return
merged_labels
return
merged_labels
...
@@ -2122,9 +2121,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2122,9 +2121,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
batch: a pair of input features and its corresponding ground truth structure
batch: a pair of input features and its corresponding ground truth structure
"""
"""
features
,
labels
=
batch
features
,
labels
=
batch
features
[
'resolution'
]
=
labels
[
2
][
'resolution'
]
# firstly update the resolution feature
# first remove the recycling dimention of input features
# first remove the recycling dimention of input features
features
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
features
)
features
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
features
)
features
[
'resolution'
]
=
labels
[
0
][
'resolution'
]
# then permutate ground truth chains before calculating the loss
# then permutate ground truth chains before calculating the loss
permutated_labels
=
self
.
multi_chain_perm_align
(
out
,
features
,
labels
)
permutated_labels
=
self
.
multi_chain_perm_align
(
out
,
features
,
labels
)
permutated_labels
.
pop
(
'aatype'
)
permutated_labels
.
pop
(
'aatype'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment