Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
d3acabd1
Commit
d3acabd1
authored
Nov 17, 2021
by
Gustaf Ahdritz
Browse files
Add documentation to affine_utils
parent
99e35628
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
355 additions
and
44 deletions
+355
-44
openfold/utils/affine_utils.py
openfold/utils/affine_utils.py
+355
-44
No files found.
openfold/utils/affine_utils.py
View file @
d3acabd1
...
@@ -13,11 +13,27 @@
...
@@ -13,11 +13,27 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
__future__
import
annotations
from
typing
import
Tuple
,
Any
,
Sequence
,
Callable
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
def
rot_matmul
(
a
,
b
):
def
rot_matmul
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Performs matrix multiplication of two rotation matrix tensors. Written
out by hand to avoid transfer to low-precision tensor cores.
Args:
a: [*, 3, 3] left multiplicand
b: [*, 3, 3] right multiplicand
Returns:
The product ab
"""
row_1
=
torch
.
stack
(
row_1
=
torch
.
stack
(
[
[
a
[...,
0
,
0
]
*
b
[...,
0
,
0
]
a
[...,
0
,
0
]
*
b
[...,
0
,
0
]
...
@@ -64,7 +80,20 @@ def rot_matmul(a, b):
...
@@ -64,7 +80,20 @@ def rot_matmul(a, b):
return
torch
.
stack
([
row_1
,
row_2
,
row_3
],
dim
=-
2
)
return
torch
.
stack
([
row_1
,
row_2
,
row_3
],
dim
=-
2
)
def
rot_vec_mul
(
r
,
t
):
def
rot_vec_mul
(
r
:
torch
.
Tensor
,
t
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Applies a rotation to a vector. Written out by hand to avoid transfer
to low-precision tensor cores.
Args:
r: [*, 3, 3] rotation matrices
t: [*, 3] coordinate tensors
Returns:
[*, 3] rotated coordinates
"""
x
=
t
[...,
0
]
x
=
t
[...,
0
]
y
=
t
[...,
1
]
y
=
t
[...,
1
]
z
=
t
[...,
2
]
z
=
t
[...,
2
]
...
@@ -79,21 +108,35 @@ def rot_vec_mul(r, t):
...
@@ -79,21 +108,35 @@ def rot_vec_mul(r, t):
class
T
:
class
T
:
def
__init__
(
self
,
rots
,
trans
):
"""
A class representing an affine transformation. Essentially a wrapper
around two torch tensors: a [*, 3, 3] rotation and a [*, 3]
translation. Designed to behave approximately like a single torch
tensor with the shape of the shared dimensions of its component parts.
"""
def
__init__
(
self
,
rots
:
torch
.
Tensor
,
trans
:
torch
.
Tensor
):
"""
Args:
rots: A [*, 3, 3] rotation tensor
trans: A corresponding [*, 3] translation tensor
"""
self
.
rots
=
rots
self
.
rots
=
rots
self
.
trans
=
trans
self
.
trans
=
trans
if
self
.
rots
is
None
and
self
.
trans
is
None
:
if
self
.
rots
is
None
and
self
.
trans
is
None
:
raise
ValueError
(
"Only one of rots and trans can be None"
)
raise
ValueError
(
"Only one of rots and trans can be None"
)
elif
self
.
rots
is
None
:
elif
self
.
rots
is
None
:
self
.
rots
=
T
.
identity_rot
(
self
.
rots
=
T
.
_
identity_rot
(
self
.
trans
.
shape
[:
-
1
],
self
.
trans
.
shape
[:
-
1
],
self
.
trans
.
dtype
,
self
.
trans
.
dtype
,
self
.
trans
.
device
,
self
.
trans
.
device
,
self
.
trans
.
requires_grad
,
self
.
trans
.
requires_grad
,
)
)
elif
self
.
trans
is
None
:
elif
self
.
trans
is
None
:
self
.
trans
=
T
.
identity_trans
(
self
.
trans
=
T
.
_
identity_trans
(
self
.
rots
.
shape
[:
-
2
],
self
.
rots
.
shape
[:
-
2
],
self
.
rots
.
dtype
,
self
.
rots
.
dtype
,
self
.
rots
.
device
,
self
.
rots
.
device
,
...
@@ -107,7 +150,28 @@ class T:
...
@@ -107,7 +150,28 @@ class T:
):
):
raise
ValueError
(
"Incorrectly shaped input"
)
raise
ValueError
(
"Incorrectly shaped input"
)
def
__getitem__
(
self
,
index
):
def
__getitem__
(
self
,
index
:
Any
,
)
->
T
:
"""
Indexes the affine transformation with PyTorch-style indices.
The index is applied to the shared dimensions of both the rotation
and the translation.
E.g.::
t = T(torch.rand(10, 10, 3, 3), torch.rand(10, 10, 3))
indexed = t[3, 4:6]
assert(indexed.shape == (2,))
assert(indexed.rots.shape == (2, 3, 3))
assert(indexed.trans.shape == (2, 3))
Args:
index: A standard torch tensor index. E.g. 8, (10, None, 3),
or (3, slice(0, 1, None))
Returns:
The indexed tensor
"""
if
type
(
index
)
!=
tuple
:
if
type
(
index
)
!=
tuple
:
index
=
(
index
,)
index
=
(
index
,)
return
T
(
return
T
(
...
@@ -115,32 +179,93 @@ class T:
...
@@ -115,32 +179,93 @@ class T:
self
.
trans
[
index
+
(
slice
(
None
),)],
self
.
trans
[
index
+
(
slice
(
None
),)],
)
)
def
__eq__
(
self
,
obj
):
def
__eq__
(
self
,
return
torch
.
all
(
self
.
rots
==
obj
.
rots
)
and
torch
.
all
(
obj
:
T
,
self
.
trans
==
obj
.
trans
)
->
bool
:
"""
Compares two affine transformations. Returns true iff the
transformations are pointwise identical. Does not account for
floating point imprecision.
"""
return
bool
(
torch
.
all
(
self
.
rots
==
obj
.
rots
)
and
torch
.
all
(
self
.
trans
==
obj
.
trans
)
)
)
def
__mul__
(
self
,
right
):
def
__mul__
(
self
,
right
:
torch
.
Tensor
,
)
->
T
:
"""
Pointwise right multiplication of the affine transformation with a
tensor. Multiplication is broadcast over the rotation/translation
dimensions.
Args:
right: The right multiplicand
Returns:
The product transformation
"""
rots
=
self
.
rots
*
right
[...,
None
,
None
]
rots
=
self
.
rots
*
right
[...,
None
,
None
]
trans
=
self
.
trans
*
right
[...,
None
]
trans
=
self
.
trans
*
right
[...,
None
]
return
T
(
rots
,
trans
)
return
T
(
rots
,
trans
)
def
__rmul__
(
self
,
left
):
def
__rmul__
(
self
,
left
:
torch
.
Tensor
,
)
->
T
:
"""
Pointwise left multiplication of the affine transformation with a
tensor. Multiplication is broadcast over the rotation/translation
dimensions.
Args:
left: The left multiplicand
Returns:
The product transformation
"""
return
self
.
__mul__
(
left
)
return
self
.
__mul__
(
left
)
@
property
@
property
def
shape
(
self
):
def
shape
(
self
)
->
torch
.
Size
:
"""
Returns the shape of the shared dimensions of the rotation and
the translation.
Returns:
The shape of the transformation
"""
s
=
self
.
rots
.
shape
[:
-
2
]
s
=
self
.
rots
.
shape
[:
-
2
]
return
s
if
len
(
s
)
>
0
else
torch
.
Size
([
1
])
return
s
if
len
(
s
)
>
0
else
torch
.
Size
([
1
])
def
get_trans
(
self
):
return
self
.
trans
def
get_rots
(
self
):
def
get_rots
(
self
):
"""
Getter for the rotation.
Returns:
The stored rotation.
"""
return
self
.
rots
return
self
.
rots
def
compose
(
self
,
t
):
def
get_trans
(
self
)
->
torch
.
Tensor
:
"""
Getter for the translation.
Returns:
The stored translation.
"""
return
self
.
trans
def
compose
(
self
,
t
:
T
,
)
->
T
:
"""
Composes the transformation with another.
Args:
t: The inner transformation.
Returns:
The composed transformation.
"""
rot_1
,
trn_1
=
self
.
rots
,
self
.
trans
rot_1
,
trn_1
=
self
.
rots
,
self
.
trans
rot_2
,
trn_2
=
t
.
rots
,
t
.
trans
rot_2
,
trn_2
=
t
.
rots
,
t
.
trans
...
@@ -149,23 +274,60 @@ class T:
...
@@ -149,23 +274,60 @@ class T:
return
T
(
rot
,
trn
)
return
T
(
rot
,
trn
)
def
apply
(
self
,
pts
):
def
apply
(
self
,
pts
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Applies the transformation to a coordinate tensor.
Args:
pts: A [*, 3] coordinate tensor.
Returns:
The transformed points.
"""
r
,
t
=
self
.
rots
,
self
.
trans
r
,
t
=
self
.
rots
,
self
.
trans
rotated
=
rot_vec_mul
(
r
,
pts
)
rotated
=
rot_vec_mul
(
r
,
pts
)
return
rotated
+
t
return
rotated
+
t
def
invert_apply
(
self
,
pts
):
def
invert_apply
(
self
,
pts
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Applies the inverse of the transformation to a coordinate tensor.
Args:
pts: A [*, 3] coordinate tensor
Returns:
The transformed points.
"""
r
,
t
=
self
.
rots
,
self
.
trans
r
,
t
=
self
.
rots
,
self
.
trans
pts
=
pts
-
t
pts
=
pts
-
t
return
rot_vec_mul
(
r
.
transpose
(
-
1
,
-
2
),
pts
)
return
rot_vec_mul
(
r
.
transpose
(
-
1
,
-
2
),
pts
)
def
invert
(
self
):
def
invert
(
self
)
->
T
:
"""
Inverts the transformation.
Returns:
The inverse transformation.
"""
rot_inv
=
self
.
rots
.
transpose
(
-
1
,
-
2
)
rot_inv
=
self
.
rots
.
transpose
(
-
1
,
-
2
)
trn_inv
=
rot_vec_mul
(
rot_inv
,
self
.
trans
)
trn_inv
=
rot_vec_mul
(
rot_inv
,
self
.
trans
)
return
T
(
rot_inv
,
-
1
*
trn_inv
)
return
T
(
rot_inv
,
-
1
*
trn_inv
)
def
unsqueeze
(
self
,
dim
):
def
unsqueeze
(
self
,
dim
:
int
,
)
->
T
:
"""
Analogous to torch.unsqueeze. The dimension is relative to the
shared dimensions of the rotation/translation.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed transformation.
"""
if
dim
>=
len
(
self
.
shape
):
if
dim
>=
len
(
self
.
shape
):
raise
ValueError
(
"Invalid dimension"
)
raise
ValueError
(
"Invalid dimension"
)
rots
=
self
.
rots
.
unsqueeze
(
dim
if
dim
>=
0
else
dim
-
2
)
rots
=
self
.
rots
.
unsqueeze
(
dim
if
dim
>=
0
else
dim
-
2
)
...
@@ -174,7 +336,12 @@ class T:
...
@@ -174,7 +336,12 @@ class T:
return
T
(
rots
,
trans
)
return
T
(
rots
,
trans
)
@
staticmethod
@
staticmethod
def
identity_rot
(
shape
,
dtype
,
device
,
requires_grad
):
def
_identity_rot
(
shape
:
Tuple
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
requires_grad
:
bool
,
)
->
torch
.
Tensor
:
rots
=
torch
.
eye
(
rots
=
torch
.
eye
(
3
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
3
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
)
...
@@ -184,26 +351,68 @@ class T:
...
@@ -184,26 +351,68 @@ class T:
return
rots
return
rots
@
staticmethod
@
staticmethod
def
identity_trans
(
shape
,
dtype
,
device
,
requires_grad
):
def
_identity_trans
(
shape
:
Tuple
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
requires_grad
:
bool
)
->
torch
.
Tensor
:
trans
=
torch
.
zeros
(
trans
=
torch
.
zeros
(
(
*
shape
,
3
),
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
(
*
shape
,
3
),
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
)
return
trans
return
trans
@
staticmethod
@
staticmethod
def
identity
(
shape
,
dtype
,
device
,
requires_grad
=
True
):
def
identity
(
shape
:
Tuple
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
requires_grad
:
bool
=
True
)
->
T
:
"""
Constructs an identity transformation.
Args:
shape:
The desired shape
dtype:
The dtype of both internal tensors
device:
The device of both internal tensors
requires_grad:
Whether grad should be enabled for the internal tensors
Returns:
The identity transformation
"""
return
T
(
return
T
(
T
.
identity_rot
(
shape
,
dtype
,
device
,
requires_grad
),
T
.
_
identity_rot
(
shape
,
dtype
,
device
,
requires_grad
),
T
.
identity_trans
(
shape
,
dtype
,
device
,
requires_grad
),
T
.
_
identity_trans
(
shape
,
dtype
,
device
,
requires_grad
),
)
)
@
staticmethod
@
staticmethod
def
from_4x4
(
t
):
def
from_4x4
(
t
:
torch
.
Tensor
)
->
T
:
"""
Constructs a transformation from a homogenous transformation
tensor.
Args:
t: [*, 4, 4] homogenous transformation tensor
Returns:
T object with shape [*]
"""
rots
=
t
[...,
:
3
,
:
3
]
rots
=
t
[...,
:
3
,
:
3
]
trans
=
t
[...,
:
3
,
3
]
trans
=
t
[...,
:
3
,
3
]
return
T
(
rots
,
trans
)
return
T
(
rots
,
trans
)
def
to_4x4
(
self
):
def
to_4x4
(
self
)
->
torch
.
Tensor
:
"""
Converts a transformation to a homogenous transformation tensor.
Returns:
A [*, 4, 4] homogenous transformation tensor
"""
tensor
=
self
.
rots
.
new_zeros
((
*
self
.
shape
,
4
,
4
))
tensor
=
self
.
rots
.
new_zeros
((
*
self
.
shape
,
4
,
4
))
tensor
[...,
:
3
,
:
3
]
=
self
.
rots
tensor
[...,
:
3
,
:
3
]
=
self
.
rots
tensor
[...,
:
3
,
3
]
=
self
.
trans
tensor
[...,
:
3
,
3
]
=
self
.
trans
...
@@ -211,11 +420,37 @@ class T:
...
@@ -211,11 +420,37 @@ class T:
return
tensor
return
tensor
@
staticmethod
@
staticmethod
def
from_tensor
(
t
):
def
from_tensor
(
t
:
torch
.
Tensor
)
->
T
:
"""
Constructs a transformation from a homogenous transformation
tensor.
Args:
t: A [*, 4, 4] homogenous transformation tensor
Returns:
A transformation object with shape [*]
"""
return
T
.
from_4x4
(
t
)
return
T
.
from_4x4
(
t
)
@
staticmethod
@
staticmethod
def
from_3_points
(
p_neg_x_axis
,
origin
,
p_xy_plane
,
eps
=
1e-8
):
def
from_3_points
(
p_neg_x_axis
:
torch
.
Tensor
,
origin
:
torch
.
Tensor
,
p_xy_plane
:
torch
.
Tensor
,
eps
:
float
=
1e-8
)
->
T
:
"""
Implements algorithm 21. Constructs transformations from sets of 3
points using the Gram-Schmidt algorithm.
Args:
p_neg_x_axis: [*, 3] coordinates
origin: [*, 3] coordinates used as frame origins
p_xy_plane: [*, 3] coordinates
eps: Small epsilon value
Returns:
A transformation object of shape [*]
"""
p_neg_x_axis
=
torch
.
unbind
(
p_neg_x_axis
,
dim
=-
1
)
p_neg_x_axis
=
torch
.
unbind
(
p_neg_x_axis
,
dim
=-
1
)
origin
=
torch
.
unbind
(
origin
,
dim
=-
1
)
origin
=
torch
.
unbind
(
origin
,
dim
=-
1
)
p_xy_plane
=
torch
.
unbind
(
p_xy_plane
,
dim
=-
1
)
p_xy_plane
=
torch
.
unbind
(
p_xy_plane
,
dim
=-
1
)
...
@@ -241,7 +476,22 @@ class T:
...
@@ -241,7 +476,22 @@ class T:
return
T
(
rots
,
torch
.
stack
(
origin
,
dim
=-
1
))
return
T
(
rots
,
torch
.
stack
(
origin
,
dim
=-
1
))
@
staticmethod
@
staticmethod
def
concat
(
ts
,
dim
):
def
concat
(
ts
:
Sequence
[
T
],
dim
:
int
,
)
->
T
:
"""
Concatenates transformations along a new dimension.
Args:
ts:
A list of T objects
dim:
The dimension along which the transformations should be
concatenated
Returns:
A concatenated transformation object
"""
rots
=
torch
.
cat
([
t
.
rots
for
t
in
ts
],
dim
=
dim
if
dim
>=
0
else
dim
-
2
)
rots
=
torch
.
cat
([
t
.
rots
for
t
in
ts
],
dim
=
dim
if
dim
>=
0
else
dim
-
2
)
trans
=
torch
.
cat
(
trans
=
torch
.
cat
(
[
t
.
trans
for
t
in
ts
],
dim
=
dim
if
dim
>=
0
else
dim
-
1
[
t
.
trans
for
t
in
ts
],
dim
=
dim
if
dim
>=
0
else
dim
-
1
...
@@ -249,19 +499,24 @@ class T:
...
@@ -249,19 +499,24 @@ class T:
return
T
(
rots
,
trans
)
return
T
(
rots
,
trans
)
def
map_tensor_fn
(
self
,
fn
)
:
def
map_tensor_fn
(
self
,
fn
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
])
->
T
:
"""
"""
Apply a function that takes a tensor as its only argument to the
Apply a function that takes a tensor as its only argument to the
rotations and translations, treating the final two/one
rotations and translations, treating the final two/one
dimension(s), respectively, as batch dimensions.
dimension(s), respectively, as batch dimensions.
E.g.: Given t, an instance of T of shape [N, M], this function can
be used to sum out the second dimension thereof as follows::
E.g.: Given t, an instance of T of shape [N, M], this function can
t = t.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
be used to sum out the second dimension thereof as follows:
t = t.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
The resulting object has rotations of shape [N, 3, 3] and
translations of shape [N, 3]
The resulting object has rotations of shape [N, 3, 3] and
Args:
translations of shape [N, 3]
fn: A function that takes only a tensor as its argument
Returns:
The transformed transformation object.
"""
"""
rots
=
self
.
rots
.
view
(
*
self
.
rots
.
shape
[:
-
2
],
9
)
rots
=
self
.
rots
.
view
(
*
self
.
rots
.
shape
[:
-
2
],
9
)
rots
=
torch
.
stack
(
list
(
map
(
fn
,
torch
.
unbind
(
rots
,
-
1
))),
dim
=-
1
)
rots
=
torch
.
stack
(
list
(
map
(
fn
,
torch
.
unbind
(
rots
,
-
1
))),
dim
=-
1
)
...
@@ -271,14 +526,44 @@ class T:
...
@@ -271,14 +526,44 @@ class T:
return
T
(
rots
,
trans
)
return
T
(
rots
,
trans
)
def
stop_rot_gradient
(
self
):
def
stop_rot_gradient
(
self
)
->
T
:
"""
Detaches the contained rotation tensor.
Returns:
A version of the transformation with detached rotations
"""
return
T
(
self
.
rots
.
detach
(),
self
.
trans
)
return
T
(
self
.
rots
.
detach
(),
self
.
trans
)
def
scale_translation
(
self
,
factor
):
def
scale_translation
(
self
,
factor
:
int
)
->
T
:
"""
Scales the contained translation tensor by a constant factor.
Returns:
A version of the transformation with scaled translations
"""
return
T
(
self
.
rots
,
self
.
trans
*
factor
)
return
T
(
self
.
rots
,
self
.
trans
*
factor
)
@
staticmethod
@
staticmethod
def
make_transform_from_reference
(
n_xyz
,
ca_xyz
,
c_xyz
,
eps
=
1e-20
):
def
make_transform_from_reference
(
n_xyz
,
ca_xyz
,
c_xyz
,
eps
=
1e-20
):
"""
Returns a transformation object from reference coordinates.
Note that this method does not take care of symmetries. If you
provide the atom positions in the non-standard way, the N atom will
end up not at [-0.527250, 1.359329, 0.0] but instead at
[-0.527250, -1.359329, 0.0]. You need to take care of such cases in
your code.
Args:
n_xyz: A [*, 3] tensor of nitrogen xyz coordinates.
ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates.
c_xyz: A [*, 3] tensor of carbon xyz coordinates.
Returns:
A transformation object. After applying the translation and
rotation to the reference backbone, the coordinates will
approximately equal to the input coordinates.
"""
translation
=
-
1
*
ca_xyz
translation
=
-
1
*
ca_xyz
n_xyz
=
n_xyz
+
translation
n_xyz
=
n_xyz
+
translation
c_xyz
=
c_xyz
+
translation
c_xyz
=
c_xyz
+
translation
...
@@ -330,7 +615,13 @@ class T:
...
@@ -330,7 +615,13 @@ class T:
return
T
(
rots
,
translation
)
return
T
(
rots
,
translation
)
def
cuda
(
self
):
def
cuda
(
self
)
->
T
:
"""
Moves the transformation object to GPU memory
Returns:
A version of the transformation on GPU
"""
return
T
(
self
.
rots
.
cuda
(),
self
.
trans
.
cuda
())
return
T
(
self
.
rots
.
cuda
(),
self
.
trans
.
cuda
())
...
@@ -361,7 +652,15 @@ _qtr_mat[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
...
@@ -361,7 +652,15 @@ _qtr_mat[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
_qtr_mat
[...,
2
,
2
]
=
_to_mat
([(
"aa"
,
1
),
(
"bb"
,
-
1
),
(
"cc"
,
-
1
),
(
"dd"
,
1
)])
_qtr_mat
[...,
2
,
2
]
=
_to_mat
([(
"aa"
,
1
),
(
"bb"
,
-
1
),
(
"cc"
,
-
1
),
(
"dd"
,
1
)])
def
quat_to_rot
(
quat
):
# [*, 4]
def
quat_to_rot
(
quat
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Converts a quaternion to a rotation matrix.
Args:
quat: [*, 4] quaternions
Returns:
[*, 3, 3] rotation matrices
"""
# [*, 4, 4]
# [*, 4, 4]
quat
=
quat
[...,
None
]
*
quat
[...,
None
,
:]
quat
=
quat
[...,
None
]
*
quat
[...,
None
,
:]
...
@@ -376,7 +675,19 @@ def quat_to_rot(quat): # [*, 4]
...
@@ -376,7 +675,19 @@ def quat_to_rot(quat): # [*, 4]
return
torch
.
sum
(
quat
,
dim
=
(
-
3
,
-
4
))
return
torch
.
sum
(
quat
,
dim
=
(
-
3
,
-
4
))
def
affine_vector_to_4x4
(
vector
):
def
affine_vector_to_4x4
(
vector
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Transforms a tensor whose final dimension has the form:
[*quaternion, *translation]
into a homogenous transformation tensor.
Args:
vector: [*, 7] input tensor
Returns:
[*, 4, 4] homogenous transformation tensor
"""
quats
=
vector
[...,
:
4
]
quats
=
vector
[...,
:
4
]
trans
=
vector
[...,
4
:]
trans
=
vector
[...,
4
:]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment