import numpy as np import torch def limit_period(val, offset=0.5, period=np.pi): return val - torch.floor(val / period + offset) * period def rotation_3d_in_axis(points, angles, axis=0): # points: [N, point_size, 3] # angles: [N] rot_sin = torch.sin(angles) rot_cos = torch.cos(angles) ones = torch.ones_like(rot_cos) zeros = torch.zeros_like(rot_cos) if axis == 1: rot_mat_T = torch.stack([ torch.stack([rot_cos, zeros, -rot_sin]), torch.stack([zeros, ones, zeros]), torch.stack([rot_sin, zeros, rot_cos]) ]) elif axis == 2 or axis == -1: rot_mat_T = torch.stack([ torch.stack([rot_cos, -rot_sin, zeros]), torch.stack([rot_sin, rot_cos, zeros]), torch.stack([zeros, zeros, ones]) ]) elif axis == 0: rot_mat_T = torch.stack([ torch.stack([zeros, rot_cos, -rot_sin]), torch.stack([zeros, rot_sin, rot_cos]), torch.stack([ones, zeros, zeros]) ]) else: raise ValueError('axis should in range') return torch.einsum('aij,jka->aik', (points, rot_mat_T))