base_model.py 3.65 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import torch
from pytorch3d.implicitron.tools.config import ReplaceableBase
from pytorch3d.renderer.cameras import CamerasBase

from .renderer.base import EvaluationMode


@dataclass
class ImplicitronRender:
    """
    Holds the tensors that describe a result of rendering.
    """

    depth_render: Optional[torch.Tensor] = None
    image_render: Optional[torch.Tensor] = None
    mask_render: Optional[torch.Tensor] = None
    camera_distance: Optional[torch.Tensor] = None

    def clone(self) -> "ImplicitronRender":
        def safe_clone(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
            return t.detach().clone() if t is not None else None

        return ImplicitronRender(
            depth_render=safe_clone(self.depth_render),
            image_render=safe_clone(self.image_render),
            mask_render=safe_clone(self.mask_render),
            camera_distance=safe_clone(self.camera_distance),
        )


class ImplicitronModelBase(ReplaceableBase):
    """Replaceable abstract base for all image generation / rendering models.
    `forward()` method produces a render with a depth map.
    """

    def __init__(self) -> None:
        super().__init__()

    def forward(
        self,
        *,  # force keyword-only arguments
        image_rgb: Optional[torch.Tensor],
        camera: CamerasBase,
        fg_probability: Optional[torch.Tensor],
        mask_crop: Optional[torch.Tensor],
        depth_map: Optional[torch.Tensor],
        sequence_name: Optional[List[str]],
        evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
        **kwargs,
    ) -> Dict[str, Any]:
        """
        Args:
            image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images;
                the first `min(B, n_train_target_views)` images are considered targets and
                are used to supervise the renders; the rest corresponding to the source
                viewpoints from which features will be extracted.
            camera: An instance of CamerasBase containing a batch of `B` cameras corresponding
                to the viewpoints of target images, from which the rays will be sampled,
                and source images, which will be used for intersecting with target rays.
            fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
                foreground masks.
            mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid
                regions in the input images (i.e. regions that do not correspond
                to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
                "mask_sample", rays  will be sampled in the non zero regions.
            depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
            sequence_name: A list of `B` strings corresponding to the sequence names
                from which images `image_rgb` were extracted. They are used to match
                target frames with relevant source frames.
            evaluation_mode: one of EvaluationMode.TRAINING or
                EvaluationMode.EVALUATION which determines the settings used for
                rendering.

        Returns:
            preds: A dictionary containing all outputs of the forward pass. All models should
                output an instance of `ImplicitronRender` in `preds["implicitron_render"]`.
        """
        raise NotImplementedError()