utils.py 1.94 KB
Newer Older
1
2
3
4
5
import numpy as np
import torch


def limit_period(val, offset=0.5, period=np.pi):
zhangwenwei's avatar
zhangwenwei committed
6
7
8
9
10
11
12
13
14
15
16
17
    """Limit the value into a period for periodic function.

    Args:
        val (torch.Tensor): The value to be converted
        offset (float, optional): Offset to set the value range.
            Defaults to 0.5.
        period ([type], optional): Period of the value. Defaults to np.pi.

    Returns:
        torch.Tensor: value in the range of
            [-offset * period, (1-offset) * period]
    """
18
19
20
21
    return val - torch.floor(val / period + offset) * period


def rotation_3d_in_axis(points, angles, axis=0):
zhangwenwei's avatar
zhangwenwei committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    """Rotate points by angles according to axis

    Args:
        points (torch.Tensor): Points of shape (N, M, 3).
        angles (torch.Tensor): Vector of angles in shape (N,)
        axis (int, optional): The axis to be rotated. Defaults to 0.

    Raises:
        ValueError: when the axis is not in range [0, 1, 2], it will
            raise value error.

    Returns:
        torch.Tensor: rotated points in shape (N, M, 3)
    """
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    rot_sin = torch.sin(angles)
    rot_cos = torch.cos(angles)
    ones = torch.ones_like(rot_cos)
    zeros = torch.zeros_like(rot_cos)
    if axis == 1:
        rot_mat_T = torch.stack([
            torch.stack([rot_cos, zeros, -rot_sin]),
            torch.stack([zeros, ones, zeros]),
            torch.stack([rot_sin, zeros, rot_cos])
        ])
    elif axis == 2 or axis == -1:
        rot_mat_T = torch.stack([
            torch.stack([rot_cos, -rot_sin, zeros]),
            torch.stack([rot_sin, rot_cos, zeros]),
            torch.stack([zeros, zeros, ones])
        ])
    elif axis == 0:
        rot_mat_T = torch.stack([
            torch.stack([zeros, rot_cos, -rot_sin]),
            torch.stack([zeros, rot_sin, rot_cos]),
            torch.stack([ones, zeros, zeros])
        ])
    else:
zhangwenwei's avatar
zhangwenwei committed
59
        raise ValueError(f'axis should in range [0, 1, 2], got {axis}')
60
61

    return torch.einsum('aij,jka->aik', (points, rot_mat_T))