Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
a0f8a057
Commit
a0f8a057
authored
Jun 22, 2023
by
Geoffrey Yu
Browse files
now used the new way of selecting anchors
parent
54755901
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
57 additions
and
57 deletions
+57
-57
tests/test_permutation.py
tests/test_permutation.py
+5
-5
tests/unifold_permutation.py
tests/unifold_permutation.py
+52
-52
No files found.
tests/test_permutation.py
View file @
a0f8a057
...
@@ -40,12 +40,12 @@ class TestPermutation(unittest.TestCase):
...
@@ -40,12 +40,12 @@ class TestPermutation(unittest.TestCase):
In the test case, use PDB ID 1e4k as the label
In the test case, use PDB ID 1e4k as the label
"""
"""
self
.
test_data_dir
=
os
.
path
.
join
(
os
.
getcwd
(),
"tests/test_data"
)
self
.
test_data_dir
=
os
.
path
.
join
(
os
.
getcwd
(),
"tests/test_data"
)
self
.
label_ids
=
[
'label_1'
,
'label_2'
]
self
.
label_ids
=
[
'label_1'
,
'label_2'
,
'label_2'
]
def
test_dry_run
(
self
):
def
test_dry_run
(
self
):
n_seq
=
consts
.
n_seq
n_seq
=
consts
.
n_seq
n_templ
=
consts
.
n_templ
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
+
13
n_extra_seq
=
consts
.
n_extra
n_extra_seq
=
consts
.
n_extra
c
=
model_config
(
consts
.
model
,
train
=
True
)
c
=
model_config
(
consts
.
model
,
train
=
True
)
...
@@ -83,12 +83,12 @@ class TestPermutation(unittest.TestCase):
...
@@ -83,12 +83,12 @@ class TestPermutation(unittest.TestCase):
# Modify asym_id, entity_id and sym_id so that it encodes
# Modify asym_id, entity_id and sym_id so that it encodes
# 2 chains
# 2 chains
# #
# #
asym_id
=
[
1
]
*
9
+
[
2
]
*
13
asym_id
=
[
1
]
*
9
+
[
2
]
*
13
+
[
3
]
*
13
batch
[
"asym_id"
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
batch
[
"asym_id"
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
# batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
# batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
batch
[
'entity_id'
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
batch
[
'entity_id'
]
=
torch
.
tensor
(
[
1
]
*
9
+
[
2
]
*
26
,
dtype
=
torch
.
float64
)
batch
[
"sym_id"
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
batch
[
"sym_id"
]
=
torch
.
tensor
(
asym_id
,
dtype
=
torch
.
float64
)
batch
[
"num_sym"
]
=
torch
.
tensor
([
2
]
*
22
,
dtype
=
torch
.
int64
)
# currently there are just 2 chains
batch
[
"num_sym"
]
=
torch
.
tensor
([
1
]
*
9
+
[
2
]
*
26
,
dtype
=
torch
.
int64
)
# currently there are just 2 chains
batch
[
"extra_deletion_matrix"
]
=
torch
.
randint
(
0
,
2
,
size
=
(
n_extra_seq
,
n_res
))
batch
[
"extra_deletion_matrix"
]
=
torch
.
randint
(
0
,
2
,
size
=
(
n_extra_seq
,
n_res
))
add_recycling_dims
=
lambda
t
:
(
add_recycling_dims
=
lambda
t
:
(
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
...
...
tests/unifold_permutation.py
View file @
a0f8a057
...
@@ -3,6 +3,7 @@ from openfold.np import residue_constants as rc
...
@@ -3,6 +3,7 @@ from openfold.np import residue_constants as rc
import
logging
import
logging
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
import
sys
import
sys
import
random
def
kabsch_rotation
(
P
,
Q
):
def
kabsch_rotation
(
P
,
Q
):
"""
"""
...
@@ -119,10 +120,13 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
...
@@ -119,10 +120,13 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
for
cur_asym_id
in
unique_asym_ids
:
for
cur_asym_id
in
unique_asym_ids
:
asym_mask
=
(
batch
[
"asym_id"
]
==
cur_asym_id
).
bool
()
asym_mask
=
(
batch
[
"asym_id"
]
==
cur_asym_id
).
bool
()
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
batch
[
"residue_index"
][
asym_mask
]
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
batch
[
"residue_index"
][
asym_mask
]
anchor_gt_asym
,
anchor_pred_asym
=
get_anchor_candidates
(
# anchor_gt_asym, anchor_pred_asym = get_anchor_candidates(
batch
,
per_asym_residue_index
,
true_ca_masks
# batch, per_asym_residue_index, true_ca_masks
)
# )
print
(
f
"anchor_gt_asym is
{
anchor_gt_asym
}
, anchor_pred_asym is
{
anchor_pred_asym
}
"
)
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
print
(
f
"anchor_gt_asym is
{
anchor_gt_asym
}
"
)
import
sys
sys
.
exit
()
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
best_rmsd
=
1e9
best_rmsd
=
1e9
...
@@ -134,7 +138,7 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
...
@@ -134,7 +138,7 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
ent_mask
=
batch
[
"entity_id"
]
==
cur_ent_id
ent_mask
=
batch
[
"entity_id"
]
==
cur_ent_id
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
print
(
f
"entity_2_asym_list is
{
entity_2_asym_list
}
"
)
for
cur_asym_id
in
anchor_pred_asym
:
for
cur_asym_id
in
anchor_pred_asym
:
asym_mask
=
(
batch
[
"asym_id"
]
==
cur_asym_id
).
bool
()
asym_mask
=
(
batch
[
"asym_id"
]
==
cur_asym_id
).
bool
()
anchor_residue_idx
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
anchor_residue_idx
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
...
@@ -148,9 +152,7 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
...
@@ -148,9 +152,7 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
anchor_pred_pos
,
anchor_pred_pos
,
(
anchor_true_mask
.
to
(
'cpu'
)
*
anchor_pred_mask
.
to
(
'cpu'
)).
bool
(),
(
anchor_true_mask
.
to
(
'cpu'
)
*
anchor_pred_mask
.
to
(
'cpu'
)).
bool
(),
)
)
print
(
f
"finished getting optimal transform"
)
import
sys
sys
.
exit
()
aligned_true_ca_poses
=
[
ca
.
to
(
'cpu'
)
@
r
.
to
(
'cpu'
)
+
x
.
to
(
'cpu'
)
for
ca
in
true_ca_poses
]
# apply transforms
aligned_true_ca_poses
=
[
ca
.
to
(
'cpu'
)
@
r
.
to
(
'cpu'
)
+
x
.
to
(
'cpu'
)
for
ca
in
true_ca_poses
]
# apply transforms
for
_
in
range
(
shuffle_times
):
for
_
in
range
(
shuffle_times
):
shuffle_idx
=
torch
.
randperm
(
shuffle_idx
=
torch
.
randperm
(
...
@@ -167,57 +169,61 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
...
@@ -167,57 +169,61 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
aligned_true_ca_poses
,
aligned_true_ca_poses
,
true_ca_masks
,
true_ca_masks
,
)
)
merged_labels
=
merge_labels
(
merged_labels
=
merge_labels
(
batch
,
batch
,
per_asym_residue_index
,
per_asym_residue_index
,
labels
,
labels
,
align
,
align
,
)
)
rmsd
=
kabsch_rmsd
(
rmsd
=
kabsch_rmsd
(
merged_labels
[
"all_atom_positions"
][...,
ca_idx
,
:]
@
r
+
x
,
merged_labels
[
"all_atom_positions"
][...,
ca_idx
,
:]
.
to
(
'cpu'
)
@
r
.
to
(
'cpu'
)
+
x
.
to
(
'cpu'
)
,
pred_ca_pos
,
pred_ca_pos
,
(
pred_ca_mask
*
merged_labels
[
"all_atom_mask"
][...,
ca_idx
]).
bool
(),
(
pred_ca_mask
.
to
(
'cpu'
)
*
merged_labels
[
"all_atom_mask"
][...,
ca_idx
]
.
to
(
'cpu'
)
).
bool
(),
)
)
if
rmsd
<
best_rmsd
:
if
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
best_rmsd
=
rmsd
best_labels
=
merged_labels
best_labels
=
merged_labels
print
(
f
"finished kabsh_rmsd"
)
return
best_labels
return
best_labels
def
get_anchor_candidates
(
batch
,
per_asym_residue_index
,
true_masks
):
def
get_least_asym_entity_or_longest_length
(
batch
):
def
find_by_num_sym
(
min_num_sym
):
"""
best_len
=
-
1
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
best_gt_asym
=
None
one of the A as anchor
asym_ids
=
batch
[
"asym_id"
][
batch
[
"num_sym"
]
==
min_num_sym
]
asym_ids
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"num_sym"
]
==
min_num_sym
])
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
for
cur_asym_id
in
asym_ids
:
then choose one of the corresponding subunits as anchor
assert
cur_asym_id
>
0
"""
cur_residue_index
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
j
=
int
(
cur_asym_id
-
1
)
entity_asym_count
=
{}
cur_true_mask
=
true_masks
[
j
][
cur_residue_index
]
entity_length
=
{}
cur_len
=
cur_true_mask
.
shape
[
0
]
if
cur_len
>
best_len
:
for
entity_id
in
unique_entity_ids
:
best_len
=
cur_len
asym_ids
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
entity_id
])
best_gt_asym
=
cur_asym_id
entity_asym_count
[
int
(
entity_id
)]
=
len
(
asym_ids
)
return
best_gt_asym
,
best_len
sorted_num_sym
=
batch
[
"num_sym"
][
batch
[
"num_sym"
]
>
0
].
sort
()[
0
]
# Calculate entity length
best_gt_asym
=
None
entity_mask
=
(
batch
[
"entity_id"
]
==
entity_id
)
best_len
=
-
1
entity_length
[
int
(
entity_id
)]
=
entity_mask
.
sum
().
item
()
for
cur_num_sym
in
sorted_num_sym
:
if
cur_num_sym
<=
0
:
min_asym_count
=
min
(
entity_asym_count
.
values
())
continue
least_asym_entities
=
[
entity
for
entity
,
count
in
entity_asym_count
.
items
()
if
count
==
min_asym_count
]
cur_gt_sym
,
cur_len
=
find_by_num_sym
(
cur_num_sym
)
if
cur_len
>
best_len
:
# If multiple entities have the least asym_id count, return those with the shortest length
best_len
=
cur_len
if
len
(
least_asym_entities
)
>
1
:
best_gt_asym
=
cur_gt_sym
max_length
=
max
([
entity_length
[
entity
]
for
entity
in
least_asym_entities
])
if
best_len
>=
3
:
least_asym_entities
=
[
entity
for
entity
in
least_asym_entities
if
entity_length
[
entity
]
==
max_length
]
break
best_entity
=
batch
[
"entity_id"
][
batch
[
"asym_id"
]
==
best_gt_asym
][
0
]
# If still multiple entities, return a random one
best_pred_asym
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
best_entity
])
if
len
(
least_asym_entities
)
>
1
:
return
best_gt_asym
,
best_pred_asym
least_asym_entities
=
random
.
choice
(
least_asym_entities
)
print
(
f
"line 249 least_asym_entities is
{
least_asym_entities
}
and entity_length is
{
entity_length
}
"
)
assert
len
(
least_asym_entities
)
==
1
best_pred_asym
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
least_asym_entities
[
0
]])
return
least_asym_entities
[
0
],
best_pred_asym
def
greedy_align
(
def
greedy_align
(
...
@@ -251,27 +257,21 @@ def greedy_align(
...
@@ -251,27 +257,21 @@ def greedy_align(
cur_asym_list
=
entity_2_asym_list
[
int
(
cur_entity_ids
)]
cur_asym_list
=
entity_2_asym_list
[
int
(
cur_entity_ids
)]
cur_residue_index
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
cur_residue_index
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
cur_pred_pos
=
pred_ca_pos
[
asym_mask
[:,
0
],:]
# only need the first 1 column of asym_mask
cur_pred_pos
=
pred_ca_pos
[
asym_mask
]
# only need the first 1 column of asym_mask
print
(
f
"line 266 pred_ca_pos shape:
{
pred_ca_pos
.
shape
}
cur_pred_pos shape:
{
cur_pred_pos
.
shape
}
and pred_ca_pos is
{
pred_ca_pos
.
shape
}
"
)
cur_pred_mask
=
pred_ca_mask
[
asym_mask
]
cur_pred_mask
=
pred_ca_mask
[
asym_mask
[:,
0
]]
# only need the first column of asym_mask
for
next_asym_id
in
cur_asym_list
:
for
next_asym_id
in
cur_asym_list
:
if
next_asym_id
==
0
:
if
next_asym_id
==
0
:
continue
continue
j
=
int
(
next_asym_id
-
1
)
j
=
int
(
next_asym_id
-
1
)
if
not
used
[
j
]:
# possible candidate
if
not
used
[
j
]:
# possible candidate
print
(
f
"line 265 curr_residue_index is
{
cur_residue_index
}
and j is
{
j
}
"
)
print
(
f
"true_ca_poses shape:
{
true_ca_poses
[
j
].
shape
}
"
)
cropped_pos
=
true_ca_poses
[
j
]
cropped_pos
=
true_ca_poses
[
j
]
mask
=
true_ca_masks
[
j
][
cur_residue_index
[:,
0
]]
mask
=
true_ca_masks
[
j
][
cur_residue_index
]
print
(
f
"line 278 cur_pred_mask shape:
{
cur_pred_mask
.
shape
}
\n
mask shape:
{
mask
.
shape
}
"
)
print
(
f
"cropped_pos shape
{
cropped_pos
.
shape
}
and cur_pred_pos shape
{
cur_pred_pos
.
shape
}
"
)
rmsd
=
compute_rmsd
(
rmsd
=
compute_rmsd
(
cropped_pos
,
cur_pred_pos
,
(
cur_pred_mask
.
to
(
'cpu'
)
*
mask
.
to
(
'cpu'
)).
bool
()
cropped_pos
,
cur_pred_pos
,
(
cur_pred_mask
.
to
(
'cpu'
)
*
mask
.
to
(
'cpu'
)).
bool
()
)
)
if
rmsd
<
best_rmsd
:
if
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
best_rmsd
=
rmsd
best_idx
=
j
best_idx
=
j
print
(
f
"rmds is now
{
rmsd
}
and best_idx is
{
best_idx
}
"
)
assert
best_idx
is
not
None
assert
best_idx
is
not
None
used
[
best_idx
]
=
True
used
[
best_idx
]
=
True
align
.
append
((
i
,
best_idx
))
align
.
append
((
i
,
best_idx
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment