nerf_synthetic.py 7.31 KB
Newer Older
Ruilong Li's avatar
Ruilong Li committed
1
import collections
Ruilong Li's avatar
Ruilong Li committed
2
3
4
5
6
7
import json
import os

import imageio.v2 as imageio
import numpy as np
import torch
Ruilong Li's avatar
Ruilong Li committed
8
import torch.nn.functional as F
Ruilong Li's avatar
Ruilong Li committed
9

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
10
from .utils import Rays
Ruilong Li's avatar
Ruilong Li committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24


def _load_renderings(root_fp: str, subject_id: str, split: str):
    """Load images from disk."""
    if not root_fp.startswith("/"):
        # allow relative path. e.g., "./data/nerf_synthetic/"
        root_fp = os.path.join(
            os.path.dirname(os.path.abspath(__file__)),
            "..",
            "..",
            root_fp,
        )

    data_dir = os.path.join(root_fp, subject_id)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
25
26
27
    with open(
        os.path.join(data_dir, "transforms_{}.json".format(split)), "r"
    ) as fp:
Ruilong Li's avatar
Ruilong Li committed
28
29
30
31
32
33
34
35
36
37
38
        meta = json.load(fp)
    images = []
    camtoworlds = []

    for i in range(len(meta["frames"])):
        frame = meta["frames"][i]
        fname = os.path.join(data_dir, frame["file_path"] + ".png")
        rgba = imageio.imread(fname)
        camtoworlds.append(frame["transform_matrix"])
        images.append(rgba)

Ruilong Li's avatar
Ruilong Li committed
39
40
    images = np.stack(images, axis=0)
    camtoworlds = np.stack(camtoworlds, axis=0)
Ruilong Li's avatar
Ruilong Li committed
41
42
43
44
45
46
47
48

    h, w = images.shape[1:3]
    camera_angle_x = float(meta["camera_angle_x"])
    focal = 0.5 * w / np.tan(0.5 * camera_angle_x)

    return images, camtoworlds, focal


Ruilong Li's avatar
Ruilong Li committed
49
class SubjectLoader(torch.utils.data.Dataset):
Ruilong Li's avatar
Ruilong Li committed
50
51
    """Single subject data loader for training and evaluation."""

Ruilong Li's avatar
Ruilong Li committed
52
    SPLITS = ["train", "val", "trainval", "test"]
Ruilong Li's avatar
Ruilong Li committed
53
54
55
56
57
58
59
60
    SUBJECT_IDS = [
        "chair",
        "drums",
        "ficus",
        "hotdog",
        "lego",
        "materials",
        "mic",
Ruilong Li's avatar
Ruilong Li committed
61
        "ship",
Ruilong Li's avatar
Ruilong Li committed
62
63
64
65
    ]

    WIDTH, HEIGHT = 800, 800
    NEAR, FAR = 2.0, 6.0
Ruilong Li's avatar
wtf  
Ruilong Li committed
66
    OPENGL_CAMERA = True
Ruilong Li's avatar
Ruilong Li committed
67
68
69
70
71
72
73
74
75
76

    def __init__(
        self,
        subject_id: str,
        root_fp: str,
        split: str,
        color_bkgd_aug: str = "white",
        num_rays: int = None,
        near: float = None,
        far: float = None,
Ruilong Li's avatar
Ruilong Li committed
77
        batch_over_images: bool = True,
Ruilong Li's avatar
Ruilong Li committed
78
    ):
Ruilong Li's avatar
Ruilong Li committed
79
        super().__init__()
Ruilong Li's avatar
Ruilong Li committed
80
81
82
83
84
85
86
        assert split in self.SPLITS, "%s" % split
        assert subject_id in self.SUBJECT_IDS, "%s" % subject_id
        assert color_bkgd_aug in ["white", "black", "random"]
        self.split = split
        self.num_rays = num_rays
        self.near = self.NEAR if near is None else near
        self.far = self.FAR if far is None else far
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
87
88
89
        self.training = (num_rays is not None) and (
            split in ["train", "trainval"]
        )
Ruilong Li's avatar
Ruilong Li committed
90
        self.color_bkgd_aug = color_bkgd_aug
Ruilong Li's avatar
Ruilong Li committed
91
        self.batch_over_images = batch_over_images
Ruilong Li's avatar
Ruilong Li committed
92
93
94
95
96
97
98
99
        if split == "trainval":
            _images_train, _camtoworlds_train, _focal_train = _load_renderings(
                root_fp, subject_id, "train"
            )
            _images_val, _camtoworlds_val, _focal_val = _load_renderings(
                root_fp, subject_id, "val"
            )
            self.images = np.concatenate([_images_train, _images_val])
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
100
101
102
            self.camtoworlds = np.concatenate(
                [_camtoworlds_train, _camtoworlds_val]
            )
Ruilong Li's avatar
Ruilong Li committed
103
104
105
106
107
            self.focal = _focal_train
        else:
            self.images, self.camtoworlds, self.focal = _load_renderings(
                root_fp, subject_id, split
            )
Ruilong Li's avatar
Ruilong Li committed
108
109
110
111
112
113
114
115
116
117
        self.images = torch.from_numpy(self.images).to(torch.uint8)
        self.camtoworlds = torch.from_numpy(self.camtoworlds).to(torch.float32)
        self.K = torch.tensor(
            [
                [self.focal, 0, self.WIDTH / 2.0],
                [0, self.focal, self.HEIGHT / 2.0],
                [0, 0, 1],
            ],
            dtype=torch.float32,
        )  # (3, 3)
Ruilong Li's avatar
Ruilong Li committed
118
119
120
121
122
        assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH)

    def __len__(self):
        return len(self.images)

Ruilong Li's avatar
Ruilong Li committed
123
    @torch.no_grad()
Ruilong Li's avatar
Ruilong Li committed
124
125
126
127
128
    def __getitem__(self, index):
        data = self.fetch_data(index)
        data = self.preprocess(data)
        return data

Ruilong Li's avatar
Ruilong Li committed
129
130
131
132
133
134
135
    def preprocess(self, data):
        """Process the fetched / cached data with randomness."""
        rgba, rays = data["rgba"], data["rays"]
        pixels, alpha = torch.split(rgba, [3, 1], dim=-1)

        if self.training:
            if self.color_bkgd_aug == "random":
Ruilong Li's avatar
Ruilong Li committed
136
                color_bkgd = torch.rand(3, device=self.images.device)
Ruilong Li's avatar
Ruilong Li committed
137
            elif self.color_bkgd_aug == "white":
Ruilong Li's avatar
Ruilong Li committed
138
                color_bkgd = torch.ones(3, device=self.images.device)
Ruilong Li's avatar
Ruilong Li committed
139
            elif self.color_bkgd_aug == "black":
Ruilong Li's avatar
Ruilong Li committed
140
                color_bkgd = torch.zeros(3, device=self.images.device)
Ruilong Li's avatar
Ruilong Li committed
141
142
        else:
            # just use white during inference
Ruilong Li's avatar
Ruilong Li committed
143
            color_bkgd = torch.ones(3, device=self.images.device)
Ruilong Li's avatar
Ruilong Li committed
144
145
146
147
148
149
150
151
152

        pixels = pixels * alpha + color_bkgd * (1.0 - alpha)
        return {
            "pixels": pixels,  # [n_rays, 3] or [h, w, 3]
            "rays": rays,  # [n_rays,] or [h, w]
            "color_bkgd": color_bkgd,  # [3,]
            **{k: v for k, v in data.items() if k not in ["rgba", "rays"]},
        }

Ruilong Li's avatar
Ruilong Li committed
153
154
155
    def update_num_rays(self, num_rays):
        self.num_rays = num_rays

Ruilong Li's avatar
Ruilong Li committed
156
157
    def fetch_data(self, index):
        """Fetch the data (it maybe cached for multiple batches)."""
Ruilong Li's avatar
Ruilong Li committed
158
159
        num_rays = self.num_rays

Ruilong Li's avatar
Ruilong Li committed
160
161
162
163
164
        if self.training:
            if self.batch_over_images:
                image_id = torch.randint(
                    0,
                    len(self.images),
Ruilong Li's avatar
Ruilong Li committed
165
                    size=(num_rays,),
Ruilong Li's avatar
Ruilong Li committed
166
167
168
169
170
                    device=self.images.device,
                )
            else:
                image_id = [index]
            x = torch.randint(
Ruilong Li's avatar
Ruilong Li committed
171
                0, self.WIDTH, size=(num_rays,), device=self.images.device
Ruilong Li's avatar
Ruilong Li committed
172
173
            )
            y = torch.randint(
Ruilong Li's avatar
Ruilong Li committed
174
                0, self.HEIGHT, size=(num_rays,), device=self.images.device
Ruilong Li's avatar
Ruilong Li committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
            )
        else:
            image_id = [index]
            x, y = torch.meshgrid(
                torch.arange(self.WIDTH, device=self.images.device),
                torch.arange(self.HEIGHT, device=self.images.device),
                indexing="xy",
            )
            x = x.flatten()
            y = y.flatten()

        # generate rays
        rgba = self.images[image_id, y, x] / 255.0  # (num_rays, 4)
        c2w = self.camtoworlds[image_id]  # (num_rays, 3, 4)
        camera_dirs = F.pad(
            torch.stack(
                [
Ruilong Li's avatar
Ruilong Li committed
192
                    (x - self.K[0, 2] + 0.5) / self.K[0, 0],
Ruilong Li's avatar
wtf  
Ruilong Li committed
193
194
195
                    (y - self.K[1, 2] + 0.5)
                    / self.K[1, 1]
                    * (-1.0 if self.OPENGL_CAMERA else 1.0),
Ruilong Li's avatar
Ruilong Li committed
196
197
198
199
                ],
                dim=-1,
            ),
            (0, 1),
Ruilong Li's avatar
Ruilong Li committed
200
            value=(-1.0 if self.OPENGL_CAMERA else 1.0),
Ruilong Li's avatar
Ruilong Li committed
201
202
203
204
205
        )  # [num_rays, 3]

        # [n_cams, height, width, 3]
        directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1)
        origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
206
207
208
        viewdirs = directions / torch.linalg.norm(
            directions, dim=-1, keepdims=True
        )
Ruilong Li's avatar
Ruilong Li committed
209

Ruilong Li's avatar
Ruilong Li committed
210
        if self.training:
Ruilong Li's avatar
Ruilong Li committed
211
212
213
            origins = torch.reshape(origins, (num_rays, 3))
            viewdirs = torch.reshape(viewdirs, (num_rays, 3))
            rgba = torch.reshape(rgba, (num_rays, 4))
Ruilong Li's avatar
Ruilong Li committed
214
        else:
Ruilong Li's avatar
Ruilong Li committed
215
216
217
218
219
            origins = torch.reshape(origins, (self.HEIGHT, self.WIDTH, 3))
            viewdirs = torch.reshape(viewdirs, (self.HEIGHT, self.WIDTH, 3))
            rgba = torch.reshape(rgba, (self.HEIGHT, self.WIDTH, 4))

        rays = Rays(origins=origins, viewdirs=viewdirs)
Ruilong Li's avatar
Ruilong Li committed
220
221
222

        return {
            "rgba": rgba,  # [h, w, 4] or [num_rays, 4]
Ruilong Li's avatar
Ruilong Li committed
223
            "rays": rays,  # [h, w, 3] or [num_rays, 3]
Ruilong Li's avatar
Ruilong Li committed
224
        }