Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
260db67f
"...network/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "c0e008b4fe459a60f2bb3e8784b3cf5c16b71afd"
Commit
260db67f
authored
May 12, 2022
by
Gustaf Ahdritz
Browse files
Finish multimer inference
parent
6e68d6b0
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
205 additions
and
132 deletions
+205
-132
openfold/data/input_pipeline_multimer.py
openfold/data/input_pipeline_multimer.py
+1
-0
openfold/model/embedders.py
openfold/model/embedders.py
+1
-1
openfold/model/structure_module.py
openfold/model/structure_module.py
+3
-6
openfold/np/protein.py
openfold/np/protein.py
+1
-1
openfold/utils/feats.py
openfold/utils/feats.py
+9
-9
openfold/utils/geometry/quat_rigid.py
openfold/utils/geometry/quat_rigid.py
+0
-2
openfold/utils/geometry/rigid_matrix_vector.py
openfold/utils/geometry/rigid_matrix_vector.py
+146
-98
openfold/utils/geometry/rotation_matrix.py
openfold/utils/geometry/rotation_matrix.py
+27
-9
openfold/utils/geometry/vector.py
openfold/utils/geometry/vector.py
+10
-3
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+2
-2
run_pretrained_openfold.py
run_pretrained_openfold.py
+5
-1
No files found.
openfold/data/input_pipeline_multimer.py
View file @
260db67f
...
@@ -29,6 +29,7 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
...
@@ -29,6 +29,7 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms
.
cast_to_64bit_ints
,
data_transforms
.
cast_to_64bit_ints
,
data_transforms_multimer
.
make_msa_profile
,
data_transforms_multimer
.
make_msa_profile
,
data_transforms_multimer
.
create_target_feat
,
data_transforms_multimer
.
create_target_feat
,
data_transforms
.
make_atom14_masks
,
]
]
if
(
common_cfg
.
use_templates
):
if
(
common_cfg
.
use_templates
):
...
...
openfold/model/embedders.py
View file @
260db67f
...
@@ -868,7 +868,7 @@ class TemplateEmbedderMultimer(nn.Module):
...
@@ -868,7 +868,7 @@ class TemplateEmbedderMultimer(nn.Module):
raw_atom_pos
=
single_template_feats
[
"template_all_atom_positions"
]
raw_atom_pos
=
single_template_feats
[
"template_all_atom_positions"
]
atom_pos
=
geometry
.
Vec3Array
.
from_
tensor
(
raw_atom_pos
)
atom_pos
=
geometry
.
Vec3Array
.
from_
array
(
raw_atom_pos
)
rigid
,
backbone_mask
=
all_atom_multimer
.
make_backbone_affine
(
rigid
,
backbone_mask
=
all_atom_multimer
.
make_backbone_affine
(
atom_pos
,
atom_pos
,
single_template_feats
[
"template_all_atom_mask"
],
single_template_feats
[
"template_all_atom_mask"
],
...
...
openfold/model/structure_module.py
View file @
260db67f
...
@@ -363,9 +363,6 @@ class InvariantPointAttention(nn.Module):
...
@@ -363,9 +363,6 @@ class InvariantPointAttention(nn.Module):
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
for
c
in
q_pts
:
print
(
type
(
c
))
# [*, N_res, N_res, H, P_q, 3]
# [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
[...,
None
,
:,
:]
-
k_pts
[...,
None
,
:,
:,
:]
pt_att
=
q_pts
[...,
None
,
:,
:]
-
k_pts
[...,
None
,
:,
:,
:]
...
@@ -669,7 +666,7 @@ class StructureModule(nn.Module):
...
@@ -669,7 +666,7 @@ class StructureModule(nn.Module):
self
,
r
,
f
# [*, N, 8] # [*, N]
self
,
r
,
f
# [*, N, 8] # [*, N]
):
):
# Lazily initialize the residue constants on the correct device
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
r
.
get_rots
().
dtype
,
r
.
get_rots
()
.
device
)
self
.
_init_residue_constants
(
r
.
dtype
,
r
.
device
)
return
frames_and_literature_positions_to_atom14_pos
(
return
frames_and_literature_positions_to_atom14_pos
(
r
,
r
,
f
,
f
,
...
@@ -818,11 +815,11 @@ class StructureModule(nn.Module):
...
@@ -818,11 +815,11 @@ class StructureModule(nn.Module):
)
)
preds
=
{
preds
=
{
"frames"
:
rigids
.
scale_translation
(
self
.
trans_scale_factor
).
to_tensor
7
(),
"frames"
:
rigids
.
scale_translation
(
self
.
trans_scale_factor
).
to_tensor
(),
"sidechain_frames"
:
all_frames_to_global
.
to_tensor_4x4
(),
"sidechain_frames"
:
all_frames_to_global
.
to_tensor_4x4
(),
"unnormalized_angles"
:
unnormalized_angles
,
"unnormalized_angles"
:
unnormalized_angles
,
"angles"
:
angles
,
"angles"
:
angles
,
"positions"
:
pred_xyz
,
"positions"
:
pred_xyz
.
to_tensor
()
,
}
}
outputs
.
append
(
preds
)
outputs
.
append
(
preds
)
...
...
openfold/np/protein.py
View file @
260db67f
...
@@ -205,7 +205,7 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
...
@@ -205,7 +205,7 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
)
)
def
_chain_end
(
atom_index
,
end_resname
,
chain_name
,
residue_indx
)
->
str
:
def
_chain_end
(
atom_index
,
end_resname
,
chain_name
,
residue_ind
e
x
)
->
str
:
chain_end
=
'TER'
chain_end
=
'TER'
return
(
return
(
f
'
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
{
end_resname
:
>
3
}
'
f
'
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
{
end_resname
:
>
3
}
'
...
...
openfold/utils/feats.py
View file @
260db67f
...
@@ -22,6 +22,7 @@ from typing import Dict
...
@@ -22,6 +22,7 @@ from typing import Dict
from
openfold.np
import
protein
from
openfold.np
import
protein
import
openfold.np.residue_constants
as
rc
import
openfold.np.residue_constants
as
rc
from
openfold.utils.geometry
import
rigid_matrix_vector
,
rotation_matrix
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
batched_gather
,
batched_gather
,
...
@@ -213,15 +214,14 @@ def torsion_angles_to_frames(
...
@@ -213,15 +214,14 @@ def torsion_angles_to_frames(
# This follows the original code rather than the supplement, which uses
# This follows the original code rather than the supplement, which uses
# different indices.
# different indices.
all_rots
=
alpha
.
new_zeros
(
default_r
.
get_rots
().
get_rot_mats
().
shape
)
all_rots
=
alpha
.
new_zeros
(
default_r
.
shape
+
(
3
,
3
)
)
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
=
Rigid
(
Rotation
(
rot_mats
=
all_rots
),
None
)
all_rots
=
rotation_matrix
.
Rot3Array
.
from_array
(
all_rots
)
all_frames
=
default_r
.
compose_rotation
(
all_rots
)
all_frames
=
default_r
.
compose
(
all_rots
)
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
...
@@ -232,7 +232,7 @@ def torsion_angles_to_frames(
...
@@ -232,7 +232,7 @@ def torsion_angles_to_frames(
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
all_frames_to_bb
=
Rigid
.
cat
(
all_frames_to_bb
=
rigid_matrix_vector
.
Rigid3Array
.
cat
(
[
[
all_frames
[...,
:
5
],
all_frames
[...,
:
5
],
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
...
@@ -241,14 +241,14 @@ def torsion_angles_to_frames(
...
@@ -241,14 +241,14 @@ def torsion_angles_to_frames(
],
],
dim
=-
1
,
dim
=-
1
,
)
)
all_frames_to_global
=
r
[...,
None
].
compose
(
all_frames_to_bb
)
all_frames_to_global
=
r
[...,
None
].
compose
(
all_frames_to_bb
)
return
all_frames_to_global
return
all_frames_to_global
def
frames_and_literature_positions_to_atom14_pos
(
def
frames_and_literature_positions_to_atom14_pos
(
r
:
Rigid
,
r
:
rigid_matrix_vector
.
Rigid3Array
,
aatype
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
default_frames
,
default_frames
,
group_idx
,
group_idx
,
...
@@ -275,8 +275,8 @@ def frames_and_literature_positions_to_atom14_pos(
...
@@ -275,8 +275,8 @@ def frames_and_literature_positions_to_atom14_pos(
lambda
x
:
torch
.
sum
(
x
,
dim
=-
1
)
lambda
x
:
torch
.
sum
(
x
,
dim
=-
1
)
)
)
# [*, N, 14
, 1
]
# [*, N, 14]
atom_mask
=
atom_mask
[
aatype
,
...]
.
unsqueeze
(
-
1
)
atom_mask
=
atom_mask
[
aatype
,
...]
# [*, N, 14, 3]
# [*, N, 14, 3]
lit_positions
=
lit_positions
[
aatype
,
...]
lit_positions
=
lit_positions
[
aatype
,
...]
...
...
openfold/utils/geometry/quat_rigid.py
View file @
260db67f
...
@@ -22,8 +22,6 @@ class QuatRigid(nn.Module):
...
@@ -22,8 +22,6 @@ class QuatRigid(nn.Module):
# NOTE: During training, this needs to be run in higher precision
# NOTE: During training, this needs to be run in higher precision
rigid_flat
=
self
.
linear
(
activations
.
to
(
torch
.
float32
))
rigid_flat
=
self
.
linear
(
activations
.
to
(
torch
.
float32
))
print
(
rigid_flat
.
shape
)
rigid_flat
=
torch
.
unbind
(
rigid_flat
,
dim
=-
1
)
rigid_flat
=
torch
.
unbind
(
rigid_flat
,
dim
=-
1
)
if
(
self
.
full_quat
):
if
(
self
.
full_quat
):
qw
,
qx
,
qy
,
qz
=
rigid_flat
[:
4
]
qw
,
qx
,
qy
,
qz
=
rigid_flat
[:
4
]
...
...
openfold/utils/geometry/rigid_matrix_vector.py
View file @
260db67f
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -29,100 +29,148 @@ Float = Union[float, torch.Tensor]
...
@@ -29,100 +29,148 @@ Float = Union[float, torch.Tensor]
@
dataclasses
.
dataclass
(
frozen
=
True
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Rigid3Array
:
class
Rigid3Array
:
"""Rigid Transformation, i.e. element of special euclidean group."""
"""Rigid Transformation, i.e. element of special euclidean group."""
rotation
:
rotation_matrix
.
Rot3Array
rotation
:
rotation_matrix
.
Rot3Array
translation
:
vector
.
Vec3Array
translation
:
vector
.
Vec3Array
def
__matmul__
(
self
,
other
:
Rigid3Array
)
->
Rigid3Array
:
def
__matmul__
(
self
,
other
:
Rigid3Array
)
->
Rigid3Array
:
new_rotation
=
self
.
rotation
@
other
.
rotation
# __matmul__
new_rotation
=
self
.
rotation
@
other
.
rotation
# __matmul__
new_translation
=
self
.
apply_to_point
(
other
.
translation
)
new_translation
=
self
.
apply_to_point
(
other
.
translation
)
return
Rigid3Array
(
new_rotation
,
new_translation
)
return
Rigid3Array
(
new_rotation
,
new_translation
)
def
__getitem__
(
self
,
index
)
->
Rigid3Array
:
def
__getitem__
(
self
,
index
)
->
Rigid3Array
:
return
Rigid3Array
(
return
Rigid3Array
(
self
.
rotation
[
index
],
self
.
rotation
[
index
],
self
.
translation
[
index
],
self
.
translation
[
index
],
)
)
def
__mul__
(
self
,
other
:
torch
.
Tensor
)
->
Rigid3Array
:
def
__mul__
(
self
,
other
:
torch
.
Tensor
)
->
Rigid3Array
:
return
Rigid3Array
(
return
Rigid3Array
(
self
.
rotation
*
other
,
self
.
rotation
*
other
,
self
.
translation
*
other
,
self
.
translation
*
other
,
)
)
def
map_tensor_fn
(
self
,
fn
)
->
Rigid3Array
:
def
map_tensor_fn
(
self
,
fn
)
->
Rigid3Array
:
return
Rigid3Array
(
return
Rigid3Array
(
self
.
rotation
.
map_tensor_fn
(
fn
),
self
.
rotation
.
map_tensor_fn
(
fn
),
self
.
translation
.
map_tensor_fn
(
fn
),
self
.
translation
.
map_tensor_fn
(
fn
),
)
)
def
inverse
(
self
)
->
Rigid3Array
:
def
inverse
(
self
)
->
Rigid3Array
:
"""Return Rigid3Array corresponding to inverse transform."""
"""Return Rigid3Array corresponding to inverse transform."""
inv_rotation
=
self
.
rotation
.
inverse
()
inv_rotation
=
self
.
rotation
.
inverse
()
inv_translation
=
inv_rotation
.
apply_to_point
(
-
self
.
translation
)
inv_translation
=
inv_rotation
.
apply_to_point
(
-
self
.
translation
)
return
Rigid3Array
(
inv_rotation
,
inv_translation
)
return
Rigid3Array
(
inv_rotation
,
inv_translation
)
def
apply_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
def
apply_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Apply Rigid3Array transform to point."""
"""Apply Rigid3Array transform to point."""
return
self
.
rotation
.
apply_to_point
(
point
)
+
self
.
translation
return
self
.
rotation
.
apply_to_point
(
point
)
+
self
.
translation
def
apply_inverse_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
def
apply
(
self
,
point
:
torch
.
Tensor
)
->
vector
.
Vec3Array
:
"""Apply inverse Rigid3Array transform to point."""
return
self
.
apply_to_point
(
vector
.
Vec3Array
.
from_array
(
point
))
new_point
=
point
-
self
.
translation
return
self
.
rotation
.
apply_inverse_to_point
(
new_point
)
def
apply_inverse_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Apply inverse Rigid3Array transform to point."""
def
compose_rotation
(
self
,
other_rotation
):
new_point
=
point
-
self
.
translation
rot
=
self
.
rotation
@
other_rotation
return
self
.
rotation
.
apply_inverse_to_point
(
new_point
)
return
Rigid3Array
(
rot
,
trans
.
clone
())
def
compose_rotation
(
self
,
other_rotation
):
@
classmethod
rot
=
self
.
rotation
@
other_rotation
def
identity
(
cls
,
shape
,
device
)
->
Rigid3Array
:
return
Rigid3Array
(
rot
,
self
.
translation
.
clone
())
"""Return identity Rigid3Array of given shape."""
return
cls
(
def
compose
(
self
,
other_rigid
):
rotation_matrix
.
Rot3Array
.
identity
(
shape
,
device
),
return
self
@
other_rigid
vector
.
Vec3Array
.
zeros
(
shape
,
device
)
)
def
unsqueeze
(
self
,
dim
:
int
):
return
Rigid3Array
(
@
classmethod
self
.
rotation
.
unsqueeze
(
dim
),
def
cat
(
cls
,
rigids
:
List
[
Rigid3Array
],
dim
:
int
)
->
Rigid3Array
:
self
.
translation
.
unsqueeze
(
dim
),
return
cls
(
)
Rot3Array
.
cat
([
r
.
rotation
for
r
in
rigids
],
dim
=
dim
),
Vec3Array
.
cat
([
r
.
translation
for
r
in
rigids
],
dim
=
dim
),
@
property
)
def
shape
(
self
)
->
torch
.
Size
:
return
self
.
rotation
.
xx
.
shape
def
scale_translation
(
self
,
factor
:
Float
)
->
Rigid3Array
:
"""Scale translation in Rigid3Array by 'factor'."""
@
property
return
Rigid3Array
(
self
.
rotation
,
self
.
translation
*
factor
)
def
dtype
(
self
)
->
torch
.
dtype
:
return
self
.
rotation
.
xx
.
dtype
def
to_array
(
self
):
rot_array
=
self
.
rotation
.
to_array
()
@
property
vec_array
=
self
.
translation
.
to_array
()
def
device
(
self
)
->
torch
.
device
:
return
torch
.
cat
([
rot_array
,
vec_array
[...,
None
]],
dim
=-
1
)
return
self
.
rotation
.
xx
.
device
def
reshape
(
self
,
new_shape
)
->
Rigid3Array
:
@
classmethod
rots
=
self
.
rotation
.
reshape
(
new_shape
)
def
identity
(
cls
,
shape
,
device
)
->
Rigid3Array
:
trans
=
self
.
translation
.
reshape
(
new_shape
)
"""Return identity Rigid3Array of given shape."""
return
Rigid3Aray
(
rots
,
trans
)
return
cls
(
rotation_matrix
.
Rot3Array
.
identity
(
shape
,
device
),
@
classmethod
vector
.
Vec3Array
.
zeros
(
shape
,
device
)
def
from_array
(
cls
,
array
):
)
rot
=
rotation_matrix
.
Rot3Array
.
from_array
(
array
[...,
:
3
])
vec
=
vector
.
Vec3Array
.
from_array
(
array
[...,
-
1
])
@
classmethod
return
cls
(
rot
,
vec
)
def
cat
(
cls
,
rigids
:
List
[
Rigid3Array
],
dim
:
int
)
->
Rigid3Array
:
return
cls
(
@
classmethod
rotation_matrix
.
Rot3Array
.
cat
(
def
from_tensor_4x4
(
cls
,
array
):
[
r
.
rotation
for
r
in
rigids
],
dim
=
dim
return
cls
.
from_array
(
array
)
),
vector
.
Vec3Array
.
cat
(
@
classmethod
[
r
.
translation
for
r
in
rigids
],
dim
=
dim
def
from_array4x4
(
cls
,
array
:
torch
.
tensor
)
->
Rigid3Array
:
),
"""Construct Rigid3Array from homogeneous 4x4 array."""
)
rotation
=
rotation_matrix
.
Rot3Array
(
array
[...,
0
,
0
],
array
[...,
0
,
1
],
array
[...,
0
,
2
],
def
scale_translation
(
self
,
factor
:
Float
)
->
Rigid3Array
:
array
[...,
1
,
0
],
array
[...,
1
,
1
],
array
[...,
1
,
2
],
"""Scale translation in Rigid3Array by 'factor'."""
array
[...,
2
,
0
],
array
[...,
2
,
1
],
array
[...,
2
,
2
]
return
Rigid3Array
(
self
.
rotation
,
self
.
translation
*
factor
)
)
translation
=
vector
.
Vec3Array
(
def
to_tensor
(
self
)
->
torch
.
Tensor
:
array
[...,
0
,
3
],
array
[...,
1
,
3
],
array
[...,
2
,
3
])
rot_array
=
self
.
rotation
.
to_tensor
()
return
cls
(
rotation
,
translation
)
vec_array
=
self
.
translation
.
to_tensor
()
array
=
torch
.
zeros
(
rot_array
.
shape
[:
-
2
]
+
(
4
,
4
),
device
=
rot_array
.
device
,
dtype
=
rot_array
.
dtype
)
array
[...,
:
3
,
:
3
]
=
rot_array
array
[...,
:
3
,
3
]
=
vec_array
array
[...,
3
,
3
]
=
1.
return
array
def
to_tensor_4x4
(
self
)
->
torch
.
Tensor
:
return
self
.
to_tensor
()
def
reshape
(
self
,
new_shape
)
->
Rigid3Array
:
rots
=
self
.
rotation
.
reshape
(
new_shape
)
trans
=
self
.
translation
.
reshape
(
new_shape
)
return
Rigid3Aray
(
rots
,
trans
)
def
stop_rot_gradient
(
self
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
.
stop_gradient
(),
self
.
translation
,
)
@
classmethod
def
from_array
(
cls
,
array
):
rot
=
rotation_matrix
.
Rot3Array
.
from_array
(
array
[...,
:
3
,
:
3
],
)
vec
=
vector
.
Vec3Array
.
from_array
(
array
[...,
:
3
,
3
])
return
cls
(
rot
,
vec
)
@
classmethod
def
from_tensor_4x4
(
cls
,
array
):
return
cls
.
from_array
(
array
)
@
classmethod
def
from_array4x4
(
cls
,
array
:
torch
.
tensor
)
->
Rigid3Array
:
"""Construct Rigid3Array from homogeneous 4x4 array."""
rotation
=
rotation_matrix
.
Rot3Array
(
array
[...,
0
,
0
],
array
[...,
0
,
1
],
array
[...,
0
,
2
],
array
[...,
1
,
0
],
array
[...,
1
,
1
],
array
[...,
1
,
2
],
array
[...,
2
,
0
],
array
[...,
2
,
1
],
array
[...,
2
,
2
]
)
translation
=
vector
.
Vec3Array
(
array
[...,
0
,
3
],
array
[...,
1
,
3
],
array
[...,
2
,
3
]
)
return
cls
(
rotation
,
translation
)
openfold/utils/geometry/rotation_matrix.py
View file @
260db67f
...
@@ -22,6 +22,7 @@ import numpy as np
...
@@ -22,6 +22,7 @@ import numpy as np
from
openfold.utils.geometry
import
struct_of_array
from
openfold.utils.geometry
import
struct_of_array
from
openfold.utils.geometry
import
utils
from
openfold.utils.geometry
import
utils
from
openfold.utils.geometry
import
vector
from
openfold.utils.geometry
import
vector
from
openfold.utils.tensor_utils
import
tensor_tree_map
COMPONENTS
=
[
'xx'
,
'xy'
,
'xz'
,
'yx'
,
'yy'
,
'yz'
,
'zx'
,
'zy'
,
'zz'
]
COMPONENTS
=
[
'xx'
,
'xy'
,
'xz'
,
'yx'
,
'yy'
,
'yz'
,
'zx'
,
'zy'
,
'zz'
]
...
@@ -58,6 +59,13 @@ class Rot3Array:
...
@@ -58,6 +59,13 @@ class Rot3Array:
for
name
in
field_names
for
name
in
field_names
}
}
)
)
def
__matmul__
(
self
,
other
:
Rot3Array
)
->
Rot3Array
:
"""Composes two Rot3Arrays."""
c0
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xx
,
other
.
yx
,
other
.
zx
))
c1
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xy
,
other
.
yy
,
other
.
zy
))
c2
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xz
,
other
.
yz
,
other
.
zz
))
return
Rot3Array
(
c0
.
x
,
c1
.
x
,
c2
.
x
,
c0
.
y
,
c1
.
y
,
c2
.
y
,
c0
.
z
,
c1
.
z
,
c2
.
z
)
def
map_tensor_fn
(
self
,
fn
)
->
Rot3Array
:
def
map_tensor_fn
(
self
,
fn
)
->
Rot3Array
:
field_names
=
utils
.
get_field_names
(
Rot3Array
)
field_names
=
utils
.
get_field_names
(
Rot3Array
)
...
@@ -88,12 +96,19 @@ class Rot3Array:
...
@@ -88,12 +96,19 @@ class Rot3Array:
"""Applies inverse Rot3Array to point."""
"""Applies inverse Rot3Array to point."""
return
self
.
inverse
().
apply_to_point
(
point
)
return
self
.
inverse
().
apply_to_point
(
point
)
def
__matmul__
(
self
,
other
:
Rot3Array
)
->
Rot3Array
:
"""Composes two Rot3Arrays."""
def
unsqueeze
(
self
,
dim
:
int
):
c0
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xx
,
other
.
yx
,
other
.
zx
))
return
Rot3Array
(
c1
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xy
,
other
.
yy
,
other
.
zy
))
*
tensor_tree_map
(
c2
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xz
,
other
.
yz
,
other
.
zz
))
lambda
t
:
t
.
unsqueeze
(
dim
),
return
Rot3Array
(
c0
.
x
,
c1
.
x
,
c2
.
x
,
c0
.
y
,
c1
.
y
,
c2
.
y
,
c0
.
z
,
c1
.
z
,
c2
.
z
)
[
getattr
(
self
,
c
)
for
c
in
COMPONENTS
]
)
)
def
stop_gradient
(
self
)
->
Rot3Array
:
return
Rot3Array
(
*
[
getattr
(
self
,
c
).
detach
()
for
c
in
COMPONENTS
]
)
@
classmethod
@
classmethod
def
identity
(
cls
,
shape
,
device
)
->
Rot3Array
:
def
identity
(
cls
,
shape
,
device
)
->
Rot3Array
:
...
@@ -130,9 +145,11 @@ class Rot3Array:
...
@@ -130,9 +145,11 @@ class Rot3Array:
@
classmethod
@
classmethod
def
from_array
(
cls
,
array
:
torch
.
Tensor
)
->
Rot3Array
:
def
from_array
(
cls
,
array
:
torch
.
Tensor
)
->
Rot3Array
:
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
return
cls
(
torch
.
unbind
(
array
,
dim
=-
2
))
rows
=
torch
.
unbind
(
array
,
dim
=-
2
)
rc
=
[
torch
.
unbind
(
e
,
dim
=-
1
)
for
e
in
rows
]
return
cls
(
*
[
e
for
row
in
rc
for
e
in
row
])
def
to_
array
(
self
)
->
torch
.
Tensor
:
def
to_
tensor
(
self
)
->
torch
.
Tensor
:
"""Convert Rot3Array to array of shape [..., 3, 3]."""
"""Convert Rot3Array to array of shape [..., 3, 3]."""
return
torch
.
stack
(
return
torch
.
stack
(
[
[
...
@@ -140,7 +157,8 @@ class Rot3Array:
...
@@ -140,7 +157,8 @@ class Rot3Array:
torch
.
stack
([
self
.
yx
,
self
.
yy
,
self
.
yz
],
dim
=-
1
),
torch
.
stack
([
self
.
yx
,
self
.
yy
,
self
.
yz
],
dim
=-
1
),
torch
.
stack
([
self
.
zx
,
self
.
zy
,
self
.
zz
],
dim
=-
1
)
torch
.
stack
([
self
.
zx
,
self
.
zy
,
self
.
zz
],
dim
=-
1
)
],
],
dim
=-
2
)
dim
=-
2
)
@
classmethod
@
classmethod
def
from_quaternion
(
cls
,
def
from_quaternion
(
cls
,
...
...
openfold/utils/geometry/vector.py
View file @
260db67f
...
@@ -134,13 +134,20 @@ class Vec3Array:
...
@@ -134,13 +134,20 @@ class Vec3Array:
return
Vec3Array
(
x
,
y
,
z
)
return
Vec3Array
(
x
,
y
,
z
)
def
sum
(
self
,
dim
)
->
Vec3Array
:
def
sum
(
self
,
dim
:
int
)
->
Vec3Array
:
return
Vec3Array
(
return
Vec3Array
(
torch
.
sum
(
self
.
x
,
dim
=
dim
),
torch
.
sum
(
self
.
x
,
dim
=
dim
),
torch
.
sum
(
self
.
y
,
dim
=
dim
),
torch
.
sum
(
self
.
y
,
dim
=
dim
),
torch
.
sum
(
self
.
z
,
dim
=
dim
),
torch
.
sum
(
self
.
z
,
dim
=
dim
),
)
)
def
unsqueeze
(
self
,
dim
:
int
):
return
Vec3Array
(
self
.
x
.
unsqueeze
(
dim
),
self
.
y
.
unsqueeze
(
dim
),
self
.
z
.
unsqueeze
(
dim
),
)
@
classmethod
@
classmethod
def
zeros
(
cls
,
shape
,
device
=
"cpu"
):
def
zeros
(
cls
,
shape
,
device
=
"cpu"
):
"""Return Vec3Array corresponding to zeros of given shape."""
"""Return Vec3Array corresponding to zeros of given shape."""
...
@@ -150,11 +157,11 @@ class Vec3Array:
...
@@ -150,11 +157,11 @@ class Vec3Array:
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
)
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
)
)
)
def
to_
array
(
self
)
->
torch
.
Tensor
:
def
to_
tensor
(
self
)
->
torch
.
Tensor
:
return
torch
.
stack
([
self
.
x
,
self
.
y
,
self
.
z
],
dim
=-
1
)
return
torch
.
stack
([
self
.
x
,
self
.
y
,
self
.
z
],
dim
=-
1
)
@
classmethod
@
classmethod
def
from_
tensor
(
cls
,
tensor
):
def
from_
array
(
cls
,
tensor
):
return
cls
(
*
torch
.
unbind
(
tensor
,
dim
=-
1
))
return
cls
(
*
torch
.
unbind
(
tensor
,
dim
=-
1
))
@
classmethod
@
classmethod
...
...
openfold/utils/import_weights.py
View file @
260db67f
...
@@ -623,8 +623,8 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -623,8 +623,8 @@ def import_jax_weights_(model, npz_path, version="model_1"):
flat_keys
=
list
(
flat
.
keys
())
flat_keys
=
list
(
flat
.
keys
())
incorrect
=
[
k
for
k
in
flat_keys
if
k
not
in
keys
]
incorrect
=
[
k
for
k
in
flat_keys
if
k
not
in
keys
]
missing
=
[
k
for
k
in
keys
if
k
not
in
flat_keys
]
missing
=
[
k
for
k
in
keys
if
k
not
in
flat_keys
]
print
(
f
"Incorrect:
{
incorrect
}
"
)
#
print(f"Incorrect: {incorrect}")
print
(
f
"Missing:
{
missing
}
"
)
#
print(f"Missing: {missing}")
assert
len
(
incorrect
)
==
0
assert
len
(
incorrect
)
==
0
# assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
# assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
...
...
run_pretrained_openfold.py
View file @
260db67f
...
@@ -217,7 +217,8 @@ def main(args):
...
@@ -217,7 +217,8 @@ def main(args):
unrelaxed_protein
=
protein
.
from_prediction
(
unrelaxed_protein
=
protein
.
from_prediction
(
features
=
batch
,
features
=
batch
,
result
=
out
,
result
=
out
,
b_factors
=
plddt_b_factors
b_factors
=
plddt_b_factors
,
remove_leading_feature_dimension
=
not
is_multimer
,
)
)
# Save the unrelaxed PDB.
# Save the unrelaxed PDB.
...
@@ -227,6 +228,9 @@ def main(args):
...
@@ -227,6 +228,9 @@ def main(args):
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
print
(
unrelaxed_output_path
)
print
(
"asdjfh klasjdhf lkasjdhf lkjasdhflkjasdh fkl jasdhfklj hasdkljf hasldkjfh lkasjdfh lkajsdhflk asd"
)
amber_relaxer
=
relax
.
AmberRelaxation
(
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
**
config
.
relax
,
**
config
.
relax
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment