so3.py 3.77 KB
Newer Older
yukang's avatar
yukang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""SO(3) group transformations."""

import kornia.geometry.conversions as C
import torch
from torch import Tensor
from math import pi as PI


@torch.jit.script
def quat_to_mat(quat_wxyz: Tensor) -> Tensor:
    """Convert scalar first quaternion to rotation matrix.

    Args:
        quat_wxyz: (...,4) Scalar first quaternions.

    Returns:
        (...,3,3) 3D rotation matrices.
    """
    return C.quaternion_to_rotation_matrix(
        quat_wxyz, order=C.QuaternionCoeffOrder.WXYZ
    )


# @torch.jit.script
def mat_to_quat(mat: Tensor) -> Tensor:
    """Convert rotation matrix to scalar first quaternion.

    Args:
        mat: (...,3,3) 3D rotation matrices.

    Returns:
        (...,4) Scalar first quaternions.
    """
    return C.rotation_matrix_to_quaternion(
        mat, order=C.QuaternionCoeffOrder.WXYZ
    )


@torch.jit.script
def quat_to_xyz(
    quat_wxyz: Tensor, singularity_value: float = PI / 2
) -> Tensor:
    """Convert scalar first quaternion to Tait-Bryan angles.

    Reference:
        https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Source_code_2

    Args:
        quat_wxyz: (...,4) Scalar first quaternions.
        singularity_value: Value that's set at the singularities.

    Returns:
        (...,3) The Tait-Bryan angles --- roll, pitch, and yaw.
    """
    qw = quat_wxyz[..., 0]
    qx = quat_wxyz[..., 1]
    qy = quat_wxyz[..., 2]
    qz = quat_wxyz[..., 3]

    # roll (x-axis rotation)
    sinr_cosp = 2 * (qw * qx + qy * qz)
    cosr_cosp = 1 - 2 * (qx * qx + qy * qy)
    roll = torch.atan2(sinr_cosp, cosr_cosp)

    # pitch (y-axis rotation)
    pitch = 2 * (qw * qy - qz * qx)
    is_out_of_range = torch.abs(pitch) >= 1
    pitch[is_out_of_range] = torch.copysign(
        torch.as_tensor(singularity_value), pitch[is_out_of_range]
    )
    pitch[~is_out_of_range] = torch.asin(pitch[~is_out_of_range])

    # yaw (z-axis rotation)
    siny_cosp = 2 * (qw * qz + qx * qy)
    cosy_cosp = 1 - 2 * (qy * qy + qz * qz)
    yaw = torch.atan2(siny_cosp, cosy_cosp)
    xyz = torch.stack([roll, pitch, yaw], dim=-1)
    return xyz


@torch.jit.script
def quat_to_yaw(quat_wxyz: Tensor) -> Tensor:
    """Convert scalar first quaternion to yaw (rotation about vertical axis).

    Reference:
        https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Source_code_2

    Args:
        quat_wxyz: (...,4) Scalar first quaternions.

    Returns:
        (...,) The rotation about the z-axis in radians.
    """
    xyz = quat_to_xyz(quat_wxyz)
    yaw_rad: Tensor = xyz[..., -1]
    return yaw_rad


@torch.jit.script
def xyz_to_quat(xyz_rad: Tensor) -> Tensor:
    """Convert euler angles (xyz - pitch, roll, yaw) to scalar first quaternions.

    Args:
        xyz_rad: (...,3) Tensor of roll, pitch, and yaw in radians.

    Returns:
        (...,4) Scalar first quaternions (wxyz).
    """
    x_rad = xyz_rad[..., 0]
    y_rad = xyz_rad[..., 1]
    z_rad = xyz_rad[..., 2]

    cy = torch.cos(z_rad * 0.5)
    sy = torch.sin(z_rad * 0.5)
    cp = torch.cos(y_rad * 0.5)
    sp = torch.sin(y_rad * 0.5)
    cr = torch.cos(x_rad * 0.5)
    sr = torch.sin(x_rad * 0.5)

    qw = cr * cp * cy + sr * sp * sy
    qx = sr * cp * cy - cr * sp * sy
    qy = cr * sp * cy + sr * cp * sy
    qz = cr * cp * sy - sr * sp * cy
    quat_wxyz = torch.stack([qw, qx, qy, qz], dim=-1)
    return quat_wxyz


@torch.jit.script
def yaw_to_quat(yaw_rad: Tensor) -> Tensor:
    """Convert yaw (rotation about the vertical axis) to scalar first quaternions.

    Args:
        yaw_rad: (...,1) Rotations about the z-axis.

    Returns:
        (...,4) scalar first quaternions (wxyz).
    """
    xyz_rad = torch.zeros_like(yaw_rad)[..., None].repeat_interleave(3, dim=-1)
    xyz_rad[..., -1] = yaw_rad
    quat_wxyz: Tensor = xyz_to_quat(xyz_rad)
    return quat_wxyz