Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
6e68d6b0
"...deps/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "a0e1da037c17c4c5a7f990ab09f54d0a8f446994"
Commit
6e68d6b0
authored
Apr 29, 2022
by
Gustaf Ahdritz
Browse files
Add geometry functions to multimer
parent
dba44612
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
986 additions
and
0 deletions
+986
-0
openfold/utils/geometry/__init__.py
openfold/utils/geometry/__init__.py
+31
-0
openfold/utils/geometry/quat_rigid.py
openfold/utils/geometry/quat_rigid.py
+40
-0
openfold/utils/geometry/rigid_matrix_vector.py
openfold/utils/geometry/rigid_matrix_vector.py
+128
-0
openfold/utils/geometry/rotation_matrix.py
openfold/utils/geometry/rotation_matrix.py
+191
-0
openfold/utils/geometry/struct_of_array.py
openfold/utils/geometry/struct_of_array.py
+220
-0
openfold/utils/geometry/test_utils.py
openfold/utils/geometry/test_utils.py
+98
-0
openfold/utils/geometry/utils.py
openfold/utils/geometry/utils.py
+22
-0
openfold/utils/geometry/vector.py
openfold/utils/geometry/vector.py
+256
-0
No files found.
openfold/utils/geometry/__init__.py
0 → 100644
View file @
6e68d6b0
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Geometry Module."""
from
openfold.utils.geometry
import
rigid_matrix_vector
from
openfold.utils.geometry
import
rotation_matrix
from
openfold.utils.geometry
import
struct_of_array
from
openfold.utils.geometry
import
vector
Rot3Array
=
rotation_matrix
.
Rot3Array
Rigid3Array
=
rigid_matrix_vector
.
Rigid3Array
StructOfArray
=
struct_of_array
.
StructOfArray
Vec3Array
=
vector
.
Vec3Array
square_euclidean_distance
=
vector
.
square_euclidean_distance
euclidean_distance
=
vector
.
euclidean_distance
dihedral_angle
=
vector
.
dihedral_angle
dot
=
vector
.
dot
cross
=
vector
.
cross
openfold/utils/geometry/quat_rigid.py
0 → 100644
View file @
6e68d6b0
import
torch
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
openfold.utils.geometry.rotation_matrix
import
Rot3Array
from
openfold.utils.geometry.vector
import
Vec3Array
class
QuatRigid
(
nn
.
Module
):
def
__init__
(
self
,
c_hidden
,
full_quat
):
super
().
__init__
()
self
.
full_quat
=
full_quat
if
self
.
full_quat
:
rigid_dim
=
7
else
:
rigid_dim
=
6
self
.
linear
=
Linear
(
c_hidden
,
rigid_dim
)
def
forward
(
self
,
activations
:
torch
.
Tensor
)
->
Rigid3Array
:
# NOTE: During training, this needs to be run in higher precision
rigid_flat
=
self
.
linear
(
activations
.
to
(
torch
.
float32
))
print
(
rigid_flat
.
shape
)
rigid_flat
=
torch
.
unbind
(
rigid_flat
,
dim
=-
1
)
if
(
self
.
full_quat
):
qw
,
qx
,
qy
,
qz
=
rigid_flat
[:
4
]
translation
=
rigid_flat
[
4
:]
else
:
qx
,
qy
,
qz
=
rigid_flat
[:
3
]
qw
=
torch
.
ones_like
(
qx
)
translation
=
rigid_flat
[
3
:]
rotation
=
Rot3Array
.
from_quaternion
(
qw
,
qx
,
qy
,
qz
,
normalize
=
True
,
)
translation
=
Vec3Array
(
*
translation
)
return
Rigid3Array
(
rotation
,
translation
)
openfold/utils/geometry/rigid_matrix_vector.py
0 → 100644
View file @
6e68d6b0
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rigid3Array Transformations represented by a Matrix and a Vector."""
from
__future__
import
annotations
import
dataclasses
from
typing
import
Union
,
List
import
torch
from
openfold.utils.geometry
import
rotation_matrix
from
openfold.utils.geometry
import
struct_of_array
from
openfold.utils.geometry
import
vector
Float
=
Union
[
float
,
torch
.
Tensor
]
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Rigid3Array
:
"""Rigid Transformation, i.e. element of special euclidean group."""
rotation
:
rotation_matrix
.
Rot3Array
translation
:
vector
.
Vec3Array
def
__matmul__
(
self
,
other
:
Rigid3Array
)
->
Rigid3Array
:
new_rotation
=
self
.
rotation
@
other
.
rotation
# __matmul__
new_translation
=
self
.
apply_to_point
(
other
.
translation
)
return
Rigid3Array
(
new_rotation
,
new_translation
)
def
__getitem__
(
self
,
index
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
[
index
],
self
.
translation
[
index
],
)
def
__mul__
(
self
,
other
:
torch
.
Tensor
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
*
other
,
self
.
translation
*
other
,
)
def
map_tensor_fn
(
self
,
fn
)
->
Rigid3Array
:
return
Rigid3Array
(
self
.
rotation
.
map_tensor_fn
(
fn
),
self
.
translation
.
map_tensor_fn
(
fn
),
)
def
inverse
(
self
)
->
Rigid3Array
:
"""Return Rigid3Array corresponding to inverse transform."""
inv_rotation
=
self
.
rotation
.
inverse
()
inv_translation
=
inv_rotation
.
apply_to_point
(
-
self
.
translation
)
return
Rigid3Array
(
inv_rotation
,
inv_translation
)
def
apply_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Apply Rigid3Array transform to point."""
return
self
.
rotation
.
apply_to_point
(
point
)
+
self
.
translation
def
apply_inverse_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Apply inverse Rigid3Array transform to point."""
new_point
=
point
-
self
.
translation
return
self
.
rotation
.
apply_inverse_to_point
(
new_point
)
def
compose_rotation
(
self
,
other_rotation
):
rot
=
self
.
rotation
@
other_rotation
return
Rigid3Array
(
rot
,
trans
.
clone
())
@
classmethod
def
identity
(
cls
,
shape
,
device
)
->
Rigid3Array
:
"""Return identity Rigid3Array of given shape."""
return
cls
(
rotation_matrix
.
Rot3Array
.
identity
(
shape
,
device
),
vector
.
Vec3Array
.
zeros
(
shape
,
device
)
)
@
classmethod
def
cat
(
cls
,
rigids
:
List
[
Rigid3Array
],
dim
:
int
)
->
Rigid3Array
:
return
cls
(
Rot3Array
.
cat
([
r
.
rotation
for
r
in
rigids
],
dim
=
dim
),
Vec3Array
.
cat
([
r
.
translation
for
r
in
rigids
],
dim
=
dim
),
)
def
scale_translation
(
self
,
factor
:
Float
)
->
Rigid3Array
:
"""Scale translation in Rigid3Array by 'factor'."""
return
Rigid3Array
(
self
.
rotation
,
self
.
translation
*
factor
)
def
to_array
(
self
):
rot_array
=
self
.
rotation
.
to_array
()
vec_array
=
self
.
translation
.
to_array
()
return
torch
.
cat
([
rot_array
,
vec_array
[...,
None
]],
dim
=-
1
)
def
reshape
(
self
,
new_shape
)
->
Rigid3Array
:
rots
=
self
.
rotation
.
reshape
(
new_shape
)
trans
=
self
.
translation
.
reshape
(
new_shape
)
return
Rigid3Aray
(
rots
,
trans
)
@
classmethod
def
from_array
(
cls
,
array
):
rot
=
rotation_matrix
.
Rot3Array
.
from_array
(
array
[...,
:
3
])
vec
=
vector
.
Vec3Array
.
from_array
(
array
[...,
-
1
])
return
cls
(
rot
,
vec
)
@
classmethod
def
from_tensor_4x4
(
cls
,
array
):
return
cls
.
from_array
(
array
)
@
classmethod
def
from_array4x4
(
cls
,
array
:
torch
.
tensor
)
->
Rigid3Array
:
"""Construct Rigid3Array from homogeneous 4x4 array."""
rotation
=
rotation_matrix
.
Rot3Array
(
array
[...,
0
,
0
],
array
[...,
0
,
1
],
array
[...,
0
,
2
],
array
[...,
1
,
0
],
array
[...,
1
,
1
],
array
[...,
1
,
2
],
array
[...,
2
,
0
],
array
[...,
2
,
1
],
array
[...,
2
,
2
]
)
translation
=
vector
.
Vec3Array
(
array
[...,
0
,
3
],
array
[...,
1
,
3
],
array
[...,
2
,
3
])
return
cls
(
rotation
,
translation
)
openfold/utils/geometry/rotation_matrix.py
0 → 100644
View file @
6e68d6b0
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rot3Array Matrix Class."""
from
__future__
import
annotations
import
dataclasses
import
torch
import
numpy
as
np
from
openfold.utils.geometry
import
struct_of_array
from
openfold.utils.geometry
import
utils
from
openfold.utils.geometry
import
vector
COMPONENTS
=
[
'xx'
,
'xy'
,
'xz'
,
'yx'
,
'yy'
,
'yz'
,
'zx'
,
'zy'
,
'zz'
]
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Rot3Array
:
"""Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays."""
xx
:
torch
.
Tensor
=
dataclasses
.
field
(
metadata
=
{
'dtype'
:
torch
.
float32
})
xy
:
torch
.
Tensor
xz
:
torch
.
Tensor
yx
:
torch
.
Tensor
yy
:
torch
.
Tensor
yz
:
torch
.
Tensor
zx
:
torch
.
Tensor
zy
:
torch
.
Tensor
zz
:
torch
.
Tensor
__array_ufunc__
=
None
def
__getitem__
(
self
,
index
):
field_names
=
utils
.
get_field_names
(
Rot3Array
)
return
Rot3Array
(
**
{
name
:
getattr
(
self
,
name
)[
index
]
for
name
in
field_names
}
)
def
__mul__
(
self
,
other
:
torch
.
Tensor
):
field_names
=
utils
.
get_field_names
(
Rot3Array
)
return
Rot3Array
(
**
{
name
:
getattr
(
self
,
name
)
*
other
for
name
in
field_names
}
)
def
map_tensor_fn
(
self
,
fn
)
->
Rot3Array
:
field_names
=
utils
.
get_field_names
(
Rot3Array
)
return
Rot3Array
(
**
{
name
:
fn
(
getattr
(
self
,
name
))
for
name
in
field_names
}
)
def
inverse
(
self
)
->
Rot3Array
:
"""Returns inverse of Rot3Array."""
return
Rot3Array
(
self
.
xx
,
self
.
yx
,
self
.
zx
,
self
.
xy
,
self
.
yy
,
self
.
zy
,
self
.
xz
,
self
.
yz
,
self
.
zz
)
def
apply_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Applies Rot3Array to point."""
return
vector
.
Vec3Array
(
self
.
xx
*
point
.
x
+
self
.
xy
*
point
.
y
+
self
.
xz
*
point
.
z
,
self
.
yx
*
point
.
x
+
self
.
yy
*
point
.
y
+
self
.
yz
*
point
.
z
,
self
.
zx
*
point
.
x
+
self
.
zy
*
point
.
y
+
self
.
zz
*
point
.
z
)
def
apply_inverse_to_point
(
self
,
point
:
vector
.
Vec3Array
)
->
vector
.
Vec3Array
:
"""Applies inverse Rot3Array to point."""
return
self
.
inverse
().
apply_to_point
(
point
)
def
__matmul__
(
self
,
other
:
Rot3Array
)
->
Rot3Array
:
"""Composes two Rot3Arrays."""
c0
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xx
,
other
.
yx
,
other
.
zx
))
c1
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xy
,
other
.
yy
,
other
.
zy
))
c2
=
self
.
apply_to_point
(
vector
.
Vec3Array
(
other
.
xz
,
other
.
yz
,
other
.
zz
))
return
Rot3Array
(
c0
.
x
,
c1
.
x
,
c2
.
x
,
c0
.
y
,
c1
.
y
,
c2
.
y
,
c0
.
z
,
c1
.
z
,
c2
.
z
)
@
classmethod
def
identity
(
cls
,
shape
,
device
)
->
Rot3Array
:
"""Returns identity of given shape."""
ones
=
torch
.
ones
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
)
zeros
=
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
)
return
cls
(
ones
,
zeros
,
zeros
,
zeros
,
ones
,
zeros
,
zeros
,
zeros
,
ones
)
@
classmethod
def
from_two_vectors
(
cls
,
e0
:
vector
.
Vec3Array
,
e1
:
vector
.
Vec3Array
)
->
Rot3Array
:
"""Construct Rot3Array from two Vectors.
Rot3Array is constructed such that in the corresponding frame 'e0' lies on
the positive x-Axis and 'e1' lies in the xy plane with positive sign of y.
Args:
e0: Vector
e1: Vector
Returns:
Rot3Array
"""
# Normalize the unit vector for the x-axis, e0.
e0
=
e0
.
normalized
()
# make e1 perpendicular to e0.
c
=
e1
.
dot
(
e0
)
e1
=
(
e1
-
c
*
e0
).
normalized
()
# Compute e2 as cross product of e0 and e1.
e2
=
e0
.
cross
(
e1
)
return
cls
(
e0
.
x
,
e1
.
x
,
e2
.
x
,
e0
.
y
,
e1
.
y
,
e2
.
y
,
e0
.
z
,
e1
.
z
,
e2
.
z
)
@
classmethod
def
from_array
(
cls
,
array
:
torch
.
Tensor
)
->
Rot3Array
:
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
return
cls
(
torch
.
unbind
(
array
,
dim
=-
2
))
def
to_array
(
self
)
->
torch
.
Tensor
:
"""Convert Rot3Array to array of shape [..., 3, 3]."""
return
torch
.
stack
(
[
torch
.
stack
([
self
.
xx
,
self
.
xy
,
self
.
xz
],
dim
=-
1
),
torch
.
stack
([
self
.
yx
,
self
.
yy
,
self
.
yz
],
dim
=-
1
),
torch
.
stack
([
self
.
zx
,
self
.
zy
,
self
.
zz
],
dim
=-
1
)
],
dim
=-
2
)
@
classmethod
def
from_quaternion
(
cls
,
w
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
normalize
:
bool
=
True
,
eps
:
float
=
1e-6
)
->
Rot3Array
:
"""Construct Rot3Array from components of quaternion."""
if
normalize
:
inv_norm
=
torch
.
rsqrt
(
eps
+
w
**
2
+
x
**
2
+
y
**
2
+
z
**
2
)
w
*=
inv_norm
x
*=
inv_norm
y
*=
inv_norm
z
*=
inv_norm
xx
=
1
-
2
*
(
y
**
2
+
z
**
2
)
xy
=
2
*
(
x
*
y
-
w
*
z
)
xz
=
2
*
(
x
*
z
+
w
*
y
)
yx
=
2
*
(
x
*
y
+
w
*
z
)
yy
=
1
-
2
*
(
x
**
2
+
z
**
2
)
yz
=
2
*
(
y
*
z
-
w
*
x
)
zx
=
2
*
(
x
*
z
-
w
*
y
)
zy
=
2
*
(
y
*
z
+
w
*
x
)
zz
=
1
-
2
*
(
x
**
2
+
y
**
2
)
return
cls
(
xx
,
xy
,
xz
,
yx
,
yy
,
yz
,
zx
,
zy
,
zz
)
def
reshape
(
self
,
new_shape
):
field_names
=
utils
.
get_field_names
(
Rot3Array
)
reshape_fn
=
lambda
t
:
t
.
reshape
(
new_shape
)
return
Rot3Array
(
**
{
name
:
reshape_fn
(
getattr
(
self
,
name
))
for
name
in
field_names
}
)
@
classmethod
def
cat
(
cls
,
rots
:
List
[
Rot3Array
],
dim
:
int
)
->
Rot3Array
:
field_names
=
utils
.
get_field_names
(
Rot3Array
)
cat_fn
=
lambda
l
:
torch
.
cat
(
l
,
dim
=
dim
)
return
cls
(
**
{
name
:
cat_fn
([
getattr
(
r
,
name
)
for
r
in
rots
])
for
name
in
field_names
}
)
openfold/utils/geometry/struct_of_array.py
0 → 100644
View file @
6e68d6b0
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class decorator to represent (nested) struct of arrays."""
import
dataclasses
import
jax
def
get_item
(
instance
,
key
):
sliced
=
{}
for
field
in
get_array_fields
(
instance
):
num_trailing_dims
=
field
.
metadata
.
get
(
'num_trailing_dims'
,
0
)
this_key
=
key
if
isinstance
(
key
,
tuple
)
and
Ellipsis
in
this_key
:
this_key
+=
(
slice
(
None
),)
*
num_trailing_dims
sliced
[
field
.
name
]
=
getattr
(
instance
,
field
.
name
)[
this_key
]
return
dataclasses
.
replace
(
instance
,
**
sliced
)
@
property
def
get_shape
(
instance
):
"""Returns Shape for given instance of dataclass."""
first_field
=
dataclasses
.
fields
(
instance
)[
0
]
num_trailing_dims
=
first_field
.
metadata
.
get
(
'num_trailing_dims'
,
None
)
value
=
getattr
(
instance
,
first_field
.
name
)
if
num_trailing_dims
:
return
value
.
shape
[:
-
num_trailing_dims
]
else
:
return
value
.
shape
def
get_len
(
instance
):
"""Returns length for given instance of dataclass."""
shape
=
instance
.
shape
if
shape
:
return
shape
[
0
]
else
:
raise
TypeError
(
'len() of unsized object'
)
# Match jax.numpy behavior.
@
property
def
get_dtype
(
instance
):
"""Returns Dtype for given instance of dataclass."""
fields
=
dataclasses
.
fields
(
instance
)
sets_dtype
=
[
field
.
name
for
field
in
fields
if
field
.
metadata
.
get
(
'sets_dtype'
,
False
)
]
if
sets_dtype
:
assert
len
(
sets_dtype
)
==
1
,
'at most field can set dtype'
field_value
=
getattr
(
instance
,
sets_dtype
[
0
])
elif
instance
.
same_dtype
:
field_value
=
getattr
(
instance
,
fields
[
0
].
name
)
else
:
# Should this be Value Error?
raise
AttributeError
(
'Trying to access Dtype on Struct of Array without'
'either "same_dtype" or field setting dtype'
)
if
hasattr
(
field_value
,
'dtype'
):
return
field_value
.
dtype
else
:
# Should this be Value Error?
raise
AttributeError
(
f
'field_value
{
field_value
}
does not have dtype'
)
def
replace
(
instance
,
**
kwargs
):
return
dataclasses
.
replace
(
instance
,
**
kwargs
)
def
post_init
(
instance
):
"""Validate instance has same shapes & dtypes."""
array_fields
=
get_array_fields
(
instance
)
arrays
=
list
(
get_array_fields
(
instance
,
return_values
=
True
).
values
())
first_field
=
array_fields
[
0
]
# These slightly weird constructions about checking whether the leaves are
# actual arrays is since e.g. vmap internally relies on being able to
# construct pytree's with object() as leaves, this would break the checking
# as such we are only validating the object when the entries in the dataclass
# Are arrays or other dataclasses of arrays.
try
:
dtype
=
instance
.
dtype
except
AttributeError
:
dtype
=
None
if
dtype
is
not
None
:
first_shape
=
instance
.
shape
for
array
,
field
in
zip
(
arrays
,
array_fields
):
field_shape
=
array
.
shape
num_trailing_dims
=
field
.
metadata
.
get
(
'num_trailing_dims'
,
None
)
if
num_trailing_dims
:
array_shape
=
array
.
shape
field_shape
=
array_shape
[:
-
num_trailing_dims
]
msg
=
(
f
'field
{
field
}
should have number of trailing dims'
' {num_trailing_dims}'
)
assert
len
(
array_shape
)
==
len
(
first_shape
)
+
num_trailing_dims
,
msg
else
:
field_shape
=
array
.
shape
shape_msg
=
(
f
"Stripped Shape
{
field_shape
}
of field
{
field
}
doesn't "
f
"match shape
{
first_shape
}
of field
{
first_field
}
"
)
assert
field_shape
==
first_shape
,
shape_msg
field_dtype
=
array
.
dtype
allowed_metadata_dtypes
=
field
.
metadata
.
get
(
'allowed_dtypes'
,
[])
if
allowed_metadata_dtypes
:
msg
=
f
'Dtype is
{
field_dtype
}
but must be in
{
allowed_metadata_dtypes
}
'
assert
field_dtype
in
allowed_metadata_dtypes
,
msg
if
'dtype'
in
field
.
metadata
:
target_dtype
=
field
.
metadata
[
'dtype'
]
else
:
target_dtype
=
dtype
msg
=
f
'Dtype is
{
field_dtype
}
but must be
{
target_dtype
}
'
assert
field_dtype
==
target_dtype
,
msg
def
flatten
(
instance
):
"""Flatten Struct of Array instance."""
array_likes
=
list
(
get_array_fields
(
instance
,
return_values
=
True
).
values
())
flat_array_likes
=
[]
inner_treedefs
=
[]
num_arrays
=
[]
for
array_like
in
array_likes
:
flat_array_like
,
inner_treedef
=
jax
.
tree_flatten
(
array_like
)
inner_treedefs
.
append
(
inner_treedef
)
flat_array_likes
+=
flat_array_like
num_arrays
.
append
(
len
(
flat_array_like
))
metadata
=
get_metadata_fields
(
instance
,
return_values
=
True
)
metadata
=
type
(
instance
).
metadata_cls
(
**
metadata
)
return
flat_array_likes
,
(
inner_treedefs
,
metadata
,
num_arrays
)
def
make_metadata_class
(
cls
):
metadata_fields
=
get_fields
(
cls
,
lambda
x
:
x
.
metadata
.
get
(
'is_metadata'
,
False
))
metadata_cls
=
dataclasses
.
make_dataclass
(
cls_name
=
'Meta'
+
cls
.
__name__
,
fields
=
[(
field
.
name
,
field
.
type
,
field
)
for
field
in
metadata_fields
],
frozen
=
True
,
eq
=
True
)
return
metadata_cls
def
get_fields
(
cls_or_instance
,
filterfn
,
return_values
=
False
):
fields
=
dataclasses
.
fields
(
cls_or_instance
)
fields
=
[
field
for
field
in
fields
if
filterfn
(
field
)]
if
return_values
:
return
{
field
.
name
:
getattr
(
cls_or_instance
,
field
.
name
)
for
field
in
fields
}
else
:
return
fields
def
get_array_fields
(
cls
,
return_values
=
False
):
return
get_fields
(
cls
,
lambda
x
:
not
x
.
metadata
.
get
(
'is_metadata'
,
False
),
return_values
=
return_values
)
def
get_metadata_fields
(
cls
,
return_values
=
False
):
return
get_fields
(
cls
,
lambda
x
:
x
.
metadata
.
get
(
'is_metadata'
,
False
),
return_values
=
return_values
)
class
StructOfArray
:
"""Class Decorator for Struct Of Arrays."""
def
__init__
(
self
,
same_dtype
=
True
):
self
.
same_dtype
=
same_dtype
def
__call__
(
self
,
cls
):
cls
.
__array_ufunc__
=
None
cls
.
replace
=
replace
cls
.
same_dtype
=
self
.
same_dtype
cls
.
dtype
=
get_dtype
cls
.
shape
=
get_shape
cls
.
__len__
=
get_len
cls
.
__getitem__
=
get_item
cls
.
__post_init__
=
post_init
new_cls
=
dataclasses
.
dataclass
(
cls
,
frozen
=
True
,
eq
=
False
)
# pytype: disable=wrong-keyword-args
# pytree claims to require metadata to be hashable, not sure why,
# But making derived dataclass that can just hold metadata
new_cls
.
metadata_cls
=
make_metadata_class
(
new_cls
)
def
unflatten
(
aux
,
data
):
inner_treedefs
,
metadata
,
num_arrays
=
aux
array_fields
=
[
field
.
name
for
field
in
get_array_fields
(
new_cls
)]
value_dict
=
{}
array_start
=
0
for
num_array
,
inner_treedef
,
array_field
in
zip
(
num_arrays
,
inner_treedefs
,
array_fields
):
value_dict
[
array_field
]
=
jax
.
tree_unflatten
(
inner_treedef
,
data
[
array_start
:
array_start
+
num_array
])
array_start
+=
num_array
metadata_fields
=
get_metadata_fields
(
new_cls
)
for
field
in
metadata_fields
:
value_dict
[
field
.
name
]
=
getattr
(
metadata
,
field
.
name
)
return
new_cls
(
**
value_dict
)
jax
.
tree_util
.
register_pytree_node
(
nodetype
=
new_cls
,
flatten_func
=
flatten
,
unflatten_func
=
unflatten
)
return
new_cls
openfold/utils/geometry/test_utils.py
0 → 100644
View file @
6e68d6b0
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utils for tests."""
import
dataclasses
from
alphafold.model.geometry
import
rigid_matrix_vector
from
alphafold.model.geometry
import
rotation_matrix
from
alphafold.model.geometry
import
vector
import
jax.numpy
as
jnp
import
numpy
as
np
def
assert_rotation_matrix_equal
(
matrix1
:
rotation_matrix
.
Rot3Array
,
matrix2
:
rotation_matrix
.
Rot3Array
):
for
field
in
dataclasses
.
fields
(
rotation_matrix
.
Rot3Array
):
field
=
field
.
name
np
.
testing
.
assert_array_equal
(
getattr
(
matrix1
,
field
),
getattr
(
matrix2
,
field
))
def
assert_rotation_matrix_close
(
mat1
:
rotation_matrix
.
Rot3Array
,
mat2
:
rotation_matrix
.
Rot3Array
):
np
.
testing
.
assert_array_almost_equal
(
mat1
.
to_array
(),
mat2
.
to_array
(),
6
)
def
assert_array_equal_to_rotation_matrix
(
array
:
jnp
.
ndarray
,
matrix
:
rotation_matrix
.
Rot3Array
):
"""Check that array and Matrix match."""
np
.
testing
.
assert_array_equal
(
matrix
.
xx
,
array
[...,
0
,
0
])
np
.
testing
.
assert_array_equal
(
matrix
.
xy
,
array
[...,
0
,
1
])
np
.
testing
.
assert_array_equal
(
matrix
.
xz
,
array
[...,
0
,
2
])
np
.
testing
.
assert_array_equal
(
matrix
.
yx
,
array
[...,
1
,
0
])
np
.
testing
.
assert_array_equal
(
matrix
.
yy
,
array
[...,
1
,
1
])
np
.
testing
.
assert_array_equal
(
matrix
.
yz
,
array
[...,
1
,
2
])
np
.
testing
.
assert_array_equal
(
matrix
.
zx
,
array
[...,
2
,
0
])
np
.
testing
.
assert_array_equal
(
matrix
.
zy
,
array
[...,
2
,
1
])
np
.
testing
.
assert_array_equal
(
matrix
.
zz
,
array
[...,
2
,
2
])
def
assert_array_close_to_rotation_matrix
(
array
:
jnp
.
ndarray
,
matrix
:
rotation_matrix
.
Rot3Array
):
np
.
testing
.
assert_array_almost_equal
(
matrix
.
to_array
(),
array
,
6
)
def
assert_vectors_equal
(
vec1
:
vector
.
Vec3Array
,
vec2
:
vector
.
Vec3Array
):
np
.
testing
.
assert_array_equal
(
vec1
.
x
,
vec2
.
x
)
np
.
testing
.
assert_array_equal
(
vec1
.
y
,
vec2
.
y
)
np
.
testing
.
assert_array_equal
(
vec1
.
z
,
vec2
.
z
)
def
assert_vectors_close
(
vec1
:
vector
.
Vec3Array
,
vec2
:
vector
.
Vec3Array
):
np
.
testing
.
assert_allclose
(
vec1
.
x
,
vec2
.
x
,
atol
=
1e-6
,
rtol
=
0.
)
np
.
testing
.
assert_allclose
(
vec1
.
y
,
vec2
.
y
,
atol
=
1e-6
,
rtol
=
0.
)
np
.
testing
.
assert_allclose
(
vec1
.
z
,
vec2
.
z
,
atol
=
1e-6
,
rtol
=
0.
)
def
assert_array_close_to_vector
(
array
:
jnp
.
ndarray
,
vec
:
vector
.
Vec3Array
):
np
.
testing
.
assert_allclose
(
vec
.
to_array
(),
array
,
atol
=
1e-6
,
rtol
=
0.
)
def
assert_array_equal_to_vector
(
array
:
jnp
.
ndarray
,
vec
:
vector
.
Vec3Array
):
np
.
testing
.
assert_array_equal
(
vec
.
to_array
(),
array
)
def
assert_rigid_equal_to_rigid
(
rigid1
:
rigid_matrix_vector
.
Rigid3Array
,
rigid2
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rot_trans_equal_to_rigid
(
rigid1
.
rotation
,
rigid1
.
translation
,
rigid2
)
def
assert_rigid_close_to_rigid
(
rigid1
:
rigid_matrix_vector
.
Rigid3Array
,
rigid2
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rot_trans_close_to_rigid
(
rigid1
.
rotation
,
rigid1
.
translation
,
rigid2
)
def
assert_rot_trans_equal_to_rigid
(
rot
:
rotation_matrix
.
Rot3Array
,
trans
:
vector
.
Vec3Array
,
rigid
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rotation_matrix_equal
(
rot
,
rigid
.
rotation
)
assert_vectors_equal
(
trans
,
rigid
.
translation
)
def
assert_rot_trans_close_to_rigid
(
rot
:
rotation_matrix
.
Rot3Array
,
trans
:
vector
.
Vec3Array
,
rigid
:
rigid_matrix_vector
.
Rigid3Array
):
assert_rotation_matrix_close
(
rot
,
rigid
.
rotation
)
assert_vectors_close
(
trans
,
rigid
.
translation
)
openfold/utils/geometry/utils.py
0 → 100644
View file @
6e68d6b0
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for geometry library."""
import
dataclasses
def
get_field_names
(
cls
):
fields
=
dataclasses
.
fields
(
cls
)
field_names
=
[
f
.
name
for
f
in
fields
]
return
field_names
openfold/utils/geometry/vector.py
0 → 100644
View file @
6e68d6b0
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vec3Array Class."""
from
__future__
import
annotations
import
dataclasses
from
typing
import
Union
,
List
import
torch
from
openfold.utils.geometry
import
utils
Float
=
Union
[
float
,
torch
.
Tensor
]
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Vec3Array
:
x
:
torch
.
Tensor
=
dataclasses
.
field
(
metadata
=
{
'dtype'
:
torch
.
float32
})
y
:
torch
.
Tensor
z
:
torch
.
Tensor
def
__post_init__
(
self
):
if
hasattr
(
self
.
x
,
'dtype'
):
assert
self
.
x
.
dtype
==
self
.
y
.
dtype
assert
self
.
x
.
dtype
==
self
.
z
.
dtype
assert
all
([
x
==
y
for
x
,
y
in
zip
(
self
.
x
.
shape
,
self
.
y
.
shape
)])
assert
all
([
x
==
z
for
x
,
z
in
zip
(
self
.
x
.
shape
,
self
.
z
.
shape
)])
def
__add__
(
self
,
other
:
Vec3Array
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
+
other
.
x
,
self
.
y
+
other
.
y
,
self
.
z
+
other
.
z
,
)
def
__sub__
(
self
,
other
:
Vec3Array
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
-
other
.
x
,
self
.
y
-
other
.
y
,
self
.
z
-
other
.
z
,
)
def
__mul__
(
self
,
other
:
Float
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
*
other
,
self
.
y
*
other
,
self
.
z
*
other
,
)
def
__rmul__
(
self
,
other
:
Float
)
->
Vec3Array
:
return
self
*
other
def
__truediv__
(
self
,
other
:
Float
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
/
other
,
self
.
y
/
other
,
self
.
z
/
other
,
)
def
__neg__
(
self
)
->
Vec3Array
:
return
self
*
-
1
def
__pos__
(
self
)
->
Vec3Array
:
return
self
*
1
def
__getitem__
(
self
,
index
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
[
index
],
self
.
y
[
index
],
self
.
z
[
index
],
)
def
__iter__
(
self
):
return
iter
((
self
.
x
,
self
.
y
,
self
.
z
))
@
property
def
shape
(
self
):
return
self
.
x
.
shape
def
map_tensor_fn
(
self
,
fn
)
->
Vec3Array
:
return
Vec3Array
(
fn
(
self
.
x
),
fn
(
self
.
y
),
fn
(
self
.
z
),
)
def
cross
(
self
,
other
:
Vec3Array
)
->
Vec3Array
:
"""Compute cross product between 'self' and 'other'."""
new_x
=
self
.
y
*
other
.
z
-
self
.
z
*
other
.
y
new_y
=
self
.
z
*
other
.
x
-
self
.
x
*
other
.
z
new_z
=
self
.
x
*
other
.
y
-
self
.
y
*
other
.
x
return
Vec3Array
(
new_x
,
new_y
,
new_z
)
def
dot
(
self
,
other
:
Vec3Array
)
->
Float
:
"""Compute dot product between 'self' and 'other'."""
return
self
.
x
*
other
.
x
+
self
.
y
*
other
.
y
+
self
.
z
*
other
.
z
def
norm
(
self
,
epsilon
:
float
=
1e-6
)
->
Float
:
"""Compute Norm of Vec3Array, clipped to epsilon."""
# To avoid NaN on the backward pass, we must use maximum before the sqrt
norm2
=
self
.
dot
(
self
)
if
epsilon
:
norm2
=
torch
.
clamp
(
norm2
,
max
=
epsilon
**
2
)
return
torch
.
sqrt
(
norm2
)
def
norm2
(
self
):
return
self
.
dot
(
self
)
def
normalized
(
self
,
epsilon
:
float
=
1e-6
)
->
Vec3Array
:
"""Return unit vector with optional clipping."""
return
self
/
self
.
norm
(
epsilon
)
def
clone
(
self
)
->
Vec3Array
:
return
Vec3Array
(
self
.
x
.
clone
(),
self
.
y
.
clone
(),
self
.
z
.
clone
(),
)
def
reshape
(
self
,
new_shape
)
->
Vec3Array
:
x
=
self
.
x
.
reshape
(
new_shape
)
y
=
self
.
y
.
reshape
(
new_shape
)
z
=
self
.
z
.
reshape
(
new_shape
)
return
Vec3Array
(
x
,
y
,
z
)
def
sum
(
self
,
dim
)
->
Vec3Array
:
return
Vec3Array
(
torch
.
sum
(
self
.
x
,
dim
=
dim
),
torch
.
sum
(
self
.
y
,
dim
=
dim
),
torch
.
sum
(
self
.
z
,
dim
=
dim
),
)
@
classmethod
def
zeros
(
cls
,
shape
,
device
=
"cpu"
):
"""Return Vec3Array corresponding to zeros of given shape."""
return
cls
(
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
),
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
),
torch
.
zeros
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
)
)
def
to_array
(
self
)
->
torch
.
Tensor
:
return
torch
.
stack
([
self
.
x
,
self
.
y
,
self
.
z
],
dim
=-
1
)
@
classmethod
def
from_tensor
(
cls
,
tensor
):
return
cls
(
*
torch
.
unbind
(
tensor
,
dim
=-
1
))
@
classmethod
def
cat
(
cls
,
vecs
:
List
[
Vec3Array
],
dim
:
int
)
->
Vec3Array
:
return
cls
(
torch
.
cat
([
v
.
x
for
v
in
vecs
],
dim
=
dim
),
torch
.
cat
([
v
.
y
for
v
in
vecs
],
dim
=
dim
),
torch
.
cat
([
v
.
z
for
v
in
vecs
],
dim
=
dim
),
)
def
square_euclidean_distance
(
vec1
:
Vec3Array
,
vec2
:
Vec3Array
,
epsilon
:
float
=
1e-6
)
->
Float
:
"""Computes square of euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute distance to
vec2: Vec3Array to compute distance from, should be
broadcast compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of square euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
difference
=
vec1
-
vec2
distance
=
difference
.
dot
(
difference
)
if
epsilon
:
distance
=
torch
.
maximum
(
distance
,
epsilon
)
return
distance
def
dot
(
vector1
:
Vec3Array
,
vector2
:
Vec3Array
)
->
Float
:
return
vector1
.
dot
(
vector2
)
def
cross
(
vector1
:
Vec3Array
,
vector2
:
Vec3Array
)
->
Float
:
return
vector1
.
cross
(
vector2
)
def
norm
(
vector
:
Vec3Array
,
epsilon
:
float
=
1e-6
)
->
Float
:
return
vector
.
norm
(
epsilon
)
def
normalized
(
vector
:
Vec3Array
,
epsilon
:
float
=
1e-6
)
->
Vec3Array
:
return
vector
.
normalized
(
epsilon
)
def
euclidean_distance
(
vec1
:
Vec3Array
,
vec2
:
Vec3Array
,
epsilon
:
float
=
1e-6
)
->
Float
:
"""Computes euclidean distance between 'vec1' and 'vec2'.
Args:
vec1: Vec3Array to compute euclidean distance to
vec2: Vec3Array to compute euclidean distance from, should be
broadcast compatible with 'vec1'
epsilon: distance is clipped from below to be at least epsilon
Returns:
Array of euclidean distances;
shape will be result of broadcasting 'vec1' and 'vec2'
"""
distance_sq
=
square_euclidean_distance
(
vec1
,
vec2
,
epsilon
**
2
)
distance
=
torch
.
sqrt
(
distance_sq
)
return
distance
def
dihedral_angle
(
a
:
Vec3Array
,
b
:
Vec3Array
,
c
:
Vec3Array
,
d
:
Vec3Array
)
->
Float
:
"""Computes torsion angle for a quadruple of points.
For points (a, b, c, d), this is the angle between the planes defined by
points (a, b, c) and (b, c, d). It is also known as the dihedral angle.
Arguments:
a: A Vec3Array of coordinates.
b: A Vec3Array of coordinates.
c: A Vec3Array of coordinates.
d: A Vec3Array of coordinates.
Returns:
A tensor of angles in radians: [-pi, pi].
"""
v1
=
a
-
b
v2
=
b
-
c
v3
=
d
-
c
c1
=
v1
.
cross
(
v2
)
c2
=
v3
.
cross
(
v2
)
c3
=
c2
.
cross
(
c1
)
v2_mag
=
v2
.
norm
()
return
torch
.
atan2
(
c3
.
dot
(
v2
),
v2_mag
*
c1
.
dot
(
c2
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment