Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
2c565664
"...git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "16310b269f866e6f4b7968ba6780e54a4f7b76f6"
Commit
2c565664
authored
Feb 15, 2024
by
Geoffrey Yu
Browse files
restore to the verison on main
parent
aa18a56b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
3 deletions
+5
-3
openfold/utils/multi_chain_permutation.py
openfold/utils/multi_chain_permutation.py
+5
-3
No files found.
openfold/utils/multi_chain_permutation.py
View file @
2c565664
...
@@ -105,12 +105,13 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
...
@@ -105,12 +105,13 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates
anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates
"""
"""
entity_2_asym_list
=
get_entity_2_asym_list
(
batch
)
entity_2_asym_list
=
get_entity_2_asym_list
(
batch
)
unique_entity_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"entity_id"
])
if
i
!=
0
]
# if entity_id is 0, that means this entity_id comes from padding
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
entity_asym_count
=
{}
entity_asym_count
=
{}
entity_length
=
{}
entity_length
=
{}
for
entity_id
in
unique_entity_ids
:
for
entity_id
in
unique_entity_ids
:
asym_ids
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
entity_id
])
asym_ids
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
entity_id
])
# Make sure some asym IDs associated with ground truth entity ID exist in cropped prediction
# Make sure some asym IDs associated with ground truth entity ID exist in cropped prediction
asym_ids_in_pred
=
[
a
for
a
in
asym_ids
if
a
in
input_asym_id
]
asym_ids_in_pred
=
[
a
for
a
in
asym_ids
if
a
in
input_asym_id
]
if
not
asym_ids_in_pred
:
if
not
asym_ids_in_pred
:
...
@@ -121,8 +122,10 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
...
@@ -121,8 +122,10 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
# Calculate entity length
# Calculate entity length
entity_mask
=
(
batch
[
"entity_id"
]
==
entity_id
)
entity_mask
=
(
batch
[
"entity_id"
]
==
entity_id
)
entity_length
[
int
(
entity_id
)]
=
entity_mask
.
sum
().
item
()
entity_length
[
int
(
entity_id
)]
=
entity_mask
.
sum
().
item
()
min_asym_count
=
min
(
entity_asym_count
.
values
())
min_asym_count
=
min
(
entity_asym_count
.
values
())
least_asym_entities
=
[
entity
for
entity
,
count
in
entity_asym_count
.
items
()
if
count
==
min_asym_count
]
least_asym_entities
=
[
entity
for
entity
,
count
in
entity_asym_count
.
items
()
if
count
==
min_asym_count
]
# If multiple entities have the least asym_id count, return those with the longest length
# If multiple entities have the least asym_id count, return those with the longest length
if
len
(
least_asym_entities
)
>
1
:
if
len
(
least_asym_entities
)
>
1
:
max_length
=
max
([
entity_length
[
entity
]
for
entity
in
least_asym_entities
])
max_length
=
max
([
entity_length
[
entity
]
for
entity
in
least_asym_entities
])
...
@@ -137,6 +140,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
...
@@ -137,6 +140,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
anchor_gt_asym_id
=
random
.
choice
(
entity_2_asym_list
[
least_asym_entities
])
anchor_gt_asym_id
=
random
.
choice
(
entity_2_asym_list
[
least_asym_entities
])
anchor_pred_asym_ids
=
[
asym_id
for
asym_id
in
entity_2_asym_list
[
least_asym_entities
]
if
asym_id
in
input_asym_id
]
anchor_pred_asym_ids
=
[
asym_id
for
asym_id
in
entity_2_asym_list
[
least_asym_entities
]
if
asym_id
in
input_asym_id
]
return
anchor_gt_asym_id
,
anchor_pred_asym_ids
return
anchor_gt_asym_id
,
anchor_pred_asym_ids
...
@@ -156,7 +160,6 @@ def greedy_align(
...
@@ -156,7 +160,6 @@ def greedy_align(
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
align
=
[]
align
=
[]
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
for
cur_asym_id
in
unique_asym_ids
:
for
cur_asym_id
in
unique_asym_ids
:
i
=
int
(
cur_asym_id
-
1
)
i
=
int
(
cur_asym_id
-
1
)
asym_mask
=
batch
[
"asym_id"
]
==
cur_asym_id
asym_mask
=
batch
[
"asym_id"
]
==
cur_asym_id
...
@@ -346,7 +349,6 @@ def compute_permutation_alignment(out, features, ground_truth):
...
@@ -346,7 +349,6 @@ def compute_permutation_alignment(out, features, ground_truth):
# First select anchors from predicted structures and ground truths
# First select anchors from predicted structures and ground truths
anchor_gt_asym
,
anchor_pred_asym_ids
=
get_least_asym_entity_or_longest_length
(
ground_truth
,
anchor_gt_asym
,
anchor_pred_asym_ids
=
get_least_asym_entity_or_longest_length
(
ground_truth
,
features
[
'asym_id'
])
features
[
'asym_id'
])
entity_2_asym_list
=
get_entity_2_asym_list
(
ground_truth
)
entity_2_asym_list
=
get_entity_2_asym_list
(
ground_truth
)
labels
=
split_ground_truth_labels
(
ground_truth
)
labels
=
split_ground_truth_labels
(
ground_truth
)
assert
isinstance
(
labels
,
list
)
assert
isinstance
(
labels
,
list
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment