Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
260db67f
"lib/llm/vscode:/vscode.git/clone" did not exist on "42ce6931a788c31a9f1f99b51336ba15e189f617"
Commit
260db67f
authored
May 12, 2022
by
Gustaf Ahdritz
Browse files
Finish multimer inference
parent
6e68d6b0
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
205 additions
and
132 deletions
+205
-132
openfold/data/input_pipeline_multimer.py
openfold/data/input_pipeline_multimer.py
+1
-0
openfold/model/embedders.py
openfold/model/embedders.py
+1
-1
openfold/model/structure_module.py
openfold/model/structure_module.py
+3
-6
openfold/np/protein.py
openfold/np/protein.py
+1
-1
openfold/utils/feats.py
openfold/utils/feats.py
+9
-9
openfold/utils/geometry/quat_rigid.py
openfold/utils/geometry/quat_rigid.py
+0
-2
openfold/utils/geometry/rigid_matrix_vector.py
openfold/utils/geometry/rigid_matrix_vector.py
+146
-98
openfold/utils/geometry/rotation_matrix.py
openfold/utils/geometry/rotation_matrix.py
+27
-9
openfold/utils/geometry/vector.py
openfold/utils/geometry/vector.py
+10
-3
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+2
-2
run_pretrained_openfold.py
run_pretrained_openfold.py
+5
-1
No files found.
openfold/data/input_pipeline_multimer.py
View file @
260db67f
...
...
@@ -29,6 +29,7 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms
.
cast_to_64bit_ints
,
data_transforms_multimer
.
make_msa_profile
,
data_transforms_multimer
.
create_target_feat
,
data_transforms
.
make_atom14_masks
,
]
if
(
common_cfg
.
use_templates
):
...
...
openfold/model/embedders.py
View file @
260db67f
...
...
@@ -868,7 +868,7 @@ class TemplateEmbedderMultimer(nn.Module):
raw_atom_pos
=
single_template_feats
[
"template_all_atom_positions"
]
atom_pos
=
geometry
.
Vec3Array
.
from_
tensor
(
raw_atom_pos
)
atom_pos
=
geometry
.
Vec3Array
.
from_
array
(
raw_atom_pos
)
rigid
,
backbone_mask
=
all_atom_multimer
.
make_backbone_affine
(
atom_pos
,
single_template_feats
[
"template_all_atom_mask"
],
...
...
openfold/model/structure_module.py
View file @
260db67f
...
...
@@ -363,9 +363,6 @@ class InvariantPointAttention(nn.Module):
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
for
c
in
q_pts
:
print
(
type
(
c
))
# [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
[...,
None
,
:,
:]
-
k_pts
[...,
None
,
:,
:,
:]
...
...
@@ -669,7 +666,7 @@ class StructureModule(nn.Module):
self
,
r
,
f
# [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
r
.
get_rots
().
dtype
,
r
.
get_rots
()
.
device
)
self
.
_init_residue_constants
(
r
.
dtype
,
r
.
device
)
return
frames_and_literature_positions_to_atom14_pos
(
r
,
f
,
...
...
@@ -818,11 +815,11 @@ class StructureModule(nn.Module):
)
preds
=
{
"frames"
:
rigids
.
scale_translation
(
self
.
trans_scale_factor
).
to_tensor
7
(),
"frames"
:
rigids
.
scale_translation
(
self
.
trans_scale_factor
).
to_tensor
(),
"sidechain_frames"
:
all_frames_to_global
.
to_tensor_4x4
(),
"unnormalized_angles"
:
unnormalized_angles
,
"angles"
:
angles
,
"positions"
:
pred_xyz
,
"positions"
:
pred_xyz
.
to_tensor
()
,
}
outputs
.
append
(
preds
)
...
...
openfold/np/protein.py
View file @
260db67f
...
...
@@ -205,7 +205,7 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
)
def
_chain_end
(
atom_index
,
end_resname
,
chain_name
,
residue_indx
)
->
str
:
def
_chain_end
(
atom_index
,
end_resname
,
chain_name
,
residue_ind
e
x
)
->
str
:
chain_end
=
'TER'
return
(
f
'
{
chain_end
:
<
6
}{
atom_index
:
>
5
}
{
end_resname
:
>
3
}
'
...
...
openfold/utils/feats.py
View file @
260db67f
...
...
@@ -22,6 +22,7 @@ from typing import Dict
from
openfold.np
import
protein
import
openfold.np.residue_constants
as
rc
from
openfold.utils.geometry
import
rigid_matrix_vector
,
rotation_matrix
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.tensor_utils
import
(
batched_gather
,
...
...
@@ -213,15 +214,14 @@ def torsion_angles_to_frames(
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots
=
alpha
.
new_zeros
(
default_r
.
get_rots
().
get_rot_mats
().
shape
)
all_rots
=
alpha
.
new_zeros
(
default_r
.
shape
+
(
3
,
3
)
)
all_rots
[...,
0
,
0
]
=
1
all_rots
[...,
1
,
1
]
=
alpha
[...,
1
]
all_rots
[...,
1
,
2
]
=
-
alpha
[...,
0
]
all_rots
[...,
2
,
1
:]
=
alpha
all_rots
=
Rigid
(
Rotation
(
rot_mats
=
all_rots
),
None
)
all_frames
=
default_r
.
compose
(
all_rots
)
all_rots
=
rotation_matrix
.
Rot3Array
.
from_array
(
all_rots
)
all_frames
=
default_r
.
compose_rotation
(
all_rots
)
chi2_frame_to_frame
=
all_frames
[...,
5
]
chi3_frame_to_frame
=
all_frames
[...,
6
]
...
...
@@ -232,7 +232,7 @@ def torsion_angles_to_frames(
chi3_frame_to_bb
=
chi2_frame_to_bb
.
compose
(
chi3_frame_to_frame
)
chi4_frame_to_bb
=
chi3_frame_to_bb
.
compose
(
chi4_frame_to_frame
)
all_frames_to_bb
=
Rigid
.
cat
(
all_frames_to_bb
=
rigid_matrix_vector
.
Rigid3Array
.
cat
(
[
all_frames
[...,
:
5
],
chi2_frame_to_bb
.
unsqueeze
(
-
1
),
...
...
@@ -248,7 +248,7 @@ def torsion_angles_to_frames(
def
frames_and_literature_positions_to_atom14_pos
(
r
:
Rigid
,
r
:
rigid_matrix_vector
.
Rigid3Array
,
aatype
:
torch
.
Tensor
,
default_frames
,
group_idx
,
...
...
@@ -275,8 +275,8 @@ def frames_and_literature_positions_to_atom14_pos(
lambda
x
:
torch
.
sum
(
x
,
dim
=-
1
)
)
# [*, N, 14
, 1
]
atom_mask
=
atom_mask
[
aatype
,
...]
.
unsqueeze
(
-
1
)
# [*, N, 14]
atom_mask
=
atom_mask
[
aatype
,
...]
# [*, N, 14, 3]
lit_positions
=
lit_positions
[
aatype
,
...]
...
...
openfold/utils/geometry/quat_rigid.py
View file @
260db67f
...
...
@@ -22,8 +22,6 @@ class QuatRigid(nn.Module):
# NOTE: During training, this needs to be run in higher precision
rigid_flat
=
self
.
linear
(
activations
.
to
(
torch
.
float32
))
print
(
rigid_flat
.
shape
)
rigid_flat
=
torch
.
unbind
(
rigid_flat
,
dim
=-
1
)
if
(
self
.
full_quat
):
qw
,
qx
,
qy
,
qz
=
rigid_flat
[:
4
]
...
...
openfold/utils/geometry/rigid_matrix_vector.py
View file @
260db67f
...
...
@@ -67,6 +67,9 @@ class Rigid3Array:
"""Apply Rigid3Array transform to point."""
return
self
.
rotation
.
apply_to_point
(
point
)
+
self
.
translation
def
apply
(
self
,
point
:
torch
.
Tensor
)
->
vector
.
Vec3Array
:
return
self
.
apply_to_point
(
vector
.
Vec3Array
.
from_array
(
point
))
def
apply_inverse_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Apply inverse Rigid3Array transform to point."""
new_point
=
point
-
self
.
translation
...
...
@@ -74,7 +77,28 @@ class Rigid3Array:
def
compose_rotation
(
self
,
other_rotation
):
rot
=
self
.
rotation
@
other_rotation
return
Rigid3Array
(
rot
,
trans
.
clone
())
return
Rigid3Array
(
rot
,
self
.
translation
.
clone
())
def
compose
(
self
,
other_rigid
):
return
self
@
other_rigid
def
unsqueeze
(
self
,
dim
:
int
):
return
Rigid3Array
(
self
.
rotation
.
unsqueeze
(
dim
),
self
.
translation
.
unsqueeze
(
dim
),
)
@
property
def
shape
(
self
)
->
torch
.
Size
:
return
self
.
rotation
.
xx
.
shape
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
return
self
.
rotation
.
xx
.
dtype
@
property
def
device
(
self
)
->
torch
.
device
:
return
self
.
rotation
.
xx
.
device
@
classmethod
def
identity
(
cls
,
shape
,
device
)
->
Rigid3Array
:
...
...
@@ -87,28 +111,51 @@ class Rigid3Array:
@
classmethod
def
cat
(
cls
,
rigids
:
List
[
Rigid3Array
],
dim
:
int
)
->
Rigid3Array
:
return
cls
(
Rot3Array
.
cat
([
r
.
rotation
for
r
in
rigids
],
dim
=
dim
),
Vec3Array
.
cat
([
r
.
translation
for
r
in
rigids
],
dim
=
dim
),
rotation_matrix
.
Rot3Array
.
cat
(
[
r
.
rotation
for
r
in
rigids
],
dim
=
dim
),
vector
.
Vec3Array
.
cat
(
[
r
.
translation
for
r
in
rigids
],
dim
=
dim
),
)
def
scale_translation
(
self
,
factor
:
Float
)
->
Rigid3Array
:
"""Scale translation in Rigid3Array by 'factor'."""
return
Rigid3Array
(
self
.
rotation
,
self
.
translation
*
factor
)
def
to_array
(
self
):
rot_array
=
self
.
rotation
.
to_array
()
vec_array
=
self
.
translation
.
to_array
()
return
torch
.
cat
([
rot_array
,
vec_array
[...,
None
]],
dim
=-
1
)
def
to_tensor
(
self
)
->
torch
.
Tensor
:
rot_array
=
self
.
rotation
.
to_tensor
()
vec_array
=
self
.
translation
.
to_tensor
()
array
=
torch
.
zeros
(
rot_array
.
shape
[:
-
2
]
+
(
4
,
4
),
device
=
rot_array
.
device
,
dtype
=
rot_array
.
dtype
)
array
[...,
:
3
,
:
3
]
=
rot_array
array
[...,
:
3
,
3
]
=
vec_array
array
[...,
3
,
3
]
=
1.
return
array
def
to_tensor_4x4
(
self
)
->
torch
.
Tensor
:
return
self
.
to_tensor
()
def
reshape
(
self
,
new_shape
)
->
Rigid3Array
:
rots
=
self
.
rotation
.
reshape
(
new_shape
)
trans
=
self
.
translation
.
reshape
(
new_shape
)
return
Rigid3Aray
(
rots
,
trans
)
def
stop_rot_gradient
(
self
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
.
stop_gradient
(),
self
.
translation
,
)
@
classmethod
def
from_array
(
cls
,
array
):
rot
=
rotation_matrix
.
Rot3Array
.
from_array
(
array
[...,
:
3
])
vec
=
vector
.
Vec3Array
.
from_array
(
array
[...,
-
1
])
rot
=
rotation_matrix
.
Rot3Array
.
from_array
(
array
[...,
:
3
,
:
3
],
)
vec
=
vector
.
Vec3Array
.
from_array
(
array
[...,
:
3
,
3
])
return
cls
(
rot
,
vec
)
@
classmethod
...
...
@@ -124,5 +171,6 @@ class Rigid3Array:
array
[...,
2
,
0
],
array
[...,
2
,
1
],
array
[...,
2
,
2
]
)
translation
=
vector
.
Vec3Array
(
array
[...,
0
,
3
],
array
[...,
1
,
3
],
array
[...,
2
,
3
])
array
[...,
0
,
3
],
array
[...,
1
,
3
],
array
[...,
2
,
3
]
)
return
cls
(
rotation
,
translation
)
openfold/utils/geometry/rotation_matrix.py
View file @
260db67f
...
...
@@ -22,6 +22,7 @@ import numpy as np
from
openfold.utils.geometry
import
struct_of_array
from
openfold.utils.geometry
import
utils
from
openfold.utils.geometry
import
vector
from
openfold.utils.tensor_utils
import
tensor_tree_map
COMPONENTS
=
[
'xx'
,
'xy'
,
'xz'
,
'yx'
,
'yy'
,
'yz'
,
'zx'
,
'zy'
,
'zz'
]
...
...
@@ -59,6 +60,13 @@ class Rot3Array:
}
)
def
__matmul__
(
self
,
other
:
Rot3Array
)
->
Rot3Array
:
"""Composes two Rot3Arrays."""
c0
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xx
,
other
.
yx
,
other
.
zx
))
c1
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xy
,
other
.
yy
,
other
.
zy
))
c2
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xz
,
other
.
yz
,
other
.
zz
))
return
Rot3Array
(
c0
.
x
,
c1
.
x
,
c2
.
x
,
c0
.
y
,
c1
.
y
,
c2
.
y
,
c0
.
z
,
c1
.
z
,
c2
.
z
)
def
map_tensor_fn
(
self
,
fn
)
->
Rot3Array
:
field_names
=
utils
.
get_field_names
(
Rot3Array
)
return
Rot3Array
(
...
...
@@ -88,12 +96,19 @@ class Rot3Array:
"""Applies inverse Rot3Array to point."""
return
self
.
inverse
().
apply_to_point
(
point
)
def
__matmul__
(
self
,
other
:
Rot3Array
)
->
Rot3Array
:
"""Composes two Rot3Arrays."""
c0
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xx
,
other
.
yx
,
other
.
zx
))
c1
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xy
,
other
.
yy
,
other
.
zy
))
c2
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xz
,
other
.
yz
,
other
.
zz
))
return
Rot3Array
(
c0
.
x
,
c1
.
x
,
c2
.
x
,
c0
.
y
,
c1
.
y
,
c2
.
y
,
c0
.
z
,
c1
.
z
,
c2
.
z
)
def
unsqueeze
(
self
,
dim
:
int
):
return
Rot3Array
(
*
tensor_tree_map
(
lambda
t
:
t
.
unsqueeze
(
dim
),
[
getattr
(
self
,
c
)
for
c
in
COMPONENTS
]
)
)
def
stop_gradient
(
self
)
->
Rot3Array
:
return
Rot3Array
(
*
[
getattr
(
self
,
c
).
detach
()
for
c
in
COMPONENTS
]
)
@
classmethod
def
identity
(
cls
,
shape
,
device
)
->
Rot3Array
:
...
...
@@ -130,9 +145,11 @@ class Rot3Array:
@
classmethod
def
from_array
(
cls
,
array
:
torch
.
Tensor
)
->
Rot3Array
:
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
return
cls
(
torch
.
unbind
(
array
,
dim
=-
2
))
rows
=
torch
.
unbind
(
array
,
dim
=-
2
)
rc
=
[
torch
.
unbind
(
e
,
dim
=-
1
)
for
e
in
rows
]
return
cls
(
*
[
e
for
row
in
rc
for
e
in
row
])
def
to_
array
(
self
)
->
torch
.
Tensor
:
def
to_
tensor
(
self
)
->
torch
.
Tensor
:
"""Convert Rot3Array to array of shape [..., 3, 3]."""
return
torch
.
stack
(
[
...
...
@@ -140,7 +157,8 @@ class Rot3Array:
torch
.
stack
([
self
.
yx
,
self
.
yy
,
self
.
yz
],
dim
=-
1
),
torch
.
stack
([
self
.
zx
,
self
.
zy
,
self
.
zz
],
dim
=-
1
)
],
dim
=-
2
)
dim
=-
2
)
@
classmethod
def
from_quaternion
(
cls
,
...
...
openfold/utils/geometry/vector.py
View file @
260db67f
...
...
@@ -134,13 +134,20 @@ class Vec3Array:
return
Vec3Array
(
x
,
y
,
z
)
def
sum
(
self
,
dim
)
->
Vec3Array
:
def
sum
(
self
,
dim
:
int
)
->
Vec3Array
:
return
Vec3Array
(
torch
.
sum
(
self
.
x
,
dim
=
dim
),
torch
.
sum
(
self
.
y
,
dim
=
dim
),
torch
.
sum
(
self
.
z
,
dim
=
dim
),
)
def
unsqueeze
(
self
,
dim
:
int
):
return
Vec3Array
(
self
.
x
.
unsqueeze
(
dim
),
self
.
y
.
unsqueeze
(
dim
),
self
.
z
.
unsqueeze
(
dim
),
)
@
classmethod
def
zeros
(
cls
,
shape
,
device
=
"cpu"
):
"""Return Vec3Array corresponding to zeros of given shape."""
...
...
@@ -150,11 +157,11 @@ class Vec3Array:
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
)
)
def
to_
array
(
self
)
->
torch
.
Tensor
:
def
to_
tensor
(
self
)
->
torch
.
Tensor
:
return
torch
.
stack
([
self
.
x
,
self
.
y
,
self
.
z
],
dim
=-
1
)
@
classmethod
def
from_
tensor
(
cls
,
tensor
):
def
from_
array
(
cls
,
tensor
):
return
cls
(
*
torch
.
unbind
(
tensor
,
dim
=-
1
))
@
classmethod
...
...
openfold/utils/import_weights.py
View file @
260db67f
...
...
@@ -623,8 +623,8 @@ def import_jax_weights_(model, npz_path, version="model_1"):
flat_keys
=
list
(
flat
.
keys
())
incorrect
=
[
k
for
k
in
flat_keys
if
k
not
in
keys
]
missing
=
[
k
for
k
in
keys
if
k
not
in
flat_keys
]
print
(
f
"Incorrect:
{
incorrect
}
"
)
print
(
f
"Missing:
{
missing
}
"
)
#
print(f"Incorrect: {incorrect}")
#
print(f"Missing: {missing}")
assert
len
(
incorrect
)
==
0
# assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
...
...
run_pretrained_openfold.py
View file @
260db67f
...
...
@@ -217,7 +217,8 @@ def main(args):
unrelaxed_protein
=
protein
.
from_prediction
(
features
=
batch
,
result
=
out
,
b_factors
=
plddt_b_factors
b_factors
=
plddt_b_factors
,
remove_leading_feature_dimension
=
not
is_multimer
,
)
# Save the unrelaxed PDB.
...
...
@@ -227,6 +228,9 @@ def main(args):
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
print
(
unrelaxed_output_path
)
print
(
"asdjfh klasjdhf lkasjdhf lkjasdhflkjasdh fkl jasdhfklj hasdkljf hasldkjfh lkasjdfh lkajsdhflk asd"
)
amber_relaxer
=
relax
.
AmberRelaxation
(
use_gpu
=
(
args
.
model_device
!=
"cpu"
),
**
config
.
relax
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment