"docs/vscode:/vscode.git/clone" did not exist on "97390468c7104c9b3255050ce22ea382f59fba5e"
utils.py 3.48 KB
Newer Older
Ruilong Li's avatar
Ruilong Li committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Copyright (c) Meta Platforms, Inc. and affiliates.
import collections
import math

import torch
import torch.nn.functional as F

Rays = collections.namedtuple(
    "Rays", ("origins", "directions", "viewdirs", "radii", "near", "far")
)

Cameras = collections.namedtuple(
    "Cameras", ("intrins", "extrins", "distorts", "width", "height")
)


def namedtuple_map(fn, tup):
    """Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
    return type(tup)(*(None if x is None else fn(x) for x in tup))


def homo(points: torch.Tensor) -> torch.Tensor:
    """Get the homogeneous coordinates."""
    return F.pad(points, (0, 1), value=1)


def transform_cameras(cameras: Cameras, resize_factor: float) -> torch.Tensor:
    intrins = cameras.intrins
    intrins[..., :2, :] = intrins[..., :2, :] * resize_factor
    width = int(cameras.width * resize_factor + 0.5)
    height = int(cameras.height * resize_factor + 0.5)
    return Cameras(
        intrins=intrins,
        extrins=cameras.extrins,
        distorts=cameras.distorts,
        width=width,
        height=height,
    )


def generate_rays(
    cameras: Cameras,
    opencv_format: bool = True,
    near: float = None,
    far: float = None,
    pixels_xy: torch.Tensor = None,
) -> Rays:
    """Generating rays for a single or multiple cameras.

    :params cameras [(n_cams,)]
    :returns: Rays
        [(n_cams,) height, width] if pixels_xy is None
        [(n_cams,) num_pixels] if pixels_xy is given
    """
    if pixels_xy is not None:
        K = cameras.intrins[..., None, :, :]
        c2w = cameras.extrins[..., None, :, :].inverse()
        x, y = pixels_xy[..., 0], pixels_xy[..., 1]
    else:
        K = cameras.intrins[..., None, None, :, :]
        c2w = cameras.extrins[..., None, None, :, :].inverse()
        x, y = torch.meshgrid(
            torch.arange(cameras.width, dtype=K.dtype),
            torch.arange(cameras.height, dtype=K.dtype),
            indexing="xy",
        )  # [height, width]

    camera_dirs = homo(
        torch.stack(
            [
                (x - K[..., 0, 2] + 0.5) / K[..., 0, 0],
                (y - K[..., 1, 2] + 0.5) / K[..., 1, 1],
            ],
            dim=-1,
        )
    )  # [n_cams, height, width, 3]
    if not opencv_format:
        camera_dirs[..., [1, 2]] *= -1

    # [n_cams, height, width, 3]
    directions = (camera_dirs[..., None, :] * c2w[..., :3, :3]).sum(dim=-1)
    origins = torch.broadcast_to(c2w[..., :3, -1], directions.shape)
    viewdirs = directions / torch.linalg.norm(directions, dim=-1, keepdims=True)

    if pixels_xy is None:
        # Distance from each unit-norm direction vector to its x-axis neighbor.
        dx = torch.sqrt(
            torch.sum(
                (directions[..., :-1, :, :] - directions[..., 1:, :, :]) ** 2,
                dim=-1,
            )
        )
        dx = torch.cat([dx, dx[..., -2:-1, :]], dim=-2)
        radii = dx[..., None] * 2 / math.sqrt(12)  # [n_cams, height, width, 1]
    else:
        radii = None

    if near is not None:
        near = near * torch.ones_like(origins[..., 0:1])
    if far is not None:
        far = far * torch.ones_like(origins[..., 0:1])
    rays = Rays(
        origins=origins,  # [n_cams, height, width, 3]
        directions=directions,  # [n_cams, height, width, 3]
        viewdirs=viewdirs,  # [n_cams, height, width, 3]
        radii=radii,  # [n_cams, height, width, 1]
        # near far is not needed when they are estimated by skeleton.
        near=near,
        far=far,
    )
    return rays