Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
260db67f
Commit
260db67f
authored
May 12, 2022
by
Gustaf Ahdritz
Browse files
Finish multimer inference
parent
6e68d6b0
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
205 additions
and
132 deletions
+205
-132
openfold/data/input_pipeline_multimer.py
openfold/data/input_pipeline_multimer.py
+1
-0
openfold/model/embedders.py
openfold/model/embedders.py
+1
-1
openfold/model/structure_module.py
openfold/model/structure_module.py
+3
-6
openfold/np/protein.py
openfold/np/protein.py
+1
-1
openfold/utils/feats.py
openfold/utils/feats.py
+9
-9
openfold/utils/geometry/quat_rigid.py
openfold/utils/geometry/quat_rigid.py
+0
-2
openfold/utils/geometry/rigid_matrix_vector.py
openfold/utils/geometry/rigid_matrix_vector.py
+146
-98
openfold/utils/geometry/rotation_matrix.py
openfold/utils/geometry/rotation_matrix.py
+27
-9
openfold/utils/geometry/vector.py
openfold/utils/geometry/vector.py
+10
-3
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+2
-2
run_pretrained_openfold.py
run_pretrained_openfold.py
+5
-1
No files found.
openfold/data/input_pipeline_multimer.py
View file @
260db67f
...
...
@@ -29,6 +29,7 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms
.
cast_to_64bit_ints
,
data_transforms_multimer
.
make_msa_profile
,
data_transforms_multimer
.
create_target_feat
,
data_transforms
.
make_atom14_masks
,
]
if
(
common_cfg
.
use_templates
):
...
...
openfold/model/embedders.py
View file @
260db67f
...
...
@@ -868,7 +868,7 @@ class TemplateEmbedderMultimer(nn.Module):
raw_atom_pos
=
single_template_feats
[
"template_all_atom_positions"
]
atom_pos
=
geometry
.
Vec3Array
.
from_
tensor
(
raw_atom_pos
)
atom_pos
=
geometry
.
Vec3Array
.
from_
array
(
raw_atom_pos
)
rigid
,
backbone_mask
=
all_atom_multimer
.
make_backbone_affine
(
atom_pos
,
single_template_feats
[
"template_all_atom_mask"
],
...
...
openfold/model/structure_module.py
View file @
260db67f
...
...
@@ -363,9 +363,6 @@ class InvariantPointAttention(nn.Module):
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
for
c
in
q_pts
:
print
(
type
(
c
))
# [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
[...,
None
,
:,
:]
-
k_pts
[...,
None
,
:,
:,
:]
...
...
@@ -669,7 +666,7 @@ class StructureModule(nn.Module):
self
,
r
,
f
# [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
r
.
get_rots
().
dtype
,
r
.
get_rots
()
.
device
)
self
.
_init_residue_constants
(
r
.
dtype
,
r
.
device
)
return
frames_and_literature_positions_to_atom14_pos
(
r
,
f
,
...
...
@@ -818,11 +815,11 @@ class StructureModule(nn.Module):
)
preds
=
{
"frames"
:
rigids
.
scale_translation
(
self
.
trans_scale_factor
).
to_tensor
7
(),
"frames"
:
rigids
.
scale_translation
(
self
.
trans_scale_factor
).
to_tensor
(),
"sidechain_frames"
:
all_frames_to_global
.
to_tensor_4x4
(),
"unnormalized_angles"
:
unnormalized_angles
,
"angles"
:
angles
,
"positions"
:
pred_xyz
,
"positions"
:
pred_xyz
.
to_tensor
()
,
}
outputs
.
append
(
preds
)
...
...
openfold/np/protein.py
View file @
260db67f
...
...
@@ -205,7 +205,7 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
)
def
_chain_end
(
atom_index
,
end_resname
,
chain_name
,
residue_indx
)
->
str
:
def
_chain_end
(
atom_index
,
end_resname
,
chain_name
,
residue_ind
e
x
)
->
str
:
chain_end
=
'TER'
return
(
f
'
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
{
end_resname
:
>
3
}
'
...
...
openfold/utils/feats.py
View file @
260db67f
...
...
@@ -22,6 +22,7 @@ from typing import Dict
from
openfold.np
import
protein
import
openfold.np.residue_constants
as
rc
from
openfold.utils.geometry
import
rigid_matrix_vector
,
rotation_matrix
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
batched_gather
,
...
...
@@ -213,15 +214,14 @@ def torsion_angles_to_frames(
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots
=
alpha
.
new_zeros
(
default_r
.
get_rots
().
get_rot_mats
().
shape
)
all_rots
=
alpha
.
new_zeros
(
default_r
.
shape
+
(
3
,
3
)
)
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
=
Rigid
(
Rotation
(
rot_mats
=
all_rots
),
None
)
all_frames
=
default_r
.
compose
(
all_rots
)
all_rots
=
rotation_matrix
.
Rot3Array
.
from_array
(
all_rots
)
all_frames
=
default_r
.
compose_rotation
(
all_rots
)
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
...
...
@@ -232,7 +232,7 @@ def torsion_angles_to_frames(
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
all_frames_to_bb
=
Rigid
.
cat
(
all_frames_to_bb
=
rigid_matrix_vector
.
Rigid3Array
.
cat
(
[
all_frames
[...,
:
5
],
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
...
...
@@ -241,14 +241,14 @@ def torsion_angles_to_frames(
],
dim
=-
1
,
)
all_frames_to_global
=
r
[...,
None
].
compose
(
all_frames_to_bb
)
return
all_frames_to_global
def
frames_and_literature_positions_to_atom14_pos
(
r
:
Rigid
,
r
:
rigid_matrix_vector
.
Rigid3Array
,
aatype
:
torch
.
Tensor
,
default_frames
,
group_idx
,
...
...
@@ -275,8 +275,8 @@ def frames_and_literature_positions_to_atom14_pos(
lambda
x
:
torch
.
sum
(
x
,
dim
=-
1
)
)
# [*, N, 14
, 1
]
atom_mask
=
atom_mask
[
aatype
,
...]
.
unsqueeze
(
-
1
)
# [*, N, 14]
atom_mask
=
atom_mask
[
aatype
,
...]
# [*, N, 14, 3]
lit_positions
=
lit_positions
[
aatype
,
...]
...
...
openfold/utils/geometry/quat_rigid.py
View file @
260db67f
...
...
@@ -22,8 +22,6 @@ class QuatRigid(nn.Module):
# NOTE: During training, this needs to be run in higher precision
rigid_flat
=
self
.
linear
(
activations
.
to
(
torch
.
float32
))
print
(
rigid_flat
.
shape
)
rigid_flat
=
torch
.
unbind
(
rigid_flat
,
dim
=-
1
)
if
(
self
.
full_quat
):
qw
,
qx
,
qy
,
qz
=
rigid_flat
[:
4
]
...
...
openfold/utils/geometry/rigid_matrix_vector.py
View file @
260db67f
...
...
@@ -4,7 +4,7 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
@@ -29,100 +29,148 @@ Float = Union[float, torch.Tensor]
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Rigid3Array
:
"""Rigid Transformation, i.e. element of special euclidean group."""
rotation
:
rotation_matrix
.
Rot3Array
translation
:
vector
.
Vec3Array
def
__matmul__
(
self
,
other
:
Rigid3Array
)
->
Rigid3Array
:
new_rotation
=
self
.
rotation
@
other
.
rotation
# __matmul__
new_translation
=
self
.
apply_to_point
(
other
.
translation
)
return
Rigid3Array
(
new_rotation
,
new_translation
)
def
__getitem__
(
self
,
index
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
[
index
],
self
.
translation
[
index
],
)
def
__mul__
(
self
,
other
:
torch
.
Tensor
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
*
other
,
self
.
translation
*
other
,
)
def
map_tensor_fn
(
self
,
fn
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
.
map_tensor_fn
(
fn
),
self
.
translation
.
map_tensor_fn
(
fn
),
)
def
inverse
(
self
)
->
Rigid3Array
:
"""Return Rigid3Array corresponding to inverse transform."""
inv_rotation
=
self
.
rotation
.
inverse
()
inv_translation
=
inv_rotation
.
apply_to_point
(
-
self
.
translation
)
return
Rigid3Array
(
inv_rotation
,
inv_translation
)
def
apply_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Apply Rigid3Array transform to point."""
return
self
.
rotation
.
apply_to_point
(
point
)
+
self
.
translation
def
apply_inverse_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Apply inverse Rigid3Array transform to point."""
new_point
=
point
-
self
.
translation
return
self
.
rotation
.
apply_inverse_to_point
(
new_point
)
def
compose_rotation
(
self
,
other_rotation
):
rot
=
self
.
rotation
@
other_rotation
return
Rigid3Array
(
rot
,
trans
.
clone
())
@
classmethod
def
identity
(
cls
,
shape
,
device
)
->
Rigid3Array
:
"""Return identity Rigid3Array of given shape."""
return
cls
(
rotation_matrix
.
Rot3Array
.
identity
(
shape
,
device
),
vector
.
Vec3Array
.
zeros
(
shape
,
device
)
)
@
classmethod
def
cat
(
cls
,
rigids
:
List
[
Rigid3Array
],
dim
:
int
)
->
Rigid3Array
:
return
cls
(
Rot3Array
.
cat
([
r
.
rotation
for
r
in
rigids
],
dim
=
dim
),
Vec3Array
.
cat
([
r
.
translation
for
r
in
rigids
],
dim
=
dim
),
)
def
scale_translation
(
self
,
factor
:
Float
)
->
Rigid3Array
:
"""Scale translation in Rigid3Array by 'factor'."""
return
Rigid3Array
(
self
.
rotation
,
self
.
translation
*
factor
)
def
to_array
(
self
):
rot_array
=
self
.
rotation
.
to_array
()
vec_array
=
self
.
translation
.
to_array
()
return
torch
.
cat
([
rot_array
,
vec_array
[...,
None
]],
dim
=-
1
)
def
reshape
(
self
,
new_shape
)
->
Rigid3Array
:
rots
=
self
.
rotation
.
reshape
(
new_shape
)
trans
=
self
.
translation
.
reshape
(
new_shape
)
return
Rigid3Aray
(
rots
,
trans
)
@
classmethod
def
from_array
(
cls
,
array
):
rot
=
rotation_matrix
.
Rot3Array
.
from_array
(
array
[...,
:
3
])
vec
=
vector
.
Vec3Array
.
from_array
(
array
[...,
-
1
])
return
cls
(
rot
,
vec
)
@
classmethod
def
from_tensor_4x4
(
cls
,
array
):
return
cls
.
from_array
(
array
)
@
classmethod
def
from_array4x4
(
cls
,
array
:
torch
.
tensor
)
->
Rigid3Array
:
"""Construct Rigid3Array from homogeneous 4x4 array."""
rotation
=
rotation_matrix
.
Rot3Array
(
array
[...,
0
,
0
],
array
[...,
0
,
1
],
array
[...,
0
,
2
],
array
[...,
1
,
0
],
array
[...,
1
,
1
],
array
[...,
1
,
2
],
array
[...,
2
,
0
],
array
[...,
2
,
1
],
array
[...,
2
,
2
]
)
translation
=
vector
.
Vec3Array
(
array
[...,
0
,
3
],
array
[...,
1
,
3
],
array
[...,
2
,
3
])
return
cls
(
rotation
,
translation
)
"""Rigid Transformation, i.e. element of special euclidean group."""
rotation
:
rotation_matrix
.
Rot3Array
translation
:
vector
.
Vec3Array
def
__matmul__
(
self
,
other
:
Rigid3Array
)
->
Rigid3Array
:
new_rotation
=
self
.
rotation
@
other
.
rotation
# __matmul__
new_translation
=
self
.
apply_to_point
(
other
.
translation
)
return
Rigid3Array
(
new_rotation
,
new_translation
)
def
__getitem__
(
self
,
index
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
[
index
],
self
.
translation
[
index
],
)
def
__mul__
(
self
,
other
:
torch
.
Tensor
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
*
other
,
self
.
translation
*
other
,
)
def
map_tensor_fn
(
self
,
fn
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
.
map_tensor_fn
(
fn
),
self
.
translation
.
map_tensor_fn
(
fn
),
)
def
inverse
(
self
)
->
Rigid3Array
:
"""Return Rigid3Array corresponding to inverse transform."""
inv_rotation
=
self
.
rotation
.
inverse
()
inv_translation
=
inv_rotation
.
apply_to_point
(
-
self
.
translation
)
return
Rigid3Array
(
inv_rotation
,
inv_translation
)
def
apply_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Apply Rigid3Array transform to point."""
return
self
.
rotation
.
apply_to_point
(
point
)
+
self
.
translation
def
apply
(
self
,
point
:
torch
.
Tensor
)
->
vector
.
Vec3Array
:
return
self
.
apply_to_point
(
vector
.
Vec3Array
.
from_array
(
point
))
def
apply_inverse_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Apply inverse Rigid3Array transform to point."""
new_point
=
point
-
self
.
translation
return
self
.
rotation
.
apply_inverse_to_point
(
new_point
)
def
compose_rotation
(
self
,
other_rotation
):
rot
=
self
.
rotation
@
other_rotation
return
Rigid3Array
(
rot
,
self
.
translation
.
clone
())
def
compose
(
self
,
other_rigid
):
return
self
@
other_rigid
def
unsqueeze
(
self
,
dim
:
int
):
return
Rigid3Array
(
self
.
rotation
.
unsqueeze
(
dim
),
self
.
translation
.
unsqueeze
(
dim
),
)
@
property
def
shape
(
self
)
->
torch
.
Size
:
return
self
.
rotation
.
xx
.
shape
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
return
self
.
rotation
.
xx
.
dtype
@
property
def
device
(
self
)
->
torch
.
device
:
return
self
.
rotation
.
xx
.
device
@
classmethod
def
identity
(
cls
,
shape
,
device
)
->
Rigid3Array
:
"""Return identity Rigid3Array of given shape."""
return
cls
(
rotation_matrix
.
Rot3Array
.
identity
(
shape
,
device
),
vector
.
Vec3Array
.
zeros
(
shape
,
device
)
)
@
classmethod
def
cat
(
cls
,
rigids
:
List
[
Rigid3Array
],
dim
:
int
)
->
Rigid3Array
:
return
cls
(
rotation_matrix
.
Rot3Array
.
cat
(
[
r
.
rotation
for
r
in
rigids
],
dim
=
dim
),
vector
.
Vec3Array
.
cat
(
[
r
.
translation
for
r
in
rigids
],
dim
=
dim
),
)
def
scale_translation
(
self
,
factor
:
Float
)
->
Rigid3Array
:
"""Scale translation in Rigid3Array by 'factor'."""
return
Rigid3Array
(
self
.
rotation
,
self
.
translation
*
factor
)
def
to_tensor
(
self
)
->
torch
.
Tensor
:
rot_array
=
self
.
rotation
.
to_tensor
()
vec_array
=
self
.
translation
.
to_tensor
()
array
=
torch
.
zeros
(
rot_array
.
shape
[:
-
2
]
+
(
4
,
4
),
device
=
rot_array
.
device
,
dtype
=
rot_array
.
dtype
)
array
[...,
:
3
,
:
3
]
=
rot_array
array
[...,
:
3
,
3
]
=
vec_array
array
[...,
3
,
3
]
=
1.
return
array
def
to_tensor_4x4
(
self
)
->
torch
.
Tensor
:
return
self
.
to_tensor
()
def
reshape
(
self
,
new_shape
)
->
Rigid3Array
:
rots
=
self
.
rotation
.
reshape
(
new_shape
)
trans
=
self
.
translation
.
reshape
(
new_shape
)
return
Rigid3Aray
(
rots
,
trans
)
def
stop_rot_gradient
(
self
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
.
stop_gradient
(),
self
.
translation
,
)
@
classmethod
def
from_array
(
cls
,
array
):
rot
=
rotation_matrix
.
Rot3Array
.
from_array
(
array
[...,
:
3
,
:
3
],
)
vec
=
vector
.
Vec3Array
.
from_array
(
array
[...,
:
3
,
3
])
return
cls
(
rot
,
vec
)
@
classmethod
def
from_tensor_4x4
(
cls
,
array
):
return
cls
.
from_array
(
array
)
@
classmethod
def
from_array4x4
(
cls
,
array
:
torch
.
tensor
)
->
Rigid3Array
:
"""Construct Rigid3Array from homogeneous 4x4 array."""
rotation
=
rotation_matrix
.
Rot3Array
(
array
[...,
0
,
0
],
array
[...,
0
,
1
],
array
[...,
0
,
2
],
array
[...,
1
,
0
],
array
[...,
1
,
1
],
array
[...,
1
,
2
],
array
[...,
2
,
0
],
array
[...,
2
,
1
],
array
[...,
2
,
2
]
)
translation
=
vector
.
Vec3Array
(
array
[...,
0
,
3
],
array
[...,
1
,
3
],
array
[...,
2
,
3
]
)
return
cls
(
rotation
,
translation
)
openfold/utils/geometry/rotation_matrix.py
View file @
260db67f
...
...
@@ -22,6 +22,7 @@ import numpy as np
from
openfold.utils.geometry
import
struct_of_array
from
openfold.utils.geometry
import
utils
from
openfold.utils.geometry
import
vector
from
openfold.utils.tensor_utils
import
tensor_tree_map
COMPONENTS
=
[
'xx'
,
'xy'
,
'xz'
,
'yx'
,
'yy'
,
'yz'
,
'zx'
,
'zy'
,
'zz'
]
...
...
@@ -58,6 +59,13 @@ class Rot3Array:
for
name
in
field_names
}
)
def
__matmul__
(
self
,
other
:
Rot3Array
)
->
Rot3Array
:
"""Composes two Rot3Arrays."""
c0
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xx
,
other
.
yx
,
other
.
zx
))
c1
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xy
,
other
.
yy
,
other
.
zy
))
c2
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xz
,
other
.
yz
,
other
.
zz
))
return
Rot3Array
(
c0
.
x
,
c1
.
x
,
c2
.
x
,
c0
.
y
,
c1
.
y
,
c2
.
y
,
c0
.
z
,
c1
.
z
,
c2
.
z
)
def
map_tensor_fn
(
self
,
fn
)
->
Rot3Array
:
field_names
=
utils
.
get_field_names
(
Rot3Array
)
...
...
@@ -88,12 +96,19 @@ class Rot3Array:
"""Applies inverse Rot3Array to point."""
return
self
.
inverse
().
apply_to_point
(
point
)
def
__matmul__
(
self
,
other
:
Rot3Array
)
->
Rot3Array
:
"""Composes two Rot3Arrays."""
c0
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xx
,
other
.
yx
,
other
.
zx
))
c1
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xy
,
other
.
yy
,
other
.
zy
))
c2
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xz
,
other
.
yz
,
other
.
zz
))
return
Rot3Array
(
c0
.
x
,
c1
.
x
,
c2
.
x
,
c0
.
y
,
c1
.
y
,
c2
.
y
,
c0
.
z
,
c1
.
z
,
c2
.
z
)
def
unsqueeze
(
self
,
dim
:
int
):
return
Rot3Array
(
*
tensor_tree_map
(
lambda
t
:
t
.
unsqueeze
(
dim
),
[
getattr
(
self
,
c
)
for
c
in
COMPONENTS
]
)
)
def
stop_gradient
(
self
)
->
Rot3Array
:
return
Rot3Array
(
*
[
getattr
(
self
,
c
).
detach
()
for
c
in
COMPONENTS
]
)
@
classmethod
def
identity
(
cls
,
shape
,
device
)
->
Rot3Array
:
...
...
@@ -130,9 +145,11 @@ class Rot3Array:
@
classmethod
def
from_array
(
cls
,
array
:
torch
.
Tensor
)
->
Rot3Array
:
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
return
cls
(
torch
.
unbind
(
array
,
dim
=-
2
))
rows
=
torch
.
unbind
(
array
,
dim
=-
2
)
rc
=
[
torch
.
unbind
(
e
,
dim
=-
1
)
for
e
in
rows
]
return
cls
(
*
[
e
for
row
in
rc
for
e
in
row
])
def
to_
array
(
self
)
->
torch
.
Tensor
:
def
to_
tensor
(
self
)
->
torch
.
Tensor
:
"""Convert Rot3Array to array of shape [..., 3, 3]."""
return
torch
.
stack
(
[
...
...
@@ -140,7 +157,8 @@ class Rot3Array:
torch
.
stack
([
self
.
yx
,
self
.
yy
,
self
.
yz
],
dim
=-
1
),
torch
.
stack
([
self
.
zx
,
self
.
zy
,
self
.
zz
],
dim
=-
1
)
],
dim
=-
2
)
dim
=-
2
)
@
classmethod
def
from_quaternion
(
cls
,
...
...
openfold/utils/geometry/vector.py
View file @
260db67f
...
...
@@ -134,13 +134,20 @@ class Vec3Array:
return
Vec3Array
(
x
,
y
,
z
)
def
sum
(
self
,
dim
)
->
Vec3Array
:
def
sum
(
self
,
dim
:
int
)
->
Vec3Array
:
return
Vec3Array
(
torch
.
sum
(
self
.
x
,
dim
=
dim
),
torch
.
sum
(
self
.
y
,
dim
=
dim
),
torch
.
sum
(
self
.
z
,
dim
=
dim
),
)
def
unsqueeze
(
self
,
dim
:
int
):
return
Vec3Array
(
self
.
x
.
unsqueeze
(
dim
),
self
.
y
.
unsqueeze
(
dim
),
self
.
z
.
unsqueeze
(
dim
),
)
@
classmethod
def
zeros
(
cls
,
shape
,
device
=
"cpu"
):
"""Return Vec3Array corresponding to zeros of given shape."""
...
...
@@ -150,11 +157,11 @@ class Vec3Array:
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
)
)
def
to_
array
(
self
)
->
torch
.
Tensor
:
def
to_
tensor
(
self
)
->
torch
.
Tensor
:
return
torch
.
stack
([
self
.
x
,
self
.
y
,
self
.
z
],
dim
=-
1
)
@
classmethod
def
from_
tensor
(
cls
,
tensor
):
def
from_
array
(
cls
,
tensor
):
return
cls
(
*
torch
.
unbind
(
tensor
,
dim
=-
1
))
@
classmethod
...
...
openfold/utils/import_weights.py
View file @
260db67f
...
...
@@ -623,8 +623,8 @@ def import_jax_weights_(model, npz_path, version="model_1"):
flat_keys
=
list
(
flat
.
keys
())
incorrect
=
[
k
for
k
in
flat_keys
if
k
not
in
keys
]
missing
=
[
k
for
k
in
keys
if
k
not
in
flat_keys
]
print
(
f
"Incorrect:
{
incorrect
}
"
)
print
(
f
"Missing:
{
missing
}
"
)
#
print(f"Incorrect: {incorrect}")
#
print(f"Missing: {missing}")
assert
len
(
incorrect
)
==
0
# assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
...
...
run_pretrained_openfold.py
View file @
260db67f
...
...
@@ -217,7 +217,8 @@ def main(args):
unrelaxed_protein
=
protein
.
from_prediction
(
features
=
batch
,
result
=
out
,
b_factors
=
plddt_b_factors
b_factors
=
plddt_b_factors
,
remove_leading_feature_dimension
=
not
is_multimer
,
)
# Save the unrelaxed PDB.
...
...
@@ -227,6 +228,9 @@ def main(args):
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
print
(
unrelaxed_output_path
)
print
(
"asdjfh klasjdhf lkasjdhf lkjasdhflkjasdh fkl jasdhfklj hasdkljf hasldkjfh lkasjdfh lkajsdhflk asd"
)
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
**
config
.
relax
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment