Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
61cbc7d3
Unverified
Commit
61cbc7d3
authored
Sep 06, 2023
by
Christina Floristean
Committed by
GitHub
Sep 06, 2023
Browse files
Merge pull request #343 from dingquanyu/permutation
Update multi-chain permutation
parents
8e522aca
e097da95
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
241 additions
and
14 deletions
+241
-14
openfold/utils/loss.py
openfold/utils/loss.py
+72
-14
tests/test_permutation.py
tests/test_permutation.py
+169
-0
No files found.
openfold/utils/loss.py
View file @
61cbc7d3
...
@@ -1770,8 +1770,8 @@ def get_optimal_transform(
...
@@ -1770,8 +1770,8 @@ def get_optimal_transform(
else
:
else
:
src_atoms
=
src_atoms
[
mask
,
:]
src_atoms
=
src_atoms
[
mask
,
:]
tgt_atoms
=
tgt_atoms
[
mask
,
:]
tgt_atoms
=
tgt_atoms
[
mask
,
:]
src_center
=
src_atoms
.
mean
(
-
2
,
keepdim
=
True
)
src_center
=
src_atoms
.
mean
(
-
2
,
keepdim
=
True
,
dtype
=
src_atoms
.
dtype
)
tgt_center
=
tgt_atoms
.
mean
(
-
2
,
keepdim
=
True
)
tgt_center
=
tgt_atoms
.
mean
(
-
2
,
keepdim
=
True
,
dtype
=
src_atoms
.
dtype
)
r
=
kabsch_rotation
(
src_atoms
,
tgt_atoms
)
r
=
kabsch_rotation
(
src_atoms
,
tgt_atoms
)
del
src_atoms
,
tgt_atoms
,
del
src_atoms
,
tgt_atoms
,
gc
.
collect
()
gc
.
collect
()
...
@@ -1792,6 +1792,12 @@ def get_least_asym_entity_or_longest_length(batch):
...
@@ -1792,6 +1792,12 @@ def get_least_asym_entity_or_longest_length(batch):
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor
then choose one of the corresponding subunits as anchor
"""
"""
REQUIRED_FEATURES
=
[
'entity_id'
,
'asym_id'
]
seq_length
=
batch
[
'seq_length'
].
item
()
# remove padding part before selecting candidate
remove_padding
=
lambda
t
:
torch
.
index_select
(
t
,
dim
=
1
,
index
=
torch
.
arange
(
seq_length
,
device
=
t
.
device
))
batch
=
{
k
:
tensor_tree_map
(
remove_padding
,
batch
[
k
])
for
k
in
REQUIRED_FEATURES
}
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
entity_asym_count
=
{}
entity_asym_count
=
{}
entity_length
=
{}
entity_length
=
{}
...
@@ -1870,7 +1876,15 @@ def greedy_align(
...
@@ -1870,7 +1876,15 @@ def greedy_align(
return
align
return
align
def
merge_labels
(
per_asym_residue_index
,
labels
,
align
):
def
pad_features
(
feature_tensor
,
nres_pad
,
pad_dim
):
"""Pad input feature tensor"""
pad_shape
=
list
(
feature_tensor
.
shape
)
pad_shape
[
pad_dim
]
=
nres_pad
padding_tensor
=
feature_tensor
.
new_zeros
(
pad_shape
,
device
=
feature_tensor
.
device
)
return
torch
.
concat
((
feature_tensor
,
padding_tensor
),
dim
=
pad_dim
)
def
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
):
"""
"""
per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex.
per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex.
labels: list of original ground truth feats
labels: list of original ground truth feats
...
@@ -1897,6 +1911,10 @@ def merge_labels(per_asym_residue_index, labels, align):
...
@@ -1897,6 +1911,10 @@ def merge_labels(per_asym_residue_index, labels, align):
cur_out
=
[
x
[
1
]
for
x
in
sorted
(
cur_out
.
items
())]
cur_out
=
[
x
[
1
]
for
x
in
sorted
(
cur_out
.
items
())]
if
len
(
cur_out
)
>
0
:
if
len
(
cur_out
)
>
0
:
new_v
=
torch
.
concat
(
cur_out
,
dim
=
dimension_to_merge
)
new_v
=
torch
.
concat
(
cur_out
,
dim
=
dimension_to_merge
)
# below check whether padding is needed
if
new_v
.
shape
[
dimension_to_merge
]
!=
original_nres
:
nres_pad
=
original_nres
-
new_v
.
shape
[
dimension_to_merge
]
new_v
=
pad_features
(
new_v
,
nres_pad
,
pad_dim
=
dimension_to_merge
)
outs
[
k
]
=
new_v
outs
[
k
]
=
new_v
return
outs
return
outs
...
@@ -2019,9 +2037,35 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2019,9 +2037,35 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
AlphaFoldMultimerLoss
,
self
).
__init__
(
config
)
super
(
AlphaFoldMultimerLoss
,
self
).
__init__
(
config
)
self
.
config
=
config
self
.
config
=
config
@
staticmethod
def
determine_split_dim
(
batch
)
->
dict
:
"""A method to determine which dimension to split in split_ground_truth_labels"""
padded_dim
=
batch
[
'aatype'
].
shape
[
-
1
]
dim_dict
=
{
k
:
list
(
v
.
shape
).
index
(
padded_dim
)
for
k
,
v
in
batch
.
items
()
if
padded_dim
in
v
.
shape
}
return
dim_dict
@
staticmethod
@
staticmethod
def
multi_chain_perm_align
(
out
,
batch
,
labels
,
permutate_chains
=
True
):
def
split_ground_truth_labels
(
batch
,
REQUIRED_FEATURES
,
dim_dict
):
"""
Splits ground truth features according to chains
Returns a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation
"""
unique_asym_ids
,
asym_id_counts
=
torch
.
unique
(
batch
[
"asym_id"
],
sorted
=
False
,
return_counts
=
True
)
unique_asym_ids
,
asym_id_counts
=
unique_asym_ids
.
tolist
(),
asym_id_counts
.
tolist
()
if
0
in
unique_asym_ids
:
pop_idx
=
unique_asym_ids
.
index
(
0
)
padding_asym_id
=
unique_asym_ids
.
pop
(
pop_idx
)
padding_asym_counts
=
asym_id_counts
.
pop
(
pop_idx
)
unique_asym_ids
.
append
(
padding_asym_id
)
asym_id_counts
.
append
(
padding_asym_counts
)
labels
=
list
(
map
(
dict
,
zip
(
*
[[(
k
,
v
)
for
v
in
torch
.
split
(
value
,
asym_id_counts
,
dim
=
dim_dict
[
k
])]
for
k
,
value
in
batch
.
items
()
if
k
in
REQUIRED_FEATURES
])))
return
labels
@
staticmethod
def
multi_chain_perm_align
(
out
,
batch
,
dim_dict
,
permutate_chains
=
True
):
"""
"""
A class method that first permutate chains in ground truth first
A class method that first permutate chains in ground truth first
before calculating the loss.
before calculating the loss.
...
@@ -2029,6 +2073,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2029,6 +2073,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
"""
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
batch
,
dim_dict
=
dim_dict
,
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
assert
isinstance
(
labels
,
list
)
assert
isinstance
(
labels
,
list
)
ca_idx
=
rc
.
atom_order
[
"CA"
]
ca_idx
=
rc
.
atom_order
[
"CA"
]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
...
@@ -2041,7 +2087,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2041,7 +2087,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
true_ca_masks
=
[
true_ca_masks
=
[
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
]
# list([nres,])
]
# list([nres,])
unique_asym_ids
=
torch
.
unique
(
batch
[
"asym_id"
])
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
per_asym_residue_index
=
{}
per_asym_residue_index
=
{}
for
cur_asym_id
in
unique_asym_ids
:
for
cur_asym_id
in
unique_asym_ids
:
...
@@ -2049,6 +2095,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2049,6 +2095,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
torch
.
masked_select
(
batch
[
"residue_index"
],
asym_mask
)
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
torch
.
masked_select
(
batch
[
"residue_index"
],
asym_mask
)
if
permutate_chains
:
if
permutate_chains
:
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
print
(
f
"anchor_gt_asym:
{
anchor_gt_asym
}
anchor_pred_asym:
{
anchor_pred_asym
}
"
)
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
...
@@ -2074,7 +2121,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2074,7 +2121,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del
anchor_pred_mask
del
anchor_pred_mask
del
anchor_true_mask
del
anchor_true_mask
gc
.
collect
()
gc
.
collect
()
aligned_true_ca_poses
=
[
ca
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
del
true_ca_poses
del
true_ca_poses
gc
.
collect
()
gc
.
collect
()
align
=
greedy_align
(
align
=
greedy_align
(
...
@@ -2099,7 +2146,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2099,7 +2146,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
return
align
,
per_asym_residue_index
return
align
,
per_asym_residue_index
def
forward
(
self
,
out
,
features
,
_return_breakdown
=
False
,
permutate_chains
=
True
):
def
forward
(
self
,
out
,
features
,
_return_breakdown
=
False
,
permutate_chains
=
True
):
"""
"""
Overwrite AlphaFoldLoss forward function so that
Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation
it first compute multi-chain permutation
...
@@ -2108,12 +2155,23 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2108,12 +2155,23 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
out: the output of model.forward()
out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure
batch: a pair of input features and its corresponding ground truth structure
"""
"""
# permutate ground truth chains before calculating the loss
# first check if it is a monomer
# align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features, labels,
is_monomer
=
len
(
torch
.
unique
(
features
[
'asym_id'
]))
==
1
or
torch
.
unique
(
features
[
'asym_id'
]).
tolist
()
==
[
0
,
1
]
# permutate_chains=permutate_chains)
if
not
is_monomer
:
# permutated_labels = merge_labels(per_asym_residue_index, labels, align)
permutate_chains
=
True
# permutated_labels.pop('aatype')
# first determin which dimension in the tensor to split into individual ground truth labels
# features.update(permutated_labels)
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
features
)
# Then permutate ground truth chains before calculating the loss
align
,
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
features
,
dim_dict
=
dim_dict
,
permutate_chains
=
permutate_chains
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
features
,
dim_dict
=
dim_dict
,
REQUIRED_FEATURES
=
[
i
for
i
in
features
.
keys
()
if
i
in
dim_dict
])
# reorder ground truth labels according to permutation results
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
=
features
[
'aatype'
].
shape
[
-
1
])
features
.
update
(
labels
)
if
(
not
_return_breakdown
):
if
(
not
_return_breakdown
):
cum_loss
=
self
.
loss
(
out
,
features
,
_return_breakdown
)
cum_loss
=
self
.
loss
(
out
,
features
,
_return_breakdown
)
...
@@ -2122,4 +2180,4 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2122,4 +2180,4 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
else
:
else
:
cum_loss
,
losses
=
self
.
loss
(
out
,
features
,
_return_breakdown
)
cum_loss
,
losses
=
self
.
loss
(
out
,
features
,
_return_breakdown
)
print
(
f
"cum_loss:
{
cum_loss
}
losses:
{
losses
}
"
)
print
(
f
"cum_loss:
{
cum_loss
}
losses:
{
losses
}
"
)
return
cum_loss
,
losses
return
cum_loss
,
losses
\ No newline at end of file
tests/test_permutation.py
0 → 100644
View file @
61cbc7d3
# Copyright 2021 AlQuraishi Laboratory
# Dingquan Yu @ EMBL-Hamburg Kosinski group
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
unittest
from
openfold.utils.loss
import
AlphaFoldMultimerLoss
from
openfold.utils.loss
import
get_least_asym_entity_or_longest_length
,
merge_labels
,
pad_features
from
openfold.utils.tensor_utils
import
tensor_tree_map
import
math
class
TestPermutation
(
unittest
.
TestCase
):
def
setUp
(
self
):
"""
create fake input structure features
and rotation matrices
"""
theta
=
math
.
pi
/
4
self
.
rotation_matrix_z
=
torch
.
tensor
([
[
math
.
cos
(
theta
),
-
math
.
sin
(
theta
),
0
],
[
math
.
sin
(
theta
),
math
.
cos
(
theta
),
0
],
[
0
,
0
,
1
]
],
device
=
'cuda'
)
self
.
rotation_matrix_x
=
torch
.
tensor
([
[
1
,
0
,
0
],
[
0
,
math
.
cos
(
theta
),
-
math
.
sin
(
theta
)],
[
0
,
math
.
sin
(
theta
),
math
.
cos
(
theta
)],
],
device
=
'cuda'
)
self
.
rotation_matrix_y
=
torch
.
tensor
([
[
math
.
cos
(
theta
),
0
,
math
.
sin
(
theta
)],
[
0
,
1
,
0
],
[
-
math
.
sin
(
theta
),
1
,
math
.
cos
(
theta
)],
],
device
=
'cuda'
)
self
.
chain_a_num_res
=
9
self
.
chain_b_num_res
=
13
# below create default fake ground truth structures for a hetero-pentamer A2B3
self
.
residue_index
=
list
(
range
(
self
.
chain_a_num_res
))
*
2
+
list
(
range
(
self
.
chain_b_num_res
))
*
3
self
.
num_res
=
self
.
chain_a_num_res
*
2
+
self
.
chain_b_num_res
*
3
self
.
asym_id
=
torch
.
tensor
([[
1
]
*
self
.
chain_a_num_res
+
[
2
]
*
self
.
chain_a_num_res
+
[
3
]
*
self
.
chain_b_num_res
+
[
4
]
*
self
.
chain_b_num_res
+
[
5
]
*
self
.
chain_b_num_res
],
device
=
'cuda'
)
self
.
sym_id
=
self
.
asym_id
self
.
entity_id
=
torch
.
tensor
([[
1
]
*
(
self
.
chain_a_num_res
*
2
)
+
[
2
]
*
(
self
.
chain_b_num_res
*
3
)],
device
=
'cuda'
)
def
test_1_selecting_anchors
(
self
):
self
.
batch
=
{
'asym_id'
:
self
.
asym_id
,
'sym_id'
:
self
.
sym_id
,
'entity_id'
:
self
.
entity_id
,
'seq_length'
:
torch
.
tensor
([
57
])
}
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
self
.
batch
)
self
.
assertIn
(
int
(
anchor_gt_asym
),[
1
,
2
])
self
.
assertNotIn
(
int
(
anchor_gt_asym
),[
3
,
4
,
5
])
self
.
assertIn
(
int
(
anchor_pred_asym
),[
1
,
2
])
self
.
assertNotIn
(
int
(
anchor_pred_asym
),[
3
,
4
,
5
])
def
test_2_permutation_pentamer
(
self
):
batch
=
{
'asym_id'
:
self
.
asym_id
,
'sym_id'
:
self
.
sym_id
,
'entity_id'
:
self
.
entity_id
,
'seq_length'
:
torch
.
tensor
([
57
]),
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
57
))
}
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
self
.
num_res
)
batch
[
"residue_index"
]
=
torch
.
tensor
([
self
.
residue_index
],
device
=
'cuda'
)
# create fake ground truth atom positions
chain_a1_pos
=
torch
.
randint
(
15
,(
self
.
chain_a_num_res
,
3
*
37
),
device
=
'cuda'
,
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
chain_a2_pos
=
torch
.
matmul
(
chain_a1_pos
,
self
.
rotation_matrix_x
)
+
10
chain_b1_pos
=
torch
.
randint
(
low
=
15
,
high
=
30
,
size
=
(
self
.
chain_b_num_res
,
3
*
37
),
device
=
'cuda'
,
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_b_num_res
,
37
,
3
)
chain_b2_pos
=
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_y
)
+
10
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
# Below permutate predicted chain positions
pred_atom_position
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
pred_atom_mask
=
torch
.
ones
((
1
,
self
.
num_res
,
37
),
device
=
'cuda'
)
out
=
{
'final_atom_positions'
:
pred_atom_position
,
'final_atom_mask'
:
pred_atom_mask
}
true_atom_position
=
torch
.
cat
((
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_mask
=
torch
.
cat
((
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
),
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
),
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
),
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
),
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
),
device
=
'cuda'
)),
dim
=
1
)
batch
[
'all_atom_positions'
]
=
true_atom_position
batch
[
'all_atom_mask'
]
=
true_atom_mask
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
batch
)
aligns
,
_
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
batch
,
dim_dict
,
permutate_chains
=
True
)
possible_outcome
=
[[(
0
,
1
),(
1
,
0
),(
2
,
3
),(
3
,
4
),(
4
,
2
)],[(
0
,
0
),(
1
,
1
),(
2
,
3
),(
3
,
4
),(
4
,
2
)]]
wrong_outcome
=
[[(
0
,
1
),(
1
,
0
),(
2
,
4
),(
3
,
2
),(
4
,
3
)],[(
0
,
0
),(
1
,
1
),(
2
,
2
),(
3
,
3
),(
4
,
4
)]]
self
.
assertIn
(
aligns
,
possible_outcome
)
self
.
assertNotIn
(
aligns
,
wrong_outcome
)
def
test_3_merge_labels
(
self
):
nres_pad
=
325
-
57
# suppose the cropping size is 325
batch
=
{
'asym_id'
:
pad_features
(
self
.
asym_id
,
nres_pad
,
pad_dim
=
1
),
'sym_id'
:
pad_features
(
self
.
sym_id
,
nres_pad
,
pad_dim
=
1
),
'entity_id'
:
pad_features
(
self
.
entity_id
,
nres_pad
,
pad_dim
=
1
),
'aatype'
:
torch
.
randint
(
21
,
size
=
(
1
,
325
)),
'seq_length'
:
torch
.
tensor
([
57
])
}
batch
[
'asym_id'
]
=
batch
[
'asym_id'
].
reshape
(
1
,
325
)
batch
[
"residue_index"
]
=
pad_features
(
torch
.
tensor
(
self
.
residue_index
).
reshape
(
1
,
57
),
nres_pad
,
pad_dim
=
1
)
# create fake ground truth atom positions
chain_a1_pos
=
torch
.
randint
(
15
,(
self
.
chain_a_num_res
,
3
*
37
),
device
=
'cuda'
,
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_a_num_res
,
37
,
3
)
chain_a2_pos
=
torch
.
matmul
(
chain_a1_pos
,
self
.
rotation_matrix_x
)
+
10
chain_b1_pos
=
torch
.
randint
(
low
=
15
,
high
=
30
,
size
=
(
self
.
chain_b_num_res
,
3
*
37
),
device
=
'cuda'
,
dtype
=
torch
.
float
).
reshape
(
1
,
self
.
chain_b_num_res
,
37
,
3
)
chain_b2_pos
=
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_y
)
+
10
chain_b3_pos
=
torch
.
matmul
(
torch
.
matmul
(
chain_b1_pos
,
self
.
rotation_matrix_z
),
self
.
rotation_matrix_x
)
+
30
# Below permutate predicted chain positions
pred_atom_position
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
pred_atom_mask
=
torch
.
ones
((
1
,
self
.
num_res
,
37
),
device
=
'cuda'
)
pred_atom_position
=
pad_features
(
pred_atom_position
,
nres_pad
,
pad_dim
=
1
)
pred_atom_mask
=
pad_features
(
pred_atom_mask
,
nres_pad
,
pad_dim
=
1
)
out
=
{
'final_atom_positions'
:
pred_atom_position
,
'final_atom_mask'
:
pred_atom_mask
}
true_atom_position
=
torch
.
cat
((
chain_a1_pos
,
chain_a2_pos
,
chain_b1_pos
,
chain_b2_pos
,
chain_b3_pos
),
dim
=
1
)
true_atom_mask
=
torch
.
cat
((
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
),
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_a_num_res
,
37
),
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
),
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
),
device
=
'cuda'
),
torch
.
ones
((
1
,
self
.
chain_b_num_res
,
37
),
device
=
'cuda'
)),
dim
=
1
)
batch
[
'all_atom_positions'
]
=
pad_features
(
true_atom_position
,
nres_pad
,
pad_dim
=
1
)
batch
[
'all_atom_mask'
]
=
pad_features
(
true_atom_mask
,
nres_pad
=
nres_pad
,
pad_dim
=
1
)
tensor_to_cuda
=
lambda
t
:
t
.
to
(
'cuda'
)
batch
=
tensor_tree_map
(
tensor_to_cuda
,
batch
)
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
batch
)
aligns
,
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
batch
,
dim_dict
,
permutate_chains
=
True
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
batch
,
dim_dict
=
dim_dict
,
REQUIRED_FEATURES
=
[
i
for
i
in
batch
.
keys
()
if
i
in
dim_dict
])
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
aligns
,
original_nres
=
batch
[
'aatype'
].
shape
[
-
1
])
self
.
assertTrue
(
torch
.
equal
(
labels
[
'residue_index'
],
batch
[
'residue_index'
]))
expected_permutated_gt_pos
=
torch
.
cat
((
chain_a2_pos
,
chain_a1_pos
,
chain_b2_pos
,
chain_b3_pos
,
chain_b1_pos
),
dim
=
1
)
expected_permutated_gt_pos
=
pad_features
(
expected_permutated_gt_pos
,
nres_pad
,
pad_dim
=
1
)
self
.
assertTrue
(
torch
.
equal
(
labels
[
'all_atom_positions'
],
expected_permutated_gt_pos
))
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment