r2n2.py 18.3 KB
Newer Older
Luya Gao's avatar
Luya Gao committed
1
2
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import json
import warnings
from os import path
from pathlib import Path
Luya Gao's avatar
Luya Gao committed
7
from typing import Dict, List, Optional
Luya Gao's avatar
Luya Gao committed
8

Luya Gao's avatar
Luya Gao committed
9
10
11
import numpy as np
import torch
from PIL import Image
Luya Gao's avatar
Luya Gao committed
12
from pytorch3d.datasets.shapenet_base import ShapeNetBase
Luya Gao's avatar
Luya Gao committed
13
from pytorch3d.renderer import HardPhongShader
14
from tabulate import tabulate
Luya Gao's avatar
Luya Gao committed
15

Luya Gao's avatar
Luya Gao committed
16
17
18
19
20
21
22
from .utils import (
    BlenderCamera,
    align_bbox,
    compute_extrinsic_matrix,
    read_binvox_coords,
    voxelize,
)
Luya Gao's avatar
Luya Gao committed
23
24


Luya Gao's avatar
Luya Gao committed
25
26
27
28
29
30
31
32
33
34
35
36
37
SYNSET_DICT_DIR = Path(__file__).resolve().parent
MAX_CAMERA_DISTANCE = 1.75  # Constant from R2N2.
VOXEL_SIZE = 128
# Intrinsic matrix extracted from Blender. Taken from meshrcnn codebase:
# https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py
BLENDER_INTRINSIC = torch.tensor(
    [
        [2.1875, 0.0, 0.0, 0.0],
        [0.0, 2.1875, 0.0, 0.0],
        [0.0, 0.0, -1.002002, -0.2002002],
        [0.0, 0.0, -1.0, 0.0],
    ]
)
Luya Gao's avatar
Luya Gao committed
38

Luya Gao's avatar
Luya Gao committed
39
40
41
42
43
44

class R2N2(ShapeNetBase):
    """
    This class loads the R2N2 dataset from a given directory into a Dataset object.
    The R2N2 dataset contains 13 categories that are a subset of the ShapeNetCore v.1
    dataset. The R2N2 dataset also contains its own 24 renderings of each object and
45
46
    voxelized models. Most of the models have all 24 views in the same split, but there
    are eight of them that divide their views between train and test splits.
47

Luya Gao's avatar
Luya Gao committed
48
49
    """

Luya Gao's avatar
Luya Gao committed
50
51
52
53
54
55
56
    def __init__(
        self,
        split: str,
        shapenet_dir,
        r2n2_dir,
        splits_file,
        return_all_views: bool = True,
Luya Gao's avatar
Luya Gao committed
57
        return_voxels: bool = False,
58
59
60
61
        views_rel_path: str = "ShapeNetRendering",
        voxels_rel_path: str = "ShapeNetVoxels",
        load_textures: bool = True,
        texture_resolution: int = 4,
Luya Gao's avatar
Luya Gao committed
62
    ):
Luya Gao's avatar
Luya Gao committed
63
64
        """
        Store each object's synset id and models id the given directories.
Luya Gao's avatar
Luya Gao committed
65

Luya Gao's avatar
Luya Gao committed
66
67
68
69
70
        Args:
            split (str): One of (train, val, test).
            shapenet_dir (path): Path to ShapeNet core v1.
            r2n2_dir (path): Path to the R2N2 dataset.
            splits_file (path): File containing the train/val/test splits.
71
72
73
            return_all_views (bool): Indicator of whether or not to load all the views in
                the split. If set to False, one of the views in the split will be randomly
                selected and loaded.
Luya Gao's avatar
Luya Gao committed
74
75
            return_voxels(bool): Indicator of whether or not to return voxels as a tensor
                of shape (D, D, D) where D is the number of voxels along each dimension.
76
77
78
79
80
81
82
83
84
85
            views_rel_path: path to rendered views within the r2n2_dir. If not specified,
                the renderings are assumed to be at os.path.join(rn2n_dir, "ShapeNetRendering").
            voxels_rel_path: path to rendered views within the r2n2_dir. If not specified,
                the renderings are assumed to be at os.path.join(rn2n_dir, "ShapeNetVoxels").
            load_textures: Boolean indicating whether textures should loaded for the model.
                Textures will be of type TexturesAtlas i.e. a texture map per face.
            texture_resolution: Int specifying the resolution of the texture map per face
                created using the textures in the obj file. A
                (texture_resolution, texture_resolution, 3) map is created per face.

Luya Gao's avatar
Luya Gao committed
86
87
88
89
        """
        super().__init__()
        self.shapenet_dir = shapenet_dir
        self.r2n2_dir = r2n2_dir
90
91
92
93
        self.views_rel_path = views_rel_path
        self.voxels_rel_path = voxels_rel_path
        self.load_textures = load_textures
        self.texture_resolution = texture_resolution
Luya Gao's avatar
Luya Gao committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        # Examine if split is valid.
        if split not in ["train", "val", "test"]:
            raise ValueError("split has to be one of (train, val, test).")
        # Synset dictionary mapping synset offsets in R2N2 to corresponding labels.
        with open(
            path.join(SYNSET_DICT_DIR, "r2n2_synset_dict.json"), "r"
        ) as read_dict:
            self.synset_dict = json.load(read_dict)
        # Inverse dicitonary mapping synset labels to corresponding offsets.
        self.synset_inv = {label: offset for offset, label in self.synset_dict.items()}

        # Store synset and model ids of objects mentioned in the splits_file.
        with open(splits_file) as splits:
            split_dict = json.load(splits)[split]

Luya Gao's avatar
Luya Gao committed
109
110
        self.return_images = True
        # Check if the folder containing R2N2 renderings is included in r2n2_dir.
111
        if not path.isdir(path.join(r2n2_dir, views_rel_path)):
Luya Gao's avatar
Luya Gao committed
112
113
            self.return_images = False
            msg = (
114
                "%s not found in %s. R2N2 renderings will "
Luya Gao's avatar
Luya Gao committed
115
                "be skipped when returning models."
116
            ) % (views_rel_path, r2n2_dir)
Luya Gao's avatar
Luya Gao committed
117
118
            warnings.warn(msg)

Luya Gao's avatar
Luya Gao committed
119
120
        self.return_voxels = return_voxels
        # Check if the folder containing voxel coordinates is included in r2n2_dir.
121
        if not path.isdir(path.join(r2n2_dir, voxels_rel_path)):
Luya Gao's avatar
Luya Gao committed
122
123
            self.return_voxels = False
            msg = (
124
                "%s not found in %s. Voxel coordinates will "
Luya Gao's avatar
Luya Gao committed
125
                "be skipped when returning models."
126
            ) % (voxels_rel_path, r2n2_dir)
Luya Gao's avatar
Luya Gao committed
127
128
            warnings.warn(msg)

Luya Gao's avatar
Luya Gao committed
129
        synset_set = set()
130
131
132
133
        # Store lists of views of each model in a list.
        self.views_per_model_list = []
        # Store tuples of synset label and total number of views in each category in a list.
        synset_num_instances = []
Luya Gao's avatar
Luya Gao committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
        for synset in split_dict.keys():
            # Examine if the given synset is present in the ShapeNetCore dataset
            # and is also part of the standard R2N2 dataset.
            if not (
                path.isdir(path.join(shapenet_dir, synset))
                and synset in self.synset_dict
            ):
                msg = (
                    "Synset category %s from the splits file is either not "
                    "present in %s or not part of the standard R2N2 dataset."
                ) % (synset, shapenet_dir)
                warnings.warn(msg)
                continue

            synset_set.add(synset)
149
150
151
152
            self.synset_start_idxs[synset] = len(self.synset_ids)
            # Start counting total number of views in the current category.
            synset_view_count = 0
            for model in split_dict[synset]:
Luya Gao's avatar
Luya Gao committed
153
154
155
156
157
158
159
160
161
162
163
164
                # Examine if the given model is present in the ShapeNetCore path.
                shapenet_path = path.join(shapenet_dir, synset, model)
                if not path.isdir(shapenet_path):
                    msg = "Model %s from category %s is not present in %s." % (
                        model,
                        synset,
                        shapenet_dir,
                    )
                    warnings.warn(msg)
                    continue
                self.synset_ids.append(synset)
                self.model_ids.append(model)
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180

                model_views = split_dict[synset][model]
                # Randomly select a view index if return_all_views set to False.
                if not return_all_views:
                    rand_idx = torch.randint(len(model_views), (1,))
                    model_views = [model_views[rand_idx]]
                self.views_per_model_list.append(model_views)
                synset_view_count += len(model_views)
            synset_num_instances.append((self.synset_dict[synset], synset_view_count))
            model_count = len(self.synset_ids) - self.synset_start_idxs[synset]
            self.synset_num_models[synset] = model_count
        headers = ["category", "#instances"]
        synset_num_instances.append(("total", sum(n for _, n in synset_num_instances)))
        print(
            tabulate(synset_num_instances, headers, numalign="left", stralign="center")
        )
Luya Gao's avatar
Luya Gao committed
181
182
183
184
185

        # Examine if all the synsets in the standard R2N2 mapping are present.
        # Update self.synset_inv so that it only includes the loaded categories.
        synset_not_present = [
            self.synset_inv.pop(self.synset_dict[synset])
186
            for synset in self.synset_dict
Luya Gao's avatar
Luya Gao committed
187
188
189
190
191
192
193
194
195
            if synset not in synset_set
        ]
        if len(synset_not_present) > 0:
            msg = (
                "The following categories are included in R2N2's"
                "official mapping but not found in the dataset location %s: %s"
            ) % (shapenet_dir, ", ".join(synset_not_present))
            warnings.warn(msg)

Luya Gao's avatar
Luya Gao committed
196
    def __getitem__(self, model_idx, view_idxs: Optional[List[int]] = None) -> Dict:
Luya Gao's avatar
Luya Gao committed
197
198
199
200
        """
        Read a model by the given index.

        Args:
Luya Gao's avatar
Luya Gao committed
201
202
            model_idx: The idx of the model to be retrieved in the dataset.
            view_idx: List of indices of the view to be returned. Each index needs to be
203
204
205
                contained in the loaded split (always between 0 and 23, inclusive). If
                an invalid index is supplied, view_idx will be ignored and all the loaded
                views will be returned.
Luya Gao's avatar
Luya Gao committed
206
207
208
209
210
211
212
213

        Returns:
            dictionary with following keys:
            - verts: FloatTensor of shape (V, 3).
            - faces: faces.verts_idx, LongTensor of shape (F, 3).
            - synset_id (str): synset id.
            - model_id (str): model id.
            - label (str): synset label.
Luya Gao's avatar
Luya Gao committed
214
215
            - images: FloatTensor of shape (V, H, W, C), where V is number of views
                returned. Returns a batch of the renderings of the models from the R2N2 dataset.
Luya Gao's avatar
Luya Gao committed
216
217
218
            - R: Rotation matrix of shape (V, 3, 3), where V is number of views returned.
            - T: Translation matrix of shape (V, 3), where V is number of views returned.
            - K: Intrinsic matrix of shape (V, 4, 4), where V is number of views returned.
Luya Gao's avatar
Luya Gao committed
219
220
            - voxels: Voxels of shape (D, D, D), where D is the number of voxels along each
                dimension.
Luya Gao's avatar
Luya Gao committed
221
        """
222
        if isinstance(model_idx, tuple):
Luya Gao's avatar
Luya Gao committed
223
            model_idx, view_idxs = model_idx
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        if view_idxs is not None:
            if isinstance(view_idxs, int):
                view_idxs = [view_idxs]
            if not isinstance(view_idxs, list) and not torch.is_tensor(view_idxs):
                raise TypeError(
                    "view_idxs is of type %s but it needs to be a list."
                    % type(view_idxs)
                )

        model_views = self.views_per_model_list[model_idx]
        if view_idxs is not None and any(
            idx not in self.views_per_model_list[model_idx] for idx in view_idxs
        ):
            msg = """At least one of the indices in view_idxs is not available.
                Specified view of the model needs to be contained in the
                loaded split. If return_all_views is set to False, only one
                random view is loaded. Try accessing the specified view(s)
                after loading the dataset with self.return_all_views set to True.
                Now returning all view(s) in the loaded dataset."""
            warnings.warn(msg)
        elif view_idxs is not None:
            model_views = view_idxs

Luya Gao's avatar
Luya Gao committed
247
        model = self._get_item_ids(model_idx)
Luya Gao's avatar
Luya Gao committed
248
249
250
        model_path = path.join(
            self.shapenet_dir, model["synset_id"], model["model_id"], "model.obj"
        )
251
252
253
254
255

        verts, faces, textures = self._load_mesh(model_path)
        model["verts"] = verts
        model["faces"] = faces
        model["textures"] = textures
Luya Gao's avatar
Luya Gao committed
256
        model["label"] = self.synset_dict[model["synset_id"]]
Luya Gao's avatar
Luya Gao committed
257
258

        model["images"] = None
Luya Gao's avatar
Luya Gao committed
259
        images, Rs, Ts, voxel_RTs = [], [], [], []
Luya Gao's avatar
Luya Gao committed
260
261
262
263
        # Retrieve R2N2's renderings if required.
        if self.return_images:
            rendering_path = path.join(
                self.r2n2_dir,
264
                self.views_rel_path,
Luya Gao's avatar
Luya Gao committed
265
266
267
268
                model["synset_id"],
                model["model_id"],
                "rendering",
            )
Luya Gao's avatar
Luya Gao committed
269
270
271
            # Read metadata file to obtain params for calibration matrices.
            with open(path.join(rendering_path, "rendering_metadata.txt"), "r") as f:
                metadata_lines = f.readlines()
272
            for i in model_views:
Luya Gao's avatar
Luya Gao committed
273
274
275
276
277
278
                # Read image.
                image_path = path.join(rendering_path, "%02d.png" % i)
                raw_img = Image.open(image_path)
                image = torch.from_numpy(np.array(raw_img) / 255.0)[..., :3]
                images.append(image.to(dtype=torch.float32))

Luya Gao's avatar
Luya Gao committed
279
280
281
282
                # Get camera calibration.
                azim, elev, yaw, dist_ratio, fov = [
                    float(v) for v in metadata_lines[i].strip().split(" ")
                ]
Luya Gao's avatar
Luya Gao committed
283
284
285
286
                dist = dist_ratio * MAX_CAMERA_DISTANCE
                # Extrinsic matrix before transformation to PyTorch3D world space.
                RT = compute_extrinsic_matrix(azim, elev, dist)
                R, T = self._compute_camera_calibration(RT)
Luya Gao's avatar
Luya Gao committed
287
288
                Rs.append(R)
                Ts.append(T)
Luya Gao's avatar
Luya Gao committed
289
                voxel_RTs.append(RT)
Luya Gao's avatar
Luya Gao committed
290
291
292
293
294
295
296
297
298
299
300
301

            # Intrinsic matrix extracted from the Blender with slight modification to work with
            # PyTorch3D world space. Taken from meshrcnn codebase:
            # https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py
            K = torch.tensor(
                [
                    [2.1875, 0.0, 0.0, 0.0],
                    [0.0, 2.1875, 0.0, 0.0],
                    [0.0, 0.0, -1.002002, -0.2002002],
                    [0.0, 0.0, 1.0, 0.0],
                ]
            )
Luya Gao's avatar
Luya Gao committed
302
            model["images"] = torch.stack(images)
Luya Gao's avatar
Luya Gao committed
303
304
305
            model["R"] = torch.stack(Rs)
            model["T"] = torch.stack(Ts)
            model["K"] = K.expand(len(model_views), 4, 4)
Luya Gao's avatar
Luya Gao committed
306

Luya Gao's avatar
Luya Gao committed
307
        voxels_list = []
308

Luya Gao's avatar
Luya Gao committed
309
310
311
        # Read voxels if required.
        voxel_path = path.join(
            self.r2n2_dir,
312
            self.voxels_rel_path,
Luya Gao's avatar
Luya Gao committed
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
            model["synset_id"],
            model["model_id"],
            "model.binvox",
        )
        if self.return_voxels:
            if not path.isfile(voxel_path):
                msg = "Voxel file not found for model %s from category %s."
                raise FileNotFoundError(msg % (model["model_id"], model["synset_id"]))

            with open(voxel_path, "rb") as f:
                # Read voxel coordinates as a tensor of shape (N, 3).
                voxel_coords = read_binvox_coords(f)
            # Align voxels to the same coordinate system as mesh verts.
            voxel_coords = align_bbox(voxel_coords, model["verts"])
            for RT in voxel_RTs:
                # Compute projection matrix.
                P = BLENDER_INTRINSIC.mm(RT)
                # Convert voxel coordinates of shape (N, 3) to voxels of shape (D, D, D).
                voxels = voxelize(voxel_coords, P, VOXEL_SIZE)
                voxels_list.append(voxels)
            model["voxels"] = torch.stack(voxels_list)

Luya Gao's avatar
Luya Gao committed
335
        return model
Luya Gao's avatar
Luya Gao committed
336

Luya Gao's avatar
Luya Gao committed
337
    def _compute_camera_calibration(self, RT):
Luya Gao's avatar
Luya Gao committed
338
        """
Luya Gao's avatar
Luya Gao committed
339
340
        Helper function for calculating rotation and translation matrices from ShapeNet
        to camera transformation and ShapeNet to PyTorch3D transformation.
Luya Gao's avatar
Luya Gao committed
341
342

        Args:
Luya Gao's avatar
Luya Gao committed
343
344
            RT: Extrinsic matrix that performs ShapeNet world view to camera view
                transformation.
Luya Gao's avatar
Luya Gao committed
345
346

        Returns:
Luya Gao's avatar
Luya Gao committed
347
348
            R: Rotation matrix of shape (3, 3).
            T: Translation matrix of shape (3).
Luya Gao's avatar
Luya Gao committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
        """
        # Transform the mesh vertices from shapenet world to pytorch3d world.
        shapenet_to_pytorch3d = torch.tensor(
            [
                [-1.0, 0.0, 0.0, 0.0],
                [0.0, 1.0, 0.0, 0.0],
                [0.0, 0.0, -1.0, 0.0],
                [0.0, 0.0, 0.0, 1.0],
            ],
            dtype=torch.float32,
        )
        RT = torch.transpose(RT, 0, 1).mm(shapenet_to_pytorch3d)  # (4, 4)
        # Extract rotation and translation matrices from RT.
        R = RT[:3, :3]
        T = RT[3, :3]
        return R, T

    def render(
        self,
        model_ids: Optional[List[str]] = None,
        categories: Optional[List[str]] = None,
        sample_nums: Optional[List[int]] = None,
        idxs: Optional[List[int]] = None,
        view_idxs: Optional[List[int]] = None,
        shader_type=HardPhongShader,
        device="cpu",
        **kwargs
    ) -> torch.Tensor:
        """
        Render models with BlenderCamera by default to achieve the same orientations as the
        R2N2 renderings. Also accepts other types of cameras and any of the args that the
        render function in the ShapeNetBase class accepts.

        Args:
            view_idxs: each model will be rendered with the orientation(s) of the specified
                views. Only render by view_idxs if no camera or args for BlenderCamera is
                supplied.
            Accepts any of the args of the render function in ShapnetBase:
            model_ids: List[str] of model_ids of models intended to be rendered.
            categories: List[str] of categories intended to be rendered. categories
                and sample_nums must be specified at the same time. categories can be given
                in the form of synset offsets or labels, or a combination of both.
            sample_nums: List[int] of number of models to be randomly sampled from
                each category. Could also contain one single integer, in which case it
                will be broadcasted for every category.
            idxs: List[int] of indices of models to be rendered in the dataset.
            shader_type: Shader to use for rendering. Examples include HardPhongShader
            (default), SoftPhongShader etc or any other type of valid Shader class.
            device: torch.device on which the tensors should be located.
            **kwargs: Accepts any of the kwargs that the renderer supports and any of the
                args that BlenderCamera supports.

        Returns:
            Batch of rendered images of shape (N, H, W, 3).
        """
        idxs = self._handle_render_inputs(model_ids, categories, sample_nums, idxs)
        r = torch.cat([self[idxs[i], view_idxs]["R"] for i in range(len(idxs))])
        t = torch.cat([self[idxs[i], view_idxs]["T"] for i in range(len(idxs))])
        k = torch.cat([self[idxs[i], view_idxs]["K"] for i in range(len(idxs))])
        # Initialize default camera using R, T, K from kwargs or R, T, K of the specified views.
        blend_cameras = BlenderCamera(
            R=kwargs.get("R", r),
            T=kwargs.get("T", t),
            K=kwargs.get("K", k),
            device=device,
        )
        cameras = kwargs.get("cameras", blend_cameras).to(device)
        kwargs.pop("cameras", None)
        # pass down all the same inputs
        return super().render(
            idxs=idxs, shader_type=shader_type, device=device, cameras=cameras, **kwargs
        )