Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
a0f8a057
Commit
a0f8a057
authored
Jun 22, 2023
by
Geoffrey Yu
Browse files
now used the new way of selecting anchors
parent
54755901
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
57 additions
and
57 deletions
+57
-57
tests/test_permutation.py
tests/test_permutation.py
+5
-5
tests/unifold_permutation.py
tests/unifold_permutation.py
+52
-52
No files found.
tests/test_permutation.py
View file @
a0f8a057
...
...
@@ -40,12 +40,12 @@ class TestPermutation(unittest.TestCase):
In the test case, use PDB ID 1e4k as the label
"""
self
.
test_data_dir
=
os
.
path
.
join
(
os
.
getcwd
(),
"tests/test_data"
)
self
.
label_ids
=
[
'label_1'
,
'label_2'
]
self
.
label_ids
=
[
'label_1'
,
'label_2'
,
'label_2'
]
def
test_dry_run
(
self
):
n_seq
=
consts
.
n_seq
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
+
13
n_extra_seq
=
consts
.
n_extra
c
=
model_config
(
consts
.
model
,
train
=
True
)
...
...
@@ -83,12 +83,12 @@ class TestPermutation(unittest.TestCase):
# Modify asym_id, entity_id and sym_id so that it encodes
# 2 chains
# #
asym_id
=
[
1
]
*
9
+
[
2
]
*
13
asym_id
=
[
1
]
*
9
+
[
2
]
*
13
+
[
3
]
*
13
batch
[
"asym_id"
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
# batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
batch
[
'entity_id'
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
batch
[
'entity_id'
]
=
torch
.
tensor
(
[
1
]
*
9
+
[
2
]
*
26
,
dtype
=
torch
.
float64
)
batch
[
"sym_id"
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
batch
[
"num_sym"
]
=
torch
.
tensor
([
2
]
*
22
,
dtype
=
torch
.
int64
)
# currently there are just 2 chains
batch
[
"num_sym"
]
=
torch
.
tensor
([
1
]
*
9
+
[
2
]
*
26
,
dtype
=
torch
.
int64
)
# currently there are just 2 chains
batch
[
"extra_deletion_matrix"
]
=
torch
.
randint
(
0
,
2
,
size
=
(
n_extra_seq
,
n_res
))
add_recycling_dims
=
lambda
t
:
(
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
...
...
tests/unifold_permutation.py
View file @
a0f8a057
...
...
@@ -3,6 +3,7 @@ from openfold.np import residue_constants as rc
import
logging
logger
=
logging
.
getLogger
(
__name__
)
import
sys
import
random
def
kabsch_rotation
(
P
,
Q
):
"""
...
...
@@ -119,10 +120,13 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
for
cur_asym_id
in
unique_asym_ids
:
asym_mask
=
(
batch
[
"asym_id"
]
==
cur_asym_id
).
bool
()
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
batch
[
"residue_index"
][
asym_mask
]
anchor_gt_asym
,
anchor_pred_asym
=
get_anchor_candidates
(
batch
,
per_asym_residue_index
,
true_ca_masks
)
print
(
f
"anchor_gt_asym is
{
anchor_gt_asym
}
, anchor_pred_asym is
{
anchor_pred_asym
}
"
)
# anchor_gt_asym, anchor_pred_asym = get_anchor_candidates(
# batch, per_asym_residue_index, true_ca_masks
# )
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
print
(
f
"anchor_gt_asym is
{
anchor_gt_asym
}
"
)
import
sys
sys
.
exit
()
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
best_rmsd
=
1e9
...
...
@@ -134,7 +138,7 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
ent_mask
=
batch
[
"entity_id"
]
==
cur_ent_id
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
print
(
f
"entity_2_asym_list is
{
entity_2_asym_list
}
"
)
for
cur_asym_id
in
anchor_pred_asym
:
asym_mask
=
(
batch
[
"asym_id"
]
==
cur_asym_id
).
bool
()
anchor_residue_idx
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
...
...
@@ -148,9 +152,7 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
anchor_pred_pos
,
(
anchor_true_mask
.
to
(
'cpu'
)
*
anchor_pred_mask
.
to
(
'cpu'
)).
bool
(),
)
print
(
f
"finished getting optimal transform"
)
import
sys
sys
.
exit
()
aligned_true_ca_poses
=
[
ca
.
to
(
'cpu'
)
@
r
.
to
(
'cpu'
)
+
x
.
to
(
'cpu'
)
for
ca
in
true_ca_poses
]
# apply transforms
for
_
in
range
(
shuffle_times
):
shuffle_idx
=
torch
.
randperm
(
...
...
@@ -167,57 +169,61 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
aligned_true_ca_poses
,
true_ca_masks
,
)
merged_labels
=
merge_labels
(
batch
,
per_asym_residue_index
,
labels
,
align
,
)
rmsd
=
kabsch_rmsd
(
merged_labels
[
"all_atom_positions"
][...,
ca_idx
,
:]
@
r
+
x
,
merged_labels
[
"all_atom_positions"
][...,
ca_idx
,
:]
.
to
(
'cpu'
)
@
r
.
to
(
'cpu'
)
+
x
.
to
(
'cpu'
)
,
pred_ca_pos
,
(
pred_ca_mask
*
merged_labels
[
"all_atom_mask"
][...,
ca_idx
]).
bool
(),
(
pred_ca_mask
.
to
(
'cpu'
)
*
merged_labels
[
"all_atom_mask"
][...,
ca_idx
]
.
to
(
'cpu'
)
).
bool
(),
)
if
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
best_labels
=
merged_labels
print
(
f
"finished kabsh_rmsd"
)
return
best_labels
def
get_anchor_candidates
(
batch
,
per_asym_residue_index
,
true_masks
):
def
find_by_num_sym
(
min_num_sym
):
best_len
=
-
1
best_gt_asym
=
None
asym_ids
=
batch
[
"asym_id"
][
batch
[
"num_sym"
]
==
min_num_sym
]
asym_ids
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"num_sym"
]
==
min_num_sym
])
for
cur_asym_id
in
asym_ids
:
assert
cur_asym_id
>
0
cur_residue_index
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
j
=
int
(
cur_asym_id
-
1
)
cur_true_mask
=
true_masks
[
j
][
cur_residue_index
]
cur_len
=
cur_true_mask
.
shape
[
0
]
if
cur_len
>
best_len
:
best_len
=
cur_len
best_gt_asym
=
cur_asym_id
return
best_gt_asym
,
best_len
sorted_num_sym
=
batch
[
"num_sym"
][
batch
[
"num_sym"
]
>
0
].
sort
()[
0
]
best_gt_asym
=
None
best_len
=
-
1
for
cur_num_sym
in
sorted_num_sym
:
if
cur_num_sym
<=
0
:
continue
cur_gt_sym
,
cur_len
=
find_by_num_sym
(
cur_num_sym
)
if
cur_len
>
best_len
:
best_len
=
cur_len
best_gt_asym
=
cur_gt_sym
if
best_len
>=
3
:
break
best_entity
=
batch
[
"entity_id"
][
batch
[
"asym_id"
]
==
best_gt_asym
][
0
]
best_pred_asym
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
best_entity
])
return
best_gt_asym
,
best_pred_asym
def
get_least_asym_entity_or_longest_length
(
batch
):
"""
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
one of the A as anchor
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor
"""
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
entity_asym_count
=
{}
entity_length
=
{}
for
entity_id
in
unique_entity_ids
:
asym_ids
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
entity_id
])
entity_asym_count
[
int
(
entity_id
)]
=
len
(
asym_ids
)
# Calculate entity length
entity_mask
=
(
batch
[
"entity_id"
]
==
entity_id
)
entity_length
[
int
(
entity_id
)]
=
entity_mask
.
sum
().
item
()
min_asym_count
=
min
(
entity_asym_count
.
values
())
least_asym_entities
=
[
entity
for
entity
,
count
in
entity_asym_count
.
items
()
if
count
==
min_asym_count
]
# If multiple entities have the least asym_id count, return those with the shortest length
if
len
(
least_asym_entities
)
>
1
:
max_length
=
max
([
entity_length
[
entity
]
for
entity
in
least_asym_entities
])
least_asym_entities
=
[
entity
for
entity
in
least_asym_entities
if
entity_length
[
entity
]
==
max_length
]
# If still multiple entities, return a random one
if
len
(
least_asym_entities
)
>
1
:
least_asym_entities
=
random
.
choice
(
least_asym_entities
)
print
(
f
"line 249 least_asym_entities is
{
least_asym_entities
}
and entity_length is
{
entity_length
}
"
)
assert
len
(
least_asym_entities
)
==
1
best_pred_asym
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
least_asym_entities
[
0
]])
return
least_asym_entities
[
0
],
best_pred_asym
def
greedy_align
(
...
...
@@ -251,27 +257,21 @@ def greedy_align(
cur_asym_list
=
entity_2_asym_list
[
int
(
cur_entity_ids
)]
cur_residue_index
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
cur_pred_pos
=
pred_ca_pos
[
asym_mask
[:,
0
],:]
# only need the first 1 column of asym_mask
print
(
f
"line 266 pred_ca_pos shape:
{
pred_ca_pos
.
shape
}
cur_pred_pos shape:
{
cur_pred_pos
.
shape
}
and pred_ca_pos is
{
pred_ca_pos
.
shape
}
"
)
cur_pred_mask
=
pred_ca_mask
[
asym_mask
[:,
0
]]
# only need the first column of asym_mask
cur_pred_pos
=
pred_ca_pos
[
asym_mask
]
# only need the first 1 column of asym_mask
cur_pred_mask
=
pred_ca_mask
[
asym_mask
]
for
next_asym_id
in
cur_asym_list
:
if
next_asym_id
==
0
:
continue
j
=
int
(
next_asym_id
-
1
)
if
not
used
[
j
]:
# possible candidate
print
(
f
"line 265 curr_residue_index is
{
cur_residue_index
}
and j is
{
j
}
"
)
print
(
f
"true_ca_poses shape:
{
true_ca_poses
[
j
].
shape
}
"
)
cropped_pos
=
true_ca_poses
[
j
]
mask
=
true_ca_masks
[
j
][
cur_residue_index
[:,
0
]]
print
(
f
"line 278 cur_pred_mask shape:
{
cur_pred_mask
.
shape
}
\n
mask shape:
{
mask
.
shape
}
"
)
print
(
f
"cropped_pos shape
{
cropped_pos
.
shape
}
and cur_pred_pos shape
{
cur_pred_pos
.
shape
}
"
)
mask
=
true_ca_masks
[
j
][
cur_residue_index
]
rmsd
=
compute_rmsd
(
cropped_pos
,
cur_pred_pos
,
(
cur_pred_mask
.
to
(
'cpu'
)
*
mask
.
to
(
'cpu'
)).
bool
()
)
if
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
best_idx
=
j
print
(
f
"rmds is now
{
rmsd
}
and best_idx is
{
best_idx
}
"
)
assert
best_idx
is
not
None
used
[
best_idx
]
=
True
align
.
append
((
i
,
best_idx
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment