Commit ebdbfde0 authored by Ildar Salakhiev's avatar Ildar Salakhiev Committed by Facebook GitHub Bot
Browse files

Extract BlobLoader class from JsonIndexDataset and moving crop_by_bbox to FrameData

Summary:
extracted blob loader
added documentation for blob_loader
did some refactoring on fields
for detailed steps and discussions see:
https://github.com/facebookresearch/pytorch3d/pull/1463
https://github.com/fairinternal/pixar_replay/pull/160

Reviewed By: bottler

Differential Revision: D44061728

fbshipit-source-id: eefb21e9679003045d73729f96e6a93a1d4d2d51
parent c759fc56
...@@ -18,8 +18,9 @@ from torch.utils.data import ( ...@@ -18,8 +18,9 @@ from torch.utils.data import (
Sampler, Sampler,
) )
from .dataset_base import DatasetBase, FrameData from .dataset_base import DatasetBase
from .dataset_map_provider import DatasetMap from .dataset_map_provider import DatasetMap
from .frame_data import FrameData
from .scene_batch_sampler import SceneBatchSampler from .scene_batch_sampler import SceneBatchSampler
from .utils import is_known_frame_scalar from .utils import is_known_frame_scalar
......
...@@ -5,217 +5,27 @@ ...@@ -5,217 +5,27 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field, fields from dataclasses import dataclass
from typing import ( from typing import (
Any,
ClassVar, ClassVar,
Dict, Dict,
Iterable, Iterable,
Iterator, Iterator,
List, List,
Mapping,
Optional, Optional,
Sequence, Sequence,
Tuple, Tuple,
Type, Type,
Union,
) )
import numpy as np
import torch import torch
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
from pytorch3d.implicitron.dataset.frame_data import FrameData
@dataclass from pytorch3d.implicitron.dataset.utils import GenericWorkaround
class FrameData(Mapping[str, Any]):
"""
A type of the elements returned by indexing the dataset object.
It can represent both individual frames and batches of thereof;
in this documentation, the sizes of tensors refer to single frames;
add the first batch dimension for the collation result.
Args:
frame_number: The number of the frame within its sequence.
0-based continuous integers.
sequence_name: The unique name of the frame's sequence.
sequence_category: The object category of the sequence.
frame_timestamp: The time elapsed since the start of a sequence in sec.
image_size_hw: The size of the image in pixels; (height, width) tensor
of shape (2,).
image_path: The qualified path to the loaded image (with dataset_root).
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
of the frame; elements are floats in [0, 1].
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
are a result of zero-padding of the image after cropping around
the object bounding box; elements are floats in {0.0, 1.0}.
depth_path: The qualified path to the frame's depth map.
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
of the frame; values correspond to distances from the camera;
use `depth_mask` and `mask_crop` to filter for valid pixels.
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
depth map that are valid for evaluation, they have been checked for
consistency across views; elements are floats in {0.0, 1.0}.
mask_path: A qualified path to the foreground probability mask.
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
pixels belonging to the captured object; elements are floats
in [0, 1].
bbox_xywh: The bounding box tightly enclosing the foreground object in the
format (x0, y0, width, height). The convention assumes that
`x0+width` and `y0+height` includes the boundary of the box.
I.e., to slice out the corresponding crop from an image tensor `I`
we execute `crop = I[..., y0:y0+height, x0:x0+width]`
crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb`
in the original image coordinates in the format (x0, y0, width, height).
The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs
from `bbox_xywh` due to padding (which can happen e.g. due to
setting `JsonIndexDataset.box_crop_context > 0`)
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
corrected for cropping if it happened.
camera_quality_score: The score proportional to the confidence of the
frame's camera estimation (the higher the more accurate).
point_cloud_quality_score: The score proportional to the accuracy of the
frame's sequence point cloud (the higher the more accurate).
sequence_point_cloud_path: The path to the sequence's point cloud.
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
point cloud corresponding to the frame's sequence. When the object
represents a batch of frames, point clouds may be deduplicated;
see `sequence_point_cloud_idx`.
sequence_point_cloud_idx: Integer indices mapping frame indices to the
corresponding point clouds in `sequence_point_cloud`; to get the
corresponding point cloud to `image_rgb[i]`, use
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
frame_type: The type of the loaded frame specified in
`subset_lists_file`, if provided.
meta: A dict for storing additional frame information.
"""
frame_number: Optional[torch.LongTensor]
sequence_name: Union[str, List[str]]
sequence_category: Union[str, List[str]]
frame_timestamp: Optional[torch.Tensor] = None
image_size_hw: Optional[torch.Tensor] = None
image_path: Union[str, List[str], None] = None
image_rgb: Optional[torch.Tensor] = None
# masks out padding added due to cropping the square bit
mask_crop: Optional[torch.Tensor] = None
depth_path: Union[str, List[str], None] = None
depth_map: Optional[torch.Tensor] = None
depth_mask: Optional[torch.Tensor] = None
mask_path: Union[str, List[str], None] = None
fg_probability: Optional[torch.Tensor] = None
bbox_xywh: Optional[torch.Tensor] = None
crop_bbox_xywh: Optional[torch.Tensor] = None
camera: Optional[PerspectiveCameras] = None
camera_quality_score: Optional[torch.Tensor] = None
point_cloud_quality_score: Optional[torch.Tensor] = None
sequence_point_cloud_path: Union[str, List[str], None] = None
sequence_point_cloud: Optional[Pointclouds] = None
sequence_point_cloud_idx: Optional[torch.Tensor] = None
frame_type: Union[str, List[str], None] = None # known | unseen
meta: dict = field(default_factory=lambda: {})
def to(self, *args, **kwargs):
new_params = {}
for f in fields(self):
value = getattr(self, f.name)
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
new_params[f.name] = value.to(*args, **kwargs)
else:
new_params[f.name] = value
return type(self)(**new_params)
def cpu(self):
return self.to(device=torch.device("cpu"))
def cuda(self):
return self.to(device=torch.device("cuda"))
# the following functions make sure **frame_data can be passed to functions
def __iter__(self):
for f in fields(self):
yield f.name
def __getitem__(self, key):
return getattr(self, key)
def __len__(self):
return len(fields(self))
@classmethod
def collate(cls, batch):
"""
Given a list objects `batch` of class `cls`, collates them into a batched
representation suitable for processing with deep networks.
"""
elem = batch[0]
if isinstance(elem, cls):
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
id_to_idx = defaultdict(list)
for i, pc_id in enumerate(pointcloud_ids):
id_to_idx[pc_id].append(i)
sequence_point_cloud = []
sequence_point_cloud_idx = -np.ones((len(batch),))
for i, ind in enumerate(id_to_idx.values()):
sequence_point_cloud_idx[ind] = i
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
assert (sequence_point_cloud_idx >= 0).all()
override_fields = {
"sequence_point_cloud": sequence_point_cloud,
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
}
# note that the pre-collate value of sequence_point_cloud_idx is unused
collated = {}
for f in fields(elem):
list_values = override_fields.get(
f.name, [getattr(d, f.name) for d in batch]
)
collated[f.name] = (
cls.collate(list_values)
if all(list_value is not None for list_value in list_values)
else None
)
return cls(**collated)
elif isinstance(elem, Pointclouds):
return join_pointclouds_as_batch(batch)
elif isinstance(elem, CamerasBase):
# TODO: don't store K; enforce working in NDC space
return join_cameras_as_batch(batch)
else:
return torch.utils.data._utils.collate.default_collate(batch)
class _GenericWorkaround:
"""
OmegaConf.structured has a weirdness when you try to apply
it to a dataclass whose first base class is a Generic which is not
Dict. The issue is with a function called get_dict_key_value_types
in omegaconf/_utils.py.
For example this fails:
@dataclass(eq=False)
class D(torch.utils.data.Dataset[int]):
a: int = 3
OmegaConf.structured(D)
We avoid the problem by adding this class as an extra base class.
"""
pass
@dataclass(eq=False) @dataclass(eq=False)
class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]): class DatasetBase(GenericWorkaround, torch.utils.data.Dataset[FrameData]):
""" """
Base class to describe a dataset to be used with Implicitron. Base class to describe a dataset to be used with Implicitron.
......
This diff is collapsed.
...@@ -20,8 +20,9 @@ from pytorch3d.implicitron.tools.config import ( ...@@ -20,8 +20,9 @@ from pytorch3d.implicitron.tools.config import (
) )
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
from .dataset_base import DatasetBase, FrameData from .dataset_base import DatasetBase
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
from .frame_data import FrameData
from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
_SINGLE_SEQUENCE_NAME: str = "one_sequence" _SINGLE_SEQUENCE_NAME: str = "one_sequence"
...@@ -69,7 +70,8 @@ class SingleSceneDataset(DatasetBase, Configurable): ...@@ -69,7 +70,8 @@ class SingleSceneDataset(DatasetBase, Configurable):
sequence_name=_SINGLE_SEQUENCE_NAME, sequence_name=_SINGLE_SEQUENCE_NAME,
sequence_category=self.object_name, sequence_category=self.object_name,
camera=pose, camera=pose,
image_size_hw=torch.tensor(image.shape[1:]), # pyre-ignore
image_size_hw=torch.tensor(image.shape[1:], dtype=torch.long),
image_rgb=image, image_rgb=image,
fg_probability=fg_probability, fg_probability=fg_probability,
frame_type=frame_type, frame_type=frame_type,
......
...@@ -55,6 +55,8 @@ class MaskAnnotation: ...@@ -55,6 +55,8 @@ class MaskAnnotation:
path: str path: str
# (soft) number of pixels in the mask; sum(Prob(fg | pixel)) # (soft) number of pixels in the mask; sum(Prob(fg | pixel))
mass: Optional[float] = None mass: Optional[float] = None
# tight bounding box around the foreground mask
bounding_box_xywh: Optional[Tuple[float, float, float, float]] = None
@dataclass @dataclass
......
...@@ -5,10 +5,18 @@ ...@@ -5,10 +5,18 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import List, Optional import functools
import warnings
from pathlib import Path
from typing import List, Optional, Tuple, TypeVar, Union
import numpy as np
import torch import torch
from PIL import Image
from pytorch3d.io import IO
from pytorch3d.renderer.cameras import PerspectiveCameras
from pytorch3d.structures.pointclouds import Pointclouds
DATASET_TYPE_TRAIN = "train" DATASET_TYPE_TRAIN = "train"
DATASET_TYPE_TEST = "test" DATASET_TYPE_TEST = "test"
...@@ -16,6 +24,26 @@ DATASET_TYPE_KNOWN = "known" ...@@ -16,6 +24,26 @@ DATASET_TYPE_KNOWN = "known"
DATASET_TYPE_UNKNOWN = "unseen" DATASET_TYPE_UNKNOWN = "unseen"
class GenericWorkaround:
"""
OmegaConf.structured has a weirdness when you try to apply
it to a dataclass whose first base class is a Generic which is not
Dict. The issue is with a function called get_dict_key_value_types
in omegaconf/_utils.py.
For example this fails:
@dataclass(eq=False)
class D(torch.utils.data.Dataset[int]):
a: int = 3
OmegaConf.structured(D)
We avoid the problem by adding this class as an extra base class.
"""
pass
def is_known_frame_scalar(frame_type: str) -> bool: def is_known_frame_scalar(frame_type: str) -> bool:
""" """
Given a single frame type corresponding to a single frame, return whether Given a single frame type corresponding to a single frame, return whether
...@@ -52,3 +80,286 @@ def is_train_frame( ...@@ -52,3 +80,286 @@ def is_train_frame(
dtype=torch.bool, dtype=torch.bool,
device=device, device=device,
) )
def get_bbox_from_mask(
mask: np.ndarray, thr: float, decrease_quant: float = 0.05
) -> Tuple[int, int, int, int]:
# bbox in xywh
masks_for_box = np.zeros_like(mask)
while masks_for_box.sum() <= 1.0:
masks_for_box = (mask > thr).astype(np.float32)
thr -= decrease_quant
if thr <= 0.0:
warnings.warn(
f"Empty masks_for_bbox (thr={thr}) => using full image.", stacklevel=1
)
x0, x1 = get_1d_bounds(masks_for_box.sum(axis=-2))
y0, y1 = get_1d_bounds(masks_for_box.sum(axis=-1))
return x0, y0, x1 - x0, y1 - y0
def crop_around_box(
tensor: torch.Tensor, bbox: torch.Tensor, impath: str = ""
) -> torch.Tensor:
# bbox is xyxy, where the upper bound is corrected with +1
bbox = clamp_box_to_image_bounds_and_round(
bbox,
image_size_hw=tuple(tensor.shape[-2:]),
)
tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
return tensor
def clamp_box_to_image_bounds_and_round(
bbox_xyxy: torch.Tensor,
image_size_hw: Tuple[int, int],
) -> torch.LongTensor:
bbox_xyxy = bbox_xyxy.clone()
bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
if not isinstance(bbox_xyxy, torch.LongTensor):
bbox_xyxy = bbox_xyxy.round().long()
return bbox_xyxy # pyre-ignore [7]
T = TypeVar("T", bound=torch.Tensor)
def bbox_xyxy_to_xywh(xyxy: T) -> T:
wh = xyxy[2:] - xyxy[:2]
xywh = torch.cat([xyxy[:2], wh])
return xywh # pyre-ignore
def get_clamp_bbox(
bbox: torch.Tensor,
box_crop_context: float = 0.0,
image_path: str = "",
) -> torch.Tensor:
# box_crop_context: rate of expansion for bbox
# returns possibly expanded bbox xyxy as float
bbox = bbox.clone() # do not edit bbox in place
# increase box size
if box_crop_context > 0.0:
c = box_crop_context
bbox = bbox.float()
bbox[0] -= bbox[2] * c / 2
bbox[1] -= bbox[3] * c / 2
bbox[2] += bbox[2] * c
bbox[3] += bbox[3] * c
if (bbox[2:] <= 1.0).any():
raise ValueError(
f"squashed image {image_path}!! The bounding box contains no pixels."
)
bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes
bbox_xyxy = bbox_xywh_to_xyxy(bbox, clamp_size=2)
return bbox_xyxy
def rescale_bbox(
bbox: torch.Tensor,
orig_res: Union[Tuple[int, int], torch.LongTensor],
new_res: Union[Tuple[int, int], torch.LongTensor],
) -> torch.Tensor:
assert bbox is not None
assert np.prod(orig_res) > 1e-8
# average ratio of dimensions
# pyre-ignore
rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0
return bbox * rel_size
def bbox_xywh_to_xyxy(
xywh: torch.Tensor, clamp_size: Optional[int] = None
) -> torch.Tensor:
xyxy = xywh.clone()
if clamp_size is not None:
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
xyxy[2:] += xyxy[:2]
return xyxy
def get_1d_bounds(arr: np.ndarray) -> Tuple[int, int]:
nz = np.flatnonzero(arr)
return nz[0], nz[-1] + 1
def resize_image(
image: Union[np.ndarray, torch.Tensor],
image_height: Optional[int],
image_width: Optional[int],
mode: str = "bilinear",
) -> Tuple[torch.Tensor, float, torch.Tensor]:
if type(image) == np.ndarray:
image = torch.from_numpy(image)
if image_height is None or image_width is None:
# skip the resizing
return image, 1.0, torch.ones_like(image[:1])
# takes numpy array or tensor, returns pytorch tensor
minscale = min(
image_height / image.shape[-2],
image_width / image.shape[-1],
)
imre = torch.nn.functional.interpolate(
image[None],
scale_factor=minscale,
mode=mode,
align_corners=False if mode == "bilinear" else None,
recompute_scale_factor=True,
)[0]
imre_ = torch.zeros(image.shape[0], image_height, image_width)
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
mask = torch.zeros(1, image_height, image_width)
mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
return imre_, minscale, mask
def load_image(path: str) -> np.ndarray:
with Image.open(path) as pil_im:
im = np.array(pil_im.convert("RGB"))
im = im.transpose((2, 0, 1))
im = im.astype(np.float32) / 255.0
return im
def load_mask(path: str) -> np.ndarray:
with Image.open(path) as pil_im:
mask = np.array(pil_im)
mask = mask.astype(np.float32) / 255.0
return mask[None] # fake feature channel
def load_depth(path: str, scale_adjustment: float) -> np.ndarray:
if not path.lower().endswith(".png"):
raise ValueError('unsupported depth file name "%s"' % path)
d = load_16big_png_depth(path) * scale_adjustment
d[~np.isfinite(d)] = 0.0
return d[None] # fake feature channel
def load_16big_png_depth(depth_png: str) -> np.ndarray:
with Image.open(depth_png) as depth_pil:
# the image is stored with 16-bit depth but PIL reads it as I (32 bit).
# we cast it to uint16, then reinterpret as float16, then cast to float32
depth = (
np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
.astype(np.float32)
.reshape((depth_pil.size[1], depth_pil.size[0]))
)
return depth
def load_1bit_png_mask(file: str) -> np.ndarray:
with Image.open(file) as pil_im:
mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32)
return mask
def load_depth_mask(path: str) -> np.ndarray:
if not path.lower().endswith(".png"):
raise ValueError('unsupported depth mask file name "%s"' % path)
m = load_1bit_png_mask(path)
return m[None] # fake feature channel
def safe_as_tensor(data, dtype):
return torch.tensor(data, dtype=dtype) if data is not None else None
def _convert_ndc_to_pixels(
focal_length: torch.Tensor,
principal_point: torch.Tensor,
image_size_wh: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
half_image_size = image_size_wh / 2
rescale = half_image_size.min()
principal_point_px = half_image_size - principal_point * rescale
focal_length_px = focal_length * rescale
return focal_length_px, principal_point_px
def _convert_pixels_to_ndc(
focal_length_px: torch.Tensor,
principal_point_px: torch.Tensor,
image_size_wh: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
half_image_size = image_size_wh / 2
rescale = half_image_size.min()
principal_point = (half_image_size - principal_point_px) / rescale
focal_length = focal_length_px / rescale
return focal_length, principal_point
def adjust_camera_to_bbox_crop_(
camera: PerspectiveCameras,
image_size_wh: torch.Tensor,
clamp_bbox_xywh: torch.Tensor,
) -> None:
if len(camera) != 1:
raise ValueError("Adjusting currently works with singleton cameras camera only")
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
camera.focal_length[0],
camera.principal_point[0], # pyre-ignore
image_size_wh,
)
principal_point_px_cropped = principal_point_px - clamp_bbox_xywh[:2]
focal_length, principal_point_cropped = _convert_pixels_to_ndc(
focal_length_px,
principal_point_px_cropped,
clamp_bbox_xywh[2:],
)
camera.focal_length = focal_length[None]
camera.principal_point = principal_point_cropped[None] # pyre-ignore
def adjust_camera_to_image_scale_(
camera: PerspectiveCameras,
original_size_wh: torch.Tensor,
new_size_wh: torch.LongTensor,
) -> PerspectiveCameras:
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
camera.focal_length[0],
camera.principal_point[0], # pyre-ignore
original_size_wh,
)
# now scale and convert from pixels to NDC
image_size_wh_output = new_size_wh.float()
scale = (image_size_wh_output / original_size_wh).min(dim=-1, keepdim=True).values
focal_length_px_scaled = focal_length_px * scale
principal_point_px_scaled = principal_point_px * scale
focal_length_scaled, principal_point_scaled = _convert_pixels_to_ndc(
focal_length_px_scaled,
principal_point_px_scaled,
image_size_wh_output,
)
camera.focal_length = focal_length_scaled[None]
camera.principal_point = principal_point_scaled[None] # pyre-ignore
# NOTE this cache is per-worker; they are implemented as processes.
# each batch is loaded and collated by a single worker;
# since sequences tend to co-occur within batches, this is useful.
@functools.lru_cache(maxsize=256)
def load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds:
pcl = IO().load_pointcloud(pcl_path)
if max_points > 0:
pcl = pcl.subsample(max_points)
return pcl
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
from pytorch3d.structures import Pointclouds from pytorch3d.structures import Pointclouds
from .dataset_base import FrameData from .frame_data import FrameData
from .json_index_dataset import JsonIndexDataset from .json_index_dataset import JsonIndexDataset
......
...@@ -14,7 +14,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Un ...@@ -14,7 +14,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Un
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from pytorch3d.implicitron.dataset.dataset_base import FrameData from pytorch3d.implicitron.dataset.frame_data import FrameData
from pytorch3d.implicitron.dataset.utils import is_train_frame from pytorch3d.implicitron.dataset.utils import is_train_frame
from pytorch3d.implicitron.models.base_model import ImplicitronRender from pytorch3d.implicitron.models.base_model import ImplicitronRender
from pytorch3d.implicitron.tools import vis_utils from pytorch3d.implicitron.tools import vis_utils
......
...@@ -17,7 +17,8 @@ from pytorch3d.implicitron.dataset.data_loader_map_provider import ( ...@@ -17,7 +17,8 @@ from pytorch3d.implicitron.dataset.data_loader_map_provider import (
DoublePoolBatchSampler, DoublePoolBatchSampler,
) )
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
from pytorch3d.implicitron.dataset.frame_data import FrameData
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
......
...@@ -9,11 +9,19 @@ import unittest ...@@ -9,11 +9,19 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from pytorch3d.implicitron.dataset.json_index_dataset import (
_bbox_xywh_to_xyxy, from pytorch3d.implicitron.dataset.utils import (
_bbox_xyxy_to_xywh, bbox_xywh_to_xyxy,
_get_bbox_from_mask, bbox_xyxy_to_xywh,
clamp_box_to_image_bounds_and_round,
crop_around_box,
get_1d_bounds,
get_bbox_from_mask,
get_clamp_bbox,
rescale_bbox,
resize_image,
) )
from tests.common_testing import TestCaseMixin from tests.common_testing import TestCaseMixin
...@@ -31,9 +39,9 @@ class TestBBox(TestCaseMixin, unittest.TestCase): ...@@ -31,9 +39,9 @@ class TestBBox(TestCaseMixin, unittest.TestCase):
] ]
) )
for bbox_xywh in bbox_xywh_list: for bbox_xywh in bbox_xywh_list:
bbox_xyxy = _bbox_xywh_to_xyxy(bbox_xywh) bbox_xyxy = bbox_xywh_to_xyxy(bbox_xywh)
bbox_xywh_ = _bbox_xyxy_to_xywh(bbox_xyxy) bbox_xywh_ = bbox_xyxy_to_xywh(bbox_xyxy)
bbox_xyxy_ = _bbox_xywh_to_xyxy(bbox_xywh_) bbox_xyxy_ = bbox_xywh_to_xyxy(bbox_xywh_)
self.assertClose(bbox_xywh_, bbox_xywh) self.assertClose(bbox_xywh_, bbox_xywh)
self.assertClose(bbox_xyxy, bbox_xyxy_) self.assertClose(bbox_xyxy, bbox_xyxy_)
...@@ -47,8 +55,8 @@ class TestBBox(TestCaseMixin, unittest.TestCase): ...@@ -47,8 +55,8 @@ class TestBBox(TestCaseMixin, unittest.TestCase):
] ]
) )
for bbox_xywh, bbox_xyxy_expected in bbox_xywh_to_xyxy_expected: for bbox_xywh, bbox_xyxy_expected in bbox_xywh_to_xyxy_expected:
self.assertClose(_bbox_xywh_to_xyxy(bbox_xywh), bbox_xyxy_expected) self.assertClose(bbox_xywh_to_xyxy(bbox_xywh), bbox_xyxy_expected)
self.assertClose(_bbox_xyxy_to_xywh(bbox_xyxy_expected), bbox_xywh) self.assertClose(bbox_xyxy_to_xywh(bbox_xyxy_expected), bbox_xywh)
clamp_amnt = 3 clamp_amnt = 3
bbox_xywh_to_xyxy_clamped_expected = torch.LongTensor( bbox_xywh_to_xyxy_clamped_expected = torch.LongTensor(
...@@ -61,7 +69,7 @@ class TestBBox(TestCaseMixin, unittest.TestCase): ...@@ -61,7 +69,7 @@ class TestBBox(TestCaseMixin, unittest.TestCase):
) )
for bbox_xywh, bbox_xyxy_expected in bbox_xywh_to_xyxy_clamped_expected: for bbox_xywh, bbox_xyxy_expected in bbox_xywh_to_xyxy_clamped_expected:
self.assertClose( self.assertClose(
_bbox_xywh_to_xyxy(bbox_xywh, clamp_size=clamp_amnt), bbox_xywh_to_xyxy(bbox_xywh, clamp_size=clamp_amnt),
bbox_xyxy_expected, bbox_xyxy_expected,
) )
...@@ -74,5 +82,61 @@ class TestBBox(TestCaseMixin, unittest.TestCase): ...@@ -74,5 +82,61 @@ class TestBBox(TestCaseMixin, unittest.TestCase):
] ]
).astype(np.float32) ).astype(np.float32)
expected_bbox_xywh = [2, 1, 2, 1] expected_bbox_xywh = [2, 1, 2, 1]
bbox_xywh = _get_bbox_from_mask(mask, 0.5) bbox_xywh = get_bbox_from_mask(mask, 0.5)
self.assertClose(bbox_xywh, expected_bbox_xywh) self.assertClose(bbox_xywh, expected_bbox_xywh)
def test_crop_around_box(self):
bbox = torch.LongTensor([0, 1, 2, 3]) # (x_min, y_min, x_max, y_max)
image = torch.LongTensor(
[
[0, 0, 10, 20],
[10, 20, 5, 1],
[10, 20, 1, 1],
[5, 4, 0, 1],
]
)
cropped = crop_around_box(image, bbox)
self.assertClose(cropped, image[1:3, 0:2])
def test_clamp_box_to_image_bounds_and_round(self):
bbox = torch.LongTensor([0, 1, 10, 12])
image_size = (5, 6)
expected_clamped_bbox = torch.LongTensor([0, 1, image_size[1], image_size[0]])
clamped_bbox = clamp_box_to_image_bounds_and_round(bbox, image_size)
self.assertClose(clamped_bbox, expected_clamped_bbox)
def test_get_clamp_bbox(self):
bbox_xywh = torch.LongTensor([1, 1, 4, 5])
clamped_bbox_xyxy = get_clamp_bbox(bbox_xywh, box_crop_context=2)
# size multiplied by 2 and added coordinates
self.assertClose(clamped_bbox_xyxy, torch.Tensor([-3, -4, 9, 11]))
def test_rescale_bbox(self):
bbox = torch.Tensor([0.0, 1.0, 3.0, 4.0])
original_resolution = (4, 4)
new_resolution = (8, 8) # twice bigger
rescaled_bbox = rescale_bbox(bbox, original_resolution, new_resolution)
self.assertClose(bbox * 2, rescaled_bbox)
def test_get_1d_bounds(self):
array = [0, 1, 2]
bounds = get_1d_bounds(array)
# make nonzero 1d bounds of image
self.assertClose(bounds, [1, 3])
def test_resize_image(self):
image = np.random.rand(3, 300, 500) # rgb image 300x500
expected_shape = (150, 250)
resized_image, scale, mask_crop = resize_image(
image, image_height=expected_shape[0], image_width=expected_shape[1]
)
original_shape = image.shape[-2:]
expected_scale = min(
expected_shape[0] / original_shape[0], expected_shape[1] / original_shape[1]
)
self.assertEqual(scale, expected_scale)
self.assertEqual(resized_image.shape[-2:], expected_shape)
self.assertEqual(mask_crop.shape[-2:], expected_shape)
...@@ -8,7 +8,7 @@ import os ...@@ -8,7 +8,7 @@ import os
import unittest import unittest
import torch import torch
from pytorch3d.implicitron.dataset.dataset_base import FrameData from pytorch3d.implicitron.dataset.frame_data import FrameData
from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import ( from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import (
RenderedMeshDatasetMapProvider, RenderedMeshDatasetMapProvider,
) )
......
...@@ -13,8 +13,10 @@ import os ...@@ -13,8 +13,10 @@ import os
import unittest import unittest
import lpips import lpips
import numpy as np
import torch import torch
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.frame_data import FrameData
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
...@@ -268,7 +270,7 @@ class TestEvaluation(unittest.TestCase): ...@@ -268,7 +270,7 @@ class TestEvaluation(unittest.TestCase):
for metric in lower_better: for metric in lower_better:
m_better = eval_result[metric] m_better = eval_result[metric]
m_worse = eval_result_bad[metric] m_worse = eval_result_bad[metric]
if m_better != m_better or m_worse != m_worse: if np.isnan(m_better) or np.isnan(m_worse):
continue # metric is missing, i.e. NaN continue # metric is missing, i.e. NaN
_assert = ( _assert = (
self.assertLessEqual self.assertLessEqual
......
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import gzip
import os
import unittest
from typing import List
import numpy as np
import torch
from pytorch3d.implicitron.dataset import types
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.frame_data import FrameDataBuilder
from pytorch3d.implicitron.dataset.utils import (
load_16big_png_depth,
load_1bit_png_mask,
load_depth,
load_depth_mask,
load_image,
load_mask,
safe_as_tensor,
)
from pytorch3d.implicitron.tools.config import get_default_args
from pytorch3d.renderer.cameras import PerspectiveCameras
from tests.common_testing import TestCaseMixin
from tests.implicitron.common_resources import get_skateboard_data
class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
def setUp(self):
torch.manual_seed(42)
category = "skateboard"
stack = contextlib.ExitStack()
self.dataset_root, self.path_manager = stack.enter_context(
get_skateboard_data()
)
self.addCleanup(stack.close)
self.image_height = 768
self.image_width = 512
self.frame_data_builder = FrameDataBuilder(
image_height=self.image_height,
image_width=self.image_width,
dataset_root=self.dataset_root,
path_manager=self.path_manager,
)
# loading single frame annotation of dataset (see JsonIndexDataset._load_frames())
frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
local_file = self.path_manager.get_local_path(frame_file)
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
frame_annots_list = types.load_dataclass(
zipfile, List[types.FrameAnnotation]
)
self.frame_annotation = frame_annots_list[0]
sequence_annotations_file = os.path.join(
self.dataset_root, category, "sequence_annotations.jgz"
)
local_file = self.path_manager.get_local_path(sequence_annotations_file)
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
seq_annots_list = types.load_dataclass(
zipfile, List[types.SequenceAnnotation]
)
seq_annots = {entry.sequence_name: entry for entry in seq_annots_list}
self.seq_annotation = seq_annots[self.frame_annotation.sequence_name]
point_cloud = self.seq_annotation.point_cloud
self.frame_data = FrameData(
frame_number=safe_as_tensor(self.frame_annotation.frame_number, torch.long),
frame_timestamp=safe_as_tensor(
self.frame_annotation.frame_timestamp, torch.float
),
sequence_name=self.frame_annotation.sequence_name,
sequence_category=self.seq_annotation.category,
camera_quality_score=safe_as_tensor(
self.seq_annotation.viewpoint_quality_score, torch.float
),
point_cloud_quality_score=safe_as_tensor(
point_cloud.quality_score, torch.float
)
if point_cloud is not None
else None,
)
def test_frame_data_builder_args(self):
# test that FrameDataBuilder works with get_default_args
get_default_args(FrameDataBuilder)
def test_fix_point_cloud_path(self):
"""Some files in Co3Dv2 have an accidental absolute path stored."""
original_path = "some_file_path"
modified_path = self.frame_data_builder._fix_point_cloud_path(original_path)
self.assertIn(original_path, modified_path)
self.assertIn(self.frame_data_builder.dataset_root, modified_path)
def test_load_and_adjust_frame_data(self):
self.frame_data.image_size_hw = safe_as_tensor(
self.frame_annotation.image.size, torch.long
)
self.frame_data.effective_image_size_hw = self.frame_data.image_size_hw
(
self.frame_data.fg_probability,
self.frame_data.mask_path,
self.frame_data.bbox_xywh,
) = self.frame_data_builder._load_fg_probability(self.frame_annotation)
self.assertIsNotNone(self.frame_data.mask_path)
self.assertTrue(torch.is_tensor(self.frame_data.fg_probability))
self.assertTrue(torch.is_tensor(self.frame_data.bbox_xywh))
# assert bboxes shape
self.assertEqual(self.frame_data.bbox_xywh.shape, torch.Size([4]))
(
self.frame_data.image_rgb,
self.frame_data.image_path,
) = self.frame_data_builder._load_images(
self.frame_annotation, self.frame_data.fg_probability
)
self.assertEqual(type(self.frame_data.image_rgb), np.ndarray)
self.assertIsNotNone(self.frame_data.image_path)
(
self.frame_data.depth_map,
depth_path,
self.frame_data.depth_mask,
) = self.frame_data_builder._load_mask_depth(
self.frame_annotation,
self.frame_data.fg_probability,
)
self.assertTrue(torch.is_tensor(self.frame_data.depth_map))
self.assertIsNotNone(depth_path)
self.assertTrue(torch.is_tensor(self.frame_data.depth_mask))
new_size = (self.image_height, self.image_width)
if self.frame_data_builder.box_crop:
self.frame_data.crop_by_metadata_bbox_(
self.frame_data_builder.box_crop_context,
)
# assert image and mask shapes after resize
self.frame_data.resize_frame_(
new_size_hw=torch.tensor(new_size, dtype=torch.long),
)
self.assertEqual(
self.frame_data.mask_crop.shape,
torch.Size([1, self.image_height, self.image_width]),
)
self.assertEqual(
self.frame_data.image_rgb.shape,
torch.Size([3, self.image_height, self.image_width]),
)
self.assertEqual(
self.frame_data.mask_crop.shape,
torch.Size([1, self.image_height, self.image_width]),
)
self.assertEqual(
self.frame_data.fg_probability.shape,
torch.Size([1, self.image_height, self.image_width]),
)
self.assertEqual(
self.frame_data.depth_map.shape,
torch.Size([1, self.image_height, self.image_width]),
)
self.assertEqual(
self.frame_data.depth_mask.shape,
torch.Size([1, self.image_height, self.image_width]),
)
self.frame_data.camera = self.frame_data_builder._get_pytorch3d_camera(
self.frame_annotation,
)
self.assertEqual(type(self.frame_data.camera), PerspectiveCameras)
def test_load_image(self):
path = os.path.join(self.dataset_root, self.frame_annotation.image.path)
local_path = self.path_manager.get_local_path(path)
image = load_image(local_path)
self.assertEqual(image.dtype, np.float32)
self.assertLessEqual(np.max(image), 1.0)
self.assertGreaterEqual(np.min(image), 0.0)
def test_load_mask(self):
path = os.path.join(self.dataset_root, self.frame_annotation.mask.path)
mask = load_mask(path)
self.assertEqual(mask.dtype, np.float32)
self.assertLessEqual(np.max(mask), 1.0)
self.assertGreaterEqual(np.min(mask), 0.0)
def test_load_depth(self):
path = os.path.join(self.dataset_root, self.frame_annotation.depth.path)
depth_map = load_depth(path, self.frame_annotation.depth.scale_adjustment)
self.assertEqual(depth_map.dtype, np.float32)
self.assertEqual(len(depth_map.shape), 3)
def test_load_16big_png_depth(self):
path = os.path.join(self.dataset_root, self.frame_annotation.depth.path)
depth_map = load_16big_png_depth(path)
self.assertEqual(depth_map.dtype, np.float32)
self.assertEqual(len(depth_map.shape), 2)
def test_load_1bit_png_mask(self):
mask_path = os.path.join(
self.dataset_root, self.frame_annotation.depth.mask_path
)
mask = load_1bit_png_mask(mask_path)
self.assertEqual(mask.dtype, np.float32)
self.assertEqual(len(mask.shape), 2)
def test_load_depth_mask(self):
mask_path = os.path.join(
self.dataset_root, self.frame_annotation.depth.mask_path
)
mask = load_depth_mask(mask_path)
self.assertEqual(mask.dtype, np.float32)
self.assertEqual(len(mask.shape), 3)
...@@ -17,7 +17,7 @@ import numpy as np ...@@ -17,7 +17,7 @@ import numpy as np
import torch import torch
import torchvision import torchvision
from PIL import Image from PIL import Image
from pytorch3d.implicitron.dataset.dataset_base import FrameData from pytorch3d.implicitron.dataset.frame_data import FrameData
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import ( from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
JsonIndexDatasetMapProviderV2, JsonIndexDatasetMapProviderV2,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment