Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
b14e47f4
Commit
b14e47f4
authored
Apr 26, 2023
by
zhuwenwen
Browse files
Merge branch 'main' of
https://github.com/hpcaitech/FastFold
parents
490cb6f5
05681304
Pipeline
#234
failed with stages
in 0 seconds
Changes
188
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4125 additions
and
0 deletions
+4125
-0
fastfold/utils/geometry/__init__.py
fastfold/utils/geometry/__init__.py
+29
-0
fastfold/utils/geometry/quat_rigid.py
fastfold/utils/geometry/quat_rigid.py
+38
-0
fastfold/utils/geometry/rigid_matrix_vector.py
fastfold/utils/geometry/rigid_matrix_vector.py
+175
-0
fastfold/utils/geometry/rotation_matrix.py
fastfold/utils/geometry/rotation_matrix.py
+208
-0
fastfold/utils/geometry/test_utils.py
fastfold/utils/geometry/test_utils.py
+97
-0
fastfold/utils/geometry/utils.py
fastfold/utils/geometry/utils.py
+22
-0
fastfold/utils/geometry/vector.py
fastfold/utils/geometry/vector.py
+263
-0
fastfold/utils/import_weights.py
fastfold/utils/import_weights.py
+627
-0
fastfold/utils/inject_fastnn.py
fastfold/utils/inject_fastnn.py
+423
-0
fastfold/utils/rigid_utils.py
fastfold/utils/rigid_utils.py
+1416
-0
fastfold/utils/superimposition.py
fastfold/utils/superimposition.py
+100
-0
fastfold/utils/tensor_utils.py
fastfold/utils/tensor_utils.py
+415
-0
fastfold/utils/test_utils.py
fastfold/utils/test_utils.py
+33
-0
fastfold/utils/validation_utils.py
fastfold/utils/validation_utils.py
+127
-0
fastfold/workflow/__init__.py
fastfold/workflow/__init__.py
+1
-0
fastfold/workflow/factory/__init__.py
fastfold/workflow/factory/__init__.py
+6
-0
fastfold/workflow/factory/hhblits.py
fastfold/workflow/factory/hhblits.py
+29
-0
fastfold/workflow/factory/hhfilter.py
fastfold/workflow/factory/hhfilter.py
+33
-0
fastfold/workflow/factory/hhsearch.py
fastfold/workflow/factory/hhsearch.py
+42
-0
fastfold/workflow/factory/hmmsearch.py
fastfold/workflow/factory/hmmsearch.py
+41
-0
No files found.
fastfold/utils/geometry/__init__.py
0 → 100644
View file @
b14e47f4
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Geometry Module."""
from
fastfold.utils.geometry
import
rigid_matrix_vector
from
fastfold.utils.geometry
import
rotation_matrix
from
fastfold.utils.geometry
import
vector
Rot3Array
=
rotation_matrix
.
Rot3Array
Rigid3Array
=
rigid_matrix_vector
.
Rigid3Array
Vec3Array
=
vector
.
Vec3Array
square_euclidean_distance
=
vector
.
square_euclidean_distance
euclidean_distance
=
vector
.
euclidean_distance
dihedral_angle
=
vector
.
dihedral_angle
dot
=
vector
.
dot
cross
=
vector
.
cross
fastfold/utils/geometry/quat_rigid.py
0 → 100644
View file @
b14e47f4
import
torch
import
torch.nn
as
nn
from
fastfold.model.nn.primitives
import
Linear
from
fastfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
fastfold.utils.geometry.rotation_matrix
import
Rot3Array
from
fastfold.utils.geometry.vector
import
Vec3Array
class
QuatRigid
(
nn
.
Module
):
def
__init__
(
self
,
c_hidden
,
full_quat
):
super
().
__init__
()
self
.
full_quat
=
full_quat
if
self
.
full_quat
:
rigid_dim
=
7
else
:
rigid_dim
=
6
self
.
linear
=
Linear
(
c_hidden
,
rigid_dim
)
def
forward
(
self
,
activations
:
torch
.
Tensor
)
->
Rigid3Array
:
# NOTE: During training, this needs to be run in higher precision
rigid_flat
=
self
.
linear
(
activations
.
to
(
torch
.
float32
))
rigid_flat
=
torch
.
unbind
(
rigid_flat
,
dim
=-
1
)
if
(
self
.
full_quat
):
qw
,
qx
,
qy
,
qz
=
rigid_flat
[:
4
]
translation
=
rigid_flat
[
4
:]
else
:
qx
,
qy
,
qz
=
rigid_flat
[:
3
]
qw
=
torch
.
ones_like
(
qx
)
translation
=
rigid_flat
[
3
:]
rotation
=
Rot3Array
.
from_quaternion
(
qw
,
qx
,
qy
,
qz
,
normalize
=
True
,
)
translation
=
Vec3Array
(
*
translation
)
return
Rigid3Array
(
rotation
,
translation
)
fastfold/utils/geometry/rigid_matrix_vector.py
0 → 100644
View file @
b14e47f4
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rigid3Array Transformations represented by a Matrix and a Vector."""
from
__future__
import
annotations
import
dataclasses
from
typing
import
Union
,
List
import
torch
from
fastfold.utils.geometry
import
rotation_matrix
from
fastfold.utils.geometry
import
vector
Float
=
Union
[
float
,
torch
.
Tensor
]
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Rigid3Array
:
"""Rigid Transformation, i.e. element of special euclidean group."""
rotation
:
rotation_matrix
.
Rot3Array
translation
:
vector
.
Vec3Array
def
__matmul__
(
self
,
other
:
Rigid3Array
)
->
Rigid3Array
:
new_rotation
=
self
.
rotation
@
other
.
rotation
# __matmul__
new_translation
=
self
.
apply_to_point
(
other
.
translation
)
return
Rigid3Array
(
new_rotation
,
new_translation
)
def
__getitem__
(
self
,
index
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
[
index
],
self
.
translation
[
index
],
)
def
__mul__
(
self
,
other
:
torch
.
Tensor
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
*
other
,
self
.
translation
*
other
,
)
def
map_tensor_fn
(
self
,
fn
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
.
map_tensor_fn
(
fn
),
self
.
translation
.
map_tensor_fn
(
fn
),
)
def
inverse
(
self
)
->
Rigid3Array
:
"""Return Rigid3Array corresponding to inverse transform."""
inv_rotation
=
self
.
rotation
.
inverse
()
inv_translation
=
inv_rotation
.
apply_to_point
(
-
self
.
translation
)
return
Rigid3Array
(
inv_rotation
,
inv_translation
)
def
apply_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Apply Rigid3Array transform to point."""
return
self
.
rotation
.
apply_to_point
(
point
)
+
self
.
translation
def
apply
(
self
,
point
:
torch
.
Tensor
)
->
vector
.
Vec3Array
:
return
self
.
apply_to_point
(
vector
.
Vec3Array
.
from_array
(
point
))
def
apply_inverse_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Apply inverse Rigid3Array transform to point."""
new_point
=
point
-
self
.
translation
return
self
.
rotation
.
apply_inverse_to_point
(
new_point
)
def
compose_rotation
(
self
,
other_rotation
):
rot
=
self
.
rotation
@
other_rotation
return
Rigid3Array
(
rot
,
self
.
translation
.
clone
())
def
compose
(
self
,
other_rigid
):
return
self
@
other_rigid
def
unsqueeze
(
self
,
dim
:
int
):
return
Rigid3Array
(
self
.
rotation
.
unsqueeze
(
dim
),
self
.
translation
.
unsqueeze
(
dim
),
)
@
property
def
shape
(
self
)
->
torch
.
Size
:
return
self
.
rotation
.
xx
.
shape
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
return
self
.
rotation
.
xx
.
dtype
@
property
def
device
(
self
)
->
torch
.
device
:
return
self
.
rotation
.
xx
.
device
@
classmethod
def
identity
(
cls
,
shape
,
device
)
->
Rigid3Array
:
"""Return identity Rigid3Array of given shape."""
return
cls
(
rotation_matrix
.
Rot3Array
.
identity
(
shape
,
device
),
vector
.
Vec3Array
.
zeros
(
shape
,
device
)
)
@
classmethod
def
cat
(
cls
,
rigids
:
List
[
Rigid3Array
],
dim
:
int
)
->
Rigid3Array
:
return
cls
(
rotation_matrix
.
Rot3Array
.
cat
(
[
r
.
rotation
for
r
in
rigids
],
dim
=
dim
),
vector
.
Vec3Array
.
cat
(
[
r
.
translation
for
r
in
rigids
],
dim
=
dim
),
)
def
scale_translation
(
self
,
factor
:
Float
)
->
Rigid3Array
:
"""Scale translation in Rigid3Array by 'factor'."""
return
Rigid3Array
(
self
.
rotation
,
self
.
translation
*
factor
)
def
to_tensor
(
self
)
->
torch
.
Tensor
:
rot_array
=
self
.
rotation
.
to_tensor
()
vec_array
=
self
.
translation
.
to_tensor
()
array
=
torch
.
zeros
(
rot_array
.
shape
[:
-
2
]
+
(
4
,
4
),
device
=
rot_array
.
device
,
dtype
=
rot_array
.
dtype
)
array
[...,
:
3
,
:
3
]
=
rot_array
array
[...,
:
3
,
3
]
=
vec_array
array
[...,
3
,
3
]
=
1.
return
array
def
to_tensor_4x4
(
self
)
->
torch
.
Tensor
:
return
self
.
to_tensor
()
def
reshape
(
self
,
new_shape
)
->
Rigid3Array
:
rots
=
self
.
rotation
.
reshape
(
new_shape
)
trans
=
self
.
translation
.
reshape
(
new_shape
)
return
Rigid3Aray
(
rots
,
trans
)
def
stop_rot_gradient
(
self
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
.
stop_gradient
(),
self
.
translation
,
)
@
classmethod
def
from_array
(
cls
,
array
):
rot
=
rotation_matrix
.
Rot3Array
.
from_array
(
array
[...,
:
3
,
:
3
],
)
vec
=
vector
.
Vec3Array
.
from_array
(
array
[...,
:
3
,
3
])
return
cls
(
rot
,
vec
)
@
classmethod
def
from_tensor_4x4
(
cls
,
array
):
return
cls
.
from_array
(
array
)
@
classmethod
def
from_array4x4
(
cls
,
array
:
torch
.
tensor
)
->
Rigid3Array
:
"""Construct Rigid3Array from homogeneous 4x4 array."""
rotation
=
rotation_matrix
.
Rot3Array
(
array
[...,
0
,
0
],
array
[...,
0
,
1
],
array
[...,
0
,
2
],
array
[...,
1
,
0
],
array
[...,
1
,
1
],
array
[...,
1
,
2
],
array
[...,
2
,
0
],
array
[...,
2
,
1
],
array
[...,
2
,
2
]
)
translation
=
vector
.
Vec3Array
(
array
[...,
0
,
3
],
array
[...,
1
,
3
],
array
[...,
2
,
3
]
)
return
cls
(
rotation
,
translation
)
fastfold/utils/geometry/rotation_matrix.py
0 → 100644
View file @
b14e47f4
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rot3Array Matrix Class."""
from
__future__
import
annotations
import
dataclasses
import
torch
import
numpy
as
np
from
fastfold.utils.geometry
import
utils
from
fastfold.utils.geometry
import
vector
from
fastfold.utils.tensor_utils
import
tensor_tree_map
COMPONENTS
=
[
'xx'
,
'xy'
,
'xz'
,
'yx'
,
'yy'
,
'yz'
,
'zx'
,
'zy'
,
'zz'
]
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Rot3Array
:
"""Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays."""
xx
:
torch
.
Tensor
=
dataclasses
.
field
(
metadata
=
{
'dtype'
:
torch
.
float32
})
xy
:
torch
.
Tensor
xz
:
torch
.
Tensor
yx
:
torch
.
Tensor
yy
:
torch
.
Tensor
yz
:
torch
.
Tensor
zx
:
torch
.
Tensor
zy
:
torch
.
Tensor
zz
:
torch
.
Tensor
__array_ufunc__
=
None
def
__getitem__
(
self
,
index
):
field_names
=
utils
.
get_field_names
(
Rot3Array
)
return
Rot3Array
(
**
{
name
:
getattr
(
self
,
name
)[
index
]
for
name
in
field_names
}
)
def
__mul__
(
self
,
other
:
torch
.
Tensor
):
field_names
=
utils
.
get_field_names
(
Rot3Array
)
return
Rot3Array
(
**
{
name
:
getattr
(
self
,
name
)
*
other
for
name
in
field_names
}
)
def
__matmul__
(
self
,
other
:
Rot3Array
)
->
Rot3Array
:
"""Composes two Rot3Arrays."""
c0
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xx
,
other
.
yx
,
other
.
zx
))
c1
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xy
,
other
.
yy
,
other
.
zy
))
c2
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xz
,
other
.
yz
,
other
.
zz
))
return
Rot3Array
(
c0
.
x
,
c1
.
x
,
c2
.
x
,
c0
.
y
,
c1
.
y
,
c2
.
y
,
c0
.
z
,
c1
.
z
,
c2
.
z
)
def
map_tensor_fn
(
self
,
fn
)
->
Rot3Array
:
field_names
=
utils
.
get_field_names
(
Rot3Array
)
return
Rot3Array
(
**
{
name
:
fn
(
getattr
(
self
,
name
))
for
name
in
field_names
}
)
def
inverse
(
self
)
->
Rot3Array
:
"""Returns inverse of Rot3Array."""
return
Rot3Array
(
self
.
xx
,
self
.
yx
,
self
.
zx
,
self
.
xy
,
self
.
yy
,
self
.
zy
,
self
.
xz
,
self
.
yz
,
self
.
zz
)
def
apply_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Applies Rot3Array to point."""
return
vector
.
Vec3Array
(
self
.
xx
*
point
.
x
+
self
.
xy
*
point
.
y
+
self
.
xz
*
point
.
z
,
self
.
yx
*
point
.
x
+
self
.
yy
*
point
.
y
+
self
.
yz
*
point
.
z
,
self
.
zx
*
point
.
x
+
self
.
zy
*
point
.
y
+
self
.
zz
*
point
.
z
)
def
apply_inverse_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Applies inverse Rot3Array to point."""
return
self
.
inverse
().
apply_to_point
(
point
)
def
unsqueeze
(
self
,
dim
:
int
):
return
Rot3Array
(
*
tensor_tree_map
(
lambda
t
:
t
.
unsqueeze
(
dim
),
[
getattr
(
self
,
c
)
for
c
in
COMPONENTS
]
)
)
def
stop_gradient
(
self
)
->
Rot3Array
:
return
Rot3Array
(
*
[
getattr
(
self
,
c
).
detach
()
for
c
in
COMPONENTS
]
)
@
classmethod
def
identity
(
cls
,
shape
,
device
)
->
Rot3Array
:
"""Returns identity of given shape."""
ones
=
torch
.
ones
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
)
zeros
=
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
)
return
cls
(
ones
,
zeros
,
zeros
,
zeros
,
ones
,
zeros
,
zeros
,
zeros
,
ones
)
@
classmethod
def
from_two_vectors
(
cls
,
e0
:
vector
.
Vec3Array
,
e1
:
vector
.
Vec3Array
)
->
Rot3Array
:
"""Construct Rot3Array from two Vectors.
Rot3Array is constructed such that in the corresponding frame 'e0' lies on
the positive x-Axis and 'e1' lies in the xy plane with positive sign of y.
Args:
e0: Vector
e1: Vector
Returns:
Rot3Array
"""
# Normalize the unit vector for the x-axis, e0.
e0
=
e0
.
normalized
()
# make e1 perpendicular to e0.
c
=
e1
.
dot
(
e0
)
e1
=
(
e1
-
c
*
e0
).
normalized
()
# Compute e2 as cross product of e0 and e1.
e2
=
e0
.
cross
(
e1
)
return
cls
(
e0
.
x
,
e1
.
x
,
e2
.
x
,
e0
.
y
,
e1
.
y
,
e2
.
y
,
e0
.
z
,
e1
.
z
,
e2
.
z
)
@
classmethod
def
from_array
(
cls
,
array
:
torch
.
Tensor
)
->
Rot3Array
:
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
rows
=
torch
.
unbind
(
array
,
dim
=-
2
)
rc
=
[
torch
.
unbind
(
e
,
dim
=-
1
)
for
e
in
rows
]
return
cls
(
*
[
e
for
row
in
rc
for
e
in
row
])
def
to_tensor
(
self
)
->
torch
.
Tensor
:
"""Convert Rot3Array to array of shape [..., 3, 3]."""
return
torch
.
stack
(
[
torch
.
stack
([
self
.
xx
,
self
.
xy
,
self
.
xz
],
dim
=-
1
),
torch
.
stack
([
self
.
yx
,
self
.
yy
,
self
.
yz
],
dim
=-
1
),
torch
.
stack
([
self
.
zx
,
self
.
zy
,
self
.
zz
],
dim
=-
1
)
],
dim
=-
2
)
@
classmethod
def
from_quaternion
(
cls
,
w
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
normalize
:
bool
=
True
,
eps
:
float
=
1e-6
)
->
Rot3Array
:
"""Construct Rot3Array from components of quaternion."""
if
normalize
:
inv_norm
=
torch
.
rsqrt
(
eps
+
w
**
2
+
x
**
2
+
y
**
2
+
z
**
2
)
w
*=
inv_norm
x
*=
inv_norm
y
*=
inv_norm
z
*=
inv_norm
xx
=
1
-
2
*
(
y
**
2
+
z
**
2
)
xy
=
2
*
(
x
*
y
-
w
*
z
)
xz
=
2
*
(
x
*
z
+
w
*
y
)
yx
=
2
*
(
x
*
y
+
w
*
z
)
yy
=
1
-
2
*
(
x
**
2
+
z
**
2
)
yz
=
2
*
(
y
*
z
-
w
*
x
)
zx
=
2
*
(
x
*
z
-
w
*
y
)
zy
=
2
*
(
y
*
z
+
w
*
x
)
zz
=
1
-
2
*
(
x
**
2
+
y
**
2
)
return
cls
(
xx
,
xy
,
xz
,
yx
,
yy
,
yz
,
zx
,
zy
,
zz
)
def
reshape
(
self
,
new_shape
):
field_names
=
utils
.
get_field_names
(
Rot3Array
)
reshape_fn
=
lambda
t
:
t
.
reshape
(
new_shape
)
return
Rot3Array
(
**
{
name
:
reshape_fn
(
getattr
(
self
,
name
))
for
name
in
field_names
}
)
@
classmethod
def
cat
(
cls
,
rots
:
List
[
Rot3Array
],
dim
:
int
)
->
Rot3Array
:
field_names
=
utils
.
get_field_names
(
Rot3Array
)
cat_fn
=
lambda
l
:
torch
.
cat
(
l
,
dim
=
dim
)
return
cls
(
**
{
name
:
cat_fn
([
getattr
(
r
,
name
)
for
r
in
rots
])
for
name
in
field_names
}
)
fastfold/utils/geometry/test_utils.py
0 → 100644
View file @
b14e47f4
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utils for tests."""
import
dataclasses
from
fastfold.utils.geometry
import
rigid_matrix_vector
from
fastfold.utils.geometry
import
rotation_matrix
from
fastfold.utils.geometry
import
vector
import
numpy
as
np
def
assert_rotation_matrix_equal
(
matrix1
:
rotation_matrix
.
Rot3Array
,
matrix2
:
rotation_matrix
.
Rot3Array
):
for
field
in
dataclasses
.
fields
(
rotation_matrix
.
Rot3Array
):
field
=
field
.
name
np
.
testing
.
assert_array_equal
(
getattr
(
matrix1
,
field
),
getattr
(
matrix2
,
field
))
def
assert_rotation_matrix_close
(
mat1
:
rotation_matrix
.
Rot3Array
,
mat2
:
rotation_matrix
.
Rot3Array
):
np
.
testing
.
assert_array_almost_equal
(
mat1
.
to_array
(),
mat2
.
to_array
(),
6
)
def
assert_array_equal_to_rotation_matrix
(
array
:
np
.
ndarray
,
matrix
:
rotation_matrix
.
Rot3Array
):
"""Check that array and Matrix match."""
np
.
testing
.
assert_array_equal
(
matrix
.
xx
,
array
[...,
0
,
0
])
np
.
testing
.
assert_array_equal
(
matrix
.
xy
,
array
[...,
0
,
1
])
np
.
testing
.
assert_array_equal
(
matrix
.
xz
,
array
[...,
0
,
2
])
np
.
testing
.
assert_array_equal
(
matrix
.
yx
,
array
[...,
1
,
0
])
np
.
testing
.
assert_array_equal
(
matrix
.
yy
,
array
[...,
1
,
1
])
np
.
testing
.
assert_array_equal
(
matrix
.
yz
,
array
[...,
1
,
2
])
np
.
testing
.
assert_array_equal
(
matrix
.
zx
,
array
[...,
2
,
0
])
np
.
testing
.
assert_array_equal
(
matrix
.
zy
,
array
[...,
2
,
1
])
np
.
testing
.
assert_array_equal
(
matrix
.
zz
,
array
[...,
2
,
2
])
def
assert_array_close_to_rotation_matrix
(
array
:
np
.
ndarray
,
matrix
:
rotation_matrix
.
Rot3Array
):
np
.
testing
.
assert_array_almost_equal
(
matrix
.
to_array
(),
array
,
6
)
def
assert_vectors_equal
(
vec1
:
vector
.
Vec3Array
,
vec2
:
vector
.
Vec3Array
):
np
.
testing
.
assert_array_equal
(
vec1
.
x
,
vec2
.
x
)
np
.
testing
.
assert_array_equal
(
vec1
.
y
,
vec2
.
y
)
np
.
testing
.
assert_array_equal
(
vec1
.
z
,
vec2
.
z
)
def
assert_vectors_close
(
vec1
:
vector
.
Vec3Array
,
vec2
:
vector
.
Vec3Array
):
np
.
testing
.
assert_allclose
(
vec1
.
x
,
vec2
.
x
,
atol
=
1e-6
,
rtol
=
0.
)
np
.
testing
.
assert_allclose
(
vec1
.
y
,
vec2
.
y
,
atol
=
1e-6
,
rtol
=
0.
)
np
.
testing
.
assert_allclose
(
vec1
.
z
,
vec2
.
z
,
atol
=
1e-6
,
rtol
=
0.
)
def
assert_array_close_to_vector
(
array
:
np
.
ndarray
,
vec
:
vector
.
Vec3Array
):
np
.
testing
.
assert_allclose
(
vec
.
to_array
(),
array
,
atol
=
1e-6
,
rtol
=
0.
)
def
assert_array_equal_to_vector
(
array
:
np
.
ndarray
,
vec
:
vector
.
Vec3Array
):
np
.
testing
.
assert_array_equal
(
vec
.
to_array
(),
array
)
def
assert_rigid_equal_to_rigid
(
rigid1
:
rigid_matrix_vector
.
Rigid3Array
,
rigid2
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rot_trans_equal_to_rigid
(
rigid1
.
rotation
,
rigid1
.
translation
,
rigid2
)
def
assert_rigid_close_to_rigid
(
rigid1
:
rigid_matrix_vector
.
Rigid3Array
,
rigid2
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rot_trans_close_to_rigid
(
rigid1
.
rotation
,
rigid1
.
translation
,
rigid2
)
def
assert_rot_trans_equal_to_rigid
(
rot
:
rotation_matrix
.
Rot3Array
,
trans
:
vector
.
Vec3Array
,
rigid
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rotation_matrix_equal
(
rot
,
rigid
.
rotation
)
assert_vectors_equal
(
trans
,
rigid
.
translation
)
def
assert_rot_trans_close_to_rigid
(
rot
:
rotation_matrix
.
Rot3Array
,
trans
:
vector
.
Vec3Array
,
rigid
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rotation_matrix_close
(
rot
,
rigid
.
rotation
)
assert_vectors_close
(
trans
,
rigid
.
translation
)
fastfold/utils/geometry/utils.py
0 → 100644
View file @
b14e47f4
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for geometry library."""
import
dataclasses
def
get_field_names
(
cls
):
fields
=
dataclasses
.
fields
(
cls
)
field_names
=
[
f
.
name
for
f
in
fields
]
return
field_names
fastfold/utils/geometry/vector.py
0 → 100644
View file @
b14e47f4
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vec3Array Class."""
from
__future__
import
annotations
import
dataclasses
from
typing
import
Union
,
List
import
torch
from
fastfold.utils.geometry
import
utils
Float
=
Union
[
float
,
torch
.
Tensor
]
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Vec3Array
:
x
:
torch
.
Tensor
=
dataclasses
.
field
(
metadata
=
{
'dtype'
:
torch
.
float32
})
y
:
torch
.
Tensor
z
:
torch
.
Tensor
def
__post_init__
(
self
):
if
hasattr
(
self
.
x
,
'dtype'
):
assert
self
.
x
.
dtype
==
self
.
y
.
dtype
assert
self
.
x
.
dtype
==
self
.
z
.
dtype
assert
all
([
x
==
y
for
x
,
y
in
zip
(
self
.
x
.
shape
,
self
.
y
.
shape
)])
assert
all
([
x
==
z
for
x
,
z
in
zip
(
self
.
x
.
shape
,
self
.
z
.
shape
)])
def
__add__
(
self
,
other
:
Vec3Array
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
+
other
.
x
,
self
.
y
+
other
.
y
,
self
.
z
+
other
.
z
,
)
def
__sub__
(
self
,
other
:
Vec3Array
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
-
other
.
x
,
self
.
y
-
other
.
y
,
self
.
z
-
other
.
z
,
)
def
__mul__
(
self
,
other
:
Float
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
*
other
,
self
.
y
*
other
,
self
.
z
*
other
,
)
def
__rmul__
(
self
,
other
:
Float
)
->
Vec3Array
:
return
self
*
other
def
__truediv__
(
self
,
other
:
Float
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
/
other
,
self
.
y
/
other
,
self
.
z
/
other
,
)
def
__neg__
(
self
)
->
Vec3Array
:
return
self
*
-
1
def
__pos__
(
self
)
->
Vec3Array
:
return
self
*
1
def
__getitem__
(
self
,
index
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
[
index
],
self
.
y
[
index
],
self
.
z
[
index
],
)
def
__iter__
(
self
):
return
iter
((
self
.
x
,
self
.
y
,
self
.
z
))
@
property
def
shape
(
self
):
return
self
.
x
.
shape
def
map_tensor_fn
(
self
,
fn
)
->
Vec3Array
:
return
Vec3Array
(
fn
(
self
.
x
),
fn
(
self
.
y
),
fn
(
self
.
z
),
)
def
cross
(
self
,
other
:
Vec3Array
)
->
Vec3Array
:
"""Compute cross product between 'self' and 'other'."""
new_x
=
self
.
y
*
other
.
z
-
self
.
z
*
other
.
y
new_y
=
self
.
z
*
other
.
x
-
self
.
x
*
other
.
z
new_z
=
self
.
x
*
other
.
y
-
self
.
y
*
other
.
x
return
Vec3Array
(
new_x
,
new_y
,
new_z
)
def
dot
(
self
,
other
:
Vec3Array
)
->
Float
:
"""Compute dot product between 'self' and 'other'."""
return
self
.
x
*
other
.
x
+
self
.
y
*
other
.
y
+
self
.
z
*
other
.
z
def
norm
(
self
,
epsilon
:
float
=
1e-6
)
->
Float
:
"""Compute Norm of Vec3Array, clipped to epsilon."""
# To avoid NaN on the backward pass, we must use maximum before the sqrt
norm2
=
self
.
dot
(
self
)
if
epsilon
:
norm2
=
torch
.
clamp
(
norm2
,
min
=
epsilon
**
2
)
return
torch
.
sqrt
(
norm2
)
def
norm2
(
self
):
return
self
.
dot
(
self
)
def
normalized
(
self
,
epsilon
:
float
=
1e-6
)
->
Vec3Array
:
"""Return unit vector with optional clipping."""
return
self
/
self
.
norm
(
epsilon
)
def
clone
(
self
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
.
clone
(),
self
.
y
.
clone
(),
self
.
z
.
clone
(),
)
def
reshape
(
self
,
new_shape
)
->
Vec3Array
:
x
=
self
.
x
.
reshape
(
new_shape
)
y
=
self
.
y
.
reshape
(
new_shape
)
z
=
self
.
z
.
reshape
(
new_shape
)
return
Vec3Array
(
x
,
y
,
z
)
def
sum
(
self
,
dim
:
int
)
->
Vec3Array
:
return
Vec3Array
(
torch
.
sum
(
self
.
x
,
dim
=
dim
),
torch
.
sum
(
self
.
y
,
dim
=
dim
),
torch
.
sum
(
self
.
z
,
dim
=
dim
),
)
def
unsqueeze
(
self
,
dim
:
int
):
return
Vec3Array
(
self
.
x
.
unsqueeze
(
dim
),
self
.
y
.
unsqueeze
(
dim
),
self
.
z
.
unsqueeze
(
dim
),
)
@
classmethod
def
zeros
(
cls
,
shape
,
device
=
"cpu"
):
"""Return Vec3Array corresponding to zeros of given shape."""
return
cls
(
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
),
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
),
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
)
)
def
to_tensor
(
self
)
->
torch
.
Tensor
:
return
torch
.
stack
([
self
.
x
,
self
.
y
,
self
.
z
],
dim
=-
1
)
@
classmethod
def
from_array
(
cls
,
tensor
):
return
cls
(
*
torch
.
unbind
(
tensor
,
dim
=-
1
))
@
classmethod
def
cat
(
cls
,
vecs
:
List
[
Vec3Array
],
dim
:
int
)
->
Vec3Array
:
return
cls
(
torch
.
cat
([
v
.
x
for
v
in
vecs
],
dim
=
dim
),
torch
.
cat
([
v
.
y
for
v
in
vecs
],
dim
=
dim
),
torch
.
cat
([
v
.
z
for
v
in
vecs
],
dim
=
dim
),
)
def
square_euclidean_distance
(
vec1
:
Vec3Array
,
vec2
:
Vec3Array
,
epsilon
:
float
=
1e-6
)
->
Float
:
"""Computes square of euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute distance to
vec2: Vec3Array to compute distance from, should be
broadcast compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of square euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
difference
=
vec1
-
vec2
distance
=
difference
.
dot
(
difference
)
if
epsilon
:
distance
=
torch
.
maximum
(
distance
,
epsilon
)
return
distance
def
dot
(
vector1
:
Vec3Array
,
vector2
:
Vec3Array
)
->
Float
:
return
vector1
.
dot
(
vector2
)
def
cross
(
vector1
:
Vec3Array
,
vector2
:
Vec3Array
)
->
Float
:
return
vector1
.
cross
(
vector2
)
def
norm
(
vector
:
Vec3Array
,
epsilon
:
float
=
1e-6
)
->
Float
:
return
vector
.
norm
(
epsilon
)
def
normalized
(
vector
:
Vec3Array
,
epsilon
:
float
=
1e-6
)
->
Vec3Array
:
return
vector
.
normalized
(
epsilon
)
def
euclidean_distance
(
vec1
:
Vec3Array
,
vec2
:
Vec3Array
,
epsilon
:
float
=
1e-6
)
->
Float
:
"""Computes euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute euclidean distance to
vec2: Vec3Array to compute euclidean distance from, should be
broadcast compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
distance_sq
=
square_euclidean_distance
(
vec1
,
vec2
,
epsilon
**
2
)
distance
=
torch
.
sqrt
(
distance_sq
)
return
distance
def
dihedral_angle
(
a
:
Vec3Array
,
b
:
Vec3Array
,
c
:
Vec3Array
,
d
:
Vec3Array
)
->
Float
:
"""Computes torsion angle for a quadruple of points.
For points (a, b, c, d), this is the angle between the planes defined by
points (a, b, c) and (b, c, d). It is also known as the dihedral angle.
Arguments:
a: A Vec3Array of coordinates.
b: A Vec3Array of coordinates.
c: A Vec3Array of coordinates.
d: A Vec3Array of coordinates.
Returns:
A tensor of angles in radians: [-pi, pi].
"""
v1
=
a
-
b
v2
=
b
-
c
v3
=
d
-
c
c1
=
v1
.
cross
(
v2
)
c2
=
v3
.
cross
(
v2
)
c3
=
c2
.
cross
(
c1
)
v2_mag
=
v2
.
norm
()
return
torch
.
atan2
(
c3
.
dot
(
v2
),
v2_mag
*
c1
.
dot
(
c2
))
fastfold/utils/import_weights.py
0 → 100644
View file @
b14e47f4
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
enum
import
Enum
from
dataclasses
import
dataclass
from
functools
import
partial
import
numpy
as
np
import
torch
from
typing
import
Union
,
List
from
fastfold.model.nn.triangular_multiplicative_update
import
is_fused_triangle_multiplication
_NPZ_KEY_PREFIX
=
"alphafold/alphafold_iteration/"
# With Param, a poor man's enum with attributes (Rust-style)
class
ParamType
(
Enum
):
LinearWeight
=
partial
(
# hack: partial prevents fns from becoming methods
lambda
w
:
w
.
transpose
(
-
1
,
-
2
)
)
LinearWeightMHA
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
2
],
-
1
).
transpose
(
-
1
,
-
2
)
)
LinearMHAOutputWeight
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
3
],
-
1
,
w
.
shape
[
-
1
]).
transpose
(
-
1
,
-
2
)
)
LinearBiasMHA
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
2
],
-
1
))
LinearWeightOPM
=
partial
(
lambda
w
:
w
.
reshape
(
*
w
.
shape
[:
-
3
],
-
1
,
w
.
shape
[
-
1
]).
transpose
(
-
1
,
-
2
)
)
LinearWeightMultimer
=
partial
(
lambda
w
:
w
.
unsqueeze
(
-
1
)
if
len
(
w
.
shape
)
==
1
else
w
.
reshape
(
w
.
shape
[
0
],
-
1
).
transpose
(
-
1
,
-
2
)
)
LinearBiasMultimer
=
partial
(
lambda
w
:
w
.
reshape
(
-
1
))
Other
=
partial
(
lambda
w
:
w
)
def
__init__
(
self
,
fn
):
self
.
transformation
=
fn
@
dataclass
class
Param
:
param
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
param_type
:
ParamType
=
ParamType
.
Other
stacked
:
bool
=
False
def
_process_translations_dict
(
d
,
top_layer
=
True
):
flat
=
{}
for
k
,
v
in
d
.
items
():
if
type
(
v
)
==
dict
:
prefix
=
_NPZ_KEY_PREFIX
if
top_layer
else
""
sub_flat
=
{
(
prefix
+
"/"
.
join
([
k
,
k_prime
])):
v_prime
for
k_prime
,
v_prime
in
_process_translations_dict
(
v
,
top_layer
=
False
).
items
()
}
flat
.
update
(
sub_flat
)
else
:
k
=
"/"
+
k
if
not
top_layer
else
k
flat
[
k
]
=
v
return
flat
def
stacked
(
param_dict_list
,
out
=
None
):
"""
Args:
param_dict_list:
A list of (nested) Param dicts to stack. The structure of
each dict must be the identical (down to the ParamTypes of
"parallel" Params). There must be at least one dict
in the list.
"""
if
out
is
None
:
out
=
{}
template
=
param_dict_list
[
0
]
for
k
,
_
in
template
.
items
():
v
=
[
d
[
k
]
for
d
in
param_dict_list
]
if
type
(
v
[
0
])
is
dict
:
out
[
k
]
=
{}
stacked
(
v
,
out
=
out
[
k
])
elif
type
(
v
[
0
])
is
Param
:
stacked_param
=
Param
(
param
=
[
param
.
param
for
param
in
v
],
param_type
=
v
[
0
].
param_type
,
stacked
=
True
,
)
out
[
k
]
=
stacked_param
return
out
def
assign
(
translation_dict
,
orig_weights
):
for
k
,
param
in
translation_dict
.
items
():
with
torch
.
no_grad
():
weights
=
torch
.
as_tensor
(
orig_weights
[
k
])
ref
,
param_type
=
param
.
param
,
param
.
param_type
if
param
.
stacked
:
weights
=
torch
.
unbind
(
weights
,
0
)
else
:
weights
=
[
weights
]
ref
=
[
ref
]
try
:
weights
=
list
(
map
(
param_type
.
transformation
,
weights
))
for
p
,
w
in
zip
(
ref
,
weights
):
p
.
copy_
(
w
)
except
:
print
(
k
)
print
(
ref
[
0
].
shape
)
print
(
weights
[
0
].
shape
)
raise
def
get_translation_dict
(
model
,
version
):
is_multimer
=
"multimer"
in
version
#######################
# Some templates
#######################
LinearWeight
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeight
))
LinearBias
=
lambda
l
:
(
Param
(
l
))
LinearWeightMHA
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightMHA
))
LinearBiasMHA
=
lambda
b
:
(
Param
(
b
,
param_type
=
ParamType
.
LinearBiasMHA
))
LinearWeightOPM
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightOPM
))
LinearWeightMultimer
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearWeightMultimer
)
)
LinearBiasMultimer
=
lambda
l
:
(
Param
(
l
,
param_type
=
ParamType
.
LinearBiasMultimer
))
LinearParams
=
lambda
l
:
{
"weights"
:
LinearWeight
(
l
.
weight
),
"bias"
:
LinearBias
(
l
.
bias
),
}
LinearParamsMultimer
=
lambda
l
:
{
"weights"
:
LinearWeightMultimer
(
l
.
weight
),
"bias"
:
LinearBiasMultimer
(
l
.
bias
),
}
LayerNormParams
=
lambda
l
:
{
"scale"
:
Param
(
l
.
weight
),
"offset"
:
Param
(
l
.
bias
),
}
AttentionParams
=
lambda
att
:
{
"query_w"
:
LinearWeightMHA
(
att
.
linear_q
.
weight
),
"key_w"
:
LinearWeightMHA
(
att
.
linear_k
.
weight
),
"value_w"
:
LinearWeightMHA
(
att
.
linear_v
.
weight
),
"output_w"
:
Param
(
att
.
linear_o
.
weight
,
param_type
=
ParamType
.
LinearMHAOutputWeight
,
),
"output_b"
:
LinearBias
(
att
.
linear_o
.
bias
),
}
AttentionGatedParams
=
lambda
att
:
dict
(
**
AttentionParams
(
att
),
**
{
"gating_w"
:
LinearWeightMHA
(
att
.
linear_g
.
weight
),
"gating_b"
:
LinearBiasMHA
(
att
.
linear_g
.
bias
),
},
)
GlobalAttentionParams
=
lambda
att
:
dict
(
AttentionGatedParams
(
att
),
key_w
=
LinearWeight
(
att
.
linear_k
.
weight
),
value_w
=
LinearWeight
(
att
.
linear_v
.
weight
),
)
TriAttParams
=
lambda
tri_att
:
{
"query_norm"
:
LayerNormParams
(
tri_att
.
layer_norm
),
"feat_2d_weights"
:
LinearWeight
(
tri_att
.
linear
.
weight
),
"attention"
:
AttentionGatedParams
(
tri_att
.
mha
),
}
if
is_fused_triangle_multiplication
():
TriMulOutParams
=
lambda
tri_mul
:
{
"left_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"projection"
:
LinearParams
(
tri_mul
.
linear_p
),
"gate"
:
LinearParams
(
tri_mul
.
linear_g
),
"center_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
"output_projection"
:
LinearParams
(
tri_mul
.
linear_z
),
"gating_linear"
:
LinearParams
(
tri_mul
.
linear_gate
),
}
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
TriMulInParams
=
lambda
tri_mul
:
{
"left_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"projection"
:
LinearParams
(
tri_mul
.
linear_p
),
"gate"
:
LinearParams
(
tri_mul
.
linear_g
),
"center_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
"output_projection"
:
LinearParams
(
tri_mul
.
linear_z
),
"gating_linear"
:
LinearParams
(
tri_mul
.
linear_gate
),
}
else
:
TriMulOutParams
=
lambda
tri_mul
:
{
"layer_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"left_projection"
:
LinearParams
(
tri_mul
.
linear_a_p
),
"right_projection"
:
LinearParams
(
tri_mul
.
linear_b_p
),
"left_gate"
:
LinearParams
(
tri_mul
.
linear_a_g
),
"right_gate"
:
LinearParams
(
tri_mul
.
linear_b_g
),
"center_layer_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
"output_projection"
:
LinearParams
(
tri_mul
.
linear_z
),
"gating_linear"
:
LinearParams
(
tri_mul
.
linear_g
),
}
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming
# iterations of triangle multiplication, which is confusing and not
# reproduced in our implementation.
TriMulInParams
=
lambda
tri_mul
:
{
"layer_norm_input"
:
LayerNormParams
(
tri_mul
.
layer_norm_in
),
"left_projection"
:
LinearParams
(
tri_mul
.
linear_b_p
),
"right_projection"
:
LinearParams
(
tri_mul
.
linear_a_p
),
"left_gate"
:
LinearParams
(
tri_mul
.
linear_b_g
),
"right_gate"
:
LinearParams
(
tri_mul
.
linear_a_g
),
"center_layer_norm"
:
LayerNormParams
(
tri_mul
.
layer_norm_out
),
"output_projection"
:
LinearParams
(
tri_mul
.
linear_z
),
"gating_linear"
:
LinearParams
(
tri_mul
.
linear_g
),
}
PairTransitionParams
=
lambda
pt
:
{
"input_layer_norm"
:
LayerNormParams
(
pt
.
layer_norm
),
"transition1"
:
LinearParams
(
pt
.
linear_1
),
"transition2"
:
LinearParams
(
pt
.
linear_2
),
}
MSAAttParams
=
lambda
matt
:
{
"query_norm"
:
LayerNormParams
(
matt
.
layer_norm_m
),
"attention"
:
AttentionGatedParams
(
matt
.
mha
),
}
MSAColAttParams
=
lambda
matt
:
{
"query_norm"
:
LayerNormParams
(
matt
.
_msa_att
.
layer_norm_m
),
"attention"
:
AttentionGatedParams
(
matt
.
_msa_att
.
mha
),
}
MSAGlobalAttParams
=
lambda
matt
:
{
"query_norm"
:
LayerNormParams
(
matt
.
layer_norm_m
),
"attention"
:
GlobalAttentionParams
(
matt
.
global_attention
),
}
MSAAttPairBiasParams
=
lambda
matt
:
dict
(
**
MSAAttParams
(
matt
),
**
{
"feat_2d_norm"
:
LayerNormParams
(
matt
.
layer_norm_z
),
"feat_2d_weights"
:
LinearWeight
(
matt
.
linear_z
.
weight
),
},
)
IPAParams
=
lambda
ipa
:
{
"q_scalar"
:
LinearParams
(
ipa
.
linear_q
),
"kv_scalar"
:
LinearParams
(
ipa
.
linear_kv
),
"q_point_local"
:
LinearParams
(
ipa
.
linear_q_points
),
# New style IPA param
# "q_point_local": LinearParams(ipa.linear_q_points.linear),
"kv_point_local"
:
LinearParams
(
ipa
.
linear_kv_points
),
# New style IPA param
# "kv_point_local": LinearParams(ipa.linear_kv_points.linear),
"trainable_point_weights"
:
Param
(
param
=
ipa
.
head_weights
,
param_type
=
ParamType
.
Other
),
"attention_2d"
:
LinearParams
(
ipa
.
linear_b
),
"output_projection"
:
LinearParams
(
ipa
.
linear_out
),
}
PointProjectionParams
=
lambda
pp
:
{
"point_projection"
:
LinearParamsMultimer
(
pp
.
linear
,
),
}
IPAParamsMultimer
=
lambda
ipa
:
{
"q_scalar_projection"
:
{
"weights"
:
LinearWeightMultimer
(
ipa
.
linear_q
.
weight
,
),
},
"k_scalar_projection"
:
{
"weights"
:
LinearWeightMultimer
(
ipa
.
linear_k
.
weight
,
),
},
"v_scalar_projection"
:
{
"weights"
:
LinearWeightMultimer
(
ipa
.
linear_v
.
weight
,
),
},
"q_point_projection"
:
PointProjectionParams
(
ipa
.
linear_q_points
),
"k_point_projection"
:
PointProjectionParams
(
ipa
.
linear_k_points
),
"v_point_projection"
:
PointProjectionParams
(
ipa
.
linear_v_points
),
"trainable_point_weights"
:
Param
(
param
=
ipa
.
head_weights
,
param_type
=
ParamType
.
Other
),
"attention_2d"
:
LinearParams
(
ipa
.
linear_b
),
"output_projection"
:
LinearParams
(
ipa
.
linear_out
),
}
TemplatePairBlockParams
=
lambda
b
:
{
"triangle_attention_starting_node"
:
TriAttParams
(
b
.
tri_att_start
),
"triangle_attention_ending_node"
:
TriAttParams
(
b
.
tri_att_end
),
"triangle_multiplication_outgoing"
:
TriMulOutParams
(
b
.
tri_mul_out
),
"triangle_multiplication_incoming"
:
TriMulInParams
(
b
.
tri_mul_in
),
"pair_transition"
:
PairTransitionParams
(
b
.
pair_transition
),
}
MSATransitionParams
=
lambda
m
:
{
"input_layer_norm"
:
LayerNormParams
(
m
.
layer_norm
),
"transition1"
:
LinearParams
(
m
.
linear_1
),
"transition2"
:
LinearParams
(
m
.
linear_2
),
}
OuterProductMeanParams
=
lambda
o
:
{
"layer_norm_input"
:
LayerNormParams
(
o
.
layer_norm
),
"left_projection"
:
LinearParams
(
o
.
linear_1
),
"right_projection"
:
LinearParams
(
o
.
linear_2
),
"output_w"
:
LinearWeightOPM
(
o
.
linear_out
.
weight
),
"output_b"
:
LinearBias
(
o
.
linear_out
.
bias
),
}
def
EvoformerBlockParams
(
b
,
is_extra_msa
=
False
):
if
is_extra_msa
:
col_att_name
=
"msa_column_global_attention"
msa_col_att_params
=
MSAGlobalAttParams
(
b
.
msa_att_col
)
else
:
col_att_name
=
"msa_column_attention"
msa_col_att_params
=
MSAColAttParams
(
b
.
msa_att_col
)
d
=
{
"msa_row_attention_with_pair_bias"
:
MSAAttPairBiasParams
(
b
.
msa_att_row
),
col_att_name
:
msa_col_att_params
,
"msa_transition"
:
MSATransitionParams
(
b
.
core
.
msa_transition
),
"outer_product_mean"
:
OuterProductMeanParams
(
b
.
core
.
outer_product_mean
),
"triangle_multiplication_outgoing"
:
TriMulOutParams
(
b
.
core
.
tri_mul_out
),
"triangle_multiplication_incoming"
:
TriMulInParams
(
b
.
core
.
tri_mul_in
),
"triangle_attention_starting_node"
:
TriAttParams
(
b
.
core
.
tri_att_start
),
"triangle_attention_ending_node"
:
TriAttParams
(
b
.
core
.
tri_att_end
),
"pair_transition"
:
PairTransitionParams
(
b
.
core
.
pair_transition
),
}
return
d
ExtraMSABlockParams
=
partial
(
EvoformerBlockParams
,
is_extra_msa
=
True
)
def
FoldIterationParams
(
sm
):
d
=
{
"invariant_point_attention"
:
IPAParamsMultimer
(
sm
.
ipa
)
if
is_multimer
else
IPAParams
(
sm
.
ipa
),
"attention_layer_norm"
:
LayerNormParams
(
sm
.
layer_norm_ipa
),
"transition"
:
LinearParams
(
sm
.
transition
.
layers
[
0
].
linear_1
),
"transition_1"
:
LinearParams
(
sm
.
transition
.
layers
[
0
].
linear_2
),
"transition_2"
:
LinearParams
(
sm
.
transition
.
layers
[
0
].
linear_3
),
"transition_layer_norm"
:
LayerNormParams
(
sm
.
transition
.
layer_norm
),
"affine_update"
:
LinearParams
(
sm
.
bb_update
.
linear
),
"rigid_sidechain"
:
{
"input_projection"
:
LinearParams
(
sm
.
angle_resnet
.
linear_in
),
"input_projection_1"
:
LinearParams
(
sm
.
angle_resnet
.
linear_initial
),
"resblock1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
0
].
linear_1
),
"resblock2"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
0
].
linear_2
),
"resblock1_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_1
),
"resblock2_1"
:
LinearParams
(
sm
.
angle_resnet
.
layers
[
1
].
linear_2
),
"unnormalized_angles"
:
LinearParams
(
sm
.
angle_resnet
.
linear_out
),
},
}
if
is_multimer
:
d
.
pop
(
"affine_update"
)
d
[
"quat_rigid"
]
=
{
"rigid"
:
LinearParams
(
sm
.
bb_update
.
linear
)}
return
d
############################
# translations dict overflow
############################
tps_blocks
=
model
.
template_embedder
.
template_pair_stack
.
blocks
tps_blocks_params
=
stacked
([
TemplatePairBlockParams
(
b
)
for
b
in
tps_blocks
])
ems_blocks
=
model
.
extra_msa_stack
.
blocks
ems_blocks_params
=
stacked
([
ExtraMSABlockParams
(
b
)
for
b
in
ems_blocks
])
evo_blocks
=
model
.
evoformer
.
blocks
evo_blocks_params
=
stacked
([
EvoformerBlockParams
(
b
)
for
b
in
evo_blocks
])
if
not
is_multimer
:
translations
=
{
"evoformer"
:
{
"preprocess_1d"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_m
),
"preprocess_msa"
:
LinearParams
(
model
.
input_embedder
.
linear_msa_m
),
"left_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_i
),
"right_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_j
),
"prev_pos_linear"
:
LinearParams
(
model
.
recycling_embedder
.
linear
),
"prev_msa_first_row_norm"
:
LayerNormParams
(
model
.
recycling_embedder
.
layer_norm_m
),
"prev_pair_norm"
:
LayerNormParams
(
model
.
recycling_embedder
.
layer_norm_z
),
"pair_activiations"
:
LinearParams
(
model
.
input_embedder
.
linear_relpos
),
"template_embedding"
:
{
"single_template_embedding"
:
{
"embedding2d"
:
LinearParams
(
model
.
template_embedder
.
template_pair_embedder
.
linear
),
"template_pair_stack"
:
{
"__layer_stack_no_state"
:
tps_blocks_params
,
},
"output_layer_norm"
:
LayerNormParams
(
model
.
template_embedder
.
template_pair_stack
.
layer_norm
),
},
"attention"
:
AttentionParams
(
model
.
template_embedder
.
template_pointwise_att
.
mha
),
},
"extra_msa_activations"
:
LinearParams
(
model
.
extra_msa_embedder
.
linear
),
"extra_msa_stack"
:
ems_blocks_params
,
"template_single_embedding"
:
LinearParams
(
model
.
template_embedder
.
template_angle_embedder
.
linear_1
),
"template_projection"
:
LinearParams
(
model
.
template_embedder
.
template_angle_embedder
.
linear_2
),
"evoformer_iteration"
:
evo_blocks_params
,
"single_activations"
:
LinearParams
(
model
.
evoformer
.
linear
),
},
"structure_module"
:
{
"single_layer_norm"
:
LayerNormParams
(
model
.
structure_module
.
layer_norm_s
),
"initial_projection"
:
LinearParams
(
model
.
structure_module
.
linear_in
),
"pair_layer_norm"
:
LayerNormParams
(
model
.
structure_module
.
layer_norm_z
),
"fold_iteration"
:
FoldIterationParams
(
model
.
structure_module
),
},
"predicted_lddt_head"
:
{
"input_layer_norm"
:
LayerNormParams
(
model
.
aux_heads
.
plddt
.
layer_norm
),
"act_0"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_1
),
"act_1"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_2
),
"logits"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_3
),
},
"distogram_head"
:
{
"half_logits"
:
LinearParams
(
model
.
aux_heads
.
distogram
.
linear
),
},
"experimentally_resolved_head"
:
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
experimentally_resolved
.
linear
),
},
"masked_msa_head"
:
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
masked_msa
.
linear
),
},
}
else
:
temp_embedder
=
model
.
template_embedder
translations
=
{
"evoformer"
:
{
"preprocess_1d"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_m
),
"preprocess_msa"
:
LinearParams
(
model
.
input_embedder
.
linear_msa_m
),
"left_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_i
),
"right_single"
:
LinearParams
(
model
.
input_embedder
.
linear_tf_z_j
),
"prev_pos_linear"
:
LinearParams
(
model
.
recycling_embedder
.
linear
),
"prev_msa_first_row_norm"
:
LayerNormParams
(
model
.
recycling_embedder
.
layer_norm_m
),
"prev_pair_norm"
:
LayerNormParams
(
model
.
recycling_embedder
.
layer_norm_z
),
"~_relative_encoding"
:
{
"position_activations"
:
LinearParams
(
model
.
input_embedder
.
linear_relpos
),
},
"template_embedding"
:
{
"single_template_embedding"
:
{
"query_embedding_norm"
:
LayerNormParams
(
temp_embedder
.
template_pair_embedder
.
query_embedding_layer_norm
),
"template_pair_embedding_0"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
dgram_linear
),
"template_pair_embedding_1"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
pseudo_beta_mask_linear
),
"template_pair_embedding_2"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
aatype_linear_1
),
"template_pair_embedding_3"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
aatype_linear_2
),
"template_pair_embedding_4"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
x_linear
),
"template_pair_embedding_5"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
y_linear
),
"template_pair_embedding_6"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
z_linear
),
"template_pair_embedding_7"
:
LinearParamsMultimer
(
temp_embedder
.
template_pair_embedder
.
backbone_mask_linear
),
"template_pair_embedding_8"
:
LinearParams
(
temp_embedder
.
template_pair_embedder
.
query_embedding_linear
),
"template_embedding_iteration"
:
tps_blocks_params
,
"output_layer_norm"
:
LayerNormParams
(
model
.
template_embedder
.
template_pair_stack
.
layer_norm
),
},
"output_linear"
:
LinearParams
(
temp_embedder
.
linear_t
),
},
"template_projection"
:
LinearParams
(
temp_embedder
.
template_single_embedder
.
template_projector
,
),
"template_single_embedding"
:
LinearParams
(
temp_embedder
.
template_single_embedder
.
template_single_embedder
,
),
"extra_msa_activations"
:
LinearParams
(
model
.
extra_msa_embedder
.
linear
),
"extra_msa_stack"
:
ems_blocks_params
,
"evoformer_iteration"
:
evo_blocks_params
,
"single_activations"
:
LinearParams
(
model
.
evoformer
.
linear
),
},
"structure_module"
:
{
"single_layer_norm"
:
LayerNormParams
(
model
.
structure_module
.
layer_norm_s
),
"initial_projection"
:
LinearParams
(
model
.
structure_module
.
linear_in
),
"pair_layer_norm"
:
LayerNormParams
(
model
.
structure_module
.
layer_norm_z
),
"fold_iteration"
:
FoldIterationParams
(
model
.
structure_module
),
},
"predicted_lddt_head"
:
{
"input_layer_norm"
:
LayerNormParams
(
model
.
aux_heads
.
plddt
.
layer_norm
),
"act_0"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_1
),
"act_1"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_2
),
"logits"
:
LinearParams
(
model
.
aux_heads
.
plddt
.
linear_3
),
},
"distogram_head"
:
{
"half_logits"
:
LinearParams
(
model
.
aux_heads
.
distogram
.
linear
),
},
"experimentally_resolved_head"
:
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
experimentally_resolved
.
linear
),
},
"masked_msa_head"
:
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
masked_msa
.
linear
),
},
}
no_templ
=
[
"model_3"
,
"model_4"
,
"model_5"
,
"model_3_ptm"
,
"model_4_ptm"
,
"model_5_ptm"
,
]
if
version
in
no_templ
:
evo_dict
=
translations
[
"evoformer"
]
keys
=
list
(
evo_dict
.
keys
())
for
k
in
keys
:
if
"template_"
in
k
:
evo_dict
.
pop
(
k
)
if
"_ptm"
in
version
or
is_multimer
:
translations
[
"predicted_aligned_error_head"
]
=
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
}
return
translations
def
import_jax_weights_
(
model
,
npz_path
,
version
=
"model_1"
):
data
=
np
.
load
(
npz_path
)
translations
=
get_translation_dict
(
model
,
version
)
# Flatten keys and insert missing key prefixes
flat
=
_process_translations_dict
(
translations
)
# Sanity check
keys
=
list
(
data
.
keys
())
flat_keys
=
list
(
flat
.
keys
())
incorrect
=
[
k
for
k
in
flat_keys
if
k
not
in
keys
]
missing
=
[
k
for
k
in
keys
if
k
not
in
flat_keys
]
# print(f"Incorrect: {incorrect}")
# print(f"Missing: {missing}")
assert
len
(
incorrect
)
==
0
# assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
# Set weights
assign
(
flat
,
data
)
if
is_fused_triangle_multiplication
():
# (NOTE) in multimer v3, alphafold use fused tri, so need change left/right here
for
b
in
model
.
template_embedder
.
template_pair_stack
.
blocks
:
_change_tri_mul_in_left_right
(
b
.
tri_mul_in
)
for
b
in
model
.
extra_msa_stack
.
blocks
:
_change_tri_mul_in_left_right
(
b
.
core
.
tri_mul_in
)
for
b
in
model
.
evoformer
.
blocks
:
_change_tri_mul_in_left_right
(
b
.
core
.
tri_mul_in
)
def
_change_tri_mul_in_left_right
(
module
):
def
_change_para
(
para
):
left_right_para
=
para
.
clone
().
chunk
(
2
,
dim
=
0
)
return
torch
.
cat
((
left_right_para
[
1
],
left_right_para
[
0
]),
dim
=
0
)
with
torch
.
no_grad
():
module
.
linear_p
.
weight
.
copy_
(
_change_para
(
module
.
linear_p
.
weight
))
module
.
linear_p
.
bias
.
copy_
(
_change_para
(
module
.
linear_p
.
bias
))
module
.
linear_g
.
weight
.
copy_
(
_change_para
(
module
.
linear_g
.
weight
))
module
.
linear_g
.
bias
.
copy_
(
_change_para
(
module
.
linear_g
.
bias
))
fastfold/utils/inject_fastnn.py
0 → 100644
View file @
b14e47f4
# Copyright 2022 BioMap (Beijing) Intelligence Technology Limited
# Copyright 2022 HPC-AI Technology Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
fastfold.model.fastnn
import
EvoformerStack
,
ExtraMSAStack
from
fastfold.model.fastnn.embedders
import
TemplateEmbedder
from
fastfold.model.fastnn.embedders_multimer
import
TemplateEmbedderMultimer
from
fastfold.model.fastnn.ops
import
RecyclingEmbedder
,
InputEmbedder
from
fastfold.model.nn.triangular_multiplicative_update
import
is_fused_triangle_multiplication
def
copy_layernorm
(
model_fast
,
model_ori
):
model_fast
.
weight
.
copy_
(
model_ori
.
weight
)
model_fast
.
bias
.
copy_
(
model_ori
.
bias
)
def
copy_linear
(
model_fast
,
model_ori
):
model_fast
.
weight
.
copy_
(
model_ori
.
weight
)
if
model_fast
.
use_bias
:
model_fast
.
bias
.
copy_
(
model_ori
.
bias
)
def
copy_native_linear
(
model_fast
,
model_ori
):
model_fast
.
weight
.
copy_
(
model_ori
.
weight
)
try
:
model_fast
.
bias
.
copy_
(
model_ori
.
bias
)
except
:
pass
def
copy_kv_linear
(
model_fast
,
ori_k
,
ori_v
):
model_fast
.
weight
.
copy_
(
torch
.
cat
((
ori_k
.
weight
,
ori_v
.
weight
),
dim
=
0
))
def
copy_qkv_linear
(
model_fast
,
ori_q
,
ori_k
,
ori_v
):
model_fast
.
weight
.
copy_
(
torch
.
cat
((
ori_q
.
weight
,
ori_k
.
weight
,
ori_v
.
weight
),
dim
=
0
))
def
copy_attention
(
model_fast
,
model_ori
):
copy_qkv_linear
(
model_fast
.
to_qkv
,
model_ori
.
linear_q
,
model_ori
.
linear_k
,
model_ori
.
linear_v
)
copy_linear
(
model_fast
.
gating_linear
,
model_ori
.
linear_g
)
copy_linear
(
model_fast
.
o_linear
,
model_ori
.
linear_o
)
try
:
model_fast
.
gating_bias
.
copy_
(
model_ori
.
linear_g
.
bias
)
except
:
print
(
"no gating_bias need copy"
)
def
copy_left_right
(
model_fast
,
ori_left
,
ori_right
):
model_fast
.
weight
.
copy_
(
torch
.
cat
((
ori_left
.
weight
,
ori_right
.
weight
),
dim
=
0
))
model_fast
.
bias
.
copy_
(
torch
.
cat
((
ori_left
.
bias
,
ori_right
.
bias
),
dim
=
0
))
def
copy_transition
(
model_fast
,
model_ori
):
copy_layernorm
(
model_fast
.
norm
,
model_ori
.
layer_norm
)
copy_linear
(
model_fast
.
linear1
,
model_ori
.
linear_1
)
copy_linear
(
model_fast
.
linear2
,
model_ori
.
linear_2
)
def
copy_triangle
(
model_fast
,
model_ori
):
copy_layernorm
(
model_fast
.
layernorm1
,
model_ori
.
layer_norm_in
)
copy_layernorm
(
model_fast
.
layernorm2
,
model_ori
.
layer_norm_out
)
copy_linear
(
model_fast
.
output_projection
,
model_ori
.
linear_z
)
model_fast
.
output_bias
.
copy_
(
model_ori
.
linear_z
.
bias
)
if
is_fused_triangle_multiplication
():
copy_linear
(
model_fast
.
output_gate
,
model_ori
.
linear_gate
)
copy_linear
(
model_fast
.
left_right_projection
,
model_ori
.
linear_p
)
copy_linear
(
model_fast
.
left_right_gate
,
model_ori
.
linear_g
)
else
:
copy_linear
(
model_fast
.
output_gate
,
model_ori
.
linear_g
)
copy_left_right
(
model_fast
.
left_right_projection
,
model_ori
.
linear_a_p
,
model_ori
.
linear_b_p
)
copy_left_right
(
model_fast
.
left_right_gate
,
model_ori
.
linear_a_g
,
model_ori
.
linear_b_g
)
def
copy_triangle_att
(
model_fast
,
model_ori
):
copy_layernorm
(
model_fast
.
layernorm1
,
model_ori
.
layer_norm
)
copy_linear
(
model_fast
.
linear_b
,
model_ori
.
linear
)
copy_attention
(
model_fast
.
attention
,
model_ori
.
mha
)
model_fast
.
out_bias
.
copy_
(
model_ori
.
mha
.
linear_o
.
bias
)
def
copy_native_att
(
model_fast
,
model_ori
):
copy_native_linear
(
model_fast
.
linear_q
,
model_ori
.
linear_q
)
copy_native_linear
(
model_fast
.
linear_k
,
model_ori
.
linear_k
)
copy_native_linear
(
model_fast
.
linear_v
,
model_ori
.
linear_v
)
copy_native_linear
(
model_fast
.
linear_o
,
model_ori
.
linear_o
)
if
model_ori
.
gating
:
copy_native_linear
(
model_fast
.
linear_g
,
model_ori
.
linear_g
)
def
copy_evoformer_para
(
block_fast
,
block_ori
):
# msa_stack
# MSARowAttentionWithPairBias
copy_layernorm
(
block_fast
.
msa
.
MSARowAttentionWithPairBias
.
layernormM
,
block_ori
.
msa_att_row
.
layer_norm_m
)
copy_layernorm
(
block_fast
.
msa
.
MSARowAttentionWithPairBias
.
layernormZ
,
block_ori
.
msa_att_row
.
layer_norm_z
)
copy_attention
(
block_fast
.
msa
.
MSARowAttentionWithPairBias
.
attention
,
block_ori
.
msa_att_row
.
mha
)
block_fast
.
msa
.
MSARowAttentionWithPairBias
.
linear_b_weights
.
copy_
(
block_ori
.
msa_att_row
.
linear_z
.
weight
)
block_fast
.
msa
.
MSARowAttentionWithPairBias
.
out_bias
.
copy_
(
block_ori
.
msa_att_row
.
mha
.
linear_o
.
bias
)
# MSAColumnAttention
copy_layernorm
(
block_fast
.
msa
.
MSAColumnAttention
.
layernormM
,
block_ori
.
msa_att_col
.
_msa_att
.
layer_norm_m
)
copy_attention
(
block_fast
.
msa
.
MSAColumnAttention
.
attention
,
block_ori
.
msa_att_col
.
_msa_att
.
mha
)
# MSATransition
copy_transition
(
block_fast
.
msa
.
MSATransition
,
block_ori
.
core
.
msa_transition
)
# communication
copy_layernorm
(
block_fast
.
communication
.
layernormM
,
block_ori
.
core
.
outer_product_mean
.
layer_norm
)
copy_linear
(
block_fast
.
communication
.
linear_a
,
block_ori
.
core
.
outer_product_mean
.
linear_1
)
copy_linear
(
block_fast
.
communication
.
linear_b
,
block_ori
.
core
.
outer_product_mean
.
linear_2
)
copy_linear
(
block_fast
.
communication
.
o_linear
,
block_ori
.
core
.
outer_product_mean
.
linear_out
)
# pair_stack
# TriangleMultiplicationOutgoing
copy_triangle
(
block_fast
.
pair
.
TriangleMultiplicationOutgoing
,
block_ori
.
core
.
tri_mul_out
)
# TriangleMultiplicationIncoming
copy_triangle
(
block_fast
.
pair
.
TriangleMultiplicationIncoming
,
block_ori
.
core
.
tri_mul_in
)
# TriangleAttentionStartingNode
copy_triangle_att
(
block_fast
.
pair
.
TriangleAttentionStartingNode
,
block_ori
.
core
.
tri_att_start
)
copy_triangle_att
(
block_fast
.
pair
.
TriangleAttentionEndingNode
,
block_ori
.
core
.
tri_att_end
)
copy_transition
(
block_fast
.
pair
.
PairTransition
,
block_ori
.
core
.
pair_transition
)
def
copy_global_attention
(
model_fast
,
model_ori
):
copy_linear
(
model_fast
.
to_q
,
model_ori
.
linear_q
)
copy_kv_linear
(
model_fast
.
to_kv
,
model_ori
.
linear_k
,
model_ori
.
linear_v
)
copy_linear
(
model_fast
.
gating_linear
,
model_ori
.
linear_g
)
copy_linear
(
model_fast
.
o_linear
,
model_ori
.
linear_o
)
try
:
model_fast
.
gating_bias
.
copy_
(
model_ori
.
linear_g
.
bias
)
except
:
print
(
"no gating_bias need copy"
)
def
copy_extra_msa_para
(
block_fast
,
block_ori
):
# msa_stack
# MSARowAttentionWithPairBias
copy_layernorm
(
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
layernormM
,
block_ori
.
msa_att_row
.
layer_norm_m
,
)
copy_layernorm
(
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
layernormZ
,
block_ori
.
msa_att_row
.
layer_norm_z
,
)
copy_attention
(
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
attention
,
block_ori
.
msa_att_row
.
mha
,
)
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
linear_b_weights
.
copy_
(
block_ori
.
msa_att_row
.
linear_z
.
weight
)
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
out_bias
.
copy_
(
block_ori
.
msa_att_row
.
mha
.
linear_o
.
bias
)
# MSAColumnAttention
copy_layernorm
(
block_fast
.
msa_stack
.
MSAColumnAttention
.
layernormM
,
block_ori
.
msa_att_col
.
layer_norm_m
,
)
copy_global_attention
(
block_fast
.
msa_stack
.
MSAColumnAttention
.
global_attention
,
block_ori
.
msa_att_col
.
global_attention
,
)
# MSATransition
copy_transition
(
block_fast
.
msa_stack
.
MSATransition
,
block_ori
.
core
.
msa_transition
)
# communication
comm_model
=
(
block_ori
.
core
.
outer_product_mean
# if not block_ori.is_multimer else block_ori.outer_product_mean
)
copy_layernorm
(
block_fast
.
communication
.
layernormM
,
comm_model
.
layer_norm
)
copy_linear
(
block_fast
.
communication
.
linear_a
,
comm_model
.
linear_1
)
copy_linear
(
block_fast
.
communication
.
linear_b
,
comm_model
.
linear_2
)
copy_linear
(
block_fast
.
communication
.
o_linear
,
comm_model
.
linear_out
)
# pair_stack
# TriangleMultiplicationOutgoing
copy_triangle
(
block_fast
.
pair_stack
.
TriangleMultiplicationOutgoing
,
block_ori
.
core
.
tri_mul_out
)
# TriangleMultiplicationIncoming
copy_triangle
(
block_fast
.
pair_stack
.
TriangleMultiplicationIncoming
,
block_ori
.
core
.
tri_mul_in
)
# TriangleAttentionStartingNode
copy_triangle_att
(
block_fast
.
pair_stack
.
TriangleAttentionStartingNode
,
block_ori
.
core
.
tri_att_start
,
)
copy_triangle_att
(
block_fast
.
pair_stack
.
TriangleAttentionEndingNode
,
block_ori
.
core
.
tri_att_end
)
copy_transition
(
block_fast
.
pair_stack
.
PairTransition
,
block_ori
.
core
.
pair_transition
)
def
copy_template_pair_stack_para
(
block_fast
,
block_ori
):
# TriangleMultiplicationOutgoing
copy_triangle
(
block_fast
.
TriangleMultiplicationOutgoing
,
block_ori
.
tri_mul_out
)
# TriangleMultiplicationIncoming
copy_triangle
(
block_fast
.
TriangleMultiplicationIncoming
,
block_ori
.
tri_mul_in
)
# TriangleAttentionStartingNode
copy_triangle_att
(
block_fast
.
TriangleAttentionStartingNode
,
block_ori
.
tri_att_start
)
copy_triangle_att
(
block_fast
.
TriangleAttentionEndingNode
,
block_ori
.
tri_att_end
)
copy_transition
(
block_fast
.
PairTransition
,
block_ori
.
pair_transition
)
def
copy_template_pair_block_para
(
fast_module
,
target_module
):
with
torch
.
no_grad
():
for
ori_block
,
fast_block
in
zip
(
target_module
.
blocks
,
fast_module
.
blocks
):
copy_template_pair_stack_para
(
fast_block
,
ori_block
)
if
ori_block
.
training
==
False
:
fast_block
.
eval
()
def
copy_template_para
(
block_fast
,
block_ori
):
# TemplateAngleEmbedder
copy_linear
(
block_fast
.
template_angle_embedder
.
linear_1
,
block_ori
.
template_angle_embedder
.
linear_1
)
copy_linear
(
block_fast
.
template_angle_embedder
.
linear_2
,
block_ori
.
template_angle_embedder
.
linear_2
)
# TemplatePairEmbedder
copy_linear
(
block_fast
.
template_pair_embedder
.
linear
,
block_ori
.
template_pair_embedder
.
linear
)
# TemplatePairStack
copy_template_pair_block_para
(
block_fast
.
template_pair_stack
,
block_ori
.
template_pair_stack
)
copy_layernorm
(
block_fast
.
template_pair_stack
.
layer_norm
,
block_ori
.
template_pair_stack
.
layer_norm
)
# TemplatePointwiseAttention
copy_native_att
(
block_fast
.
template_pointwise_att
.
mha
,
block_ori
.
template_pointwise_att
.
mha
)
def
copy_template_multimer_para
(
block_fast
,
block_ori
):
# TemplatePairEmbedderMultimer
copy_linear
(
block_fast
.
template_pair_embedder
.
dgram_linear
,
block_ori
.
template_pair_embedder
.
dgram_linear
)
copy_linear
(
block_fast
.
template_pair_embedder
.
aatype_linear_1
,
block_ori
.
template_pair_embedder
.
aatype_linear_1
)
copy_linear
(
block_fast
.
template_pair_embedder
.
aatype_linear_2
,
block_ori
.
template_pair_embedder
.
aatype_linear_2
)
copy_layernorm
(
block_fast
.
template_pair_embedder
.
query_embedding_layer_norm
,
block_ori
.
template_pair_embedder
.
query_embedding_layer_norm
)
copy_linear
(
block_fast
.
template_pair_embedder
.
query_embedding_linear
,
block_ori
.
template_pair_embedder
.
query_embedding_linear
)
copy_linear
(
block_fast
.
template_pair_embedder
.
pseudo_beta_mask_linear
,
block_ori
.
template_pair_embedder
.
pseudo_beta_mask_linear
)
copy_linear
(
block_fast
.
template_pair_embedder
.
x_linear
,
block_ori
.
template_pair_embedder
.
x_linear
)
copy_linear
(
block_fast
.
template_pair_embedder
.
y_linear
,
block_ori
.
template_pair_embedder
.
y_linear
)
copy_linear
(
block_fast
.
template_pair_embedder
.
z_linear
,
block_ori
.
template_pair_embedder
.
z_linear
)
copy_linear
(
block_fast
.
template_pair_embedder
.
backbone_mask_linear
,
block_ori
.
template_pair_embedder
.
backbone_mask_linear
)
# TemplateSingleEmbedderMultimer
copy_linear
(
block_fast
.
template_single_embedder
.
template_single_embedder
,
block_ori
.
template_single_embedder
.
template_single_embedder
)
copy_linear
(
block_fast
.
template_single_embedder
.
template_projector
,
block_ori
.
template_single_embedder
.
template_projector
)
# TemplatePairStack
copy_template_pair_block_para
(
block_fast
.
template_pair_stack
,
block_ori
.
template_pair_stack
)
copy_layernorm
(
block_fast
.
template_pair_stack
.
layer_norm
,
block_ori
.
template_pair_stack
.
layer_norm
)
# linear_t
copy_linear
(
block_fast
.
linear_t
,
block_ori
.
linear_t
)
def
inject_evoformer
(
model
):
with
torch
.
no_grad
():
target_module
=
model
.
evoformer
fast_module
=
EvoformerStack
(
c_m
=
target_module
.
blocks
[
0
].
msa_att_row
.
c_in
,
c_z
=
target_module
.
blocks
[
0
].
msa_att_row
.
c_z
,
c_s
=
target_module
.
linear
.
out_features
,
no_blocks
=
len
(
target_module
.
blocks
),
blocks_per_ckpt
=
target_module
.
blocks_per_ckpt
,
clear_cache_between_blocks
=
target_module
.
clear_cache_between_blocks
,
is_multimer
=
target_module
.
blocks
[
0
].
is_multimer
,
)
for
target_block
,
fast_block
in
zip
(
target_module
.
blocks
,
fast_module
.
blocks
):
copy_evoformer_para
(
fast_block
,
target_block
)
if
target_block
.
training
==
False
:
fast_block
.
eval
()
copy_linear
(
fast_module
.
linear
,
target_module
.
linear
)
model
.
evoformer
=
fast_module
def
inject_extramsa
(
model
):
with
torch
.
no_grad
():
target_module
=
model
.
extra_msa_stack
fast_module
=
ExtraMSAStack
(
c_m
=
target_module
.
blocks
[
0
].
msa_att_row
.
c_in
,
c_z
=
target_module
.
blocks
[
0
].
msa_att_row
.
c_z
,
no_blocks
=
len
(
target_module
.
blocks
),
clear_cache_between_blocks
=
target_module
.
clear_cache_between_blocks
,
is_multimer
=
target_module
.
blocks
[
0
].
is_multimer
,
)
for
target_block
,
fast_block
in
zip
(
target_module
.
blocks
,
fast_module
.
blocks
):
copy_extra_msa_para
(
fast_block
,
target_block
)
if
target_block
.
training
==
False
:
fast_block
.
eval
()
model
.
extra_msa_stack
=
fast_module
def
inject_template
(
model
):
with
torch
.
no_grad
():
if
model
.
evoformer
.
blocks
[
0
].
is_multimer
:
target_module
=
model
.
template_embedder
fast_module
=
TemplateEmbedderMultimer
(
config
=
model
.
template_embedder
.
config
)
copy_template_multimer_para
(
fast_module
,
target_module
)
if
target_module
.
training
==
False
:
fast_module
.
eval
()
model
.
template_embedder
=
fast_module
else
:
target_module
=
model
.
template_embedder
fast_module
=
TemplateEmbedder
(
config
=
model
.
template_embedder
.
config
)
copy_template_para
(
fast_module
,
target_module
)
if
target_module
.
training
==
False
:
fast_module
.
eval
()
model
.
template_embedder
=
fast_module
def
inject_embedder
(
model
):
if
model
.
evoformer
.
blocks
[
0
].
is_multimer
:
return
# recycle embedder
with
torch
.
no_grad
():
target_module
=
model
.
recycling_embedder
fast_module
=
RecyclingEmbedder
(
c_m
=
target_module
.
c_m
,
c_z
=
target_module
.
c_z
,
min_bin
=
target_module
.
min_bin
,
max_bin
=
target_module
.
max_bin
,
no_bins
=
target_module
.
no_bins
,
inf
=
target_module
.
inf
)
copy_native_linear
(
fast_module
.
linear
,
target_module
.
linear
)
copy_layernorm
(
fast_module
.
layer_norm_m
,
target_module
.
layer_norm_m
)
copy_layernorm
(
fast_module
.
layer_norm_z
,
target_module
.
layer_norm_z
)
if
target_module
.
training
==
False
:
fast_module
.
eval
()
model
.
recycling_embedder
=
fast_module
# input embedder
with
torch
.
no_grad
():
target_module
=
model
.
input_embedder
fast_module
=
InputEmbedder
(
tf_dim
=
target_module
.
tf_dim
,
msa_dim
=
target_module
.
msa_dim
,
c_z
=
target_module
.
c_z
,
c_m
=
target_module
.
c_m
,
relpos_k
=
target_module
.
relpos_k
,
)
copy_linear
(
fast_module
.
linear_tf_z_i
,
target_module
.
linear_tf_z_i
)
copy_linear
(
fast_module
.
linear_tf_z_j
,
target_module
.
linear_tf_z_j
)
copy_linear
(
fast_module
.
linear_tf_m
,
target_module
.
linear_tf_m
)
copy_linear
(
fast_module
.
linear_msa_m
,
target_module
.
linear_msa_m
)
copy_linear
(
fast_module
.
linear_relpos
,
target_module
.
linear_relpos
)
if
target_module
.
training
==
False
:
fast_module
.
eval
()
model
.
input_embedder
=
fast_module
def
inject_fastnn
(
model
):
inject_evoformer
(
model
)
inject_extramsa
(
model
)
inject_template
(
model
)
inject_embedder
(
model
)
return
model
\ No newline at end of file
fastfold/utils/rigid_utils.py
0 → 100644
View file @
b14e47f4
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
annotations
from
typing
import
Tuple
,
Any
,
Sequence
,
Callable
,
Optional
import
numpy
as
np
import
torch
import
fastfold.habana
as
habana
def
rot_matmul
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Performs matrix multiplication of two rotation matrix tensors. Written
out by hand to avoid AMP downcasting.
Args:
a: [*, 3, 3] left multiplicand
b: [*, 3, 3] right multiplicand
Returns:
The product ab
"""
if
habana
.
is_habana
():
if
len
(
a
.
shape
)
==
4
and
a
.
shape
[
1
]
==
1
:
aa
=
a
.
permute
(
0
,
1
,
3
,
2
)
bb
=
b
.
permute
(
0
,
1
,
3
,
2
)
cc
=
bb
@
aa
cc
=
cc
.
permute
(
0
,
1
,
3
,
2
)
return
cc
elif
len
(
a
.
shape
)
==
4
and
a
.
shape
[
1
]
!=
1
:
pass
else
:
cc
=
a
@
b
return
cc
row_1
=
torch
.
stack
(
[
a
[...,
0
,
0
]
*
b
[...,
0
,
0
]
+
a
[...,
0
,
1
]
*
b
[...,
1
,
0
]
+
a
[...,
0
,
2
]
*
b
[...,
2
,
0
],
a
[...,
0
,
0
]
*
b
[...,
0
,
1
]
+
a
[...,
0
,
1
]
*
b
[...,
1
,
1
]
+
a
[...,
0
,
2
]
*
b
[...,
2
,
1
],
a
[...,
0
,
0
]
*
b
[...,
0
,
2
]
+
a
[...,
0
,
1
]
*
b
[...,
1
,
2
]
+
a
[...,
0
,
2
]
*
b
[...,
2
,
2
],
],
dim
=-
1
,
)
row_2
=
torch
.
stack
(
[
a
[...,
1
,
0
]
*
b
[...,
0
,
0
]
+
a
[...,
1
,
1
]
*
b
[...,
1
,
0
]
+
a
[...,
1
,
2
]
*
b
[...,
2
,
0
],
a
[...,
1
,
0
]
*
b
[...,
0
,
1
]
+
a
[...,
1
,
1
]
*
b
[...,
1
,
1
]
+
a
[...,
1
,
2
]
*
b
[...,
2
,
1
],
a
[...,
1
,
0
]
*
b
[...,
0
,
2
]
+
a
[...,
1
,
1
]
*
b
[...,
1
,
2
]
+
a
[...,
1
,
2
]
*
b
[...,
2
,
2
],
],
dim
=-
1
,
)
row_3
=
torch
.
stack
(
[
a
[...,
2
,
0
]
*
b
[...,
0
,
0
]
+
a
[...,
2
,
1
]
*
b
[...,
1
,
0
]
+
a
[...,
2
,
2
]
*
b
[...,
2
,
0
],
a
[...,
2
,
0
]
*
b
[...,
0
,
1
]
+
a
[...,
2
,
1
]
*
b
[...,
1
,
1
]
+
a
[...,
2
,
2
]
*
b
[...,
2
,
1
],
a
[...,
2
,
0
]
*
b
[...,
0
,
2
]
+
a
[...,
2
,
1
]
*
b
[...,
1
,
2
]
+
a
[...,
2
,
2
]
*
b
[...,
2
,
2
],
],
dim
=-
1
,
)
return
torch
.
stack
([
row_1
,
row_2
,
row_3
],
dim
=-
2
)
def
rot_vec_mul
(
r
:
torch
.
Tensor
,
t
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Applies a rotation to a vector. Written out by hand to avoid transfer
to avoid AMP downcasting.
Args:
r: [*, 3, 3] rotation matrices
t: [*, 3] coordinate tensors
Returns:
[*, 3] rotated coordinates
"""
if
habana
.
is_habana
():
cont
=
True
if
len
(
t
.
shape
)
==
4
and
t
.
shape
[
1
]
==
1
:
cont
=
False
elif
len
(
t
.
shape
)
==
3
and
t
.
shape
[
0
]
!=
r
.
shape
[
0
]
and
t
.
shape
[
0
]
==
1
:
cont
=
False
if
cont
:
tt
=
t
.
unsqueeze
(
-
2
)
rr
=
r
.
transpose
(
-
2
,
-
1
)
cc
=
tt
@
rr
cc
=
cc
.
squeeze
(
-
2
)
return
cc
x
=
t
[...,
0
]
y
=
t
[...,
1
]
z
=
t
[...,
2
]
return
torch
.
stack
(
[
r
[...,
0
,
0
]
*
x
+
r
[...,
0
,
1
]
*
y
+
r
[...,
0
,
2
]
*
z
,
r
[...,
1
,
0
]
*
x
+
r
[...,
1
,
1
]
*
y
+
r
[...,
1
,
2
]
*
z
,
r
[...,
2
,
0
]
*
x
+
r
[...,
2
,
1
]
*
y
+
r
[...,
2
,
2
]
*
z
,
],
dim
=-
1
,
)
def
identity_rot_mats
(
batch_dims
:
Tuple
[
int
],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
True
,
)
->
torch
.
Tensor
:
rots
=
torch
.
eye
(
3
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
rots
=
rots
.
view
(
*
((
1
,)
*
len
(
batch_dims
)),
3
,
3
)
rots
=
rots
.
expand
(
*
batch_dims
,
-
1
,
-
1
)
return
rots
def
identity_trans
(
batch_dims
:
Tuple
[
int
],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
True
,
)
->
torch
.
Tensor
:
trans
=
torch
.
zeros
(
(
*
batch_dims
,
3
),
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
return
trans
def
identity_quats
(
batch_dims
:
Tuple
[
int
],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
True
,
)
->
torch
.
Tensor
:
quat
=
torch
.
zeros
(
(
*
batch_dims
,
4
),
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
with
torch
.
no_grad
():
quat
[...,
0
]
=
1
return
quat
_quat_elements
=
[
"a"
,
"b"
,
"c"
,
"d"
]
_qtr_keys
=
[
l1
+
l2
for
l1
in
_quat_elements
for
l2
in
_quat_elements
]
_qtr_ind_dict
=
{
key
:
ind
for
ind
,
key
in
enumerate
(
_qtr_keys
)}
def
_to_mat
(
pairs
):
mat
=
np
.
zeros
((
4
,
4
))
for
pair
in
pairs
:
key
,
value
=
pair
ind
=
_qtr_ind_dict
[
key
]
mat
[
ind
//
4
][
ind
%
4
]
=
value
return
mat
_QTR_MAT
=
np
.
zeros
((
4
,
4
,
3
,
3
))
_QTR_MAT
[...,
0
,
0
]
=
_to_mat
([(
"aa"
,
1
),
(
"bb"
,
1
),
(
"cc"
,
-
1
),
(
"dd"
,
-
1
)])
_QTR_MAT
[...,
0
,
1
]
=
_to_mat
([(
"bc"
,
2
),
(
"ad"
,
-
2
)])
_QTR_MAT
[...,
0
,
2
]
=
_to_mat
([(
"bd"
,
2
),
(
"ac"
,
2
)])
_QTR_MAT
[...,
1
,
0
]
=
_to_mat
([(
"bc"
,
2
),
(
"ad"
,
2
)])
_QTR_MAT
[...,
1
,
1
]
=
_to_mat
([(
"aa"
,
1
),
(
"bb"
,
-
1
),
(
"cc"
,
1
),
(
"dd"
,
-
1
)])
_QTR_MAT
[...,
1
,
2
]
=
_to_mat
([(
"cd"
,
2
),
(
"ab"
,
-
2
)])
_QTR_MAT
[...,
2
,
0
]
=
_to_mat
([(
"bd"
,
2
),
(
"ac"
,
-
2
)])
_QTR_MAT
[...,
2
,
1
]
=
_to_mat
([(
"cd"
,
2
),
(
"ab"
,
2
)])
_QTR_MAT
[...,
2
,
2
]
=
_to_mat
([(
"aa"
,
1
),
(
"bb"
,
-
1
),
(
"cc"
,
-
1
),
(
"dd"
,
1
)])
def
quat_to_rot
(
quat
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Converts a quaternion to a rotation matrix.
Args:
quat: [*, 4] quaternions
Returns:
[*, 3, 3] rotation matrices
"""
# [*, 4, 4]
quat
=
quat
[...,
None
]
*
quat
[...,
None
,
:]
# [4, 4, 3, 3]
mat
=
quat
.
new_tensor
(
_QTR_MAT
,
requires_grad
=
False
)
# [*, 4, 4, 3, 3]
shaped_qtr_mat
=
mat
.
view
((
1
,)
*
len
(
quat
.
shape
[:
-
2
])
+
mat
.
shape
)
quat
=
quat
[...,
None
,
None
]
*
shaped_qtr_mat
# [*, 3, 3]
return
torch
.
sum
(
quat
,
dim
=
(
-
3
,
-
4
))
def
rot_to_quat
(
rot
:
torch
.
Tensor
,
):
if
(
rot
.
shape
[
-
2
:]
!=
(
3
,
3
)):
raise
ValueError
(
"Input rotation is incorrectly shaped"
)
rot
=
[[
rot
[...,
i
,
j
]
for
j
in
range
(
3
)]
for
i
in
range
(
3
)]
[[
xx
,
xy
,
xz
],
[
yx
,
yy
,
yz
],
[
zx
,
zy
,
zz
]]
=
rot
k
=
[
[
xx
+
yy
+
zz
,
zy
-
yz
,
xz
-
zx
,
yx
-
xy
,],
[
zy
-
yz
,
xx
-
yy
-
zz
,
xy
+
yx
,
xz
+
zx
,],
[
xz
-
zx
,
xy
+
yx
,
yy
-
xx
-
zz
,
yz
+
zy
,],
[
yx
-
xy
,
xz
+
zx
,
yz
+
zy
,
zz
-
xx
-
yy
,]
]
k
=
(
1.
/
3.
)
*
torch
.
stack
([
torch
.
stack
(
t
,
dim
=-
1
)
for
t
in
k
],
dim
=-
2
)
_
,
vectors
=
torch
.
linalg
.
eigh
(
k
)
return
vectors
[...,
-
1
]
_QUAT_MULTIPLY
=
np
.
zeros
((
4
,
4
,
4
))
_QUAT_MULTIPLY
[:,
:,
0
]
=
[[
1
,
0
,
0
,
0
],
[
0
,
-
1
,
0
,
0
],
[
0
,
0
,
-
1
,
0
],
[
0
,
0
,
0
,
-
1
]]
_QUAT_MULTIPLY
[:,
:,
1
]
=
[[
0
,
1
,
0
,
0
],
[
1
,
0
,
0
,
0
],
[
0
,
0
,
0
,
1
],
[
0
,
0
,
-
1
,
0
]]
_QUAT_MULTIPLY
[:,
:,
2
]
=
[[
0
,
0
,
1
,
0
],
[
0
,
0
,
0
,
-
1
],
[
1
,
0
,
0
,
0
],
[
0
,
1
,
0
,
0
]]
_QUAT_MULTIPLY
[:,
:,
3
]
=
[[
0
,
0
,
0
,
1
],
[
0
,
0
,
1
,
0
],
[
0
,
-
1
,
0
,
0
],
[
1
,
0
,
0
,
0
]]
_QUAT_MULTIPLY_BY_VEC
=
_QUAT_MULTIPLY
[:,
1
:,
:]
def
quat_multiply
(
quat1
,
quat2
):
"""Multiply a quaternion by another quaternion."""
mat
=
quat1
.
new_tensor
(
_QUAT_MULTIPLY
)
reshaped_mat
=
mat
.
view
((
1
,)
*
len
(
quat1
.
shape
[:
-
1
])
+
mat
.
shape
)
return
torch
.
sum
(
reshaped_mat
*
quat1
[...,
:,
None
,
None
]
*
quat2
[...,
None
,
:,
None
],
dim
=
(
-
3
,
-
2
)
)
def
quat_multiply_by_vec
(
quat
,
vec
):
"""Multiply a quaternion by a pure-vector quaternion."""
mat
=
quat
.
new_tensor
(
_QUAT_MULTIPLY_BY_VEC
)
reshaped_mat
=
mat
.
view
((
1
,)
*
len
(
quat
.
shape
[:
-
1
])
+
mat
.
shape
)
return
torch
.
sum
(
reshaped_mat
*
quat
[...,
:,
None
,
None
]
*
vec
[...,
None
,
:,
None
],
dim
=
(
-
3
,
-
2
)
)
def
invert_rot_mat
(
rot_mat
:
torch
.
Tensor
):
return
rot_mat
.
transpose
(
-
1
,
-
2
)
def
invert_quat
(
quat
:
torch
.
Tensor
):
quat_prime
=
quat
.
clone
()
quat_prime
[...,
1
:]
*=
-
1
inv
=
quat_prime
/
torch
.
sum
(
quat
**
2
,
dim
=-
1
,
keepdim
=
True
)
return
inv
class
Rotation
:
"""
A 3D rotation. Depending on how the object is initialized, the
rotation is represented by either a rotation matrix or a
quaternion, though both formats are made available by helper functions.
To simplify gradient computation, the underlying format of the
rotation cannot be changed in-place. Like Rigid, the class is designed
to mimic the behavior of a torch Tensor, almost as if each Rotation
object were a tensor of rotations, in one format or another.
"""
def
__init__
(
self
,
rot_mats
:
Optional
[
torch
.
Tensor
]
=
None
,
quats
:
Optional
[
torch
.
Tensor
]
=
None
,
normalize_quats
:
bool
=
True
,
):
"""
Args:
rot_mats:
A [*, 3, 3] rotation matrix tensor. Mutually exclusive with
quats
quats:
A [*, 4] quaternion. Mutually exclusive with rot_mats. If
normalize_quats is not True, must be a unit quaternion
normalize_quats:
If quats is specified, whether to normalize quats
"""
if
((
rot_mats
is
None
and
quats
is
None
)
or
(
rot_mats
is
not
None
and
quats
is
not
None
)):
raise
ValueError
(
"Exactly one input argument must be specified"
)
if
((
rot_mats
is
not
None
and
rot_mats
.
shape
[
-
2
:]
!=
(
3
,
3
))
or
(
quats
is
not
None
and
quats
.
shape
[
-
1
]
!=
4
)):
raise
ValueError
(
"Incorrectly shaped rotation matrix or quaternion"
)
# Force full-precision
if
(
quats
is
not
None
):
quats
=
quats
.
to
(
dtype
=
torch
.
float32
)
if
(
rot_mats
is
not
None
):
rot_mats
=
rot_mats
.
to
(
dtype
=
torch
.
float32
)
if
(
quats
is
not
None
and
normalize_quats
):
quats
=
quats
/
torch
.
linalg
.
norm
(
quats
,
dim
=-
1
,
keepdim
=
True
)
self
.
_rot_mats
=
rot_mats
self
.
_quats
=
quats
@
staticmethod
def
identity
(
shape
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
True
,
fmt
:
str
=
"quat"
,
)
->
Rotation
:
"""
Returns an identity Rotation.
Args:
shape:
The "shape" of the resulting Rotation object. See documentation
for the shape property
dtype:
The torch dtype for the rotation
device:
The torch device for the new rotation
requires_grad:
Whether the underlying tensors in the new rotation object
should require gradient computation
fmt:
One of "quat" or "rot_mat". Determines the underlying format
of the new object's rotation
Returns:
A new identity rotation
"""
if
(
fmt
==
"rot_mat"
):
rot_mats
=
identity_rot_mats
(
shape
,
dtype
,
device
,
requires_grad
,
)
return
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
elif
(
fmt
==
"quat"
):
quats
=
identity_quats
(
shape
,
dtype
,
device
,
requires_grad
)
return
Rotation
(
rot_mats
=
None
,
quats
=
quats
,
normalize_quats
=
False
)
else
:
raise
ValueError
(
f
"Invalid format: f
{
fmt
}
"
)
# Magic methods
def
__getitem__
(
self
,
index
:
Any
)
->
Rotation
:
"""
Allows torch-style indexing over the virtual shape of the rotation
object. See documentation for the shape property.
Args:
index:
A torch index. E.g. (1, 3, 2), or (slice(None,))
Returns:
The indexed rotation
"""
if
type
(
index
)
!=
tuple
:
index
=
(
index
,)
if
(
self
.
_rot_mats
is
not
None
):
rot_mats
=
self
.
_rot_mats
[
index
+
(
slice
(
None
),
slice
(
None
))]
return
Rotation
(
rot_mats
=
rot_mats
)
elif
(
self
.
_quats
is
not
None
):
quats
=
self
.
_quats
[
index
+
(
slice
(
None
),)]
return
Rotation
(
quats
=
quats
,
normalize_quats
=
False
)
else
:
raise
ValueError
(
"Both rotations are None"
)
def
__mul__
(
self
,
right
:
torch
.
Tensor
,
)
->
Rotation
:
"""
Pointwise left multiplication of the rotation with a tensor. Can be
used to e.g. mask the Rotation.
Args:
right:
The tensor multiplicand
Returns:
The product
"""
if
not
(
isinstance
(
right
,
torch
.
Tensor
)):
raise
TypeError
(
"The other multiplicand must be a Tensor"
)
if
(
self
.
_rot_mats
is
not
None
):
rot_mats
=
self
.
_rot_mats
*
right
[...,
None
,
None
]
return
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
elif
(
self
.
_quats
is
not
None
):
quats
=
self
.
_quats
*
right
[...,
None
]
return
Rotation
(
rot_mats
=
None
,
quats
=
quats
,
normalize_quats
=
False
)
else
:
raise
ValueError
(
"Both rotations are None"
)
def
__rmul__
(
self
,
left
:
torch
.
Tensor
,
)
->
Rotation
:
"""
Reverse pointwise multiplication of the rotation with a tensor.
Args:
left:
The left multiplicand
Returns:
The product
"""
return
self
.
__mul__
(
left
)
# Properties
@
property
def
shape
(
self
)
->
torch
.
Size
:
"""
Returns the virtual shape of the rotation object. This shape is
defined as the batch dimensions of the underlying rotation matrix
or quaternion. If the Rotation was initialized with a [10, 3, 3]
rotation matrix tensor, for example, the resulting shape would be
[10].
Returns:
The virtual shape of the rotation object
"""
s
=
None
if
(
self
.
_quats
is
not
None
):
s
=
self
.
_quats
.
shape
[:
-
1
]
else
:
s
=
self
.
_rot_mats
.
shape
[:
-
2
]
return
s
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
"""
Returns the dtype of the underlying rotation.
Returns:
The dtype of the underlying rotation
"""
if
(
self
.
_rot_mats
is
not
None
):
return
self
.
_rot_mats
.
dtype
elif
(
self
.
_quats
is
not
None
):
return
self
.
_quats
.
dtype
else
:
raise
ValueError
(
"Both rotations are None"
)
@
property
def
device
(
self
)
->
torch
.
device
:
"""
The device of the underlying rotation
Returns:
The device of the underlying rotation
"""
if
(
self
.
_rot_mats
is
not
None
):
return
self
.
_rot_mats
.
device
elif
(
self
.
_quats
is
not
None
):
return
self
.
_quats
.
device
else
:
raise
ValueError
(
"Both rotations are None"
)
@
property
def
requires_grad
(
self
)
->
bool
:
"""
Returns the requires_grad property of the underlying rotation
Returns:
The requires_grad property of the underlying tensor
"""
if
(
self
.
_rot_mats
is
not
None
):
return
self
.
_rot_mats
.
requires_grad
elif
(
self
.
_quats
is
not
None
):
return
self
.
_quats
.
requires_grad
else
:
raise
ValueError
(
"Both rotations are None"
)
def
get_rot_mats
(
self
)
->
torch
.
Tensor
:
"""
Returns the underlying rotation as a rotation matrix tensor.
Returns:
The rotation as a rotation matrix tensor
"""
rot_mats
=
self
.
_rot_mats
if
(
rot_mats
is
None
):
if
(
self
.
_quats
is
None
):
raise
ValueError
(
"Both rotations are None"
)
else
:
rot_mats
=
quat_to_rot
(
self
.
_quats
)
return
rot_mats
def
get_quats
(
self
)
->
torch
.
Tensor
:
"""
Returns the underlying rotation as a quaternion tensor.
Depending on whether the Rotation was initialized with a
quaternion, this function may call torch.linalg.eigh.
Returns:
The rotation as a quaternion tensor.
"""
quats
=
self
.
_quats
if
(
quats
is
None
):
if
(
self
.
_rot_mats
is
None
):
raise
ValueError
(
"Both rotations are None"
)
else
:
quats
=
rot_to_quat
(
self
.
_rot_mats
)
return
quats
def
get_cur_rot
(
self
)
->
torch
.
Tensor
:
"""
Return the underlying rotation in its current form
Returns:
The stored rotation
"""
if
(
self
.
_rot_mats
is
not
None
):
return
self
.
_rot_mats
elif
(
self
.
_quats
is
not
None
):
return
self
.
_quats
else
:
raise
ValueError
(
"Both rotations are None"
)
# Rotation functions
def
compose_q_update_vec
(
self
,
q_update_vec
:
torch
.
Tensor
,
normalize_quats
:
bool
=
True
)
->
Rotation
:
"""
Returns a new quaternion Rotation after updating the current
object's underlying rotation with a quaternion update, formatted
as a [*, 3] tensor whose final three columns represent x, y, z such
that (1, x, y, z) is the desired (not necessarily unit) quaternion
update.
Args:
q_update_vec:
A [*, 3] quaternion update tensor
normalize_quats:
Whether to normalize the output quaternion
Returns:
An updated Rotation
"""
quats
=
self
.
get_quats
()
new_quats
=
quats
+
quat_multiply_by_vec
(
quats
,
q_update_vec
)
return
Rotation
(
rot_mats
=
None
,
quats
=
new_quats
,
normalize_quats
=
normalize_quats
,
)
def
compose_r
(
self
,
r
:
Rotation
)
->
Rotation
:
"""
Compose the rotation matrices of the current Rotation object with
those of another.
Args:
r:
An update rotation object
Returns:
An updated rotation object
"""
r1
=
self
.
get_rot_mats
()
r2
=
r
.
get_rot_mats
()
new_rot_mats
=
rot_matmul
(
r1
,
r2
)
return
Rotation
(
rot_mats
=
new_rot_mats
,
quats
=
None
)
def
compose_q
(
self
,
r
:
Rotation
,
normalize_quats
:
bool
=
True
)
->
Rotation
:
"""
Compose the quaternions of the current Rotation object with those
of another.
Depending on whether either Rotation was initialized with
quaternions, this function may call torch.linalg.eigh.
Args:
r:
An update rotation object
Returns:
An updated rotation object
"""
q1
=
self
.
get_quats
()
q2
=
r
.
get_quats
()
new_quats
=
quat_multiply
(
q1
,
q2
)
return
Rotation
(
rot_mats
=
None
,
quats
=
new_quats
,
normalize_quats
=
normalize_quats
)
def
apply
(
self
,
pts
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Apply the current Rotation as a rotation matrix to a set of 3D
coordinates.
Args:
pts:
A [*, 3] set of points
Returns:
[*, 3] rotated points
"""
rot_mats
=
self
.
get_rot_mats
()
return
rot_vec_mul
(
rot_mats
,
pts
)
def
invert_apply
(
self
,
pts
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
The inverse of the apply() method.
Args:
pts:
A [*, 3] set of points
Returns:
[*, 3] inverse-rotated points
"""
rot_mats
=
self
.
get_rot_mats
()
inv_rot_mats
=
invert_rot_mat
(
rot_mats
)
return
rot_vec_mul
(
inv_rot_mats
,
pts
)
def
invert
(
self
)
->
Rotation
:
"""
Returns the inverse of the current Rotation.
Returns:
The inverse of the current Rotation
"""
if
(
self
.
_rot_mats
is
not
None
):
return
Rotation
(
rot_mats
=
invert_rot_mat
(
self
.
_rot_mats
),
quats
=
None
)
elif
(
self
.
_quats
is
not
None
):
return
Rotation
(
rot_mats
=
None
,
quats
=
invert_quat
(
self
.
_quats
),
normalize_quats
=
False
,
)
else
:
raise
ValueError
(
"Both rotations are None"
)
# "Tensor" stuff
def
unsqueeze
(
self
,
dim
:
int
,
)
->
Rigid
:
"""
Analogous to torch.unsqueeze. The dimension is relative to the
shape of the Rotation object.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed Rotation.
"""
if
dim
>=
len
(
self
.
shape
):
raise
ValueError
(
"Invalid dimension"
)
if
(
self
.
_rot_mats
is
not
None
):
rot_mats
=
self
.
_rot_mats
.
unsqueeze
(
dim
if
dim
>=
0
else
dim
-
2
)
return
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
elif
(
self
.
_quats
is
not
None
):
quats
=
self
.
_quats
.
unsqueeze
(
dim
if
dim
>=
0
else
dim
-
1
)
return
Rotation
(
rot_mats
=
None
,
quats
=
quats
,
normalize_quats
=
False
)
else
:
raise
ValueError
(
"Both rotations are None"
)
@
staticmethod
def
cat
(
rs
:
Sequence
[
Rotation
],
dim
:
int
,
)
->
Rigid
:
"""
Concatenates rotations along one of the batch dimensions. Analogous
to torch.cat().
Note that the output of this operation is always a rotation matrix,
regardless of the format of input rotations.
Args:
rs:
A list of rotation objects
dim:
The dimension along which the rotations should be
concatenated
Returns:
A concatenated Rotation object in rotation matrix format
"""
rot_mats
=
[
r
.
get_rot_mats
()
for
r
in
rs
]
rot_mats
=
torch
.
cat
(
rot_mats
,
dim
=
dim
if
dim
>=
0
else
dim
-
2
)
return
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
def
map_tensor_fn
(
self
,
fn
:
Callable
[
torch
.
Tensor
,
torch
.
Tensor
]
)
->
Rotation
:
"""
Apply a Tensor -> Tensor function to underlying rotation tensors,
mapping over the rotation dimension(s). Can be used e.g. to sum out
a one-hot batch dimension.
Args:
fn:
A Tensor -> Tensor function to be mapped over the Rotation
Returns:
The transformed Rotation object
"""
if
(
self
.
_rot_mats
is
not
None
):
rot_mats
=
self
.
_rot_mats
.
view
(
self
.
_rot_mats
.
shape
[:
-
2
]
+
(
9
,))
rot_mats
=
torch
.
stack
(
list
(
map
(
fn
,
torch
.
unbind
(
rot_mats
,
dim
=-
1
))),
dim
=-
1
)
rot_mats
=
rot_mats
.
view
(
rot_mats
.
shape
[:
-
1
]
+
(
3
,
3
))
return
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
elif
(
self
.
_quats
is
not
None
):
quats
=
torch
.
stack
(
list
(
map
(
fn
,
torch
.
unbind
(
self
.
_quats
,
dim
=-
1
))),
dim
=-
1
)
return
Rotation
(
rot_mats
=
None
,
quats
=
quats
,
normalize_quats
=
False
)
else
:
raise
ValueError
(
"Both rotations are None"
)
def
cuda
(
self
)
->
Rotation
:
"""
Analogous to the cuda() method of torch Tensors
Returns:
A copy of the Rotation in CUDA memory
"""
if
(
self
.
_rot_mats
is
not
None
):
return
Rotation
(
rot_mats
=
self
.
_rot_mats
.
cuda
(),
quats
=
None
)
elif
(
self
.
_quats
is
not
None
):
return
Rotation
(
rot_mats
=
None
,
quats
=
self
.
_quats
.
cuda
(),
normalize_quats
=
False
)
else
:
raise
ValueError
(
"Both rotations are None"
)
def
to
(
self
,
device
:
Optional
[
torch
.
device
],
dtype
:
Optional
[
torch
.
dtype
]
)
->
Rotation
:
"""
Analogous to the to() method of torch Tensors
Args:
device:
A torch device
dtype:
A torch dtype
Returns:
A copy of the Rotation using the new device and dtype
"""
if
(
self
.
_rot_mats
is
not
None
):
return
Rotation
(
rot_mats
=
self
.
_rot_mats
.
to
(
device
=
device
,
dtype
=
dtype
),
quats
=
None
,
)
elif
(
self
.
_quats
is
not
None
):
return
Rotation
(
rot_mats
=
None
,
quats
=
self
.
_quats
.
to
(
device
=
device
,
dtype
=
dtype
),
normalize_quats
=
False
,
)
else
:
raise
ValueError
(
"Both rotations are None"
)
def
detach
(
self
)
->
Rotation
:
"""
Returns a copy of the Rotation whose underlying Tensor has been
detached from its torch graph.
Returns:
A copy of the Rotation whose underlying Tensor has been detached
from its torch graph
"""
if
(
self
.
_rot_mats
is
not
None
):
return
Rotation
(
rot_mats
=
self
.
_rot_mats
.
detach
(),
quats
=
None
)
elif
(
self
.
_quats
is
not
None
):
return
Rotation
(
rot_mats
=
None
,
quats
=
self
.
_quats
.
detach
(),
normalize_quats
=
False
,
)
else
:
raise
ValueError
(
"Both rotations are None"
)
class
Rigid
:
"""
A class representing a rigid transformation. Little more than a wrapper
around two objects: a Rotation object and a [*, 3] translation
Designed to behave approximately like a single torch tensor with the
shape of the shared batch dimensions of its component parts.
"""
def
__init__
(
self
,
rots
:
Optional
[
Rotation
],
trans
:
Optional
[
torch
.
Tensor
],
):
"""
Args:
rots: A [*, 3, 3] rotation tensor
trans: A corresponding [*, 3] translation tensor
"""
# (we need device, dtype, etc. from at least one input)
batch_dims
,
dtype
,
device
,
requires_grad
=
None
,
None
,
None
,
None
if
(
trans
is
not
None
):
batch_dims
=
trans
.
shape
[:
-
1
]
dtype
=
trans
.
dtype
device
=
trans
.
device
requires_grad
=
trans
.
requires_grad
elif
(
rots
is
not
None
):
batch_dims
=
rots
.
shape
dtype
=
rots
.
dtype
device
=
rots
.
device
requires_grad
=
rots
.
requires_grad
else
:
raise
ValueError
(
"At least one input argument must be specified"
)
if
(
rots
is
None
):
rots
=
Rotation
.
identity
(
batch_dims
,
dtype
,
device
,
requires_grad
,
)
elif
(
trans
is
None
):
trans
=
identity_trans
(
batch_dims
,
dtype
,
device
,
requires_grad
,
)
if
((
rots
.
shape
!=
trans
.
shape
[:
-
1
])
or
(
rots
.
device
!=
trans
.
device
)):
raise
ValueError
(
"Rots and trans incompatible"
)
# Force full precision. Happens to the rotations automatically.
trans
=
trans
.
to
(
dtype
=
torch
.
float32
)
self
.
_rots
=
rots
self
.
_trans
=
trans
@
staticmethod
def
identity
(
shape
:
Tuple
[
int
],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
True
,
fmt
:
str
=
"quat"
,
)
->
Rigid
:
"""
Constructs an identity transformation.
Args:
shape:
The desired shape
dtype:
The dtype of both internal tensors
device:
The device of both internal tensors
requires_grad:
Whether grad should be enabled for the internal tensors
Returns:
The identity transformation
"""
return
Rigid
(
Rotation
.
identity
(
shape
,
dtype
,
device
,
requires_grad
,
fmt
=
fmt
),
identity_trans
(
shape
,
dtype
,
device
,
requires_grad
),
)
def
__getitem__
(
self
,
index
:
Any
,
)
->
Rigid
:
"""
Indexes the affine transformation with PyTorch-style indices.
The index is applied to the shared dimensions of both the rotation
and the translation.
E.g.::
r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None)
t = Rigid(r, torch.rand(10, 10, 3))
indexed = t[3, 4:6]
assert(indexed.shape == (2,))
assert(indexed.get_rots().shape == (2,))
assert(indexed.get_trans().shape == (2, 3))
Args:
index: A standard torch tensor index. E.g. 8, (10, None, 3),
or (3, slice(0, 1, None))
Returns:
The indexed tensor
"""
if
type
(
index
)
!=
tuple
:
index
=
(
index
,)
return
Rigid
(
self
.
_rots
[
index
],
self
.
_trans
[
index
+
(
slice
(
None
),)],
)
def
__mul__
(
self
,
right
:
torch
.
Tensor
,
)
->
Rigid
:
"""
Pointwise left multiplication of the transformation with a tensor.
Can be used to e.g. mask the Rigid.
Args:
right:
The tensor multiplicand
Returns:
The product
"""
if
not
(
isinstance
(
right
,
torch
.
Tensor
)):
raise
TypeError
(
"The other multiplicand must be a Tensor"
)
new_rots
=
self
.
_rots
*
right
new_trans
=
self
.
_trans
*
right
[...,
None
]
return
Rigid
(
new_rots
,
new_trans
)
def
__rmul__
(
self
,
left
:
torch
.
Tensor
,
)
->
Rigid
:
"""
Reverse pointwise multiplication of the transformation with a
tensor.
Args:
left:
The left multiplicand
Returns:
The product
"""
return
self
.
__mul__
(
left
)
@
property
def
shape
(
self
)
->
torch
.
Size
:
"""
Returns the shape of the shared dimensions of the rotation and
the translation.
Returns:
The shape of the transformation
"""
s
=
self
.
_trans
.
shape
[:
-
1
]
return
s
@
property
def
device
(
self
)
->
torch
.
device
:
"""
Returns the device on which the Rigid's tensors are located.
Returns:
The device on which the Rigid's tensors are located
"""
return
self
.
_trans
.
device
def
get_rots
(
self
)
->
Rotation
:
"""
Getter for the rotation.
Returns:
The rotation object
"""
return
self
.
_rots
def
get_trans
(
self
)
->
torch
.
Tensor
:
"""
Getter for the translation.
Returns:
The stored translation
"""
return
self
.
_trans
def
compose_q_update_vec
(
self
,
q_update_vec
:
torch
.
Tensor
,
)
->
Rigid
:
"""
Composes the transformation with a quaternion update vector of
shape [*, 6], where the final 6 columns represent the x, y, and
z values of a quaternion of form (1, x, y, z) followed by a 3D
translation.
Args:
q_vec: The quaternion update vector.
Returns:
The composed transformation.
"""
q_vec
,
t_vec
=
q_update_vec
[...,
:
3
],
q_update_vec
[...,
3
:]
new_rots
=
self
.
_rots
.
compose_q_update_vec
(
q_vec
)
trans_update
=
self
.
_rots
.
apply
(
t_vec
)
new_translation
=
self
.
_trans
+
trans_update
return
Rigid
(
new_rots
,
new_translation
)
def
compose
(
self
,
r
:
Rigid
,
)
->
Rigid
:
"""
Composes the current rigid object with another.
Args:
r:
Another Rigid object
Returns:
The composition of the two transformations
"""
new_rot
=
self
.
_rots
.
compose_r
(
r
.
_rots
)
new_trans
=
self
.
_rots
.
apply
(
r
.
_trans
)
+
self
.
_trans
return
Rigid
(
new_rot
,
new_trans
)
def
apply
(
self
,
pts
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Applies the transformation to a coordinate tensor.
Args:
pts: A [*, 3] coordinate tensor.
Returns:
The transformed points.
"""
rotated
=
self
.
_rots
.
apply
(
pts
)
return
rotated
+
self
.
_trans
def
invert_apply
(
self
,
pts
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Applies the inverse of the transformation to a coordinate tensor.
Args:
pts: A [*, 3] coordinate tensor
Returns:
The transformed points.
"""
pts
=
pts
-
self
.
_trans
return
self
.
_rots
.
invert_apply
(
pts
)
def
invert
(
self
)
->
Rigid
:
"""
Inverts the transformation.
Returns:
The inverse transformation.
"""
rot_inv
=
self
.
_rots
.
invert
()
trn_inv
=
rot_inv
.
apply
(
self
.
_trans
)
return
Rigid
(
rot_inv
,
-
1
*
trn_inv
)
def
map_tensor_fn
(
self
,
fn
:
Callable
[
torch
.
Tensor
,
torch
.
Tensor
]
)
->
Rigid
:
"""
Apply a Tensor -> Tensor function to underlying translation and
rotation tensors, mapping over the translation/rotation dimensions
respectively.
Args:
fn:
A Tensor -> Tensor function to be mapped over the Rigid
Returns:
The transformed Rigid object
"""
new_rots
=
self
.
_rots
.
map_tensor_fn
(
fn
)
new_trans
=
torch
.
stack
(
list
(
map
(
fn
,
torch
.
unbind
(
self
.
_trans
,
dim
=-
1
))),
dim
=-
1
)
return
Rigid
(
new_rots
,
new_trans
)
def
to_tensor_4x4
(
self
)
->
torch
.
Tensor
:
"""
Converts a transformation to a homogenous transformation tensor.
Returns:
A [*, 4, 4] homogenous transformation tensor
"""
tensor
=
self
.
_trans
.
new_zeros
((
*
self
.
shape
,
4
,
4
))
tensor
[...,
:
3
,
:
3
]
=
self
.
_rots
.
get_rot_mats
()
tensor
[...,
:
3
,
3
]
=
self
.
_trans
tensor
[...,
3
,
3
]
=
1
return
tensor
@
staticmethod
def
from_tensor_4x4
(
t
:
torch
.
Tensor
)
->
Rigid
:
"""
Constructs a transformation from a homogenous transformation
tensor.
Args:
t: [*, 4, 4] homogenous transformation tensor
Returns:
T object with shape [*]
"""
if
(
t
.
shape
[
-
2
:]
!=
(
4
,
4
)):
raise
ValueError
(
"Incorrectly shaped input tensor"
)
rots
=
Rotation
(
rot_mats
=
t
[...,
:
3
,
:
3
],
quats
=
None
)
trans
=
t
[...,
:
3
,
3
]
return
Rigid
(
rots
,
trans
)
def
to_tensor_7
(
self
)
->
torch
.
Tensor
:
"""
Converts a transformation to a tensor with 7 final columns, four
for the quaternion followed by three for the translation.
Returns:
A [*, 7] tensor representation of the transformation
"""
tensor
=
self
.
_trans
.
new_zeros
((
*
self
.
shape
,
7
))
tensor
[...,
:
4
]
=
self
.
_rots
.
get_quats
()
tensor
[...,
4
:]
=
self
.
_trans
return
tensor
@
staticmethod
def
from_tensor_7
(
t
:
torch
.
Tensor
,
normalize_quats
:
bool
=
False
,
)
->
Rigid
:
if
(
t
.
shape
[
-
1
]
!=
7
):
raise
ValueError
(
"Incorrectly shaped input tensor"
)
quats
,
trans
=
t
[...,
:
4
],
t
[...,
4
:]
rots
=
Rotation
(
rot_mats
=
None
,
quats
=
quats
,
normalize_quats
=
normalize_quats
)
return
Rigid
(
rots
,
trans
)
@
staticmethod
def
from_3_points
(
p_neg_x_axis
:
torch
.
Tensor
,
origin
:
torch
.
Tensor
,
p_xy_plane
:
torch
.
Tensor
,
eps
:
float
=
1e-8
)
->
Rigid
:
"""
Implements algorithm 21. Constructs transformations from sets of 3
points using the Gram-Schmidt algorithm.
Args:
p_neg_x_axis: [*, 3] coordinates
origin: [*, 3] coordinates used as frame origins
p_xy_plane: [*, 3] coordinates
eps: Small epsilon value
Returns:
A transformation object of shape [*]
"""
p_neg_x_axis
=
torch
.
unbind
(
p_neg_x_axis
,
dim
=-
1
)
origin
=
torch
.
unbind
(
origin
,
dim
=-
1
)
p_xy_plane
=
torch
.
unbind
(
p_xy_plane
,
dim
=-
1
)
e0
=
[
c1
-
c2
for
c1
,
c2
in
zip
(
origin
,
p_neg_x_axis
)]
e1
=
[
c1
-
c2
for
c1
,
c2
in
zip
(
p_xy_plane
,
origin
)]
denom
=
torch
.
sqrt
(
sum
((
c
*
c
for
c
in
e0
))
+
eps
)
e0
=
[
c
/
denom
for
c
in
e0
]
dot
=
sum
((
c1
*
c2
for
c1
,
c2
in
zip
(
e0
,
e1
)))
e1
=
[
c2
-
c1
*
dot
for
c1
,
c2
in
zip
(
e0
,
e1
)]
denom
=
torch
.
sqrt
(
sum
((
c
*
c
for
c
in
e1
))
+
eps
)
e1
=
[
c
/
denom
for
c
in
e1
]
e2
=
[
e0
[
1
]
*
e1
[
2
]
-
e0
[
2
]
*
e1
[
1
],
e0
[
2
]
*
e1
[
0
]
-
e0
[
0
]
*
e1
[
2
],
e0
[
0
]
*
e1
[
1
]
-
e0
[
1
]
*
e1
[
0
],
]
rots
=
torch
.
stack
([
c
for
tup
in
zip
(
e0
,
e1
,
e2
)
for
c
in
tup
],
dim
=-
1
)
rots
=
rots
.
reshape
(
rots
.
shape
[:
-
1
]
+
(
3
,
3
))
rot_obj
=
Rotation
(
rot_mats
=
rots
,
quats
=
None
)
return
Rigid
(
rot_obj
,
torch
.
stack
(
origin
,
dim
=-
1
))
def
unsqueeze
(
self
,
dim
:
int
,
)
->
Rigid
:
"""
Analogous to torch.unsqueeze. The dimension is relative to the
shared dimensions of the rotation/translation.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed transformation.
"""
if
dim
>=
len
(
self
.
shape
):
raise
ValueError
(
"Invalid dimension"
)
rots
=
self
.
_rots
.
unsqueeze
(
dim
)
trans
=
self
.
_trans
.
unsqueeze
(
dim
if
dim
>=
0
else
dim
-
1
)
return
Rigid
(
rots
,
trans
)
@
staticmethod
def
cat
(
ts
:
Sequence
[
Rigid
],
dim
:
int
,
)
->
Rigid
:
"""
Concatenates transformations along a new dimension.
Args:
ts:
A list of T objects
dim:
The dimension along which the transformations should be
concatenated
Returns:
A concatenated transformation object
"""
rots
=
Rotation
.
cat
([
t
.
_rots
for
t
in
ts
],
dim
)
trans
=
torch
.
cat
(
[
t
.
_trans
for
t
in
ts
],
dim
=
dim
if
dim
>=
0
else
dim
-
1
)
return
Rigid
(
rots
,
trans
)
def
apply_rot_fn
(
self
,
fn
:
Callable
[
Rotation
,
Rotation
])
->
Rigid
:
"""
Applies a Rotation -> Rotation function to the stored rotation
object.
Args:
fn: A function of type Rotation -> Rotation
Returns:
A transformation object with a transformed rotation.
"""
return
Rigid
(
fn
(
self
.
_rots
),
self
.
_trans
)
def
apply_trans_fn
(
self
,
fn
:
Callable
[
torch
.
Tensor
,
torch
.
Tensor
])
->
Rigid
:
"""
Applies a Tensor -> Tensor function to the stored translation.
Args:
fn:
A function of type Tensor -> Tensor to be applied to the
translation
Returns:
A transformation object with a transformed translation.
"""
return
Rigid
(
self
.
_rots
,
fn
(
self
.
_trans
))
def
scale_translation
(
self
,
trans_scale_factor
:
float
)
->
Rigid
:
"""
Scales the translation by a constant factor.
Args:
trans_scale_factor:
The constant factor
Returns:
A transformation object with a scaled translation.
"""
fn
=
lambda
t
:
t
*
trans_scale_factor
return
self
.
apply_trans_fn
(
fn
)
def
stop_rot_gradient
(
self
)
->
Rigid
:
"""
Detaches the underlying rotation object
Returns:
A transformation object with detached rotations
"""
fn
=
lambda
r
:
r
.
detach
()
return
self
.
apply_rot_fn
(
fn
)
@
staticmethod
def
make_transform_from_reference
(
n_xyz
,
ca_xyz
,
c_xyz
,
eps
=
1e-20
):
"""
Returns a transformation object from reference coordinates.
Note that this method does not take care of symmetries. If you
provide the atom positions in the non-standard way, the N atom will
end up not at [-0.527250, 1.359329, 0.0] but instead at
[-0.527250, -1.359329, 0.0]. You need to take care of such cases in
your code.
Args:
n_xyz: A [*, 3] tensor of nitrogen xyz coordinates.
ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates.
c_xyz: A [*, 3] tensor of carbon xyz coordinates.
Returns:
A transformation object. After applying the translation and
rotation to the reference backbone, the coordinates will
approximately equal to the input coordinates.
"""
translation
=
-
1
*
ca_xyz
n_xyz
=
n_xyz
+
translation
c_xyz
=
c_xyz
+
translation
c_x
,
c_y
,
c_z
=
[
c_xyz
[...,
i
]
for
i
in
range
(
3
)]
norm
=
torch
.
sqrt
(
eps
+
c_x
**
2
+
c_y
**
2
)
sin_c1
=
-
c_y
/
norm
cos_c1
=
c_x
/
norm
zeros
=
sin_c1
.
new_zeros
(
sin_c1
.
shape
)
ones
=
sin_c1
.
new_ones
(
sin_c1
.
shape
)
c1_rots
=
sin_c1
.
new_zeros
((
*
sin_c1
.
shape
,
3
,
3
))
c1_rots
[...,
0
,
0
]
=
cos_c1
c1_rots
[...,
0
,
1
]
=
-
1
*
sin_c1
c1_rots
[...,
1
,
0
]
=
sin_c1
c1_rots
[...,
1
,
1
]
=
cos_c1
c1_rots
[...,
2
,
2
]
=
1
norm
=
torch
.
sqrt
(
eps
+
c_x
**
2
+
c_y
**
2
+
c_z
**
2
)
sin_c2
=
c_z
/
norm
cos_c2
=
torch
.
sqrt
(
c_x
**
2
+
c_y
**
2
)
/
norm
c2_rots
=
sin_c2
.
new_zeros
((
*
sin_c2
.
shape
,
3
,
3
))
c2_rots
[...,
0
,
0
]
=
cos_c2
c2_rots
[...,
0
,
2
]
=
sin_c2
c2_rots
[...,
1
,
1
]
=
1
c2_rots
[...,
2
,
0
]
=
-
1
*
sin_c2
c2_rots
[...,
2
,
2
]
=
cos_c2
c_rots
=
rot_matmul
(
c2_rots
,
c1_rots
)
n_xyz
=
rot_vec_mul
(
c_rots
,
n_xyz
)
_
,
n_y
,
n_z
=
[
n_xyz
[...,
i
]
for
i
in
range
(
3
)]
norm
=
torch
.
sqrt
(
eps
+
n_y
**
2
+
n_z
**
2
)
sin_n
=
-
n_z
/
norm
cos_n
=
n_y
/
norm
n_rots
=
sin_c2
.
new_zeros
((
*
sin_c2
.
shape
,
3
,
3
))
n_rots
[...,
0
,
0
]
=
1
n_rots
[...,
1
,
1
]
=
cos_n
n_rots
[...,
1
,
2
]
=
-
1
*
sin_n
n_rots
[...,
2
,
1
]
=
sin_n
n_rots
[...,
2
,
2
]
=
cos_n
rots
=
rot_matmul
(
n_rots
,
c_rots
)
rots
=
rots
.
transpose
(
-
1
,
-
2
)
translation
=
-
1
*
translation
rot_obj
=
Rotation
(
rot_mats
=
rots
,
quats
=
None
)
return
Rigid
(
rot_obj
,
translation
)
def
cuda
(
self
)
->
Rigid
:
"""
Moves the transformation object to GPU memory
Returns:
A version of the transformation on GPU
"""
return
Rigid
(
self
.
_rots
.
cuda
(),
self
.
_trans
.
cuda
())
fastfold/utils/superimposition.py
0 → 100644
View file @
b14e47f4
# Copyright 2023 HPC-AI Tech Inc.
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
Bio.SVDSuperimposer
import
SVDSuperimposer
import
torch
def
_superimpose_np
(
reference
,
coords
):
"""
Superimposes coordinates onto a reference by minimizing RMSD using SVD.
Args:
reference:
[N, 3] reference array
coords:
[N, 3] array
Returns:
A tuple of [N, 3] superimposed coords and the final RMSD.
"""
sup
=
SVDSuperimposer
()
sup
.
set
(
reference
,
coords
)
sup
.
run
()
return
sup
.
get_transformed
(),
sup
.
get_rms
()
def
_superimpose_single
(
reference
,
coords
):
reference_np
=
reference
.
detach
().
cpu
().
numpy
()
coords_np
=
coords
.
detach
().
cpu
().
numpy
()
superimposed
,
rmsd
=
_superimpose_np
(
reference_np
,
coords_np
)
return
coords
.
new_tensor
(
superimposed
),
coords
.
new_tensor
(
rmsd
)
def
superimpose
(
reference
,
coords
,
mask
):
"""
Superimposes coordinates onto a reference by minimizing RMSD using SVD.
Args:
reference:
[*, N, 3] reference tensor
coords:
[*, N, 3] tensor
mask:
[*, N] tensor
Returns:
A tuple of [*, N, 3] superimposed coords and [*] final RMSDs.
"""
def
select_unmasked_coords
(
coords
,
mask
):
return
torch
.
masked_select
(
coords
,
(
mask
>
0.
)[...,
None
],
).
reshape
(
-
1
,
3
)
batch_dims
=
reference
.
shape
[:
-
2
]
flat_reference
=
reference
.
reshape
((
-
1
,)
+
reference
.
shape
[
-
2
:])
flat_coords
=
coords
.
reshape
((
-
1
,)
+
reference
.
shape
[
-
2
:])
flat_mask
=
mask
.
reshape
((
-
1
,)
+
mask
.
shape
[
-
1
:])
superimposed_list
=
[]
rmsds
=
[]
for
r
,
c
,
m
in
zip
(
flat_reference
,
flat_coords
,
flat_mask
):
r_unmasked_coords
=
select_unmasked_coords
(
r
,
m
)
c_unmasked_coords
=
select_unmasked_coords
(
c
,
m
)
superimposed
,
rmsd
=
_superimpose_single
(
r_unmasked_coords
,
c_unmasked_coords
)
# This is very inelegant, but idk how else to invert the masking
# procedure.
count
=
0
superimposed_full_size
=
torch
.
zeros_like
(
r
)
for
i
,
unmasked
in
enumerate
(
m
):
if
(
unmasked
):
superimposed_full_size
[
i
]
=
superimposed
[
count
]
count
+=
1
superimposed_list
.
append
(
superimposed_full_size
)
rmsds
.
append
(
rmsd
)
superimposed_stacked
=
torch
.
stack
(
superimposed_list
,
dim
=
0
)
rmsds_stacked
=
torch
.
stack
(
rmsds
,
dim
=
0
)
superimposed_reshaped
=
superimposed_stacked
.
reshape
(
batch_dims
+
coords
.
shape
[
-
2
:]
)
rmsds_reshaped
=
rmsds_stacked
.
reshape
(
batch_dims
)
return
superimposed_reshaped
,
rmsds_reshaped
\ No newline at end of file
fastfold/utils/tensor_utils.py
0 → 100644
View file @
b14e47f4
# Copyright 2022 HPC-AI Tech Inc
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
torch
import
torch.nn
as
nn
from
typing
import
Tuple
,
List
,
Callable
,
Any
,
Dict
,
Sequence
,
Optional
import
fastfold.habana
as
habana
def
permute_final_dims
(
tensor
:
torch
.
Tensor
,
inds
:
List
[
int
]):
zero_index
=
-
1
*
len
(
inds
)
first_inds
=
list
(
range
(
len
(
tensor
.
shape
[:
zero_index
])))
return
tensor
.
permute
(
first_inds
+
[
zero_index
+
i
for
i
in
inds
])
def
flatten_final_dims
(
t
:
torch
.
Tensor
,
no_dims
:
int
):
return
t
.
reshape
(
t
.
shape
[:
-
no_dims
]
+
(
-
1
,))
def
masked_mean
(
mask
,
value
,
dim
,
eps
=
1e-4
):
mask
=
mask
.
expand
(
*
value
.
shape
)
return
torch
.
sum
(
mask
*
value
,
dim
=
dim
)
/
(
eps
+
torch
.
sum
(
mask
,
dim
=
dim
))
def
pts_to_distogram
(
pts
,
min_bin
=
2.3125
,
max_bin
=
21.6875
,
no_bins
=
64
):
boundaries
=
torch
.
linspace
(
min_bin
,
max_bin
,
no_bins
-
1
,
device
=
pts
.
device
)
dists
=
torch
.
sqrt
(
torch
.
sum
((
pts
.
unsqueeze
(
-
2
)
-
pts
.
unsqueeze
(
-
3
))
**
2
,
dim
=-
1
)
)
return
torch
.
bucketize
(
dists
,
boundaries
)
def
dict_multimap
(
fn
,
dicts
):
first
=
dicts
[
0
]
new_dict
=
{}
for
k
,
v
in
first
.
items
():
all_v
=
[
d
[
k
]
for
d
in
dicts
]
if
type
(
v
)
is
dict
:
new_dict
[
k
]
=
dict_multimap
(
fn
,
all_v
)
else
:
# when bs = 1, returns [...] rather than [1, ...]
new_dict
[
k
]
=
fn
(
all_v
)
if
len
(
all_v
)
>
1
else
all_v
[
0
]
return
new_dict
def
one_hot
(
x
,
v_bins
):
reshaped_bins
=
v_bins
.
view
(((
1
,)
*
len
(
x
.
shape
))
+
(
len
(
v_bins
),))
diffs
=
x
[...,
None
]
-
reshaped_bins
am
=
torch
.
argmin
(
torch
.
abs
(
diffs
),
dim
=-
1
)
return
nn
.
functional
.
one_hot
(
am
,
num_classes
=
len
(
v_bins
)).
float
()
def
batched_gather
(
data
,
inds
,
dim
=
0
,
no_batch_dims
=
0
):
ranges
=
[]
for
i
,
s
in
enumerate
(
data
.
shape
[:
no_batch_dims
]):
r
=
torch
.
arange
(
s
)
r
=
r
.
view
(
*
(
*
((
1
,)
*
i
),
-
1
,
*
((
1
,)
*
(
len
(
inds
.
shape
)
-
i
-
1
))))
ranges
.
append
(
r
)
remaining_dims
=
[
slice
(
None
)
for
_
in
range
(
len
(
data
.
shape
)
-
no_batch_dims
)
]
remaining_dims
[
dim
-
no_batch_dims
if
dim
>=
0
else
dim
]
=
inds
ranges
.
extend
(
remaining_dims
)
return
data
[
ranges
]
# With tree_map, a poor man's JAX tree_map
def
dict_map
(
fn
,
dic
,
leaf_type
):
new_dict
=
{}
for
k
,
v
in
dic
.
items
():
if
type
(
v
)
is
dict
:
new_dict
[
k
]
=
dict_map
(
fn
,
v
,
leaf_type
)
else
:
new_dict
[
k
]
=
tree_map
(
fn
,
v
,
leaf_type
)
return
new_dict
def
tree_map
(
fn
,
tree
,
leaf_type
):
if
isinstance
(
tree
,
dict
):
return
dict_map
(
fn
,
tree
,
leaf_type
)
elif
isinstance
(
tree
,
list
):
return
[
tree_map
(
fn
,
x
,
leaf_type
)
for
x
in
tree
]
elif
isinstance
(
tree
,
tuple
):
return
tuple
([
tree_map
(
fn
,
x
,
leaf_type
)
for
x
in
tree
])
elif
isinstance
(
tree
,
leaf_type
):
return
fn
(
tree
)
else
:
print
(
type
(
tree
))
raise
ValueError
(
"Not supported"
)
tensor_tree_map
=
partial
(
tree_map
,
leaf_type
=
torch
.
Tensor
)
def
_fetch_dims
(
tree
):
shapes
=
[]
tree_type
=
type
(
tree
)
if
tree_type
is
dict
:
for
v
in
tree
.
values
():
shapes
.
extend
(
_fetch_dims
(
v
))
elif
tree_type
is
list
or
tree_type
is
tuple
:
for
t
in
tree
:
shapes
.
extend
(
_fetch_dims
(
t
))
elif
tree_type
is
torch
.
Tensor
:
shapes
.
append
(
tree
.
shape
)
else
:
raise
ValueError
(
"Not supported"
)
return
shapes
@
torch
.
jit
.
ignore
def
_flat_idx_to_idx
(
flat_idx
:
int
,
dims
:
Tuple
[
int
],
)
->
Tuple
[
int
]:
idx
=
[]
for
d
in
reversed
(
dims
):
idx
.
append
(
flat_idx
%
d
)
flat_idx
=
flat_idx
//
d
return
tuple
(
reversed
(
idx
))
@
torch
.
jit
.
ignore
def
_get_minimal_slice_set
(
start
:
Sequence
[
int
],
end
:
Sequence
[
int
],
dims
:
int
,
start_edges
:
Optional
[
Sequence
[
bool
]]
=
None
,
end_edges
:
Optional
[
Sequence
[
bool
]]
=
None
,
)
->
Sequence
[
Tuple
[
int
]]:
"""
Produces an ordered sequence of tensor slices that, when used in
sequence on a tensor with shape dims, yields tensors that contain every
leaf in the contiguous range [start, end]. Care is taken to yield a
short sequence of slices, and perhaps even the shortest possible (I'm
pretty sure it's the latter).
end is INCLUSIVE.
"""
# start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree
def
reduce_edge_list
(
l
):
tally
=
1
for
i
in
range
(
len
(
l
)):
reversed_idx
=
-
1
*
(
i
+
1
)
l
[
reversed_idx
]
*=
tally
tally
=
l
[
reversed_idx
]
if
(
start_edges
is
None
):
start_edges
=
[
s
==
0
for
s
in
start
]
reduce_edge_list
(
start_edges
)
if
(
end_edges
is
None
):
end_edges
=
[
e
==
(
d
-
1
)
for
e
,
d
in
zip
(
end
,
dims
)]
reduce_edge_list
(
end_edges
)
# Base cases. Either start/end are empty and we're done, or the final,
# one-dimensional tensor can be simply sliced
if
(
len
(
start
)
==
0
):
return
[
tuple
()]
elif
(
len
(
start
)
==
1
):
return
[(
slice
(
start
[
0
],
end
[
0
]
+
1
),)]
slices
=
[]
path
=
[]
# Dimensions common to start and end can be selected directly
for
s
,
e
in
zip
(
start
,
end
):
if
(
s
==
e
):
path
.
append
(
slice
(
s
,
s
+
1
))
else
:
break
path
=
tuple
(
path
)
divergence_idx
=
len
(
path
)
# start == end, and we're done
if
(
divergence_idx
==
len
(
dims
)):
return
[
tuple
(
path
)]
def
upper
():
sdi
=
start
[
divergence_idx
]
return
[
path
+
(
slice
(
sdi
,
sdi
+
1
),)
+
s
for
s
in
_get_minimal_slice_set
(
start
[
divergence_idx
+
1
:],
[
d
-
1
for
d
in
dims
[
divergence_idx
+
1
:]],
dims
[
divergence_idx
+
1
:],
start_edges
=
start_edges
[
divergence_idx
+
1
:],
end_edges
=
[
1
for
_
in
end_edges
[
divergence_idx
+
1
:]]
)
]
def
lower
():
edi
=
end
[
divergence_idx
]
return
[
path
+
(
slice
(
edi
,
edi
+
1
),)
+
s
for
s
in
_get_minimal_slice_set
(
[
0
for
_
in
start
[
divergence_idx
+
1
:]],
end
[
divergence_idx
+
1
:],
dims
[
divergence_idx
+
1
:],
start_edges
=
[
1
for
_
in
start_edges
[
divergence_idx
+
1
:]],
end_edges
=
end_edges
[
divergence_idx
+
1
:],
)
]
# If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once
if
(
start_edges
[
divergence_idx
]
and
end_edges
[
divergence_idx
]):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
],
end
[
divergence_idx
]
+
1
),)
)
# If just start is at the edge, we can grab almost all of the subtree,
# treating only the ragged bottom edge as an edge case
elif
(
start_edges
[
divergence_idx
]):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
],
end
[
divergence_idx
]),)
)
slices
.
extend
(
lower
())
# Analogous to the previous case, but the top is ragged this time
elif
(
end_edges
[
divergence_idx
]):
slices
.
extend
(
upper
())
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
]
+
1
,
end
[
divergence_idx
]
+
1
),)
)
# If both sides of the range are ragged, we need to handle both sides
# separately. If there's contiguous meat in between them, we can index it
# in one big chunk
else
:
slices
.
extend
(
upper
())
middle_ground
=
end
[
divergence_idx
]
-
start
[
divergence_idx
]
if
(
middle_ground
>
1
):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
]
+
1
,
end
[
divergence_idx
]),)
)
slices
.
extend
(
lower
())
return
[
tuple
(
s
)
for
s
in
slices
]
@
torch
.
jit
.
ignore
def
_chunk_slice
(
t
:
torch
.
Tensor
,
flat_start
:
int
,
flat_end
:
int
,
no_batch_dims
:
int
,
)
->
torch
.
Tensor
:
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the initial reshape call, which can be
memory-intensive in certain situations. The only reshape operations
in this function are performed on sub-tensors that scale with
(flat_end - flat_start), the chunk size.
"""
batch_dims
=
t
.
shape
[:
no_batch_dims
]
start_idx
=
list
(
_flat_idx_to_idx
(
flat_start
,
batch_dims
))
# _get_minimal_slice_set is inclusive
end_idx
=
list
(
_flat_idx_to_idx
(
flat_end
-
1
,
batch_dims
))
# Get an ordered list of slices to perform
slices
=
_get_minimal_slice_set
(
start_idx
,
end_idx
,
batch_dims
,
)
sliced_tensors
=
[
t
[
s
]
for
s
in
slices
]
return
torch
.
cat
(
[
s
.
view
((
-
1
,)
+
t
.
shape
[
no_batch_dims
:])
for
s
in
sliced_tensors
]
)
def
chunk_layer
(
layer
:
Callable
,
inputs
:
Dict
[
str
,
Any
],
chunk_size
:
int
,
no_batch_dims
:
int
,
low_mem
:
bool
=
False
,
)
->
Any
:
"""
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees,"
consisting only of (arbitrarily nested) lists, tuples, and dicts with
torch.Tensor leaves.
Args:
layer:
The layer to be applied chunk-wise
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must
be tensors and must share the same batch dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
indexing of all batch dimensions simultaneously (s.t. the
number of sub-batches is the product of the batch dimensions).
no_batch_dims:
How many of the initial dimensions of each input tensor can
be considered batch dimensions.
low_mem:
Avoids flattening potentially large input tensors. Unnecessary
in most cases, and is ever so slightly slower than the default
setting.
Returns:
The reassembled output of the layer on the inputs.
"""
if
not
(
len
(
inputs
)
>
0
):
raise
ValueError
(
"Must provide at least one input"
)
initial_dims
=
[
shape
[:
no_batch_dims
]
for
shape
in
_fetch_dims
(
inputs
)]
orig_batch_dims
=
tuple
([
max
(
s
)
for
s
in
zip
(
*
initial_dims
)])
def
_prep_inputs
(
t
):
# TODO: make this more memory efficient. This sucks
if
(
not
low_mem
):
if
not
sum
(
t
.
shape
[:
no_batch_dims
])
==
no_batch_dims
:
t
=
t
.
expand
(
orig_batch_dims
+
t
.
shape
[
no_batch_dims
:])
t
=
t
.
reshape
(
-
1
,
*
t
.
shape
[
no_batch_dims
:])
else
:
t
=
t
.
expand
(
orig_batch_dims
+
t
.
shape
[
no_batch_dims
:])
return
t
prepped_inputs
=
tensor_tree_map
(
_prep_inputs
,
inputs
)
flat_batch_dim
=
1
for
d
in
orig_batch_dims
:
flat_batch_dim
*=
d
no_chunks
=
flat_batch_dim
//
chunk_size
+
(
flat_batch_dim
%
chunk_size
!=
0
)
i
=
0
out
=
None
for
_
in
range
(
no_chunks
):
# Chunk the input
if
(
not
low_mem
):
select_chunk
=
(
lambda
t
:
t
[
i
:
i
+
chunk_size
]
if
t
.
shape
[
0
]
!=
1
else
t
)
else
:
select_chunk
=
(
partial
(
_chunk_slice
,
flat_start
=
i
,
flat_end
=
min
(
flat_batch_dim
,
i
+
chunk_size
),
no_batch_dims
=
len
(
orig_batch_dims
)
)
)
chunks
=
tensor_tree_map
(
select_chunk
,
prepped_inputs
)
# Run the layer on the chunk
output_chunk
=
layer
(
**
chunks
)
# Allocate space for the output
if
out
is
None
:
allocate
=
lambda
t
:
t
.
new_zeros
((
flat_batch_dim
,)
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
allocate
,
output_chunk
)
# Put the chunk in its pre-allocated space
out_type
=
type
(
output_chunk
)
if
out_type
is
dict
:
def
assign
(
d1
,
d2
):
for
k
,
v
in
d1
.
items
():
if
type
(
v
)
is
dict
:
assign
(
v
,
d2
[
k
])
else
:
v
[
i
:
i
+
chunk_size
]
=
d2
[
k
]
assign
(
out
,
output_chunk
)
elif
out_type
is
tuple
:
for
x1
,
x2
in
zip
(
out
,
output_chunk
):
x1
[
i
:
i
+
chunk_size
]
=
x2
elif
out_type
is
torch
.
Tensor
:
out
[
i
:
i
+
chunk_size
]
=
output_chunk
else
:
raise
ValueError
(
"Not supported"
)
i
+=
chunk_size
reshape
=
lambda
t
:
t
.
view
(
orig_batch_dims
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
reshape
,
out
)
if
habana
.
is_habana
():
import
habana_frameworks.torch.core
as
htcore
htcore
.
mark_step
()
return
out
fastfold/utils/test_utils.py
0 → 100644
View file @
b14e47f4
import
os
import
random
import
torch
import
numpy
as
np
def
get_param_path
():
# develop
if
os
.
path
.
exists
(
'/data/scratch/alphafold/alphafold/params/params_model_1.npz'
):
return
'/data/scratch/alphafold/alphafold/params/params_model_1.npz'
# test
return
'/data/scratch/fastfold/weight.npz'
def
get_data_path
():
# develop
if
os
.
path
.
exists
(
'/home/lczxl/data2/fastfold/example_input/mono_batch.pkl'
):
return
'/home/lczxl/data2/fastfold/example_input/mono_batch.pkl'
# test
return
'/data/scratch/fastfold/mono_batch.pkl'
def
get_train_data_path
():
return
'/data/scratch/fastfold/std_train_batch.pkl'
def
set_seed
(
seed
):
random
.
seed
(
seed
)
os
.
environ
[
'PYTHONHASHSEED'
]
=
str
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
\ No newline at end of file
fastfold/utils/validation_utils.py
0 → 100644
View file @
b14e47f4
# Copyright 2023 HPC-AI Tech Inc.
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
fastfold.model.hub.loss
import
lddt_ca
from
fastfold.common
import
residue_constants
from
fastfold.utils.superimposition
import
superimpose
def
drmsd
(
structure_1
,
structure_2
,
mask
=
None
):
def
prep_d
(
structure
):
d
=
structure
[...,
:,
None
,
:]
-
structure
[...,
None
,
:,
:]
d
=
d
**
2
d
=
torch
.
sqrt
(
torch
.
sum
(
d
,
dim
=-
1
))
return
d
d1
=
prep_d
(
structure_1
)
d2
=
prep_d
(
structure_2
)
drmsd
=
d1
-
d2
drmsd
=
drmsd
**
2
if
(
mask
is
not
None
):
drmsd
=
drmsd
*
(
mask
[...,
None
]
*
mask
[...,
None
,
:])
drmsd
=
torch
.
sum
(
drmsd
,
dim
=
(
-
1
,
-
2
))
n
=
d1
.
shape
[
-
1
]
if
mask
is
None
else
torch
.
sum
(
mask
,
dim
=-
1
)
drmsd
=
drmsd
*
(
1
/
(
n
*
(
n
-
1
)))
if
n
>
1
else
(
drmsd
*
0.
)
drmsd
=
torch
.
sqrt
(
drmsd
)
return
drmsd
def
drmsd_np
(
structure_1
,
structure_2
,
mask
=
None
):
structure_1
=
torch
.
tensor
(
structure_1
)
structure_2
=
torch
.
tensor
(
structure_2
)
if
(
mask
is
not
None
):
mask
=
torch
.
tensor
(
mask
)
return
drmsd
(
structure_1
,
structure_2
,
mask
)
def
gdt
(
p1
,
p2
,
mask
,
cutoffs
):
n
=
torch
.
sum
(
mask
,
dim
=-
1
)
p1
=
p1
.
float
()
p2
=
p2
.
float
()
distances
=
torch
.
sqrt
(
torch
.
sum
((
p1
-
p2
)
**
2
,
dim
=-
1
))
scores
=
[]
for
c
in
cutoffs
:
score
=
torch
.
sum
((
distances
<=
c
)
*
mask
,
dim
=-
1
)
/
n
score
=
torch
.
mean
(
score
)
scores
.
append
(
score
)
return
sum
(
scores
)
/
len
(
scores
)
def
gdt_ts
(
p1
,
p2
,
mask
):
return
gdt
(
p1
,
p2
,
mask
,
[
1.
,
2.
,
4.
,
8.
])
def
gdt_ha
(
p1
,
p2
,
mask
):
return
gdt
(
p1
,
p2
,
mask
,
[
0.5
,
1.
,
2.
,
4.
])
def
compute_validation_metrics
(
batch
,
outputs
,
superimposition_metrics
=
False
,
):
metrics
=
{}
gt_coords
=
batch
[
"all_atom_positions"
]
pred_coords
=
outputs
[
"final_atom_positions"
]
all_atom_mask
=
batch
[
"all_atom_mask"
]
# This is super janky for superimposition. Fix later
gt_coords_masked
=
gt_coords
*
all_atom_mask
[...,
None
]
pred_coords_masked
=
pred_coords
*
all_atom_mask
[...,
None
]
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
gt_coords_masked_ca
=
gt_coords_masked
[...,
ca_pos
,
:]
pred_coords_masked_ca
=
pred_coords_masked
[...,
ca_pos
,
:]
all_atom_mask_ca
=
all_atom_mask
[...,
ca_pos
]
lddt_ca_score
=
lddt_ca
(
pred_coords
,
gt_coords
,
all_atom_mask
,
eps
=
1e-8
,
per_residue
=
False
,
)
metrics
[
"lddt_ca"
]
=
lddt_ca_score
drmsd_ca_score
=
drmsd
(
pred_coords_masked_ca
,
gt_coords_masked_ca
,
mask
=
all_atom_mask_ca
,
# still required here to compute n
)
metrics
[
"drmsd_ca"
]
=
drmsd_ca_score
if
(
superimposition_metrics
):
superimposed_pred
,
alignment_rmsd
=
superimpose
(
gt_coords_masked_ca
,
pred_coords_masked_ca
,
all_atom_mask_ca
,
)
gdt_ts_score
=
gdt_ts
(
superimposed_pred
,
gt_coords_masked_ca
,
all_atom_mask_ca
)
gdt_ha_score
=
gdt_ha
(
superimposed_pred
,
gt_coords_masked_ca
,
all_atom_mask_ca
)
metrics
[
"alignment_rmsd"
]
=
alignment_rmsd
metrics
[
"gdt_ts"
]
=
gdt_ts_score
metrics
[
"gdt_ha"
]
=
gdt_ha_score
return
metrics
fastfold/workflow/__init__.py
0 → 100644
View file @
b14e47f4
from
.workflow_run
import
batch_run
\ No newline at end of file
fastfold/workflow/factory/__init__.py
0 → 100644
View file @
b14e47f4
from
.task_factory
import
TaskFactory
from
.hhblits
import
HHBlitsFactory
from
.hhsearch
import
HHSearchFactory
from
.jackhmmer
import
JackHmmerFactory
from
.hhfilter
import
HHfilterFactory
from
.hmmsearch
import
HmmSearchFactory
\ No newline at end of file
fastfold/workflow/factory/hhblits.py
0 → 100644
View file @
b14e47f4
from
typing
import
List
import
ray
from
ray.dag.function_node
import
FunctionNode
from
fastfold.workflow.factory
import
TaskFactory
import
fastfold.data.tools.hhblits
as
ffHHBlits
class
HHBlitsFactory
(
TaskFactory
):
keywords
=
[
'binary_path'
,
'databases'
,
'n_cpu'
]
def
gen_node
(
self
,
fasta_path
:
str
,
output_path
:
str
,
after
:
List
[
FunctionNode
]
=
None
)
->
FunctionNode
:
self
.
isReady
()
# setup runner
runner
=
ffHHBlits
.
HHBlits
(
**
self
.
config
)
# generate function node
@
ray
.
remote
def
hhblits_node_func
(
after
:
List
[
FunctionNode
])
->
None
:
result
=
runner
.
query
(
fasta_path
)
with
open
(
output_path
,
'w'
)
as
f
:
f
.
write
(
result
[
'a3m'
])
return
hhblits_node_func
.
bind
(
after
)
fastfold/workflow/factory/hhfilter.py
0 → 100644
View file @
b14e47f4
import
subprocess
import
logging
from
typing
import
List
import
ray
from
ray.dag.function_node
import
FunctionNode
from
fastfold.workflow.factory
import
TaskFactory
class
HHfilterFactory
(
TaskFactory
):
keywords
=
[
'binary_path'
]
def
gen_node
(
self
,
fasta_path
:
str
,
output_path
:
str
,
after
:
List
[
FunctionNode
]
=
None
)
->
FunctionNode
:
self
.
isReady
()
# generate function node
@
ray
.
remote
def
hhfilter_node_func
(
after
:
List
[
FunctionNode
])
->
None
:
cmd
=
[
self
.
config
.
get
(
'binary_path'
),
]
if
'id'
in
self
.
config
:
cmd
+=
[
'-id'
,
str
(
self
.
config
.
get
(
'id'
))]
if
'cov'
in
self
.
config
:
cmd
+=
[
'-cov'
,
str
(
self
.
config
.
get
(
'cov'
))]
cmd
+=
[
'-i'
,
fasta_path
,
'-o'
,
output_path
]
subprocess
.
run
(
cmd
,
shell
=
True
)
return
hhfilter_node_func
.
bind
(
after
)
\ No newline at end of file
fastfold/workflow/factory/hhsearch.py
0 → 100644
View file @
b14e47f4
from
typing
import
List
import
inspect
import
ray
from
ray.dag.function_node
import
FunctionNode
import
fastfold.data.tools.hhsearch
as
ffHHSearch
from
fastfold.workflow.factory
import
TaskFactory
class
HHSearchFactory
(
TaskFactory
):
keywords
=
[
'binary_path'
,
'databases'
,
'n_cpu'
]
def
gen_node
(
self
,
a3m_path
:
str
,
output_path
:
str
,
atab_path
:
str
=
None
,
after
:
List
[
FunctionNode
]
=
None
)
->
FunctionNode
:
self
.
isReady
()
params
=
{
k
:
self
.
config
.
get
(
k
)
for
k
in
inspect
.
getfullargspec
(
ffHHSearch
.
HHSearch
.
__init__
).
kwonlyargs
if
self
.
config
.
get
(
k
)
}
# setup runner with a filtered config dict
runner
=
ffHHSearch
.
HHSearch
(
**
params
)
# generate function node
@
ray
.
remote
def
hhsearch_node_func
(
after
:
List
[
FunctionNode
])
->
None
:
with
open
(
a3m_path
,
"r"
)
as
f
:
a3m
=
f
.
read
()
if
atab_path
:
hhsearch_result
,
atab
=
runner
.
query
(
a3m
,
gen_atab
=
True
)
else
:
hhsearch_result
=
runner
.
query
(
a3m
)
with
open
(
output_path
,
"w"
)
as
f
:
f
.
write
(
hhsearch_result
)
if
atab_path
:
with
open
(
atab_path
,
"w"
)
as
f
:
f
.
write
(
atab
)
return
hhsearch_node_func
.
bind
(
after
)
fastfold/workflow/factory/hmmsearch.py
0 → 100644
View file @
b14e47f4
from
typing
import
List
import
inspect
import
ray
from
ray.dag.function_node
import
FunctionNode
from
fastfold.data.tools
import
hmmsearch
,
hmmbuild
from
fastfold.data
import
parsers
from
fastfold.workflow.factory
import
TaskFactory
from
typing
import
Optional
class
HmmSearchFactory
(
TaskFactory
):
keywords
=
[
'binary_path'
,
'hmmbuild_binary_path'
,
'database_path'
,
'n_cpu'
]
def
gen_node
(
self
,
msa_sto_path
:
str
,
output_dir
:
Optional
[
str
]
=
None
,
after
:
List
[
FunctionNode
]
=
None
)
->
FunctionNode
:
self
.
isReady
()
params
=
{
k
:
self
.
config
.
get
(
k
)
for
k
in
inspect
.
getfullargspec
(
hmmsearch
.
Hmmsearch
.
__init__
).
kwonlyargs
if
self
.
config
.
get
(
k
)
}
# setup runner with a filtered config dict
runner
=
hmmsearch
.
Hmmsearch
(
**
params
)
# generate function node
@
ray
.
remote
def
hmmsearch_node_func
(
after
:
List
[
FunctionNode
])
->
None
:
with
open
(
msa_sto_path
,
"r"
)
as
f
:
msa_sto
=
f
.
read
()
msa_sto
=
parsers
.
deduplicate_stockholm_msa
(
msa_sto
)
msa_sto
=
parsers
.
remove_empty_columns_from_stockholm_msa
(
msa_sto
)
hmmsearch_result
=
runner
.
query
(
msa_sto
,
output_dir
=
output_dir
)
return
hmmsearch_node_func
.
bind
(
after
)
Prev
1
…
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment