import torch import torch.nn.functional as F from einops import rearrange def generate_planes(): """ Defines planes by the three vectors that form the "axes" of the plane. Should work with arbitrary number of planes and planes of arbitrary orientation. Bugfix reference: https://github.com/NVlabs/eg3d/issues/67 """ return torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[1, 0, 0], [0, 0, 1], [0, 1, 0]], [[0, 0, 1], [0, 1, 0], [1, 0, 0]]]) def project_onto_planes(planes, coordinates): """ Does a projection of a 3D point onto a batch of 2D planes, returning 2D plane coordinates. Takes plane axes of shape n_planes, 3, 3 # Takes coordinates of shape N, M, 3 # returns projections of shape N*n_planes, M, 2 """ N, M, C = coordinates.shape n_planes, _, _ = planes.shape coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3) inv_planes = torch.linalg.inv(planes.to(coordinates.dtype)).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3) projections = torch.bmm(coordinates, inv_planes) return projections[..., :2] def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): assert padding_mode == 'zeros' N, n_planes, C, H, W = plane_features.shape _, M, _ = coordinates.shape plane_features = rearrange(plane_features, "N_b N_t C H_t W_t -> (N_b N_t) C H_t W_t") coordinates = (2/box_warp) * coordinates projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1) output_features = F.grid_sample(plane_features.float(), projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) return output_features