utils.py 7.62 KB
Newer Older
Emilien Garreau's avatar
Emilien Garreau committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


# Note: The #noqa comments below are for unused imports of pluggable implementations
# which are part of implicitron. They ensure that the registry is prepopulated.

import warnings
from logging import Logger
from typing import Any, Dict, Optional, Tuple

import torch
import tqdm
from pytorch3d.common.compat import prod

from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle

from pytorch3d.implicitron.tools import image_utils

from pytorch3d.implicitron.tools.utils import cat_dataclass


def preprocess_input(
    image_rgb: Optional[torch.Tensor],
    fg_probability: Optional[torch.Tensor],
    depth_map: Optional[torch.Tensor],
    mask_images: bool,
    mask_depths: bool,
    mask_threshold: float,
    bg_color: Tuple[float, float, float],
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
    """
    Helper function to preprocess the input images and optional depth maps
    to apply masking if required.

    Args:
        image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images
            corresponding to the source viewpoints from which features will be extracted
        fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch
            of foreground masks with values in [0, 1].
        depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
        mask_images: Whether or not to mask the RGB image background given the
            foreground mask (the `fg_probability` argument of `GenericModel.forward`)
        mask_depths: Whether or not to mask the depth image background given the
            foreground mask (the `fg_probability` argument of `GenericModel.forward`)
        mask_threshold: If greater than 0.0, the foreground mask is
            thresholded by this value before being applied to the RGB/Depth images
        bg_color: RGB values for setting the background color of input image
            if mask_images=True. Defaults to (0.0, 0.0, 0.0). Each renderer has its own
            way to determine the background color of its output, unrelated to this.

    Returns:
        Modified image_rgb, fg_mask, depth_map
    """
    if image_rgb is not None and image_rgb.ndim == 3:
        # The FrameData object is used for both frames and batches of frames,
        # and a user might get this error if those were confused.
        # Perhaps a user has a FrameData `fd` representing a single frame and
        # wrote something like `model(**fd)` instead of
        # `model(**fd.collate([fd]))`.
        raise ValueError(
            "Model received unbatched inputs. "
            + "Perhaps they came from a FrameData which had not been collated."
        )

    fg_mask = fg_probability
    if fg_mask is not None and mask_threshold > 0.0:
        # threshold masks
        warnings.warn("Thresholding masks!")
        fg_mask = (fg_mask >= mask_threshold).type_as(fg_mask)

    if mask_images and fg_mask is not None and image_rgb is not None:
        # mask the image
        warnings.warn("Masking images!")
        image_rgb = image_utils.mask_background(
            image_rgb, fg_mask, dim_color=1, bg_color=torch.tensor(bg_color)
        )

    if mask_depths and fg_mask is not None and depth_map is not None:
        # mask the depths
        assert (
            mask_threshold > 0.0
        ), "Depths should be masked only with thresholded masks"
        warnings.warn("Masking depths!")
        depth_map = depth_map * fg_mask

    return image_rgb, fg_mask, depth_map


def log_loss_weights(loss_weights: Dict[str, float], logger: Logger) -> None:
    """
    Print a table of the loss weights.
    """
    loss_weights_message = (
        "-------\nloss_weights:\n"
        + "\n".join(f"{k:40s}: {w:1.2e}" for k, w in loss_weights.items())
        + "-------"
    )
    logger.info(loss_weights_message)


def weighted_sum_losses(
    preds: Dict[str, torch.Tensor], loss_weights: Dict[str, float]
) -> Optional[torch.Tensor]:
    """
    A helper function to compute the overall loss as the dot product
    of individual loss functions with the corresponding weights.
    """
    losses_weighted = [
        preds[k] * float(w)
        for k, w in loss_weights.items()
        if (k in preds and w != 0.0)
    ]
    if len(losses_weighted) == 0:
        warnings.warn("No main objective found.")
        return None
    loss = sum(losses_weighted)
    assert torch.is_tensor(loss)
    # pyre-fixme[7]: Expected `Optional[Tensor]` but got `int`.
    return loss


def apply_chunked(func, chunk_generator, tensor_collator):
    """
    Helper function to apply a function on a sequence of
    chunked inputs yielded by a generator and collate
    the result.
    """
    processed_chunks = [
        func(*chunk_args, **chunk_kwargs)
        for chunk_args, chunk_kwargs in chunk_generator
    ]

    return cat_dataclass(processed_chunks, tensor_collator)


def chunk_generator(
    chunk_size: int,
    ray_bundle: ImplicitronRayBundle,
    chunked_inputs: Dict[str, torch.Tensor],
    tqdm_trigger_threshold: int,
    *args,
    **kwargs,
):
    """
    Helper function which yields chunks of rays from the
    input ray_bundle, to be used when the number of rays is
    large and will not fit in memory for rendering.
    """
    (
        batch_size,
        *spatial_dim,
        n_pts_per_ray,
    ) = ray_bundle.lengths.shape  # B x ... x n_pts_per_ray
    if n_pts_per_ray > 0 and chunk_size % n_pts_per_ray != 0:
        raise ValueError(
            f"chunk_size_grid ({chunk_size}) should be divisible "
            f"by n_pts_per_ray ({n_pts_per_ray})"
        )

    n_rays = prod(spatial_dim)
    # special handling for raytracing-based methods
    n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size)
    chunk_size_in_rays = -(-n_rays // n_chunks)

    iter = range(0, n_rays, chunk_size_in_rays)
    if len(iter) >= tqdm_trigger_threshold:
        iter = tqdm.tqdm(iter)

    def _safe_slice(
        tensor: Optional[torch.Tensor], start_idx: int, end_idx: int
    ) -> Any:
        return tensor[start_idx:end_idx] if tensor is not None else None

    for start_idx in iter:
        end_idx = min(start_idx + chunk_size_in_rays, n_rays)
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        bins = (
            None
            if ray_bundle.bins is None
            else ray_bundle.bins.reshape(batch_size, n_rays, n_pts_per_ray + 1)[
                :, start_idx:end_idx
            ]
        )
        pixel_radii_2d = (
            None
            if ray_bundle.pixel_radii_2d is None
            else ray_bundle.pixel_radii_2d.reshape(batch_size, -1, 1)[
                :, start_idx:end_idx
            ]
        )
Emilien Garreau's avatar
Emilien Garreau committed
194
195
196
197
198
199
200
201
202
        ray_bundle_chunk = ImplicitronRayBundle(
            origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx],
            directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
                :, start_idx:end_idx
            ],
            lengths=ray_bundle.lengths.reshape(batch_size, n_rays, n_pts_per_ray)[
                :, start_idx:end_idx
            ],
            xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
203
204
            bins=bins,
            pixel_radii_2d=pixel_radii_2d,
Emilien Garreau's avatar
Emilien Garreau committed
205
206
207
208
209
210
211
            camera_ids=_safe_slice(ray_bundle.camera_ids, start_idx, end_idx),
            camera_counts=_safe_slice(ray_bundle.camera_counts, start_idx, end_idx),
        )
        extra_args = kwargs.copy()
        for k, v in chunked_inputs.items():
            extra_args[k] = v.flatten(2)[:, :, start_idx:end_idx]
        yield [ray_bundle_chunk, *args], extra_args