Unverified Commit cdd2142d authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by GitHub
Browse files

implicitron v0 (#1133)


Co-authored-by: default avatarJeremy Francis Reizenstein <bottler@users.noreply.github.com>
parent 0e377c68
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import textwrap
import unittest
from dataclasses import dataclass, field, is_dataclass
from enum import Enum
from typing import List, Optional, Tuple
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
from pytorch3d.implicitron.tools.config import (
Configurable,
ReplaceableBase,
_is_actually_dataclass,
_Registry,
expand_args_fields,
get_default_args,
get_default_args_field,
registry,
remove_unused_components,
run_auto_creation,
)
@dataclass
class Animal(ReplaceableBase):
pass
class Fruit(ReplaceableBase):
pass
@registry.register
class Banana(Fruit):
pips: int
spots: int
bananame: str
@registry.register
class Pear(Fruit):
n_pips: int = 13
class Pineapple(Fruit):
pass
@registry.register
class Orange(Fruit):
pass
@registry.register
class Kiwi(Fruit):
pass
@registry.register
class LargePear(Pear):
pass
class MainTest(Configurable):
the_fruit: Fruit
n_ids: int
n_reps: int = 8
the_second_fruit: Fruit
def create_the_second_fruit(self):
expand_args_fields(Pineapple)
self.the_second_fruit = Pineapple()
def __post_init__(self):
run_auto_creation(self)
class TestConfig(unittest.TestCase):
def test_is_actually_dataclass(self):
@dataclass
class A:
pass
self.assertTrue(_is_actually_dataclass(A))
self.assertTrue(is_dataclass(A))
class B(A):
a: int
self.assertFalse(_is_actually_dataclass(B))
self.assertTrue(is_dataclass(B))
def test_simple_replacement(self):
struct = get_default_args(MainTest)
struct.n_ids = 9780
struct.the_fruit_Pear_args.n_pips = 3
struct.the_fruit_class_type = "Pear"
struct.the_second_fruit_class_type = "Pear"
main = MainTest(**struct)
self.assertIsInstance(main.the_fruit, Pear)
self.assertEqual(main.n_reps, 8)
self.assertEqual(main.n_ids, 9780)
self.assertEqual(main.the_fruit.n_pips, 3)
self.assertIsInstance(main.the_second_fruit, Pineapple)
struct2 = get_default_args(MainTest)
self.assertEqual(struct2.the_fruit_Pear_args.n_pips, 13)
self.assertEqual(
MainTest._creation_functions,
("create_the_fruit", "create_the_second_fruit"),
)
def test_detect_bases(self):
# testing the _base_class_from_class function
self.assertIsNone(_Registry._base_class_from_class(ReplaceableBase))
self.assertIsNone(_Registry._base_class_from_class(MainTest))
self.assertIs(_Registry._base_class_from_class(Fruit), Fruit)
self.assertIs(_Registry._base_class_from_class(Pear), Fruit)
class PricklyPear(Pear):
pass
self.assertIs(_Registry._base_class_from_class(PricklyPear), Fruit)
def test_registry_entries(self):
self.assertIs(registry.get(Fruit, "Banana"), Banana)
with self.assertRaisesRegex(ValueError, "Banana has not been registered."):
registry.get(Animal, "Banana")
with self.assertRaisesRegex(ValueError, "PricklyPear has not been registered."):
registry.get(Fruit, "PricklyPear")
self.assertIs(registry.get(Pear, "Pear"), Pear)
self.assertIs(registry.get(Pear, "LargePear"), LargePear)
with self.assertRaisesRegex(ValueError, "Banana resolves to"):
registry.get(Pear, "Banana")
all_fruit = set(registry.get_all(Fruit))
self.assertIn(Banana, all_fruit)
self.assertIn(Pear, all_fruit)
self.assertIn(LargePear, all_fruit)
self.assertEqual(set(registry.get_all(Pear)), {LargePear})
@registry.register
class Apple(Fruit):
pass
@registry.register
class CrabApple(Apple):
pass
self.assertEqual(set(registry.get_all(Apple)), {CrabApple})
self.assertIs(registry.get(Fruit, "CrabApple"), CrabApple)
with self.assertRaisesRegex(ValueError, "Cannot tell what it is."):
@registry.register
class NotAFruit:
pass
def test_recursion(self):
class Shape(ReplaceableBase):
pass
@registry.register
class Triangle(Shape):
a: float = 5.0
@registry.register
class Square(Shape):
a: float = 3.0
@registry.register
class LargeShape(Shape):
inner: Shape
def __post_init__(self):
run_auto_creation(self)
class ShapeContainer(Configurable):
shape: Shape
container = ShapeContainer(**get_default_args(ShapeContainer))
# This is because ShapeContainer is missing __post_init__
with self.assertRaises(AttributeError):
container.shape
class ShapeContainer2(Configurable):
x: Shape
x_class_type: str = "LargeShape"
def __post_init__(self):
self.x_LargeShape_args.inner_class_type = "Triangle"
run_auto_creation(self)
container2_args = get_default_args(ShapeContainer2)
container2_args.x_LargeShape_args.inner_Triangle_args.a += 10
self.assertIn("inner_Square_args", container2_args.x_LargeShape_args)
# We do not perform expansion that would result in an infinite recursion,
# so this member is not present.
self.assertNotIn("inner_LargeShape_args", container2_args.x_LargeShape_args)
container2_args.x_LargeShape_args.inner_Square_args.a += 100
container2 = ShapeContainer2(**container2_args)
self.assertIsInstance(container2.x, LargeShape)
self.assertIsInstance(container2.x.inner, Triangle)
self.assertEqual(container2.x.inner.a, 15.0)
def test_simpleclass_member(self):
# Members which are not dataclasses are
# tolerated. But it would be nice to be able to
# configure them.
class Foo:
def __init__(self, a=1, b=2):
self.a, self.b = a, b
@dataclass()
class Bar:
aa: int = 9
bb: int = 9
class Container(Configurable):
bar: Bar = Bar()
# TODO make this work?
# foo: Foo = Foo()
fruit: Fruit
fruit_class_type: str = "Orange"
def __post_init__(self):
run_auto_creation(self)
self.assertEqual(get_default_args(Foo), {"a": 1, "b": 2})
container_args = get_default_args(Container)
container = Container(**container_args)
self.assertIsInstance(container.fruit, Orange)
# self.assertIsInstance(container.bar, Bar)
container_defaulted = Container()
container_defaulted.fruit_Pear_args.n_pips += 4
container_args2 = get_default_args(Container)
container = Container(**container_args2)
self.assertEqual(container.fruit_Pear_args.n_pips, 13)
def test_inheritance(self):
class FruitBowl(ReplaceableBase):
main_fruit: Fruit
main_fruit_class_type: str = "Orange"
def __post_init__(self):
raise ValueError("This doesn't get called")
class LargeFruitBowl(FruitBowl):
extra_fruit: Fruit
extra_fruit_class_type: str = "Kiwi"
def __post_init__(self):
run_auto_creation(self)
large_args = get_default_args(LargeFruitBowl)
self.assertNotIn("extra_fruit", large_args)
self.assertNotIn("main_fruit", large_args)
large = LargeFruitBowl(**large_args)
self.assertIsInstance(large.main_fruit, Orange)
self.assertIsInstance(large.extra_fruit, Kiwi)
def test_inheritance2(self):
# This is a case where a class could contain an instance
# of a subclass, which is ignored.
class Parent(ReplaceableBase):
pass
class Main(Configurable):
parent: Parent
# Note - no __post__init__
@registry.register
class Derived(Parent, Main):
pass
args = get_default_args(Main)
# Derived has been ignored in processing Main.
self.assertCountEqual(args.keys(), ["parent_class_type"])
main = Main(**args)
with self.assertRaisesRegex(ValueError, "UNDEFAULTED has not been registered."):
run_auto_creation(main)
main.parent_class_type = "Derived"
# Illustrates that a dict works fine instead of a DictConfig.
main.parent_Derived_args = {}
with self.assertRaises(AttributeError):
main.parent
run_auto_creation(main)
self.assertIsInstance(main.parent, Derived)
def test_redefine(self):
class FruitBowl(ReplaceableBase):
main_fruit: Fruit
main_fruit_class_type: str = "Grape"
def __post_init__(self):
run_auto_creation(self)
@registry.register
@dataclass
class Grape(Fruit):
large: bool = False
def get_color(self):
return "red"
def __post_init__(self):
raise ValueError("This doesn't get called")
bowl_args = get_default_args(FruitBowl)
@registry.register
@dataclass
class Grape(Fruit): # noqa: F811
large: bool = True
def get_color(self):
return "green"
with self.assertWarnsRegex(
UserWarning, "New implementation of Grape is being chosen."
):
bowl = FruitBowl(**bowl_args)
self.assertIsInstance(bowl.main_fruit, Grape)
# Redefining the same class won't help with defaults because encoded in args
self.assertEqual(bowl.main_fruit.large, False)
# But the override worked.
self.assertEqual(bowl.main_fruit.get_color(), "green")
# 2. Try redefining without the dataclass modifier
# This relies on the fact that default creation processes the class.
# (otherwise incomprehensible messages)
@registry.register
class Grape(Fruit): # noqa: F811
large: bool = True
with self.assertWarnsRegex(
UserWarning, "New implementation of Grape is being chosen."
):
bowl = FruitBowl(**bowl_args)
# 3. Adding a new class doesn't get picked up, because the first
# get_default_args call has frozen FruitBowl. This is intrinsic to
# the way dataclass and expand_args_fields work in-place but
# expand_args_fields is not pure - it depends on the registry.
@registry.register
class Fig(Fruit):
pass
bowl_args2 = get_default_args(FruitBowl)
self.assertIn("main_fruit_Grape_args", bowl_args2)
self.assertNotIn("main_fruit_Fig_args", bowl_args2)
# TODO Is it possible to make this work?
# bowl_args2["main_fruit_Fig_args"] = get_default_args(Fig)
# bowl_args2.main_fruit_class_type = "Fig"
# bowl2 = FruitBowl(**bowl_args2) <= unexpected argument
# Note that it is possible to use Fig if you can set
# bowl2.main_fruit_Fig_args explicitly (not in bowl_args2)
# before run_auto_creation happens. See test_inheritance2
# for an example.
def test_no_replacement(self):
# Test of Configurables without ReplaceableBase
class A(Configurable):
n: int = 9
class B(Configurable):
a: A
def __post_init__(self):
run_auto_creation(self)
class C(Configurable):
b: B
def __post_init__(self):
run_auto_creation(self)
c_args = get_default_args(C)
c = C(**c_args)
self.assertIsInstance(c.b.a, A)
self.assertEqual(c.b.a.n, 9)
def test_doc(self):
# The case in the docstring.
class A(ReplaceableBase):
k: int = 1
@registry.register
class A1(A):
m: int = 3
@registry.register
class A2(A):
n: str = "2"
class B(Configurable):
a: A
a_class_type: str = "A2"
def __post_init__(self):
run_auto_creation(self)
b_args = get_default_args(B)
self.assertNotIn("a", b_args)
b = B(**b_args)
self.assertEqual(b.a.n, "2")
def test_raw_types(self):
@dataclass
class MyDataclass:
int_field: int = 0
none_field: Optional[int] = None
float_field: float = 9.3
bool_field: bool = True
tuple_field: tuple = (3, True, "j")
class SimpleClass:
def __init__(self, tuple_member_=(3, 4)):
self.tuple_member = tuple_member_
def get_tuple(self):
return self.tuple_member
def f(*, a: int = 3, b: str = "kj"):
self.assertEqual(a, 3)
self.assertEqual(b, "kj")
class C(Configurable):
simple: DictConfig = get_default_args_field(SimpleClass)
# simple2: SimpleClass2 = SimpleClass2()
mydata: DictConfig = get_default_args_field(MyDataclass)
a_tuple: Tuple[float] = (4.0, 3.0)
f_args: DictConfig = get_default_args_field(f)
args = get_default_args(C)
c = C(**args)
self.assertCountEqual(args.keys(), ["simple", "mydata", "a_tuple", "f_args"])
mydata = MyDataclass(**c.mydata)
simple = SimpleClass(**c.simple)
# OmegaConf converts tuples to ListConfigs (which act like lists).
self.assertEqual(simple.get_tuple(), [3, 4])
self.assertTrue(isinstance(simple.get_tuple(), ListConfig))
self.assertEqual(c.a_tuple, [4.0, 3.0])
self.assertTrue(isinstance(c.a_tuple, ListConfig))
self.assertEqual(mydata.tuple_field, (3, True, "j"))
self.assertTrue(isinstance(mydata.tuple_field, ListConfig))
f(**c.f_args)
def test_irrelevant_bases(self):
class NotADataclass:
# Like torch.nn.Module, this class contains annotations
# but is not designed to be dataclass'd.
# This test ensures that such classes, when inherited fron,
# are not accidentally expand_args_fields.
a: int = 9
b: int
class LeftConfigured(Configurable, NotADataclass):
left: int = 1
class RightConfigured(NotADataclass, Configurable):
right: int = 2
class Outer(Configurable):
left: LeftConfigured
right: RightConfigured
def __post_init__(self):
run_auto_creation(self)
outer = Outer(**get_default_args(Outer))
self.assertEqual(outer.left.left, 1)
self.assertEqual(outer.right.right, 2)
with self.assertRaisesRegex(TypeError, "non-default argument"):
dataclass(NotADataclass)
def test_unprocessed(self):
# behavior of Configurable classes which need processing in __new__,
class Unprocessed(Configurable):
a: int = 9
class UnprocessedReplaceable(ReplaceableBase):
a: int = 1
with self.assertWarnsRegex(UserWarning, "must be processed"):
Unprocessed()
with self.assertWarnsRegex(UserWarning, "must be processed"):
UnprocessedReplaceable()
def test_enum(self):
# Test that enum values are kept, i.e. that OmegaConf's runtime checks
# are in use.
class A(Enum):
B1 = "b1"
B2 = "b2"
class C(Configurable):
a: A = A.B1
base = get_default_args(C)
replaced = OmegaConf.merge(base, {"a": "B2"})
self.assertEqual(replaced.a, A.B2)
with self.assertRaises(ValidationError):
# You can't use a value which is not one of the
# choices, even if it is the str representation
# of one of the choices.
OmegaConf.merge(base, {"a": "b2"})
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base)))
self.assertEqual(remerged.a, A.B1)
def test_remove_unused_components(self):
struct = get_default_args(MainTest)
struct.n_ids = 32
struct.the_fruit_class_type = "Pear"
struct.the_second_fruit_class_type = "Banana"
remove_unused_components(struct)
expected_keys = [
"n_ids",
"n_reps",
"the_fruit_Pear_args",
"the_fruit_class_type",
"the_second_fruit_Banana_args",
"the_second_fruit_class_type",
]
expected_yaml = textwrap.dedent(
"""\
n_ids: 32
n_reps: 8
the_fruit_class_type: Pear
the_fruit_Pear_args:
n_pips: 13
the_second_fruit_class_type: Banana
the_second_fruit_Banana_args:
pips: ???
spots: ???
bananame: ???
"""
)
self.assertEqual(sorted(struct.keys()), expected_keys)
# Check that struct is what we expect
expected = OmegaConf.create(expected_yaml)
self.assertEqual(struct, expected)
# Check that we get what we expect when writing to yaml.
self.assertEqual(OmegaConf.to_yaml(struct, sort_keys=False), expected_yaml)
main = MainTest(**struct)
instance_data = OmegaConf.structured(main)
remove_unused_components(instance_data)
self.assertEqual(sorted(instance_data.keys()), expected_keys)
self.assertEqual(instance_data, expected)
@dataclass(eq=False)
class MockDataclass:
field_no_default: int
field_primitive_type: int = 42
field_reference_type: List[int] = field(default_factory=lambda: [])
class MockClassWithInit: # noqa: B903
def __init__(
self,
field_no_default: int,
field_primitive_type: int = 42,
field_reference_type: List[int] = [], # noqa: B006
):
self.field_no_default = field_no_default
self.field_primitive_type = field_primitive_type
self.field_reference_type = field_reference_type
class TestRawClasses(unittest.TestCase):
def test_get_default_args(self):
for cls in [MockDataclass, MockClassWithInit]:
dataclass_defaults = get_default_args(cls)
inst = cls(field_no_default=0)
dataclass_defaults.field_no_default = 0
for name, val in dataclass_defaults.items():
self.assertTrue(hasattr(inst, name))
self.assertEqual(val, getattr(inst, name))
def test_get_default_args_readonly(self):
for cls in [MockDataclass, MockClassWithInit]:
dataclass_defaults = get_default_args(cls)
dataclass_defaults["field_reference_type"].append(13)
inst = cls(field_no_default=0)
self.assertEqual(inst.field_reference_type, [])
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from omegaconf import OmegaConf
from pytorch3d.implicitron.models.autodecoder import Autodecoder
from pytorch3d.implicitron.models.base import GenericModel
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
IdrFeatureField,
)
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import (
NeuralRadianceFieldImplicitFunction,
)
from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer
from pytorch3d.implicitron.models.renderer.multipass_ea import (
MultiPassEmissionAbsorptionRenderer,
)
from pytorch3d.implicitron.models.view_pooling.feature_aggregation import (
AngleWeightedIdentityFeatureAggregator,
AngleWeightedReductionFeatureAggregator,
)
from pytorch3d.implicitron.tools.config import (
get_default_args,
remove_unused_components,
)
if os.environ.get("FB_TEST", False):
from common_testing import get_tests_dir
else:
from tests.common_testing import get_tests_dir
DATA_DIR = get_tests_dir() / "implicitron/data"
DEBUG: bool = False
# Tests the use of the config system in implicitron
class TestGenericModel(unittest.TestCase):
def setUp(self):
self.maxDiff = None
def test_create_gm(self):
args = get_default_args(GenericModel)
gm = GenericModel(**args)
self.assertIsInstance(gm.renderer, MultiPassEmissionAbsorptionRenderer)
self.assertIsInstance(
gm.feature_aggregator, AngleWeightedReductionFeatureAggregator
)
self.assertIsInstance(
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
self.assertFalse(hasattr(gm, "implicit_function"))
self.assertFalse(hasattr(gm, "image_feature_extractor"))
def test_create_gm_overrides(self):
args = get_default_args(GenericModel)
args.feature_aggregator_class_type = "AngleWeightedIdentityFeatureAggregator"
args.implicit_function_class_type = "IdrFeatureField"
args.renderer_class_type = "LSTMRenderer"
gm = GenericModel(**args)
self.assertIsInstance(gm.renderer, LSTMRenderer)
self.assertIsInstance(
gm.feature_aggregator, AngleWeightedIdentityFeatureAggregator
)
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
self.assertFalse(hasattr(gm, "implicit_function"))
instance_args = OmegaConf.structured(gm)
remove_unused_components(instance_args)
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
if DEBUG:
(DATA_DIR / "overrides.yaml_").write_text(yaml)
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import copy
import os
import unittest
import torch
import torchvision
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud
from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d
from pytorch3d.vis.plotly_vis import plot_scene
from visdom import Visdom
if os.environ.get("FB_TEST", False):
from .common_resources import get_skateboard_data
else:
from common_resources import get_skateboard_data
class TestDatasetVisualize(unittest.TestCase):
def setUp(self):
if os.environ.get("INSIDE_RE_WORKER") is not None:
raise unittest.SkipTest("Visdom not available")
category = "skateboard"
stack = contextlib.ExitStack()
dataset_root, path_manager = stack.enter_context(get_skateboard_data())
self.addCleanup(stack.close)
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
self.image_size = 256
self.datasets = {
"simple": ImplicitronDataset(
frame_annotations_file=frame_file,
sequence_annotations_file=sequence_file,
dataset_root=dataset_root,
image_height=self.image_size,
image_width=self.image_size,
box_crop=True,
load_point_clouds=True,
path_manager=path_manager,
),
"nonsquare": ImplicitronDataset(
frame_annotations_file=frame_file,
sequence_annotations_file=sequence_file,
dataset_root=dataset_root,
image_height=self.image_size,
image_width=self.image_size // 2,
box_crop=True,
load_point_clouds=True,
path_manager=path_manager,
),
"nocrop": ImplicitronDataset(
frame_annotations_file=frame_file,
sequence_annotations_file=sequence_file,
dataset_root=dataset_root,
image_height=self.image_size,
image_width=self.image_size // 2,
box_crop=False,
load_point_clouds=True,
path_manager=path_manager,
),
}
self.datasets.update(
{
k + "_newndc": _change_annotations_to_new_ndc(dataset)
for k, dataset in self.datasets.items()
}
)
self.visdom = Visdom()
if not self.visdom.check_connection():
print("Visdom server not running! Disabling visdom visualizations.")
self.visdom = None
def _render_one_pointcloud(self, point_cloud, cameras, render_size):
(_image_render, _, _) = render_point_cloud_pytorch3d(
cameras,
point_cloud,
render_size=render_size,
point_radius=1e-2,
topk=10,
bg_color=0.0,
)
return _image_render.clamp(0.0, 1.0)
def test_one(self):
"""Test dataset visualization."""
for max_frames in (16, -1):
for load_dataset_point_cloud in (True, False):
for dataset_key in self.datasets:
self._gen_and_render_pointcloud(
max_frames, load_dataset_point_cloud, dataset_key
)
def _gen_and_render_pointcloud(
self, max_frames, load_dataset_point_cloud, dataset_key
):
dataset = self.datasets[dataset_key]
# load the point cloud of the first sequence
sequence_show = list(dataset.seq_annots.keys())[0]
device = torch.device("cuda:0")
point_cloud, sequence_frame_data = get_implicitron_sequence_pointcloud(
dataset,
sequence_name=sequence_show,
mask_points=True,
max_frames=max_frames,
num_workers=10,
load_dataset_point_cloud=load_dataset_point_cloud,
)
# render on gpu
point_cloud = point_cloud.to(device)
cameras = sequence_frame_data.camera.to(device)
# render the point_cloud from the viewpoint of loaded cameras
images_render = torch.cat(
[
self._render_one_pointcloud(
point_cloud,
cameras[frame_i],
(
dataset.image_height,
dataset.image_width,
),
)
for frame_i in range(len(cameras))
]
).cpu()
images_gt_and_render = torch.cat(
[sequence_frame_data.image_rgb, images_render], dim=3
)
imfile = os.path.join(
os.path.split(os.path.abspath(__file__))[0],
"test_dataset_visualize"
+ f"_max_frames={max_frames}"
+ f"_load_pcl={load_dataset_point_cloud}.png",
)
print(f"Exporting image {imfile}.")
torchvision.utils.save_image(images_gt_and_render, imfile, nrow=2)
if self.visdom is not None:
test_name = f"{max_frames}_{load_dataset_point_cloud}_{dataset_key}"
self.visdom.images(
images_gt_and_render,
env="test_dataset_visualize",
win=f"pcl_renders_{test_name}",
opts={"title": f"pcl_renders_{test_name}"},
)
plotlyplot = plot_scene(
{
"scene_batch": {
"cameras": cameras,
"point_cloud": point_cloud,
}
},
camera_scale=1.0,
pointcloud_max_points=10000,
pointcloud_marker_size=1.0,
)
self.visdom.plotlyplot(
plotlyplot,
env="test_dataset_visualize",
win=f"pcl_{test_name}",
)
def _change_annotations_to_new_ndc(dataset):
dataset = copy.deepcopy(dataset)
for frame in dataset.frame_annots:
vp = frame["frame_annotation"].viewpoint
vp.intrinsics_format = "ndc_isotropic"
# this assume the focal length to be equal on x and y (ok for a test)
max_flength = max(vp.focal_length)
vp.principal_point = (
vp.principal_point[0] * max_flength / vp.focal_length[0],
vp.principal_point[1] * max_flength / vp.focal_length[1],
)
vp.focal_length = (
max_flength,
max_flength,
)
return dataset
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
import torch
from pytorch3d.implicitron.tools.eval_video_trajectory import (
generate_eval_video_cameras,
)
from pytorch3d.renderer.cameras import PerspectiveCameras, look_at_view_transform
from pytorch3d.transforms import axis_angle_to_matrix
if os.environ.get("FB_TEST", False):
from common_testing import TestCaseMixin
else:
from tests.common_testing import TestCaseMixin
class TestEvalCameras(TestCaseMixin, unittest.TestCase):
def setUp(self):
torch.manual_seed(42)
def test_circular(self):
n_train_cameras = 10
n_test_cameras = 100
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
amplitude = 0.01
R_jiggled = torch.bmm(
R, axis_angle_to_matrix(torch.rand(n_train_cameras, 3) * amplitude)
)
cameras_train = PerspectiveCameras(R=R_jiggled, T=T)
cameras_test = generate_eval_video_cameras(
cameras_train, trajectory_type="circular_lsq_fit", trajectory_scale=1.0
)
positions_test = cameras_test.get_camera_center()
center = positions_test.mean(0)
self.assertClose(center, torch.zeros(3), atol=0.1)
self.assertClose(
(positions_test - center).norm(dim=[1]),
torch.ones(n_test_cameras),
atol=0.1,
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import copy
import dataclasses
import math
import os
import unittest
import lpips
import torch
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
)
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch
from pytorch3d.implicitron.models.model_dbir import ModelDBIR
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
if os.environ.get("FB_TEST", False):
from .common_resources import get_skateboard_data, provide_lpips_vgg
else:
from common_resources import get_skateboard_data, provide_lpips_vgg
class TestEvaluation(unittest.TestCase):
def setUp(self):
# initialize evaluation dataset/dataloader
torch.manual_seed(42)
stack = contextlib.ExitStack()
dataset_root, path_manager = stack.enter_context(get_skateboard_data())
self.addCleanup(stack.close)
category = "skateboard"
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
self.image_size = 256
self.dataset = ImplicitronDataset(
frame_annotations_file=frame_file,
sequence_annotations_file=sequence_file,
dataset_root=dataset_root,
image_height=self.image_size,
image_width=self.image_size,
box_crop=True,
path_manager=path_manager,
)
self.bg_color = 0.0
# init the lpips model for eval
provide_lpips_vgg()
self.lpips_model = lpips.LPIPS(net="vgg")
def test_eval_depth(self):
"""
Check that eval_depth correctly masks errors and that, for get_best_scale=True,
the error with scaled prediction equals the error without scaling the
predicted depth. Finally, test that the error values are as expected
for prediction and gt differing by a constant offset.
"""
gt = (torch.randn(10, 1, 300, 400, device="cuda") * 5.0).clamp(0.0)
mask = (torch.rand_like(gt) > 0.5).type_as(gt)
for diff in 10 ** torch.linspace(-5, 0, 6):
for crop in (0, 5):
pred = gt + (torch.rand_like(gt) - 0.5) * 2 * diff
# scaled prediction test
mse_depth, abs_depth = eval_depth(
pred,
gt,
crop=crop,
mask=mask,
get_best_scale=True,
)
mse_depth_scale, abs_depth_scale = eval_depth(
pred * 10.0,
gt,
crop=crop,
mask=mask,
get_best_scale=True,
)
self.assertAlmostEqual(
float(mse_depth.sum()), float(mse_depth_scale.sum()), delta=1e-4
)
self.assertAlmostEqual(
float(abs_depth.sum()), float(abs_depth_scale.sum()), delta=1e-4
)
# error masking test
pred_masked_err = gt + (torch.rand_like(gt) + diff) * (1 - mask)
mse_depth_masked, abs_depth_masked = eval_depth(
pred_masked_err,
gt,
crop=crop,
mask=mask,
get_best_scale=True,
)
self.assertAlmostEqual(
float(mse_depth_masked.sum()), float(0.0), delta=1e-4
)
self.assertAlmostEqual(
float(abs_depth_masked.sum()), float(0.0), delta=1e-4
)
mse_depth_unmasked, abs_depth_unmasked = eval_depth(
pred_masked_err,
gt,
crop=crop,
mask=1 - mask,
get_best_scale=True,
)
self.assertGreater(
float(mse_depth_unmasked.sum()),
float(diff ** 2),
)
self.assertGreater(
float(abs_depth_unmasked.sum()),
float(diff),
)
# tests with constant error
pred_fix_diff = gt + diff * mask
for _mask_gt in (mask, None):
mse_depth_fix_diff, abs_depth_fix_diff = eval_depth(
pred_fix_diff,
gt,
crop=crop,
mask=_mask_gt,
get_best_scale=False,
)
if _mask_gt is not None:
expected_err_abs = diff
expected_err_mse = diff ** 2
else:
err_mask = (gt > 0.0).float() * mask
if crop > 0:
err_mask = err_mask[:, :, crop:-crop, crop:-crop]
gt_cropped = gt[:, :, crop:-crop, crop:-crop]
else:
gt_cropped = gt
gt_mass = (gt_cropped > 0.0).float().sum(dim=(1, 2, 3))
expected_err_abs = (
diff * err_mask.sum(dim=(1, 2, 3)) / (gt_mass)
)
expected_err_mse = diff * expected_err_abs
self.assertTrue(
torch.allclose(
abs_depth_fix_diff,
expected_err_abs * torch.ones_like(abs_depth_fix_diff),
atol=1e-4,
)
)
self.assertTrue(
torch.allclose(
mse_depth_fix_diff,
expected_err_mse * torch.ones_like(mse_depth_fix_diff),
atol=1e-4,
)
)
def test_psnr(self):
"""
Compare against opencv and check that the psnr is above
the minimum possible value.
"""
import cv2
im1 = torch.rand(100, 3, 256, 256).cuda()
im1_uint8 = (im1 * 255).to(torch.uint8)
im1_rounded = im1_uint8.float() / 255
for max_diff in 10 ** torch.linspace(-5, 0, 6):
im2 = im1 + (torch.rand_like(im1) - 0.5) * 2 * max_diff
im2 = im2.clamp(0.0, 1.0)
im2_uint8 = (im2 * 255).to(torch.uint8)
im2_rounded = im2_uint8.float() / 255
# check that our psnr matches the output of opencv
psnr = calc_psnr(im1_rounded, im2_rounded)
# some versions of cv2 can only take uint8 input
psnr_cv2 = cv2.PSNR(
im1_uint8.cpu().numpy(),
im2_uint8.cpu().numpy(),
)
self.assertAlmostEqual(float(psnr), float(psnr_cv2), delta=1e-4)
# check that all PSNRs are bigger than the minimum possible PSNR
max_mse = max_diff ** 2
min_psnr = 10 * math.log10(1.0 / max_mse)
for _im1, _im2 in zip(im1, im2):
_psnr = calc_psnr(_im1, _im2)
self.assertGreaterEqual(float(_psnr) + 1e-6, min_psnr)
def _one_sequence_test(
self,
seq_dataset,
n_batches=2,
min_batch_size=5,
max_batch_size=10,
):
# form a list of random batches
batch_indices = []
for _ in range(n_batches):
batch_size = torch.randint(
low=min_batch_size, high=max_batch_size, size=(1,)
)
batch_indices.append(torch.randperm(len(seq_dataset))[:batch_size])
loader = torch.utils.data.DataLoader(
seq_dataset,
# batch_size=1,
shuffle=False,
batch_sampler=batch_indices,
collate_fn=FrameData.collate,
)
model = ModelDBIR(image_size=self.image_size, bg_color=self.bg_color)
model.cuda()
self.lpips_model.cuda()
for frame_data in loader:
self.assertIsNone(frame_data.frame_type)
self.assertIsNotNone(frame_data.image_rgb)
# override the frame_type
frame_data.frame_type = [
"train_unseen",
*(["train_known"] * (len(frame_data.image_rgb) - 1)),
]
# move frame_data to gpu
frame_data = dataclass_to_cuda_(frame_data)
preds = model(**dataclasses.asdict(frame_data))
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
eval_result = eval_batch(
frame_data,
nvs_prediction,
bg_color=self.bg_color,
lpips_model=self.lpips_model,
)
# Make a terribly bad NVS prediction and check that this is worse
# than the DBIR prediction.
nvs_prediction_bad = copy.deepcopy(preds["nvs_prediction"])
nvs_prediction_bad.depth_render += (
torch.randn_like(nvs_prediction.depth_render) * 100.0
)
nvs_prediction_bad.image_render += (
torch.randn_like(nvs_prediction.image_render) * 100.0
)
nvs_prediction_bad.mask_render = (
torch.randn_like(nvs_prediction.mask_render) > 0.0
).float()
eval_result_bad = eval_batch(
frame_data,
nvs_prediction_bad,
bg_color=self.bg_color,
lpips_model=self.lpips_model,
)
lower_better = {
"psnr": False,
"psnr_fg": False,
"depth_abs_fg": True,
"iou": False,
"rgb_l1": True,
"rgb_l1_fg": True,
}
for metric in lower_better.keys():
m_better = eval_result[metric]
m_worse = eval_result_bad[metric]
if m_better != m_better or m_worse != m_worse:
continue # metric is missing, i.e. NaN
_assert = (
self.assertLessEqual
if lower_better[metric]
else self.assertGreaterEqual
)
_assert(m_better, m_worse)
def test_full_eval(self, n_sequences=5):
"""Test evaluation."""
for _, idx in list(self.dataset.seq_to_idx.items())[:n_sequences]:
seq_dataset = torch.utils.data.Subset(self.dataset, idx)
self._one_sequence_test(seq_dataset)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from pytorch3d.implicitron.models.base import GenericModel
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.renderer.cameras import PerspectiveCameras, look_at_view_transform
class TestGenericModel(unittest.TestCase):
def test_gm(self):
# Simple test of a forward pass of the default GenericModel.
device = torch.device("cuda:1")
expand_args_fields(GenericModel)
model = GenericModel()
model.to(device)
n_train_cameras = 2
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
cameras = PerspectiveCameras(R=R, T=T, device=device)
# TODO: make these default to None?
defaulted_args = {
"fg_probability": None,
"depth_map": None,
"mask_crop": None,
"sequence_name": None,
}
with self.assertWarnsRegex(UserWarning, "No main objective found"):
model(
camera=cameras,
evaluation_mode=EvaluationMode.TRAINING,
**defaulted_args,
image_rgb=None,
)
target_image_rgb = torch.rand(
(n_train_cameras, 3, model.render_image_height, model.render_image_width),
device=device,
)
train_preds = model(
camera=cameras,
evaluation_mode=EvaluationMode.TRAINING,
image_rgb=target_image_rgb,
**defaulted_args,
)
self.assertGreater(train_preds["objective"].item(), 0)
model.eval()
with torch.no_grad():
# TODO: perhaps this warning should be skipped in eval mode?
with self.assertWarnsRegex(UserWarning, "No main objective found"):
eval_preds = model(
camera=cameras[0],
**defaulted_args,
image_rgb=None,
)
self.assertEqual(
eval_preds["images_render"].shape,
(1, 3, model.render_image_height, model.render_image_width),
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
import torch
from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner
from pytorch3d.renderer import RayBundle
if os.environ.get("FB_TEST", False):
from common_testing import TestCaseMixin
else:
from tests.common_testing import TestCaseMixin
class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
def test_simple(self):
length = 15
n_pts_per_ray = 10
for add_input_samples in [False, True]:
ray_point_refiner = RayPointRefiner(
n_pts_per_ray=n_pts_per_ray,
random_sampling=False,
add_input_samples=add_input_samples,
)
lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length)
bundle = RayBundle(lengths=lengths, origins=None, directions=None, xys=None)
weights = torch.ones(3, 25, length)
refined = ray_point_refiner(bundle, weights)
self.assertIsNone(refined.directions)
self.assertIsNone(refined.origins)
self.assertIsNone(refined.xys)
expected = torch.linspace(0.5, length - 1.5, n_pts_per_ray)
expected = expected.expand(3, 25, n_pts_per_ray)
if add_input_samples:
full_expected = torch.cat((lengths, expected), dim=-1).sort()[0]
else:
full_expected = expected
self.assertClose(refined.lengths, full_expected)
ray_point_refiner_random = RayPointRefiner(
n_pts_per_ray=n_pts_per_ray,
random_sampling=True,
add_input_samples=add_input_samples,
)
refined_random = ray_point_refiner_random(bundle, weights)
lengths_random = refined_random.lengths
self.assertEqual(lengths_random.shape, full_expected.shape)
if not add_input_samples:
self.assertGreater(lengths_random.min().item(), 0.5)
self.assertLess(lengths_random.max().item(), length - 1.5)
# Check sorted
self.assertTrue(
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
import torch
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import (
SRNHyperNetImplicitFunction,
SRNImplicitFunction,
SRNPixelGenerator,
)
from pytorch3d.implicitron.models.renderer.base import ImplicitFunctionWrapper
from pytorch3d.implicitron.tools.config import get_default_args
from pytorch3d.renderer import RayBundle
if os.environ.get("FB_TEST", False):
from common_testing import TestCaseMixin
else:
from tests.common_testing import TestCaseMixin
_BATCH_SIZE: int = 3
_N_RAYS: int = 100
_N_POINTS_ON_RAY: int = 10
class TestSRN(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(42)
get_default_args(SRNHyperNetImplicitFunction)
get_default_args(SRNImplicitFunction)
def test_pixel_generator(self):
SRNPixelGenerator()
def _get_bundle(self, *, device) -> RayBundle:
origins = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
directions = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
lengths = torch.rand(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, device=device)
bundle = RayBundle(
lengths=lengths, origins=origins, directions=directions, xys=None
)
return bundle
def test_srn_implicit_function(self):
implicit_function = SRNImplicitFunction()
device = torch.device("cpu")
bundle = self._get_bundle(device=device)
rays_densities, rays_colors = implicit_function(bundle)
out_features = implicit_function.raymarch_function.out_features
self.assertEqual(
rays_densities.shape,
(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, out_features),
)
self.assertIsNone(rays_colors)
def test_srn_hypernet_implicit_function(self):
# TODO investigate: If latent_dim_hypernet=0, why does this crash and dump core?
latent_dim_hypernet = 39
hypernet_args = {"latent_dim_hypernet": latent_dim_hypernet}
device = torch.device("cuda:0")
implicit_function = SRNHyperNetImplicitFunction(hypernet_args=hypernet_args)
implicit_function.to(device)
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
bundle = self._get_bundle(device=device)
rays_densities, rays_colors = implicit_function(bundle, global_code=global_code)
out_features = implicit_function.hypernet.out_features
self.assertEqual(
rays_densities.shape,
(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, out_features),
)
self.assertIsNone(rays_colors)
def test_srn_hypernet_implicit_function_optim(self):
# Test optimization loop, requiring that the cache is properly
# cleared in new_args_bound
latent_dim_hypernet = 39
hyper_args = {"latent_dim_hypernet": latent_dim_hypernet}
device = torch.device("cuda:0")
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
bundle = self._get_bundle(device=device)
implicit_function = SRNHyperNetImplicitFunction(hypernet_args=hyper_args)
implicit_function2 = SRNHyperNetImplicitFunction(hypernet_args=hyper_args)
implicit_function.to(device)
implicit_function2.to(device)
wrapper = ImplicitFunctionWrapper(implicit_function)
optimizer = torch.optim.Adam(implicit_function.parameters())
for _step in range(3):
optimizer.zero_grad()
wrapper.bind_args(global_code=global_code)
rays_densities, _rays_colors = wrapper(bundle)
wrapper.unbind_args()
loss = rays_densities.sum()
loss.backward()
optimizer.step()
wrapper2 = ImplicitFunctionWrapper(implicit_function)
optimizer2 = torch.optim.Adam(implicit_function2.parameters())
implicit_function2.load_state_dict(implicit_function.state_dict())
optimizer2.load_state_dict(optimizer.state_dict())
for _step in range(3):
optimizer2.zero_grad()
wrapper2.bind_args(global_code=global_code)
rays_densities, _rays_colors = wrapper2(bundle)
wrapper2.unbind_args()
loss = rays_densities.sum()
loss.backward()
optimizer2.step()
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import dataclasses
import unittest
from typing import Dict, List, NamedTuple, Tuple
from pytorch3d.implicitron.dataset import types
from pytorch3d.implicitron.dataset.types import FrameAnnotation
class _NT(NamedTuple):
annot: FrameAnnotation
class TestDatasetTypes(unittest.TestCase):
def setUp(self):
self.entry = FrameAnnotation(
frame_number=23,
sequence_name="1",
frame_timestamp=1.2,
image=types.ImageAnnotation(path="/tmp/1.jpg", size=(224, 224)),
mask=types.MaskAnnotation(path="/tmp/1.png", mass=42.0),
viewpoint=types.ViewpointAnnotation(
R=(
(1, 0, 0),
(1, 0, 0),
(1, 0, 0),
),
T=(0, 0, 0),
principal_point=(100, 100),
focal_length=(200, 200),
),
)
def test_asdict_rec(self):
first = [dataclasses.asdict(self.entry)]
second = types._asdict_rec([self.entry])
self.assertEqual(first, second)
def test_parsing(self):
"""Test that we handle collections enclosing dataclasses."""
dct = dataclasses.asdict(self.entry)
parsed = types._dataclass_from_dict(dct, FrameAnnotation)
self.assertEqual(parsed, self.entry)
# namedtuple
parsed = types._dataclass_from_dict(_NT(dct), _NT)
self.assertEqual(parsed.annot, self.entry)
# tuple
parsed = types._dataclass_from_dict((dct,), Tuple[FrameAnnotation])
self.assertEqual(parsed, (self.entry,))
# list
parsed = types._dataclass_from_dict(
[
dct,
],
List[FrameAnnotation],
)
self.assertEqual(
parsed,
[
self.entry,
],
)
# dict
parsed = types._dataclass_from_dict({"k": dct}, Dict[str, FrameAnnotation])
self.assertEqual(parsed, {"k": self.entry})
def test_parsing_vectorized(self):
dct = dataclasses.asdict(self.entry)
self._compare_with_scalar(dct, FrameAnnotation)
self._compare_with_scalar(_NT(dct), _NT)
self._compare_with_scalar((dct,), Tuple[FrameAnnotation])
self._compare_with_scalar([dct], List[FrameAnnotation])
self._compare_with_scalar({"k": dct}, Dict[str, FrameAnnotation])
def _compare_with_scalar(self, obj, typeannot, repeat=3):
input = [obj] * 3
vect_output = types._dataclass_list_from_dict_list(input, typeannot)
self.assertEqual(len(input), repeat)
gt = types._dataclass_from_dict(obj, typeannot)
self.assertTrue(all(res == gt for res in vect_output))
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import pytorch3d as pt3d
import torch
from pytorch3d.implicitron.models.view_pooling.view_sampling import ViewSampler
from pytorch3d.implicitron.tools.config import expand_args_fields
class TestViewsampling(unittest.TestCase):
def setUp(self):
torch.manual_seed(42)
expand_args_fields(ViewSampler)
def _init_view_sampler_problem(self, random_masks):
"""
Generates a view-sampling problem:
- 4 source views, 1st/2nd from the first sequence 'seq1', the rest from 'seq2'
- 3 sets of 3D points from sequences 'seq1', 'seq2', 'seq2' respectively.
- first 50 points in each batch correctly project to the source views,
while the remaining 50 do not land in any projection plane.
- each source view is labeled with image feature tensors of shape 7x100x50,
where all elements of the n-th tensor are set to `n+1`.
- the elements of the source view masks are either set to random binary number
(if `random_masks==True`), or all set to 1 (`random_masks==False`).
- the source view cameras are uniformly distributed on a unit circle
in the x-z plane and look at (0,0,0).
"""
seq_id_camera = ["seq1", "seq1", "seq2", "seq2"]
seq_id_pts = ["seq1", "seq2", "seq2"]
pts_batch = 3
n_pts = 100
n_views = 4
fdim = 7
H = 100
W = 50
# points that land into the projection planes of all cameras
pts_inside = (
torch.nn.functional.normalize(
torch.randn(pts_batch, n_pts // 2, 3, device="cuda"),
dim=-1,
)
* 0.1
)
# move the outside points far above the scene
pts_outside = pts_inside.clone()
pts_outside[:, :, 1] += 1e8
pts = torch.cat([pts_inside, pts_outside], dim=1)
R, T = pt3d.renderer.look_at_view_transform(
dist=1.0,
elev=0.0,
azim=torch.linspace(0, 360, n_views + 1)[:n_views],
degrees=True,
device=pts.device,
)
focal_length = R.new_ones(n_views, 2)
principal_point = R.new_zeros(n_views, 2)
camera = pt3d.renderer.PerspectiveCameras(
R=R,
T=T,
focal_length=focal_length,
principal_point=principal_point,
device=pts.device,
)
feats_map = torch.arange(n_views, device=pts.device, dtype=pts.dtype) + 1
feats = {"feats": feats_map[:, None, None, None].repeat(1, fdim, H, W)}
masks = (
torch.rand(n_views, 1, H, W, device=pts.device, dtype=pts.dtype) > 0.5
).type_as(R)
if not random_masks:
masks[:] = 1.0
return pts, camera, feats, masks, seq_id_camera, seq_id_pts
def test_compare_with_naive(self):
"""
Compares the outputs of the efficient ViewSampler module with a
naive implementation.
"""
(
pts,
camera,
feats,
masks,
seq_id_camera,
seq_id_pts,
) = self._init_view_sampler_problem(True)
for masked_sampling in (True, False):
feats_sampled_n, masks_sampled_n = _view_sample_naive(
pts,
seq_id_pts,
camera,
seq_id_camera,
feats,
masks,
masked_sampling,
)
# make sure we generate the constructor for ViewSampler
expand_args_fields(ViewSampler)
view_sampler = ViewSampler(masked_sampling=masked_sampling)
feats_sampled, masks_sampled = view_sampler(
pts=pts,
seq_id_pts=seq_id_pts,
camera=camera,
seq_id_camera=seq_id_camera,
feats=feats,
masks=masks,
)
for k in feats_sampled.keys():
self.assertTrue(torch.allclose(feats_sampled[k], feats_sampled_n[k]))
self.assertTrue(torch.allclose(masks_sampled, masks_sampled_n))
def test_viewsampling(self):
"""
Generates a viewsampling problem with predictable outcome, and compares
the ViewSampler's output to the expected result.
"""
(
pts,
camera,
feats,
masks,
seq_id_camera,
seq_id_pts,
) = self._init_view_sampler_problem(False)
expand_args_fields(ViewSampler)
for masked_sampling in (True, False):
view_sampler = ViewSampler(masked_sampling=masked_sampling)
feats_sampled, masks_sampled = view_sampler(
pts=pts,
seq_id_pts=seq_id_pts,
camera=camera,
seq_id_camera=seq_id_camera,
feats=feats,
masks=masks,
)
n_views = camera.R.shape[0]
n_pts = pts.shape[1]
feat_dim = feats["feats"].shape[1]
pts_batch = pts.shape[0]
n_pts_away = n_pts // 2
for pts_i in range(pts_batch):
for view_i in range(n_views):
if seq_id_pts[pts_i] != seq_id_camera[view_i]:
# points / cameras come from different sequences
gt_masks = pts.new_zeros(n_pts, 1)
gt_feats = pts.new_zeros(n_pts, feat_dim)
else:
gt_masks = pts.new_ones(n_pts, 1)
gt_feats = pts.new_ones(n_pts, feat_dim) * (view_i + 1)
gt_feats[n_pts_away:] = 0.0
if masked_sampling:
gt_masks[n_pts_away:] = 0.0
for k in feats_sampled:
self.assertTrue(
torch.allclose(
feats_sampled[k][pts_i, view_i],
gt_feats,
)
)
self.assertTrue(
torch.allclose(
masks_sampled[pts_i, view_i],
gt_masks,
)
)
def _view_sample_naive(
pts,
seq_id_pts,
camera,
seq_id_camera,
feats,
masks,
masked_sampling,
):
"""
A naive implementation of the forward pass of ViewSampler.
Refer to ViewSampler's docstring for description of the arguments.
"""
pts_batch = pts.shape[0]
n_views = camera.R.shape[0]
n_pts = pts.shape[1]
feats_sampled = [[[] for _ in range(n_views)] for _ in range(pts_batch)]
masks_sampled = [[[] for _ in range(n_views)] for _ in range(pts_batch)]
for pts_i in range(pts_batch):
for view_i in range(n_views):
if seq_id_pts[pts_i] != seq_id_camera[view_i]:
# points/cameras come from different sequences
feats_sampled_ = {
k: f.new_zeros(n_pts, f.shape[1]) for k, f in feats.items()
}
masks_sampled_ = masks.new_zeros(n_pts, 1)
else:
# same sequence of pts and cameras -> sample
feats_sampled_, masks_sampled_ = _sample_one_view_naive(
camera[view_i],
pts[pts_i],
{k: f[view_i] for k, f in feats.items()},
masks[view_i],
masked_sampling,
sampling_mode="bilinear",
)
feats_sampled[pts_i][view_i] = feats_sampled_
masks_sampled[pts_i][view_i] = masks_sampled_
masks_sampled_cat = torch.stack([torch.stack(m) for m in masks_sampled])
feats_sampled_cat = {}
for k in feats_sampled[0][0].keys():
feats_sampled_cat[k] = torch.stack(
[torch.stack([f_[k] for f_ in f]) for f in feats_sampled]
)
return feats_sampled_cat, masks_sampled_cat
def _sample_one_view_naive(
camera,
pts,
feats,
masks,
masked_sampling,
sampling_mode="bilinear",
):
"""
Sample a single source view.
"""
proj_ndc = camera.transform_points(pts[None])[None, ..., :-1] # 1 x 1 x n_pts x 2
feats_sampled = {
k: pt3d.renderer.ndc_grid_sample(f[None], proj_ndc, mode=sampling_mode).permute(
0, 3, 1, 2
)[0, :, :, 0]
for k, f in feats.items()
} # n_pts x dim
if not masked_sampling:
n_pts = pts.shape[0]
masks_sampled = proj_ndc.new_ones(n_pts, 1)
else:
masks_sampled = pt3d.renderer.ndc_grid_sample(
masks[None],
proj_ndc,
mode=sampling_mode,
align_corners=False,
)[0, 0, 0, :][:, None]
return feats_sampled, masks_sampled
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment