Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
260db67f
Commit
260db67f
authored
May 12, 2022
by
Gustaf Ahdritz
Browse files
Finish multimer inference
parent
6e68d6b0
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
205 additions
and
132 deletions
+205
-132
openfold/data/input_pipeline_multimer.py
openfold/data/input_pipeline_multimer.py
+1
-0
openfold/model/embedders.py
openfold/model/embedders.py
+1
-1
openfold/model/structure_module.py
openfold/model/structure_module.py
+3
-6
openfold/np/protein.py
openfold/np/protein.py
+1
-1
openfold/utils/feats.py
openfold/utils/feats.py
+9
-9
openfold/utils/geometry/quat_rigid.py
openfold/utils/geometry/quat_rigid.py
+0
-2
openfold/utils/geometry/rigid_matrix_vector.py
openfold/utils/geometry/rigid_matrix_vector.py
+146
-98
openfold/utils/geometry/rotation_matrix.py
openfold/utils/geometry/rotation_matrix.py
+27
-9
openfold/utils/geometry/vector.py
openfold/utils/geometry/vector.py
+10
-3
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+2
-2
run_pretrained_openfold.py
run_pretrained_openfold.py
+5
-1
No files found.
openfold/data/input_pipeline_multimer.py
View file @
260db67f
...
@@ -29,6 +29,7 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
...
@@ -29,6 +29,7 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms
.
cast_to_64bit_ints
,
data_transforms
.
cast_to_64bit_ints
,
data_transforms_multimer
.
make_msa_profile
,
data_transforms_multimer
.
make_msa_profile
,
data_transforms_multimer
.
create_target_feat
,
data_transforms_multimer
.
create_target_feat
,
data_transforms
.
make_atom14_masks
,
]
]
if
(
common_cfg
.
use_templates
):
if
(
common_cfg
.
use_templates
):
...
...
openfold/model/embedders.py
View file @
260db67f
...
@@ -868,7 +868,7 @@ class TemplateEmbedderMultimer(nn.Module):
...
@@ -868,7 +868,7 @@ class TemplateEmbedderMultimer(nn.Module):
raw_atom_pos
=
single_template_feats
[
"template_all_atom_positions"
]
raw_atom_pos
=
single_template_feats
[
"template_all_atom_positions"
]
atom_pos
=
geometry
.
Vec3Array
.
from_
tensor
(
raw_atom_pos
)
atom_pos
=
geometry
.
Vec3Array
.
from_
array
(
raw_atom_pos
)
rigid
,
backbone_mask
=
all_atom_multimer
.
make_backbone_affine
(
rigid
,
backbone_mask
=
all_atom_multimer
.
make_backbone_affine
(
atom_pos
,
atom_pos
,
single_template_feats
[
"template_all_atom_mask"
],
single_template_feats
[
"template_all_atom_mask"
],
...
...
openfold/model/structure_module.py
View file @
260db67f
...
@@ -363,9 +363,6 @@ class InvariantPointAttention(nn.Module):
...
@@ -363,9 +363,6 @@ class InvariantPointAttention(nn.Module):
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
for
c
in
q_pts
:
print
(
type
(
c
))
# [*, N_res, N_res, H, P_q, 3]
# [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
[...,
None
,
:,
:]
-
k_pts
[...,
None
,
:,
:,
:]
pt_att
=
q_pts
[...,
None
,
:,
:]
-
k_pts
[...,
None
,
:,
:,
:]
...
@@ -669,7 +666,7 @@ class StructureModule(nn.Module):
...
@@ -669,7 +666,7 @@ class StructureModule(nn.Module):
self
,
r
,
f
# [*, N, 8] # [*, N]
self
,
r
,
f
# [*, N, 8] # [*, N]
):
):
# Lazily initialize the residue constants on the correct device
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
r
.
get_rots
().
dtype
,
r
.
get_rots
()
.
device
)
self
.
_init_residue_constants
(
r
.
dtype
,
r
.
device
)
return
frames_and_literature_positions_to_atom14_pos
(
return
frames_and_literature_positions_to_atom14_pos
(
r
,
r
,
f
,
f
,
...
@@ -818,11 +815,11 @@ class StructureModule(nn.Module):
...
@@ -818,11 +815,11 @@ class StructureModule(nn.Module):
)
)
preds
=
{
preds
=
{
"frames"
:
rigids
.
scale_translation
(
self
.
trans_scale_factor
).
to_tensor
7
(),
"frames"
:
rigids
.
scale_translation
(
self
.
trans_scale_factor
).
to_tensor
(),
"sidechain_frames"
:
all_frames_to_global
.
to_tensor_4x4
(),
"sidechain_frames"
:
all_frames_to_global
.
to_tensor_4x4
(),
"unnormalized_angles"
:
unnormalized_angles
,
"unnormalized_angles"
:
unnormalized_angles
,
"angles"
:
angles
,
"angles"
:
angles
,
"positions"
:
pred_xyz
,
"positions"
:
pred_xyz
.
to_tensor
()
,
}
}
outputs
.
append
(
preds
)
outputs
.
append
(
preds
)
...
...
openfold/np/protein.py
View file @
260db67f
...
@@ -205,7 +205,7 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
...
@@ -205,7 +205,7 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
)
)
def
_chain_end
(
atom_index
,
end_resname
,
chain_name
,
residue_indx
)
->
str
:
def
_chain_end
(
atom_index
,
end_resname
,
chain_name
,
residue_ind
e
x
)
->
str
:
chain_end
=
'TER'
chain_end
=
'TER'
return
(
return
(
f
'
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
{
end_resname
:
>
3
}
'
f
'
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
{
end_resname
:
>
3
}
'
...
...
openfold/utils/feats.py
View file @
260db67f
...
@@ -22,6 +22,7 @@ from typing import Dict
...
@@ -22,6 +22,7 @@ from typing import Dict
from
openfold.np
import
protein
from
openfold.np
import
protein
import
openfold.np.residue_constants
as
rc
import
openfold.np.residue_constants
as
rc
from
openfold.utils.geometry
import
rigid_matrix_vector
,
rotation_matrix
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
batched_gather
,
batched_gather
,
...
@@ -213,15 +214,14 @@ def torsion_angles_to_frames(
...
@@ -213,15 +214,14 @@ def torsion_angles_to_frames(
# This follows the original code rather than the supplement, which uses
# This follows the original code rather than the supplement, which uses
# different indices.
# different indices.
all_rots
=
alpha
.
new_zeros
(
default_r
.
get_rots
().
get_rot_mats
().
shape
)
all_rots
=
alpha
.
new_zeros
(
default_r
.
shape
+
(
3
,
3
)
)
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
=
Rigid
(
Rotation
(
rot_mats
=
all_rots
),
None
)
all_rots
=
rotation_matrix
.
Rot3Array
.
from_array
(
all_rots
)
all_frames
=
default_r
.
compose_rotation
(
all_rots
)
all_frames
=
default_r
.
compose
(
all_rots
)
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
...
@@ -232,7 +232,7 @@ def torsion_angles_to_frames(
...
@@ -232,7 +232,7 @@ def torsion_angles_to_frames(
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
all_frames_to_bb
=
Rigid
.
cat
(
all_frames_to_bb
=
rigid_matrix_vector
.
Rigid3Array
.
cat
(
[
[
all_frames
[...,
:
5
],
all_frames
[...,
:
5
],
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
...
@@ -248,7 +248,7 @@ def torsion_angles_to_frames(
...
@@ -248,7 +248,7 @@ def torsion_angles_to_frames(
def
frames_and_literature_positions_to_atom14_pos
(
def
frames_and_literature_positions_to_atom14_pos
(
r
:
Rigid
,
r
:
rigid_matrix_vector
.
Rigid3Array
,
aatype
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
default_frames
,
default_frames
,
group_idx
,
group_idx
,
...
@@ -275,8 +275,8 @@ def frames_and_literature_positions_to_atom14_pos(
...
@@ -275,8 +275,8 @@ def frames_and_literature_positions_to_atom14_pos(
lambda
x
:
torch
.
sum
(
x
,
dim
=-
1
)
lambda
x
:
torch
.
sum
(
x
,
dim
=-
1
)
)
)
# [*, N, 14
, 1
]
# [*, N, 14]
atom_mask
=
atom_mask
[
aatype
,
...]
.
unsqueeze
(
-
1
)
atom_mask
=
atom_mask
[
aatype
,
...]
# [*, N, 14, 3]
# [*, N, 14, 3]
lit_positions
=
lit_positions
[
aatype
,
...]
lit_positions
=
lit_positions
[
aatype
,
...]
...
...
openfold/utils/geometry/quat_rigid.py
View file @
260db67f
...
@@ -22,8 +22,6 @@ class QuatRigid(nn.Module):
...
@@ -22,8 +22,6 @@ class QuatRigid(nn.Module):
# NOTE: During training, this needs to be run in higher precision
# NOTE: During training, this needs to be run in higher precision
rigid_flat
=
self
.
linear
(
activations
.
to
(
torch
.
float32
))
rigid_flat
=
self
.
linear
(
activations
.
to
(
torch
.
float32
))
print
(
rigid_flat
.
shape
)
rigid_flat
=
torch
.
unbind
(
rigid_flat
,
dim
=-
1
)
rigid_flat
=
torch
.
unbind
(
rigid_flat
,
dim
=-
1
)
if
(
self
.
full_quat
):
if
(
self
.
full_quat
):
qw
,
qx
,
qy
,
qz
=
rigid_flat
[:
4
]
qw
,
qx
,
qy
,
qz
=
rigid_flat
[:
4
]
...
...
openfold/utils/geometry/rigid_matrix_vector.py
View file @
260db67f
...
@@ -67,6 +67,9 @@ class Rigid3Array:
...
@@ -67,6 +67,9 @@ class Rigid3Array:
"""Apply Rigid3Array transform to point."""
"""Apply Rigid3Array transform to point."""
return
self
.
rotation
.
apply_to_point
(
point
)
+
self
.
translation
return
self
.
rotation
.
apply_to_point
(
point
)
+
self
.
translation
def
apply
(
self
,
point
:
torch
.
Tensor
)
->
vector
.
Vec3Array
:
return
self
.
apply_to_point
(
vector
.
Vec3Array
.
from_array
(
point
))
def
apply_inverse_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
def
apply_inverse_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Apply inverse Rigid3Array transform to point."""
"""Apply inverse Rigid3Array transform to point."""
new_point
=
point
-
self
.
translation
new_point
=
point
-
self
.
translation
...
@@ -74,7 +77,28 @@ class Rigid3Array:
...
@@ -74,7 +77,28 @@ class Rigid3Array:
def
compose_rotation
(
self
,
other_rotation
):
def
compose_rotation
(
self
,
other_rotation
):
rot
=
self
.
rotation
@
other_rotation
rot
=
self
.
rotation
@
other_rotation
return
Rigid3Array
(
rot
,
trans
.
clone
())
return
Rigid3Array
(
rot
,
self
.
translation
.
clone
())
def
compose
(
self
,
other_rigid
):
return
self
@
other_rigid
def
unsqueeze
(
self
,
dim
:
int
):
return
Rigid3Array
(
self
.
rotation
.
unsqueeze
(
dim
),
self
.
translation
.
unsqueeze
(
dim
),
)
@
property
def
shape
(
self
)
->
torch
.
Size
:
return
self
.
rotation
.
xx
.
shape
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
return
self
.
rotation
.
xx
.
dtype
@
property
def
device
(
self
)
->
torch
.
device
:
return
self
.
rotation
.
xx
.
device
@
classmethod
@
classmethod
def
identity
(
cls
,
shape
,
device
)
->
Rigid3Array
:
def
identity
(
cls
,
shape
,
device
)
->
Rigid3Array
:
...
@@ -87,28 +111,51 @@ class Rigid3Array:
...
@@ -87,28 +111,51 @@ class Rigid3Array:
@
classmethod
@
classmethod
def
cat
(
cls
,
rigids
:
List
[
Rigid3Array
],
dim
:
int
)
->
Rigid3Array
:
def
cat
(
cls
,
rigids
:
List
[
Rigid3Array
],
dim
:
int
)
->
Rigid3Array
:
return
cls
(
return
cls
(
Rot3Array
.
cat
([
r
.
rotation
for
r
in
rigids
],
dim
=
dim
),
rotation_matrix
.
Rot3Array
.
cat
(
Vec3Array
.
cat
([
r
.
translation
for
r
in
rigids
],
dim
=
dim
),
[
r
.
rotation
for
r
in
rigids
],
dim
=
dim
),
vector
.
Vec3Array
.
cat
(
[
r
.
translation
for
r
in
rigids
],
dim
=
dim
),
)
)
def
scale_translation
(
self
,
factor
:
Float
)
->
Rigid3Array
:
def
scale_translation
(
self
,
factor
:
Float
)
->
Rigid3Array
:
"""Scale translation in Rigid3Array by 'factor'."""
"""Scale translation in Rigid3Array by 'factor'."""
return
Rigid3Array
(
self
.
rotation
,
self
.
translation
*
factor
)
return
Rigid3Array
(
self
.
rotation
,
self
.
translation
*
factor
)
def
to_array
(
self
):
def
to_tensor
(
self
)
->
torch
.
Tensor
:
rot_array
=
self
.
rotation
.
to_array
()
rot_array
=
self
.
rotation
.
to_tensor
()
vec_array
=
self
.
translation
.
to_array
()
vec_array
=
self
.
translation
.
to_tensor
()
return
torch
.
cat
([
rot_array
,
vec_array
[...,
None
]],
dim
=-
1
)
array
=
torch
.
zeros
(
rot_array
.
shape
[:
-
2
]
+
(
4
,
4
),
device
=
rot_array
.
device
,
dtype
=
rot_array
.
dtype
)
array
[...,
:
3
,
:
3
]
=
rot_array
array
[...,
:
3
,
3
]
=
vec_array
array
[...,
3
,
3
]
=
1.
return
array
def
to_tensor_4x4
(
self
)
->
torch
.
Tensor
:
return
self
.
to_tensor
()
def
reshape
(
self
,
new_shape
)
->
Rigid3Array
:
def
reshape
(
self
,
new_shape
)
->
Rigid3Array
:
rots
=
self
.
rotation
.
reshape
(
new_shape
)
rots
=
self
.
rotation
.
reshape
(
new_shape
)
trans
=
self
.
translation
.
reshape
(
new_shape
)
trans
=
self
.
translation
.
reshape
(
new_shape
)
return
Rigid3Aray
(
rots
,
trans
)
return
Rigid3Aray
(
rots
,
trans
)
def
stop_rot_gradient
(
self
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
.
stop_gradient
(),
self
.
translation
,
)
@
classmethod
@
classmethod
def
from_array
(
cls
,
array
):
def
from_array
(
cls
,
array
):
rot
=
rotation_matrix
.
Rot3Array
.
from_array
(
array
[...,
:
3
])
rot
=
rotation_matrix
.
Rot3Array
.
from_array
(
vec
=
vector
.
Vec3Array
.
from_array
(
array
[...,
-
1
])
array
[...,
:
3
,
:
3
],
)
vec
=
vector
.
Vec3Array
.
from_array
(
array
[...,
:
3
,
3
])
return
cls
(
rot
,
vec
)
return
cls
(
rot
,
vec
)
@
classmethod
@
classmethod
...
@@ -124,5 +171,6 @@ class Rigid3Array:
...
@@ -124,5 +171,6 @@ class Rigid3Array:
array
[...,
2
,
0
],
array
[...,
2
,
1
],
array
[...,
2
,
2
]
array
[...,
2
,
0
],
array
[...,
2
,
1
],
array
[...,
2
,
2
]
)
)
translation
=
vector
.
Vec3Array
(
translation
=
vector
.
Vec3Array
(
array
[...,
0
,
3
],
array
[...,
1
,
3
],
array
[...,
2
,
3
])
array
[...,
0
,
3
],
array
[...,
1
,
3
],
array
[...,
2
,
3
]
)
return
cls
(
rotation
,
translation
)
return
cls
(
rotation
,
translation
)
openfold/utils/geometry/rotation_matrix.py
View file @
260db67f
...
@@ -22,6 +22,7 @@ import numpy as np
...
@@ -22,6 +22,7 @@ import numpy as np
from
openfold.utils.geometry
import
struct_of_array
from
openfold.utils.geometry
import
struct_of_array
from
openfold.utils.geometry
import
utils
from
openfold.utils.geometry
import
utils
from
openfold.utils.geometry
import
vector
from
openfold.utils.geometry
import
vector
from
openfold.utils.tensor_utils
import
tensor_tree_map
COMPONENTS
=
[
'xx'
,
'xy'
,
'xz'
,
'yx'
,
'yy'
,
'yz'
,
'zx'
,
'zy'
,
'zz'
]
COMPONENTS
=
[
'xx'
,
'xy'
,
'xz'
,
'yx'
,
'yy'
,
'yz'
,
'zx'
,
'zy'
,
'zz'
]
...
@@ -59,6 +60,13 @@ class Rot3Array:
...
@@ -59,6 +60,13 @@ class Rot3Array:
}
}
)
)
def
__matmul__
(
self
,
other
:
Rot3Array
)
->
Rot3Array
:
"""Composes two Rot3Arrays."""
c0
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xx
,
other
.
yx
,
other
.
zx
))
c1
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xy
,
other
.
yy
,
other
.
zy
))
c2
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xz
,
other
.
yz
,
other
.
zz
))
return
Rot3Array
(
c0
.
x
,
c1
.
x
,
c2
.
x
,
c0
.
y
,
c1
.
y
,
c2
.
y
,
c0
.
z
,
c1
.
z
,
c2
.
z
)
def
map_tensor_fn
(
self
,
fn
)
->
Rot3Array
:
def
map_tensor_fn
(
self
,
fn
)
->
Rot3Array
:
field_names
=
utils
.
get_field_names
(
Rot3Array
)
field_names
=
utils
.
get_field_names
(
Rot3Array
)
return
Rot3Array
(
return
Rot3Array
(
...
@@ -88,12 +96,19 @@ class Rot3Array:
...
@@ -88,12 +96,19 @@ class Rot3Array:
"""Applies inverse Rot3Array to point."""
"""Applies inverse Rot3Array to point."""
return
self
.
inverse
().
apply_to_point
(
point
)
return
self
.
inverse
().
apply_to_point
(
point
)
def
__matmul__
(
self
,
other
:
Rot3Array
)
->
Rot3Array
:
"""Composes two Rot3Arrays."""
def
unsqueeze
(
self
,
dim
:
int
):
c0
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xx
,
other
.
yx
,
other
.
zx
))
return
Rot3Array
(
c1
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xy
,
other
.
yy
,
other
.
zy
))
*
tensor_tree_map
(
c2
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xz
,
other
.
yz
,
other
.
zz
))
lambda
t
:
t
.
unsqueeze
(
dim
),
return
Rot3Array
(
c0
.
x
,
c1
.
x
,
c2
.
x
,
c0
.
y
,
c1
.
y
,
c2
.
y
,
c0
.
z
,
c1
.
z
,
c2
.
z
)
[
getattr
(
self
,
c
)
for
c
in
COMPONENTS
]
)
)
def
stop_gradient
(
self
)
->
Rot3Array
:
return
Rot3Array
(
*
[
getattr
(
self
,
c
).
detach
()
for
c
in
COMPONENTS
]
)
@
classmethod
@
classmethod
def
identity
(
cls
,
shape
,
device
)
->
Rot3Array
:
def
identity
(
cls
,
shape
,
device
)
->
Rot3Array
:
...
@@ -130,9 +145,11 @@ class Rot3Array:
...
@@ -130,9 +145,11 @@ class Rot3Array:
@
classmethod
@
classmethod
def
from_array
(
cls
,
array
:
torch
.
Tensor
)
->
Rot3Array
:
def
from_array
(
cls
,
array
:
torch
.
Tensor
)
->
Rot3Array
:
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
return
cls
(
torch
.
unbind
(
array
,
dim
=-
2
))
rows
=
torch
.
unbind
(
array
,
dim
=-
2
)
rc
=
[
torch
.
unbind
(
e
,
dim
=-
1
)
for
e
in
rows
]
return
cls
(
*
[
e
for
row
in
rc
for
e
in
row
])
def
to_
array
(
self
)
->
torch
.
Tensor
:
def
to_
tensor
(
self
)
->
torch
.
Tensor
:
"""Convert Rot3Array to array of shape [..., 3, 3]."""
"""Convert Rot3Array to array of shape [..., 3, 3]."""
return
torch
.
stack
(
return
torch
.
stack
(
[
[
...
@@ -140,7 +157,8 @@ class Rot3Array:
...
@@ -140,7 +157,8 @@ class Rot3Array:
torch
.
stack
([
self
.
yx
,
self
.
yy
,
self
.
yz
],
dim
=-
1
),
torch
.
stack
([
self
.
yx
,
self
.
yy
,
self
.
yz
],
dim
=-
1
),
torch
.
stack
([
self
.
zx
,
self
.
zy
,
self
.
zz
],
dim
=-
1
)
torch
.
stack
([
self
.
zx
,
self
.
zy
,
self
.
zz
],
dim
=-
1
)
],
],
dim
=-
2
)
dim
=-
2
)
@
classmethod
@
classmethod
def
from_quaternion
(
cls
,
def
from_quaternion
(
cls
,
...
...
openfold/utils/geometry/vector.py
View file @
260db67f
...
@@ -134,13 +134,20 @@ class Vec3Array:
...
@@ -134,13 +134,20 @@ class Vec3Array:
return
Vec3Array
(
x
,
y
,
z
)
return
Vec3Array
(
x
,
y
,
z
)
def
sum
(
self
,
dim
)
->
Vec3Array
:
def
sum
(
self
,
dim
:
int
)
->
Vec3Array
:
return
Vec3Array
(
return
Vec3Array
(
torch
.
sum
(
self
.
x
,
dim
=
dim
),
torch
.
sum
(
self
.
x
,
dim
=
dim
),
torch
.
sum
(
self
.
y
,
dim
=
dim
),
torch
.
sum
(
self
.
y
,
dim
=
dim
),
torch
.
sum
(
self
.
z
,
dim
=
dim
),
torch
.
sum
(
self
.
z
,
dim
=
dim
),
)
)
def
unsqueeze
(
self
,
dim
:
int
):
return
Vec3Array
(
self
.
x
.
unsqueeze
(
dim
),
self
.
y
.
unsqueeze
(
dim
),
self
.
z
.
unsqueeze
(
dim
),
)
@
classmethod
@
classmethod
def
zeros
(
cls
,
shape
,
device
=
"cpu"
):
def
zeros
(
cls
,
shape
,
device
=
"cpu"
):
"""Return Vec3Array corresponding to zeros of given shape."""
"""Return Vec3Array corresponding to zeros of given shape."""
...
@@ -150,11 +157,11 @@ class Vec3Array:
...
@@ -150,11 +157,11 @@ class Vec3Array:
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
)
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
)
)
)
def
to_
array
(
self
)
->
torch
.
Tensor
:
def
to_
tensor
(
self
)
->
torch
.
Tensor
:
return
torch
.
stack
([
self
.
x
,
self
.
y
,
self
.
z
],
dim
=-
1
)
return
torch
.
stack
([
self
.
x
,
self
.
y
,
self
.
z
],
dim
=-
1
)
@
classmethod
@
classmethod
def
from_
tensor
(
cls
,
tensor
):
def
from_
array
(
cls
,
tensor
):
return
cls
(
*
torch
.
unbind
(
tensor
,
dim
=-
1
))
return
cls
(
*
torch
.
unbind
(
tensor
,
dim
=-
1
))
@
classmethod
@
classmethod
...
...
openfold/utils/import_weights.py
View file @
260db67f
...
@@ -623,8 +623,8 @@ def import_jax_weights_(model, npz_path, version="model_1"):
...
@@ -623,8 +623,8 @@ def import_jax_weights_(model, npz_path, version="model_1"):
flat_keys
=
list
(
flat
.
keys
())
flat_keys
=
list
(
flat
.
keys
())
incorrect
=
[
k
for
k
in
flat_keys
if
k
not
in
keys
]
incorrect
=
[
k
for
k
in
flat_keys
if
k
not
in
keys
]
missing
=
[
k
for
k
in
keys
if
k
not
in
flat_keys
]
missing
=
[
k
for
k
in
keys
if
k
not
in
flat_keys
]
print
(
f
"Incorrect:
{
incorrect
}
"
)
#
print(f"Incorrect: {incorrect}")
print
(
f
"Missing:
{
missing
}
"
)
#
print(f"Missing: {missing}")
assert
len
(
incorrect
)
==
0
assert
len
(
incorrect
)
==
0
# assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
# assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
...
...
run_pretrained_openfold.py
View file @
260db67f
...
@@ -217,7 +217,8 @@ def main(args):
...
@@ -217,7 +217,8 @@ def main(args):
unrelaxed_protein
=
protein
.
from_prediction
(
unrelaxed_protein
=
protein
.
from_prediction
(
features
=
batch
,
features
=
batch
,
result
=
out
,
result
=
out
,
b_factors
=
plddt_b_factors
b_factors
=
plddt_b_factors
,
remove_leading_feature_dimension
=
not
is_multimer
,
)
)
# Save the unrelaxed PDB.
# Save the unrelaxed PDB.
...
@@ -227,6 +228,9 @@ def main(args):
...
@@ -227,6 +228,9 @@ def main(args):
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
print
(
unrelaxed_output_path
)
print
(
"asdjfh klasjdhf lkasjdhf lkjasdhflkjasdh fkl jasdhfklj hasdkljf hasldkjfh lkasjdfh lkajsdhflk asd"
)
amber_relaxer
=
relax
.
AmberRelaxation
(
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
**
config
.
relax
,
**
config
.
relax
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment