Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
pytorch3d
Commits
cdd2142d
Unverified
Commit
cdd2142d
authored
Mar 21, 2022
by
Jeremy Reizenstein
Committed by
GitHub
Mar 21, 2022
Browse files
implicitron v0 (#1133)
Co-authored-by:
Jeremy Francis Reizenstein
<
bottler@users.noreply.github.com
>
parent
0e377c68
Changes
90
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1827 additions
and
0 deletions
+1827
-0
tests/implicitron/test_config.py
tests/implicitron/test_config.py
+610
-0
tests/implicitron/test_config_use.py
tests/implicitron/test_config_use.py
+81
-0
tests/implicitron/test_dataset_visualize.py
tests/implicitron/test_dataset_visualize.py
+191
-0
tests/implicitron/test_eval_cameras.py
tests/implicitron/test_eval_cameras.py
+48
-0
tests/implicitron/test_evaluation.py
tests/implicitron/test_evaluation.py
+290
-0
tests/implicitron/test_forward_pass.py
tests/implicitron/test_forward_pass.py
+67
-0
tests/implicitron/test_ray_point_refiner.py
tests/implicitron/test_ray_point_refiner.py
+63
-0
tests/implicitron/test_srn.py
tests/implicitron/test_srn.py
+114
-0
tests/implicitron/test_types.py
tests/implicitron/test_types.py
+93
-0
tests/implicitron/test_viewsampling.py
tests/implicitron/test_viewsampling.py
+270
-0
No files found.
tests/implicitron/test_config.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
textwrap
import
unittest
from
dataclasses
import
dataclass
,
field
,
is_dataclass
from
enum
import
Enum
from
typing
import
List
,
Optional
,
Tuple
from
omegaconf
import
DictConfig
,
ListConfig
,
OmegaConf
,
ValidationError
from
pytorch3d.implicitron.tools.config
import
(
Configurable
,
ReplaceableBase
,
_is_actually_dataclass
,
_Registry
,
expand_args_fields
,
get_default_args
,
get_default_args_field
,
registry
,
remove_unused_components
,
run_auto_creation
,
)
@
dataclass
class
Animal
(
ReplaceableBase
):
pass
class
Fruit
(
ReplaceableBase
):
pass
@
registry
.
register
class
Banana
(
Fruit
):
pips
:
int
spots
:
int
bananame
:
str
@
registry
.
register
class
Pear
(
Fruit
):
n_pips
:
int
=
13
class
Pineapple
(
Fruit
):
pass
@
registry
.
register
class
Orange
(
Fruit
):
pass
@
registry
.
register
class
Kiwi
(
Fruit
):
pass
@
registry
.
register
class
LargePear
(
Pear
):
pass
class
MainTest
(
Configurable
):
the_fruit
:
Fruit
n_ids
:
int
n_reps
:
int
=
8
the_second_fruit
:
Fruit
def
create_the_second_fruit
(
self
):
expand_args_fields
(
Pineapple
)
self
.
the_second_fruit
=
Pineapple
()
def
__post_init__
(
self
):
run_auto_creation
(
self
)
class
TestConfig
(
unittest
.
TestCase
):
def
test_is_actually_dataclass
(
self
):
@
dataclass
class
A
:
pass
self
.
assertTrue
(
_is_actually_dataclass
(
A
))
self
.
assertTrue
(
is_dataclass
(
A
))
class
B
(
A
):
a
:
int
self
.
assertFalse
(
_is_actually_dataclass
(
B
))
self
.
assertTrue
(
is_dataclass
(
B
))
def
test_simple_replacement
(
self
):
struct
=
get_default_args
(
MainTest
)
struct
.
n_ids
=
9780
struct
.
the_fruit_Pear_args
.
n_pips
=
3
struct
.
the_fruit_class_type
=
"Pear"
struct
.
the_second_fruit_class_type
=
"Pear"
main
=
MainTest
(
**
struct
)
self
.
assertIsInstance
(
main
.
the_fruit
,
Pear
)
self
.
assertEqual
(
main
.
n_reps
,
8
)
self
.
assertEqual
(
main
.
n_ids
,
9780
)
self
.
assertEqual
(
main
.
the_fruit
.
n_pips
,
3
)
self
.
assertIsInstance
(
main
.
the_second_fruit
,
Pineapple
)
struct2
=
get_default_args
(
MainTest
)
self
.
assertEqual
(
struct2
.
the_fruit_Pear_args
.
n_pips
,
13
)
self
.
assertEqual
(
MainTest
.
_creation_functions
,
(
"create_the_fruit"
,
"create_the_second_fruit"
),
)
def
test_detect_bases
(
self
):
# testing the _base_class_from_class function
self
.
assertIsNone
(
_Registry
.
_base_class_from_class
(
ReplaceableBase
))
self
.
assertIsNone
(
_Registry
.
_base_class_from_class
(
MainTest
))
self
.
assertIs
(
_Registry
.
_base_class_from_class
(
Fruit
),
Fruit
)
self
.
assertIs
(
_Registry
.
_base_class_from_class
(
Pear
),
Fruit
)
class
PricklyPear
(
Pear
):
pass
self
.
assertIs
(
_Registry
.
_base_class_from_class
(
PricklyPear
),
Fruit
)
def
test_registry_entries
(
self
):
self
.
assertIs
(
registry
.
get
(
Fruit
,
"Banana"
),
Banana
)
with
self
.
assertRaisesRegex
(
ValueError
,
"Banana has not been registered."
):
registry
.
get
(
Animal
,
"Banana"
)
with
self
.
assertRaisesRegex
(
ValueError
,
"PricklyPear has not been registered."
):
registry
.
get
(
Fruit
,
"PricklyPear"
)
self
.
assertIs
(
registry
.
get
(
Pear
,
"Pear"
),
Pear
)
self
.
assertIs
(
registry
.
get
(
Pear
,
"LargePear"
),
LargePear
)
with
self
.
assertRaisesRegex
(
ValueError
,
"Banana resolves to"
):
registry
.
get
(
Pear
,
"Banana"
)
all_fruit
=
set
(
registry
.
get_all
(
Fruit
))
self
.
assertIn
(
Banana
,
all_fruit
)
self
.
assertIn
(
Pear
,
all_fruit
)
self
.
assertIn
(
LargePear
,
all_fruit
)
self
.
assertEqual
(
set
(
registry
.
get_all
(
Pear
)),
{
LargePear
})
@
registry
.
register
class
Apple
(
Fruit
):
pass
@
registry
.
register
class
CrabApple
(
Apple
):
pass
self
.
assertEqual
(
set
(
registry
.
get_all
(
Apple
)),
{
CrabApple
})
self
.
assertIs
(
registry
.
get
(
Fruit
,
"CrabApple"
),
CrabApple
)
with
self
.
assertRaisesRegex
(
ValueError
,
"Cannot tell what it is."
):
@
registry
.
register
class
NotAFruit
:
pass
def
test_recursion
(
self
):
class
Shape
(
ReplaceableBase
):
pass
@
registry
.
register
class
Triangle
(
Shape
):
a
:
float
=
5.0
@
registry
.
register
class
Square
(
Shape
):
a
:
float
=
3.0
@
registry
.
register
class
LargeShape
(
Shape
):
inner
:
Shape
def
__post_init__
(
self
):
run_auto_creation
(
self
)
class
ShapeContainer
(
Configurable
):
shape
:
Shape
container
=
ShapeContainer
(
**
get_default_args
(
ShapeContainer
))
# This is because ShapeContainer is missing __post_init__
with
self
.
assertRaises
(
AttributeError
):
container
.
shape
class
ShapeContainer2
(
Configurable
):
x
:
Shape
x_class_type
:
str
=
"LargeShape"
def
__post_init__
(
self
):
self
.
x_LargeShape_args
.
inner_class_type
=
"Triangle"
run_auto_creation
(
self
)
container2_args
=
get_default_args
(
ShapeContainer2
)
container2_args
.
x_LargeShape_args
.
inner_Triangle_args
.
a
+=
10
self
.
assertIn
(
"inner_Square_args"
,
container2_args
.
x_LargeShape_args
)
# We do not perform expansion that would result in an infinite recursion,
# so this member is not present.
self
.
assertNotIn
(
"inner_LargeShape_args"
,
container2_args
.
x_LargeShape_args
)
container2_args
.
x_LargeShape_args
.
inner_Square_args
.
a
+=
100
container2
=
ShapeContainer2
(
**
container2_args
)
self
.
assertIsInstance
(
container2
.
x
,
LargeShape
)
self
.
assertIsInstance
(
container2
.
x
.
inner
,
Triangle
)
self
.
assertEqual
(
container2
.
x
.
inner
.
a
,
15.0
)
def
test_simpleclass_member
(
self
):
# Members which are not dataclasses are
# tolerated. But it would be nice to be able to
# configure them.
class
Foo
:
def
__init__
(
self
,
a
=
1
,
b
=
2
):
self
.
a
,
self
.
b
=
a
,
b
@
dataclass
()
class
Bar
:
aa
:
int
=
9
bb
:
int
=
9
class
Container
(
Configurable
):
bar
:
Bar
=
Bar
()
# TODO make this work?
# foo: Foo = Foo()
fruit
:
Fruit
fruit_class_type
:
str
=
"Orange"
def
__post_init__
(
self
):
run_auto_creation
(
self
)
self
.
assertEqual
(
get_default_args
(
Foo
),
{
"a"
:
1
,
"b"
:
2
})
container_args
=
get_default_args
(
Container
)
container
=
Container
(
**
container_args
)
self
.
assertIsInstance
(
container
.
fruit
,
Orange
)
# self.assertIsInstance(container.bar, Bar)
container_defaulted
=
Container
()
container_defaulted
.
fruit_Pear_args
.
n_pips
+=
4
container_args2
=
get_default_args
(
Container
)
container
=
Container
(
**
container_args2
)
self
.
assertEqual
(
container
.
fruit_Pear_args
.
n_pips
,
13
)
def
test_inheritance
(
self
):
class
FruitBowl
(
ReplaceableBase
):
main_fruit
:
Fruit
main_fruit_class_type
:
str
=
"Orange"
def
__post_init__
(
self
):
raise
ValueError
(
"This doesn't get called"
)
class
LargeFruitBowl
(
FruitBowl
):
extra_fruit
:
Fruit
extra_fruit_class_type
:
str
=
"Kiwi"
def
__post_init__
(
self
):
run_auto_creation
(
self
)
large_args
=
get_default_args
(
LargeFruitBowl
)
self
.
assertNotIn
(
"extra_fruit"
,
large_args
)
self
.
assertNotIn
(
"main_fruit"
,
large_args
)
large
=
LargeFruitBowl
(
**
large_args
)
self
.
assertIsInstance
(
large
.
main_fruit
,
Orange
)
self
.
assertIsInstance
(
large
.
extra_fruit
,
Kiwi
)
def
test_inheritance2
(
self
):
# This is a case where a class could contain an instance
# of a subclass, which is ignored.
class
Parent
(
ReplaceableBase
):
pass
class
Main
(
Configurable
):
parent
:
Parent
# Note - no __post__init__
@
registry
.
register
class
Derived
(
Parent
,
Main
):
pass
args
=
get_default_args
(
Main
)
# Derived has been ignored in processing Main.
self
.
assertCountEqual
(
args
.
keys
(),
[
"parent_class_type"
])
main
=
Main
(
**
args
)
with
self
.
assertRaisesRegex
(
ValueError
,
"UNDEFAULTED has not been registered."
):
run_auto_creation
(
main
)
main
.
parent_class_type
=
"Derived"
# Illustrates that a dict works fine instead of a DictConfig.
main
.
parent_Derived_args
=
{}
with
self
.
assertRaises
(
AttributeError
):
main
.
parent
run_auto_creation
(
main
)
self
.
assertIsInstance
(
main
.
parent
,
Derived
)
def
test_redefine
(
self
):
class
FruitBowl
(
ReplaceableBase
):
main_fruit
:
Fruit
main_fruit_class_type
:
str
=
"Grape"
def
__post_init__
(
self
):
run_auto_creation
(
self
)
@
registry
.
register
@
dataclass
class
Grape
(
Fruit
):
large
:
bool
=
False
def
get_color
(
self
):
return
"red"
def
__post_init__
(
self
):
raise
ValueError
(
"This doesn't get called"
)
bowl_args
=
get_default_args
(
FruitBowl
)
@
registry
.
register
@
dataclass
class
Grape
(
Fruit
):
# noqa: F811
large
:
bool
=
True
def
get_color
(
self
):
return
"green"
with
self
.
assertWarnsRegex
(
UserWarning
,
"New implementation of Grape is being chosen."
):
bowl
=
FruitBowl
(
**
bowl_args
)
self
.
assertIsInstance
(
bowl
.
main_fruit
,
Grape
)
# Redefining the same class won't help with defaults because encoded in args
self
.
assertEqual
(
bowl
.
main_fruit
.
large
,
False
)
# But the override worked.
self
.
assertEqual
(
bowl
.
main_fruit
.
get_color
(),
"green"
)
# 2. Try redefining without the dataclass modifier
# This relies on the fact that default creation processes the class.
# (otherwise incomprehensible messages)
@
registry
.
register
class
Grape
(
Fruit
):
# noqa: F811
large
:
bool
=
True
with
self
.
assertWarnsRegex
(
UserWarning
,
"New implementation of Grape is being chosen."
):
bowl
=
FruitBowl
(
**
bowl_args
)
# 3. Adding a new class doesn't get picked up, because the first
# get_default_args call has frozen FruitBowl. This is intrinsic to
# the way dataclass and expand_args_fields work in-place but
# expand_args_fields is not pure - it depends on the registry.
@
registry
.
register
class
Fig
(
Fruit
):
pass
bowl_args2
=
get_default_args
(
FruitBowl
)
self
.
assertIn
(
"main_fruit_Grape_args"
,
bowl_args2
)
self
.
assertNotIn
(
"main_fruit_Fig_args"
,
bowl_args2
)
# TODO Is it possible to make this work?
# bowl_args2["main_fruit_Fig_args"] = get_default_args(Fig)
# bowl_args2.main_fruit_class_type = "Fig"
# bowl2 = FruitBowl(**bowl_args2) <= unexpected argument
# Note that it is possible to use Fig if you can set
# bowl2.main_fruit_Fig_args explicitly (not in bowl_args2)
# before run_auto_creation happens. See test_inheritance2
# for an example.
def
test_no_replacement
(
self
):
# Test of Configurables without ReplaceableBase
class
A
(
Configurable
):
n
:
int
=
9
class
B
(
Configurable
):
a
:
A
def
__post_init__
(
self
):
run_auto_creation
(
self
)
class
C
(
Configurable
):
b
:
B
def
__post_init__
(
self
):
run_auto_creation
(
self
)
c_args
=
get_default_args
(
C
)
c
=
C
(
**
c_args
)
self
.
assertIsInstance
(
c
.
b
.
a
,
A
)
self
.
assertEqual
(
c
.
b
.
a
.
n
,
9
)
def
test_doc
(
self
):
# The case in the docstring.
class
A
(
ReplaceableBase
):
k
:
int
=
1
@
registry
.
register
class
A1
(
A
):
m
:
int
=
3
@
registry
.
register
class
A2
(
A
):
n
:
str
=
"2"
class
B
(
Configurable
):
a
:
A
a_class_type
:
str
=
"A2"
def
__post_init__
(
self
):
run_auto_creation
(
self
)
b_args
=
get_default_args
(
B
)
self
.
assertNotIn
(
"a"
,
b_args
)
b
=
B
(
**
b_args
)
self
.
assertEqual
(
b
.
a
.
n
,
"2"
)
def
test_raw_types
(
self
):
@
dataclass
class
MyDataclass
:
int_field
:
int
=
0
none_field
:
Optional
[
int
]
=
None
float_field
:
float
=
9.3
bool_field
:
bool
=
True
tuple_field
:
tuple
=
(
3
,
True
,
"j"
)
class
SimpleClass
:
def
__init__
(
self
,
tuple_member_
=
(
3
,
4
)):
self
.
tuple_member
=
tuple_member_
def
get_tuple
(
self
):
return
self
.
tuple_member
def
f
(
*
,
a
:
int
=
3
,
b
:
str
=
"kj"
):
self
.
assertEqual
(
a
,
3
)
self
.
assertEqual
(
b
,
"kj"
)
class
C
(
Configurable
):
simple
:
DictConfig
=
get_default_args_field
(
SimpleClass
)
# simple2: SimpleClass2 = SimpleClass2()
mydata
:
DictConfig
=
get_default_args_field
(
MyDataclass
)
a_tuple
:
Tuple
[
float
]
=
(
4.0
,
3.0
)
f_args
:
DictConfig
=
get_default_args_field
(
f
)
args
=
get_default_args
(
C
)
c
=
C
(
**
args
)
self
.
assertCountEqual
(
args
.
keys
(),
[
"simple"
,
"mydata"
,
"a_tuple"
,
"f_args"
])
mydata
=
MyDataclass
(
**
c
.
mydata
)
simple
=
SimpleClass
(
**
c
.
simple
)
# OmegaConf converts tuples to ListConfigs (which act like lists).
self
.
assertEqual
(
simple
.
get_tuple
(),
[
3
,
4
])
self
.
assertTrue
(
isinstance
(
simple
.
get_tuple
(),
ListConfig
))
self
.
assertEqual
(
c
.
a_tuple
,
[
4.0
,
3.0
])
self
.
assertTrue
(
isinstance
(
c
.
a_tuple
,
ListConfig
))
self
.
assertEqual
(
mydata
.
tuple_field
,
(
3
,
True
,
"j"
))
self
.
assertTrue
(
isinstance
(
mydata
.
tuple_field
,
ListConfig
))
f
(
**
c
.
f_args
)
def
test_irrelevant_bases
(
self
):
class
NotADataclass
:
# Like torch.nn.Module, this class contains annotations
# but is not designed to be dataclass'd.
# This test ensures that such classes, when inherited fron,
# are not accidentally expand_args_fields.
a
:
int
=
9
b
:
int
class
LeftConfigured
(
Configurable
,
NotADataclass
):
left
:
int
=
1
class
RightConfigured
(
NotADataclass
,
Configurable
):
right
:
int
=
2
class
Outer
(
Configurable
):
left
:
LeftConfigured
right
:
RightConfigured
def
__post_init__
(
self
):
run_auto_creation
(
self
)
outer
=
Outer
(
**
get_default_args
(
Outer
))
self
.
assertEqual
(
outer
.
left
.
left
,
1
)
self
.
assertEqual
(
outer
.
right
.
right
,
2
)
with
self
.
assertRaisesRegex
(
TypeError
,
"non-default argument"
):
dataclass
(
NotADataclass
)
def
test_unprocessed
(
self
):
# behavior of Configurable classes which need processing in __new__,
class
Unprocessed
(
Configurable
):
a
:
int
=
9
class
UnprocessedReplaceable
(
ReplaceableBase
):
a
:
int
=
1
with
self
.
assertWarnsRegex
(
UserWarning
,
"must be processed"
):
Unprocessed
()
with
self
.
assertWarnsRegex
(
UserWarning
,
"must be processed"
):
UnprocessedReplaceable
()
def
test_enum
(
self
):
# Test that enum values are kept, i.e. that OmegaConf's runtime checks
# are in use.
class
A
(
Enum
):
B1
=
"b1"
B2
=
"b2"
class
C
(
Configurable
):
a
:
A
=
A
.
B1
base
=
get_default_args
(
C
)
replaced
=
OmegaConf
.
merge
(
base
,
{
"a"
:
"B2"
})
self
.
assertEqual
(
replaced
.
a
,
A
.
B2
)
with
self
.
assertRaises
(
ValidationError
):
# You can't use a value which is not one of the
# choices, even if it is the str representation
# of one of the choices.
OmegaConf
.
merge
(
base
,
{
"a"
:
"b2"
})
remerged
=
OmegaConf
.
merge
(
base
,
OmegaConf
.
create
(
OmegaConf
.
to_yaml
(
base
)))
self
.
assertEqual
(
remerged
.
a
,
A
.
B1
)
def
test_remove_unused_components
(
self
):
struct
=
get_default_args
(
MainTest
)
struct
.
n_ids
=
32
struct
.
the_fruit_class_type
=
"Pear"
struct
.
the_second_fruit_class_type
=
"Banana"
remove_unused_components
(
struct
)
expected_keys
=
[
"n_ids"
,
"n_reps"
,
"the_fruit_Pear_args"
,
"the_fruit_class_type"
,
"the_second_fruit_Banana_args"
,
"the_second_fruit_class_type"
,
]
expected_yaml
=
textwrap
.
dedent
(
"""
\
n_ids: 32
n_reps: 8
the_fruit_class_type: Pear
the_fruit_Pear_args:
n_pips: 13
the_second_fruit_class_type: Banana
the_second_fruit_Banana_args:
pips: ???
spots: ???
bananame: ???
"""
)
self
.
assertEqual
(
sorted
(
struct
.
keys
()),
expected_keys
)
# Check that struct is what we expect
expected
=
OmegaConf
.
create
(
expected_yaml
)
self
.
assertEqual
(
struct
,
expected
)
# Check that we get what we expect when writing to yaml.
self
.
assertEqual
(
OmegaConf
.
to_yaml
(
struct
,
sort_keys
=
False
),
expected_yaml
)
main
=
MainTest
(
**
struct
)
instance_data
=
OmegaConf
.
structured
(
main
)
remove_unused_components
(
instance_data
)
self
.
assertEqual
(
sorted
(
instance_data
.
keys
()),
expected_keys
)
self
.
assertEqual
(
instance_data
,
expected
)
@
dataclass
(
eq
=
False
)
class
MockDataclass
:
field_no_default
:
int
field_primitive_type
:
int
=
42
field_reference_type
:
List
[
int
]
=
field
(
default_factory
=
lambda
:
[])
class
MockClassWithInit
:
# noqa: B903
def
__init__
(
self
,
field_no_default
:
int
,
field_primitive_type
:
int
=
42
,
field_reference_type
:
List
[
int
]
=
[],
# noqa: B006
):
self
.
field_no_default
=
field_no_default
self
.
field_primitive_type
=
field_primitive_type
self
.
field_reference_type
=
field_reference_type
class
TestRawClasses
(
unittest
.
TestCase
):
def
test_get_default_args
(
self
):
for
cls
in
[
MockDataclass
,
MockClassWithInit
]:
dataclass_defaults
=
get_default_args
(
cls
)
inst
=
cls
(
field_no_default
=
0
)
dataclass_defaults
.
field_no_default
=
0
for
name
,
val
in
dataclass_defaults
.
items
():
self
.
assertTrue
(
hasattr
(
inst
,
name
))
self
.
assertEqual
(
val
,
getattr
(
inst
,
name
))
def
test_get_default_args_readonly
(
self
):
for
cls
in
[
MockDataclass
,
MockClassWithInit
]:
dataclass_defaults
=
get_default_args
(
cls
)
dataclass_defaults
[
"field_reference_type"
].
append
(
13
)
inst
=
cls
(
field_no_default
=
0
)
self
.
assertEqual
(
inst
.
field_reference_type
,
[])
tests/implicitron/test_config_use.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
os
import
unittest
from
omegaconf
import
OmegaConf
from
pytorch3d.implicitron.models.autodecoder
import
Autodecoder
from
pytorch3d.implicitron.models.base
import
GenericModel
from
pytorch3d.implicitron.models.implicit_function.idr_feature_field
import
(
IdrFeatureField
,
)
from
pytorch3d.implicitron.models.implicit_function.neural_radiance_field
import
(
NeuralRadianceFieldImplicitFunction
,
)
from
pytorch3d.implicitron.models.renderer.lstm_renderer
import
LSTMRenderer
from
pytorch3d.implicitron.models.renderer.multipass_ea
import
(
MultiPassEmissionAbsorptionRenderer
,
)
from
pytorch3d.implicitron.models.view_pooling.feature_aggregation
import
(
AngleWeightedIdentityFeatureAggregator
,
AngleWeightedReductionFeatureAggregator
,
)
from
pytorch3d.implicitron.tools.config
import
(
get_default_args
,
remove_unused_components
,
)
if
os
.
environ
.
get
(
"FB_TEST"
,
False
):
from
common_testing
import
get_tests_dir
else
:
from
tests.common_testing
import
get_tests_dir
DATA_DIR
=
get_tests_dir
()
/
"implicitron/data"
DEBUG
:
bool
=
False
# Tests the use of the config system in implicitron
class
TestGenericModel
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
maxDiff
=
None
def
test_create_gm
(
self
):
args
=
get_default_args
(
GenericModel
)
gm
=
GenericModel
(
**
args
)
self
.
assertIsInstance
(
gm
.
renderer
,
MultiPassEmissionAbsorptionRenderer
)
self
.
assertIsInstance
(
gm
.
feature_aggregator
,
AngleWeightedReductionFeatureAggregator
)
self
.
assertIsInstance
(
gm
.
_implicit_functions
[
0
].
_fn
,
NeuralRadianceFieldImplicitFunction
)
self
.
assertIsInstance
(
gm
.
sequence_autodecoder
,
Autodecoder
)
self
.
assertFalse
(
hasattr
(
gm
,
"implicit_function"
))
self
.
assertFalse
(
hasattr
(
gm
,
"image_feature_extractor"
))
def
test_create_gm_overrides
(
self
):
args
=
get_default_args
(
GenericModel
)
args
.
feature_aggregator_class_type
=
"AngleWeightedIdentityFeatureAggregator"
args
.
implicit_function_class_type
=
"IdrFeatureField"
args
.
renderer_class_type
=
"LSTMRenderer"
gm
=
GenericModel
(
**
args
)
self
.
assertIsInstance
(
gm
.
renderer
,
LSTMRenderer
)
self
.
assertIsInstance
(
gm
.
feature_aggregator
,
AngleWeightedIdentityFeatureAggregator
)
self
.
assertIsInstance
(
gm
.
_implicit_functions
[
0
].
_fn
,
IdrFeatureField
)
self
.
assertIsInstance
(
gm
.
sequence_autodecoder
,
Autodecoder
)
self
.
assertFalse
(
hasattr
(
gm
,
"implicit_function"
))
instance_args
=
OmegaConf
.
structured
(
gm
)
remove_unused_components
(
instance_args
)
yaml
=
OmegaConf
.
to_yaml
(
instance_args
,
sort_keys
=
False
)
if
DEBUG
:
(
DATA_DIR
/
"overrides.yaml_"
).
write_text
(
yaml
)
self
.
assertEqual
(
yaml
,
(
DATA_DIR
/
"overrides.yaml"
).
read_text
())
tests/implicitron/test_dataset_visualize.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
contextlib
import
copy
import
os
import
unittest
import
torch
import
torchvision
from
pytorch3d.implicitron.dataset.implicitron_dataset
import
ImplicitronDataset
from
pytorch3d.implicitron.dataset.visualize
import
get_implicitron_sequence_pointcloud
from
pytorch3d.implicitron.tools.point_cloud_utils
import
render_point_cloud_pytorch3d
from
pytorch3d.vis.plotly_vis
import
plot_scene
from
visdom
import
Visdom
if
os
.
environ
.
get
(
"FB_TEST"
,
False
):
from
.common_resources
import
get_skateboard_data
else
:
from
common_resources
import
get_skateboard_data
class
TestDatasetVisualize
(
unittest
.
TestCase
):
def
setUp
(
self
):
if
os
.
environ
.
get
(
"INSIDE_RE_WORKER"
)
is
not
None
:
raise
unittest
.
SkipTest
(
"Visdom not available"
)
category
=
"skateboard"
stack
=
contextlib
.
ExitStack
()
dataset_root
,
path_manager
=
stack
.
enter_context
(
get_skateboard_data
())
self
.
addCleanup
(
stack
.
close
)
frame_file
=
os
.
path
.
join
(
dataset_root
,
category
,
"frame_annotations.jgz"
)
sequence_file
=
os
.
path
.
join
(
dataset_root
,
category
,
"sequence_annotations.jgz"
)
self
.
image_size
=
256
self
.
datasets
=
{
"simple"
:
ImplicitronDataset
(
frame_annotations_file
=
frame_file
,
sequence_annotations_file
=
sequence_file
,
dataset_root
=
dataset_root
,
image_height
=
self
.
image_size
,
image_width
=
self
.
image_size
,
box_crop
=
True
,
load_point_clouds
=
True
,
path_manager
=
path_manager
,
),
"nonsquare"
:
ImplicitronDataset
(
frame_annotations_file
=
frame_file
,
sequence_annotations_file
=
sequence_file
,
dataset_root
=
dataset_root
,
image_height
=
self
.
image_size
,
image_width
=
self
.
image_size
//
2
,
box_crop
=
True
,
load_point_clouds
=
True
,
path_manager
=
path_manager
,
),
"nocrop"
:
ImplicitronDataset
(
frame_annotations_file
=
frame_file
,
sequence_annotations_file
=
sequence_file
,
dataset_root
=
dataset_root
,
image_height
=
self
.
image_size
,
image_width
=
self
.
image_size
//
2
,
box_crop
=
False
,
load_point_clouds
=
True
,
path_manager
=
path_manager
,
),
}
self
.
datasets
.
update
(
{
k
+
"_newndc"
:
_change_annotations_to_new_ndc
(
dataset
)
for
k
,
dataset
in
self
.
datasets
.
items
()
}
)
self
.
visdom
=
Visdom
()
if
not
self
.
visdom
.
check_connection
():
print
(
"Visdom server not running! Disabling visdom visualizations."
)
self
.
visdom
=
None
def
_render_one_pointcloud
(
self
,
point_cloud
,
cameras
,
render_size
):
(
_image_render
,
_
,
_
)
=
render_point_cloud_pytorch3d
(
cameras
,
point_cloud
,
render_size
=
render_size
,
point_radius
=
1e-2
,
topk
=
10
,
bg_color
=
0.0
,
)
return
_image_render
.
clamp
(
0.0
,
1.0
)
def
test_one
(
self
):
"""Test dataset visualization."""
for
max_frames
in
(
16
,
-
1
):
for
load_dataset_point_cloud
in
(
True
,
False
):
for
dataset_key
in
self
.
datasets
:
self
.
_gen_and_render_pointcloud
(
max_frames
,
load_dataset_point_cloud
,
dataset_key
)
def
_gen_and_render_pointcloud
(
self
,
max_frames
,
load_dataset_point_cloud
,
dataset_key
):
dataset
=
self
.
datasets
[
dataset_key
]
# load the point cloud of the first sequence
sequence_show
=
list
(
dataset
.
seq_annots
.
keys
())[
0
]
device
=
torch
.
device
(
"cuda:0"
)
point_cloud
,
sequence_frame_data
=
get_implicitron_sequence_pointcloud
(
dataset
,
sequence_name
=
sequence_show
,
mask_points
=
True
,
max_frames
=
max_frames
,
num_workers
=
10
,
load_dataset_point_cloud
=
load_dataset_point_cloud
,
)
# render on gpu
point_cloud
=
point_cloud
.
to
(
device
)
cameras
=
sequence_frame_data
.
camera
.
to
(
device
)
# render the point_cloud from the viewpoint of loaded cameras
images_render
=
torch
.
cat
(
[
self
.
_render_one_pointcloud
(
point_cloud
,
cameras
[
frame_i
],
(
dataset
.
image_height
,
dataset
.
image_width
,
),
)
for
frame_i
in
range
(
len
(
cameras
))
]
).
cpu
()
images_gt_and_render
=
torch
.
cat
(
[
sequence_frame_data
.
image_rgb
,
images_render
],
dim
=
3
)
imfile
=
os
.
path
.
join
(
os
.
path
.
split
(
os
.
path
.
abspath
(
__file__
))[
0
],
"test_dataset_visualize"
+
f
"_max_frames=
{
max_frames
}
"
+
f
"_load_pcl=
{
load_dataset_point_cloud
}
.png"
,
)
print
(
f
"Exporting image
{
imfile
}
."
)
torchvision
.
utils
.
save_image
(
images_gt_and_render
,
imfile
,
nrow
=
2
)
if
self
.
visdom
is
not
None
:
test_name
=
f
"
{
max_frames
}
_
{
load_dataset_point_cloud
}
_
{
dataset_key
}
"
self
.
visdom
.
images
(
images_gt_and_render
,
env
=
"test_dataset_visualize"
,
win
=
f
"pcl_renders_
{
test_name
}
"
,
opts
=
{
"title"
:
f
"pcl_renders_
{
test_name
}
"
},
)
plotlyplot
=
plot_scene
(
{
"scene_batch"
:
{
"cameras"
:
cameras
,
"point_cloud"
:
point_cloud
,
}
},
camera_scale
=
1.0
,
pointcloud_max_points
=
10000
,
pointcloud_marker_size
=
1.0
,
)
self
.
visdom
.
plotlyplot
(
plotlyplot
,
env
=
"test_dataset_visualize"
,
win
=
f
"pcl_
{
test_name
}
"
,
)
def
_change_annotations_to_new_ndc
(
dataset
):
dataset
=
copy
.
deepcopy
(
dataset
)
for
frame
in
dataset
.
frame_annots
:
vp
=
frame
[
"frame_annotation"
].
viewpoint
vp
.
intrinsics_format
=
"ndc_isotropic"
# this assume the focal length to be equal on x and y (ok for a test)
max_flength
=
max
(
vp
.
focal_length
)
vp
.
principal_point
=
(
vp
.
principal_point
[
0
]
*
max_flength
/
vp
.
focal_length
[
0
],
vp
.
principal_point
[
1
]
*
max_flength
/
vp
.
focal_length
[
1
],
)
vp
.
focal_length
=
(
max_flength
,
max_flength
,
)
return
dataset
tests/implicitron/test_eval_cameras.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
os
import
unittest
import
torch
from
pytorch3d.implicitron.tools.eval_video_trajectory
import
(
generate_eval_video_cameras
,
)
from
pytorch3d.renderer.cameras
import
PerspectiveCameras
,
look_at_view_transform
from
pytorch3d.transforms
import
axis_angle_to_matrix
if
os
.
environ
.
get
(
"FB_TEST"
,
False
):
from
common_testing
import
TestCaseMixin
else
:
from
tests.common_testing
import
TestCaseMixin
class
TestEvalCameras
(
TestCaseMixin
,
unittest
.
TestCase
):
def
setUp
(
self
):
torch
.
manual_seed
(
42
)
def
test_circular
(
self
):
n_train_cameras
=
10
n_test_cameras
=
100
R
,
T
=
look_at_view_transform
(
azim
=
torch
.
rand
(
n_train_cameras
)
*
360
)
amplitude
=
0.01
R_jiggled
=
torch
.
bmm
(
R
,
axis_angle_to_matrix
(
torch
.
rand
(
n_train_cameras
,
3
)
*
amplitude
)
)
cameras_train
=
PerspectiveCameras
(
R
=
R_jiggled
,
T
=
T
)
cameras_test
=
generate_eval_video_cameras
(
cameras_train
,
trajectory_type
=
"circular_lsq_fit"
,
trajectory_scale
=
1.0
)
positions_test
=
cameras_test
.
get_camera_center
()
center
=
positions_test
.
mean
(
0
)
self
.
assertClose
(
center
,
torch
.
zeros
(
3
),
atol
=
0.1
)
self
.
assertClose
(
(
positions_test
-
center
).
norm
(
dim
=
[
1
]),
torch
.
ones
(
n_test_cameras
),
atol
=
0.1
,
)
tests/implicitron/test_evaluation.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
contextlib
import
copy
import
dataclasses
import
math
import
os
import
unittest
import
lpips
import
torch
from
pytorch3d.implicitron.dataset.implicitron_dataset
import
(
FrameData
,
ImplicitronDataset
,
)
from
pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis
import
eval_batch
from
pytorch3d.implicitron.models.model_dbir
import
ModelDBIR
from
pytorch3d.implicitron.tools.metric_utils
import
calc_psnr
,
eval_depth
from
pytorch3d.implicitron.tools.utils
import
dataclass_to_cuda_
if
os
.
environ
.
get
(
"FB_TEST"
,
False
):
from
.common_resources
import
get_skateboard_data
,
provide_lpips_vgg
else
:
from
common_resources
import
get_skateboard_data
,
provide_lpips_vgg
class
TestEvaluation
(
unittest
.
TestCase
):
def
setUp
(
self
):
# initialize evaluation dataset/dataloader
torch
.
manual_seed
(
42
)
stack
=
contextlib
.
ExitStack
()
dataset_root
,
path_manager
=
stack
.
enter_context
(
get_skateboard_data
())
self
.
addCleanup
(
stack
.
close
)
category
=
"skateboard"
frame_file
=
os
.
path
.
join
(
dataset_root
,
category
,
"frame_annotations.jgz"
)
sequence_file
=
os
.
path
.
join
(
dataset_root
,
category
,
"sequence_annotations.jgz"
)
self
.
image_size
=
256
self
.
dataset
=
ImplicitronDataset
(
frame_annotations_file
=
frame_file
,
sequence_annotations_file
=
sequence_file
,
dataset_root
=
dataset_root
,
image_height
=
self
.
image_size
,
image_width
=
self
.
image_size
,
box_crop
=
True
,
path_manager
=
path_manager
,
)
self
.
bg_color
=
0.0
# init the lpips model for eval
provide_lpips_vgg
()
self
.
lpips_model
=
lpips
.
LPIPS
(
net
=
"vgg"
)
def
test_eval_depth
(
self
):
"""
Check that eval_depth correctly masks errors and that, for get_best_scale=True,
the error with scaled prediction equals the error without scaling the
predicted depth. Finally, test that the error values are as expected
for prediction and gt differing by a constant offset.
"""
gt
=
(
torch
.
randn
(
10
,
1
,
300
,
400
,
device
=
"cuda"
)
*
5.0
).
clamp
(
0.0
)
mask
=
(
torch
.
rand_like
(
gt
)
>
0.5
).
type_as
(
gt
)
for
diff
in
10
**
torch
.
linspace
(
-
5
,
0
,
6
):
for
crop
in
(
0
,
5
):
pred
=
gt
+
(
torch
.
rand_like
(
gt
)
-
0.5
)
*
2
*
diff
# scaled prediction test
mse_depth
,
abs_depth
=
eval_depth
(
pred
,
gt
,
crop
=
crop
,
mask
=
mask
,
get_best_scale
=
True
,
)
mse_depth_scale
,
abs_depth_scale
=
eval_depth
(
pred
*
10.0
,
gt
,
crop
=
crop
,
mask
=
mask
,
get_best_scale
=
True
,
)
self
.
assertAlmostEqual
(
float
(
mse_depth
.
sum
()),
float
(
mse_depth_scale
.
sum
()),
delta
=
1e-4
)
self
.
assertAlmostEqual
(
float
(
abs_depth
.
sum
()),
float
(
abs_depth_scale
.
sum
()),
delta
=
1e-4
)
# error masking test
pred_masked_err
=
gt
+
(
torch
.
rand_like
(
gt
)
+
diff
)
*
(
1
-
mask
)
mse_depth_masked
,
abs_depth_masked
=
eval_depth
(
pred_masked_err
,
gt
,
crop
=
crop
,
mask
=
mask
,
get_best_scale
=
True
,
)
self
.
assertAlmostEqual
(
float
(
mse_depth_masked
.
sum
()),
float
(
0.0
),
delta
=
1e-4
)
self
.
assertAlmostEqual
(
float
(
abs_depth_masked
.
sum
()),
float
(
0.0
),
delta
=
1e-4
)
mse_depth_unmasked
,
abs_depth_unmasked
=
eval_depth
(
pred_masked_err
,
gt
,
crop
=
crop
,
mask
=
1
-
mask
,
get_best_scale
=
True
,
)
self
.
assertGreater
(
float
(
mse_depth_unmasked
.
sum
()),
float
(
diff
**
2
),
)
self
.
assertGreater
(
float
(
abs_depth_unmasked
.
sum
()),
float
(
diff
),
)
# tests with constant error
pred_fix_diff
=
gt
+
diff
*
mask
for
_mask_gt
in
(
mask
,
None
):
mse_depth_fix_diff
,
abs_depth_fix_diff
=
eval_depth
(
pred_fix_diff
,
gt
,
crop
=
crop
,
mask
=
_mask_gt
,
get_best_scale
=
False
,
)
if
_mask_gt
is
not
None
:
expected_err_abs
=
diff
expected_err_mse
=
diff
**
2
else
:
err_mask
=
(
gt
>
0.0
).
float
()
*
mask
if
crop
>
0
:
err_mask
=
err_mask
[:,
:,
crop
:
-
crop
,
crop
:
-
crop
]
gt_cropped
=
gt
[:,
:,
crop
:
-
crop
,
crop
:
-
crop
]
else
:
gt_cropped
=
gt
gt_mass
=
(
gt_cropped
>
0.0
).
float
().
sum
(
dim
=
(
1
,
2
,
3
))
expected_err_abs
=
(
diff
*
err_mask
.
sum
(
dim
=
(
1
,
2
,
3
))
/
(
gt_mass
)
)
expected_err_mse
=
diff
*
expected_err_abs
self
.
assertTrue
(
torch
.
allclose
(
abs_depth_fix_diff
,
expected_err_abs
*
torch
.
ones_like
(
abs_depth_fix_diff
),
atol
=
1e-4
,
)
)
self
.
assertTrue
(
torch
.
allclose
(
mse_depth_fix_diff
,
expected_err_mse
*
torch
.
ones_like
(
mse_depth_fix_diff
),
atol
=
1e-4
,
)
)
def
test_psnr
(
self
):
"""
Compare against opencv and check that the psnr is above
the minimum possible value.
"""
import
cv2
im1
=
torch
.
rand
(
100
,
3
,
256
,
256
).
cuda
()
im1_uint8
=
(
im1
*
255
).
to
(
torch
.
uint8
)
im1_rounded
=
im1_uint8
.
float
()
/
255
for
max_diff
in
10
**
torch
.
linspace
(
-
5
,
0
,
6
):
im2
=
im1
+
(
torch
.
rand_like
(
im1
)
-
0.5
)
*
2
*
max_diff
im2
=
im2
.
clamp
(
0.0
,
1.0
)
im2_uint8
=
(
im2
*
255
).
to
(
torch
.
uint8
)
im2_rounded
=
im2_uint8
.
float
()
/
255
# check that our psnr matches the output of opencv
psnr
=
calc_psnr
(
im1_rounded
,
im2_rounded
)
# some versions of cv2 can only take uint8 input
psnr_cv2
=
cv2
.
PSNR
(
im1_uint8
.
cpu
().
numpy
(),
im2_uint8
.
cpu
().
numpy
(),
)
self
.
assertAlmostEqual
(
float
(
psnr
),
float
(
psnr_cv2
),
delta
=
1e-4
)
# check that all PSNRs are bigger than the minimum possible PSNR
max_mse
=
max_diff
**
2
min_psnr
=
10
*
math
.
log10
(
1.0
/
max_mse
)
for
_im1
,
_im2
in
zip
(
im1
,
im2
):
_psnr
=
calc_psnr
(
_im1
,
_im2
)
self
.
assertGreaterEqual
(
float
(
_psnr
)
+
1e-6
,
min_psnr
)
def
_one_sequence_test
(
self
,
seq_dataset
,
n_batches
=
2
,
min_batch_size
=
5
,
max_batch_size
=
10
,
):
# form a list of random batches
batch_indices
=
[]
for
_
in
range
(
n_batches
):
batch_size
=
torch
.
randint
(
low
=
min_batch_size
,
high
=
max_batch_size
,
size
=
(
1
,)
)
batch_indices
.
append
(
torch
.
randperm
(
len
(
seq_dataset
))[:
batch_size
])
loader
=
torch
.
utils
.
data
.
DataLoader
(
seq_dataset
,
# batch_size=1,
shuffle
=
False
,
batch_sampler
=
batch_indices
,
collate_fn
=
FrameData
.
collate
,
)
model
=
ModelDBIR
(
image_size
=
self
.
image_size
,
bg_color
=
self
.
bg_color
)
model
.
cuda
()
self
.
lpips_model
.
cuda
()
for
frame_data
in
loader
:
self
.
assertIsNone
(
frame_data
.
frame_type
)
self
.
assertIsNotNone
(
frame_data
.
image_rgb
)
# override the frame_type
frame_data
.
frame_type
=
[
"train_unseen"
,
*
([
"train_known"
]
*
(
len
(
frame_data
.
image_rgb
)
-
1
)),
]
# move frame_data to gpu
frame_data
=
dataclass_to_cuda_
(
frame_data
)
preds
=
model
(
**
dataclasses
.
asdict
(
frame_data
))
nvs_prediction
=
copy
.
deepcopy
(
preds
[
"nvs_prediction"
])
eval_result
=
eval_batch
(
frame_data
,
nvs_prediction
,
bg_color
=
self
.
bg_color
,
lpips_model
=
self
.
lpips_model
,
)
# Make a terribly bad NVS prediction and check that this is worse
# than the DBIR prediction.
nvs_prediction_bad
=
copy
.
deepcopy
(
preds
[
"nvs_prediction"
])
nvs_prediction_bad
.
depth_render
+=
(
torch
.
randn_like
(
nvs_prediction
.
depth_render
)
*
100.0
)
nvs_prediction_bad
.
image_render
+=
(
torch
.
randn_like
(
nvs_prediction
.
image_render
)
*
100.0
)
nvs_prediction_bad
.
mask_render
=
(
torch
.
randn_like
(
nvs_prediction
.
mask_render
)
>
0.0
).
float
()
eval_result_bad
=
eval_batch
(
frame_data
,
nvs_prediction_bad
,
bg_color
=
self
.
bg_color
,
lpips_model
=
self
.
lpips_model
,
)
lower_better
=
{
"psnr"
:
False
,
"psnr_fg"
:
False
,
"depth_abs_fg"
:
True
,
"iou"
:
False
,
"rgb_l1"
:
True
,
"rgb_l1_fg"
:
True
,
}
for
metric
in
lower_better
.
keys
():
m_better
=
eval_result
[
metric
]
m_worse
=
eval_result_bad
[
metric
]
if
m_better
!=
m_better
or
m_worse
!=
m_worse
:
continue
# metric is missing, i.e. NaN
_assert
=
(
self
.
assertLessEqual
if
lower_better
[
metric
]
else
self
.
assertGreaterEqual
)
_assert
(
m_better
,
m_worse
)
def
test_full_eval
(
self
,
n_sequences
=
5
):
"""Test evaluation."""
for
_
,
idx
in
list
(
self
.
dataset
.
seq_to_idx
.
items
())[:
n_sequences
]:
seq_dataset
=
torch
.
utils
.
data
.
Subset
(
self
.
dataset
,
idx
)
self
.
_one_sequence_test
(
seq_dataset
)
tests/implicitron/test_forward_pass.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
unittest
import
torch
from
pytorch3d.implicitron.models.base
import
GenericModel
from
pytorch3d.implicitron.models.renderer.base
import
EvaluationMode
from
pytorch3d.implicitron.tools.config
import
expand_args_fields
from
pytorch3d.renderer.cameras
import
PerspectiveCameras
,
look_at_view_transform
class
TestGenericModel
(
unittest
.
TestCase
):
def
test_gm
(
self
):
# Simple test of a forward pass of the default GenericModel.
device
=
torch
.
device
(
"cuda:1"
)
expand_args_fields
(
GenericModel
)
model
=
GenericModel
()
model
.
to
(
device
)
n_train_cameras
=
2
R
,
T
=
look_at_view_transform
(
azim
=
torch
.
rand
(
n_train_cameras
)
*
360
)
cameras
=
PerspectiveCameras
(
R
=
R
,
T
=
T
,
device
=
device
)
# TODO: make these default to None?
defaulted_args
=
{
"fg_probability"
:
None
,
"depth_map"
:
None
,
"mask_crop"
:
None
,
"sequence_name"
:
None
,
}
with
self
.
assertWarnsRegex
(
UserWarning
,
"No main objective found"
):
model
(
camera
=
cameras
,
evaluation_mode
=
EvaluationMode
.
TRAINING
,
**
defaulted_args
,
image_rgb
=
None
,
)
target_image_rgb
=
torch
.
rand
(
(
n_train_cameras
,
3
,
model
.
render_image_height
,
model
.
render_image_width
),
device
=
device
,
)
train_preds
=
model
(
camera
=
cameras
,
evaluation_mode
=
EvaluationMode
.
TRAINING
,
image_rgb
=
target_image_rgb
,
**
defaulted_args
,
)
self
.
assertGreater
(
train_preds
[
"objective"
].
item
(),
0
)
model
.
eval
()
with
torch
.
no_grad
():
# TODO: perhaps this warning should be skipped in eval mode?
with
self
.
assertWarnsRegex
(
UserWarning
,
"No main objective found"
):
eval_preds
=
model
(
camera
=
cameras
[
0
],
**
defaulted_args
,
image_rgb
=
None
,
)
self
.
assertEqual
(
eval_preds
[
"images_render"
].
shape
,
(
1
,
3
,
model
.
render_image_height
,
model
.
render_image_width
),
)
tests/implicitron/test_ray_point_refiner.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
os
import
unittest
import
torch
from
pytorch3d.implicitron.models.renderer.ray_point_refiner
import
RayPointRefiner
from
pytorch3d.renderer
import
RayBundle
if
os
.
environ
.
get
(
"FB_TEST"
,
False
):
from
common_testing
import
TestCaseMixin
else
:
from
tests.common_testing
import
TestCaseMixin
class
TestRayPointRefiner
(
TestCaseMixin
,
unittest
.
TestCase
):
def
test_simple
(
self
):
length
=
15
n_pts_per_ray
=
10
for
add_input_samples
in
[
False
,
True
]:
ray_point_refiner
=
RayPointRefiner
(
n_pts_per_ray
=
n_pts_per_ray
,
random_sampling
=
False
,
add_input_samples
=
add_input_samples
,
)
lengths
=
torch
.
arange
(
length
,
dtype
=
torch
.
float32
).
expand
(
3
,
25
,
length
)
bundle
=
RayBundle
(
lengths
=
lengths
,
origins
=
None
,
directions
=
None
,
xys
=
None
)
weights
=
torch
.
ones
(
3
,
25
,
length
)
refined
=
ray_point_refiner
(
bundle
,
weights
)
self
.
assertIsNone
(
refined
.
directions
)
self
.
assertIsNone
(
refined
.
origins
)
self
.
assertIsNone
(
refined
.
xys
)
expected
=
torch
.
linspace
(
0.5
,
length
-
1.5
,
n_pts_per_ray
)
expected
=
expected
.
expand
(
3
,
25
,
n_pts_per_ray
)
if
add_input_samples
:
full_expected
=
torch
.
cat
((
lengths
,
expected
),
dim
=-
1
).
sort
()[
0
]
else
:
full_expected
=
expected
self
.
assertClose
(
refined
.
lengths
,
full_expected
)
ray_point_refiner_random
=
RayPointRefiner
(
n_pts_per_ray
=
n_pts_per_ray
,
random_sampling
=
True
,
add_input_samples
=
add_input_samples
,
)
refined_random
=
ray_point_refiner_random
(
bundle
,
weights
)
lengths_random
=
refined_random
.
lengths
self
.
assertEqual
(
lengths_random
.
shape
,
full_expected
.
shape
)
if
not
add_input_samples
:
self
.
assertGreater
(
lengths_random
.
min
().
item
(),
0.5
)
self
.
assertLess
(
lengths_random
.
max
().
item
(),
length
-
1.5
)
# Check sorted
self
.
assertTrue
(
(
lengths_random
[...,
1
:]
-
lengths_random
[...,
:
-
1
]
>
0
).
all
()
)
tests/implicitron/test_srn.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
os
import
unittest
import
torch
from
pytorch3d.implicitron.models.implicit_function.scene_representation_networks
import
(
SRNHyperNetImplicitFunction
,
SRNImplicitFunction
,
SRNPixelGenerator
,
)
from
pytorch3d.implicitron.models.renderer.base
import
ImplicitFunctionWrapper
from
pytorch3d.implicitron.tools.config
import
get_default_args
from
pytorch3d.renderer
import
RayBundle
if
os
.
environ
.
get
(
"FB_TEST"
,
False
):
from
common_testing
import
TestCaseMixin
else
:
from
tests.common_testing
import
TestCaseMixin
_BATCH_SIZE
:
int
=
3
_N_RAYS
:
int
=
100
_N_POINTS_ON_RAY
:
int
=
10
class
TestSRN
(
TestCaseMixin
,
unittest
.
TestCase
):
def
setUp
(
self
)
->
None
:
torch
.
manual_seed
(
42
)
get_default_args
(
SRNHyperNetImplicitFunction
)
get_default_args
(
SRNImplicitFunction
)
def
test_pixel_generator
(
self
):
SRNPixelGenerator
()
def
_get_bundle
(
self
,
*
,
device
)
->
RayBundle
:
origins
=
torch
.
rand
(
_BATCH_SIZE
,
_N_RAYS
,
3
,
device
=
device
)
directions
=
torch
.
rand
(
_BATCH_SIZE
,
_N_RAYS
,
3
,
device
=
device
)
lengths
=
torch
.
rand
(
_BATCH_SIZE
,
_N_RAYS
,
_N_POINTS_ON_RAY
,
device
=
device
)
bundle
=
RayBundle
(
lengths
=
lengths
,
origins
=
origins
,
directions
=
directions
,
xys
=
None
)
return
bundle
def
test_srn_implicit_function
(
self
):
implicit_function
=
SRNImplicitFunction
()
device
=
torch
.
device
(
"cpu"
)
bundle
=
self
.
_get_bundle
(
device
=
device
)
rays_densities
,
rays_colors
=
implicit_function
(
bundle
)
out_features
=
implicit_function
.
raymarch_function
.
out_features
self
.
assertEqual
(
rays_densities
.
shape
,
(
_BATCH_SIZE
,
_N_RAYS
,
_N_POINTS_ON_RAY
,
out_features
),
)
self
.
assertIsNone
(
rays_colors
)
def
test_srn_hypernet_implicit_function
(
self
):
# TODO investigate: If latent_dim_hypernet=0, why does this crash and dump core?
latent_dim_hypernet
=
39
hypernet_args
=
{
"latent_dim_hypernet"
:
latent_dim_hypernet
}
device
=
torch
.
device
(
"cuda:0"
)
implicit_function
=
SRNHyperNetImplicitFunction
(
hypernet_args
=
hypernet_args
)
implicit_function
.
to
(
device
)
global_code
=
torch
.
rand
(
_BATCH_SIZE
,
latent_dim_hypernet
,
device
=
device
)
bundle
=
self
.
_get_bundle
(
device
=
device
)
rays_densities
,
rays_colors
=
implicit_function
(
bundle
,
global_code
=
global_code
)
out_features
=
implicit_function
.
hypernet
.
out_features
self
.
assertEqual
(
rays_densities
.
shape
,
(
_BATCH_SIZE
,
_N_RAYS
,
_N_POINTS_ON_RAY
,
out_features
),
)
self
.
assertIsNone
(
rays_colors
)
def
test_srn_hypernet_implicit_function_optim
(
self
):
# Test optimization loop, requiring that the cache is properly
# cleared in new_args_bound
latent_dim_hypernet
=
39
hyper_args
=
{
"latent_dim_hypernet"
:
latent_dim_hypernet
}
device
=
torch
.
device
(
"cuda:0"
)
global_code
=
torch
.
rand
(
_BATCH_SIZE
,
latent_dim_hypernet
,
device
=
device
)
bundle
=
self
.
_get_bundle
(
device
=
device
)
implicit_function
=
SRNHyperNetImplicitFunction
(
hypernet_args
=
hyper_args
)
implicit_function2
=
SRNHyperNetImplicitFunction
(
hypernet_args
=
hyper_args
)
implicit_function
.
to
(
device
)
implicit_function2
.
to
(
device
)
wrapper
=
ImplicitFunctionWrapper
(
implicit_function
)
optimizer
=
torch
.
optim
.
Adam
(
implicit_function
.
parameters
())
for
_step
in
range
(
3
):
optimizer
.
zero_grad
()
wrapper
.
bind_args
(
global_code
=
global_code
)
rays_densities
,
_rays_colors
=
wrapper
(
bundle
)
wrapper
.
unbind_args
()
loss
=
rays_densities
.
sum
()
loss
.
backward
()
optimizer
.
step
()
wrapper2
=
ImplicitFunctionWrapper
(
implicit_function
)
optimizer2
=
torch
.
optim
.
Adam
(
implicit_function2
.
parameters
())
implicit_function2
.
load_state_dict
(
implicit_function
.
state_dict
())
optimizer2
.
load_state_dict
(
optimizer
.
state_dict
())
for
_step
in
range
(
3
):
optimizer2
.
zero_grad
()
wrapper2
.
bind_args
(
global_code
=
global_code
)
rays_densities
,
_rays_colors
=
wrapper2
(
bundle
)
wrapper2
.
unbind_args
()
loss
=
rays_densities
.
sum
()
loss
.
backward
()
optimizer2
.
step
()
tests/implicitron/test_types.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
dataclasses
import
unittest
from
typing
import
Dict
,
List
,
NamedTuple
,
Tuple
from
pytorch3d.implicitron.dataset
import
types
from
pytorch3d.implicitron.dataset.types
import
FrameAnnotation
class
_NT
(
NamedTuple
):
annot
:
FrameAnnotation
class
TestDatasetTypes
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
entry
=
FrameAnnotation
(
frame_number
=
23
,
sequence_name
=
"1"
,
frame_timestamp
=
1.2
,
image
=
types
.
ImageAnnotation
(
path
=
"/tmp/1.jpg"
,
size
=
(
224
,
224
)),
mask
=
types
.
MaskAnnotation
(
path
=
"/tmp/1.png"
,
mass
=
42.0
),
viewpoint
=
types
.
ViewpointAnnotation
(
R
=
(
(
1
,
0
,
0
),
(
1
,
0
,
0
),
(
1
,
0
,
0
),
),
T
=
(
0
,
0
,
0
),
principal_point
=
(
100
,
100
),
focal_length
=
(
200
,
200
),
),
)
def
test_asdict_rec
(
self
):
first
=
[
dataclasses
.
asdict
(
self
.
entry
)]
second
=
types
.
_asdict_rec
([
self
.
entry
])
self
.
assertEqual
(
first
,
second
)
def
test_parsing
(
self
):
"""Test that we handle collections enclosing dataclasses."""
dct
=
dataclasses
.
asdict
(
self
.
entry
)
parsed
=
types
.
_dataclass_from_dict
(
dct
,
FrameAnnotation
)
self
.
assertEqual
(
parsed
,
self
.
entry
)
# namedtuple
parsed
=
types
.
_dataclass_from_dict
(
_NT
(
dct
),
_NT
)
self
.
assertEqual
(
parsed
.
annot
,
self
.
entry
)
# tuple
parsed
=
types
.
_dataclass_from_dict
((
dct
,),
Tuple
[
FrameAnnotation
])
self
.
assertEqual
(
parsed
,
(
self
.
entry
,))
# list
parsed
=
types
.
_dataclass_from_dict
(
[
dct
,
],
List
[
FrameAnnotation
],
)
self
.
assertEqual
(
parsed
,
[
self
.
entry
,
],
)
# dict
parsed
=
types
.
_dataclass_from_dict
({
"k"
:
dct
},
Dict
[
str
,
FrameAnnotation
])
self
.
assertEqual
(
parsed
,
{
"k"
:
self
.
entry
})
def
test_parsing_vectorized
(
self
):
dct
=
dataclasses
.
asdict
(
self
.
entry
)
self
.
_compare_with_scalar
(
dct
,
FrameAnnotation
)
self
.
_compare_with_scalar
(
_NT
(
dct
),
_NT
)
self
.
_compare_with_scalar
((
dct
,),
Tuple
[
FrameAnnotation
])
self
.
_compare_with_scalar
([
dct
],
List
[
FrameAnnotation
])
self
.
_compare_with_scalar
({
"k"
:
dct
},
Dict
[
str
,
FrameAnnotation
])
def
_compare_with_scalar
(
self
,
obj
,
typeannot
,
repeat
=
3
):
input
=
[
obj
]
*
3
vect_output
=
types
.
_dataclass_list_from_dict_list
(
input
,
typeannot
)
self
.
assertEqual
(
len
(
input
),
repeat
)
gt
=
types
.
_dataclass_from_dict
(
obj
,
typeannot
)
self
.
assertTrue
(
all
(
res
==
gt
for
res
in
vect_output
))
tests/implicitron/test_viewsampling.py
0 → 100644
View file @
cdd2142d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import
unittest
import
pytorch3d
as
pt3d
import
torch
from
pytorch3d.implicitron.models.view_pooling.view_sampling
import
ViewSampler
from
pytorch3d.implicitron.tools.config
import
expand_args_fields
class
TestViewsampling
(
unittest
.
TestCase
):
def
setUp
(
self
):
torch
.
manual_seed
(
42
)
expand_args_fields
(
ViewSampler
)
def
_init_view_sampler_problem
(
self
,
random_masks
):
"""
Generates a view-sampling problem:
- 4 source views, 1st/2nd from the first sequence 'seq1', the rest from 'seq2'
- 3 sets of 3D points from sequences 'seq1', 'seq2', 'seq2' respectively.
- first 50 points in each batch correctly project to the source views,
while the remaining 50 do not land in any projection plane.
- each source view is labeled with image feature tensors of shape 7x100x50,
where all elements of the n-th tensor are set to `n+1`.
- the elements of the source view masks are either set to random binary number
(if `random_masks==True`), or all set to 1 (`random_masks==False`).
- the source view cameras are uniformly distributed on a unit circle
in the x-z plane and look at (0,0,0).
"""
seq_id_camera
=
[
"seq1"
,
"seq1"
,
"seq2"
,
"seq2"
]
seq_id_pts
=
[
"seq1"
,
"seq2"
,
"seq2"
]
pts_batch
=
3
n_pts
=
100
n_views
=
4
fdim
=
7
H
=
100
W
=
50
# points that land into the projection planes of all cameras
pts_inside
=
(
torch
.
nn
.
functional
.
normalize
(
torch
.
randn
(
pts_batch
,
n_pts
//
2
,
3
,
device
=
"cuda"
),
dim
=-
1
,
)
*
0.1
)
# move the outside points far above the scene
pts_outside
=
pts_inside
.
clone
()
pts_outside
[:,
:,
1
]
+=
1e8
pts
=
torch
.
cat
([
pts_inside
,
pts_outside
],
dim
=
1
)
R
,
T
=
pt3d
.
renderer
.
look_at_view_transform
(
dist
=
1.0
,
elev
=
0.0
,
azim
=
torch
.
linspace
(
0
,
360
,
n_views
+
1
)[:
n_views
],
degrees
=
True
,
device
=
pts
.
device
,
)
focal_length
=
R
.
new_ones
(
n_views
,
2
)
principal_point
=
R
.
new_zeros
(
n_views
,
2
)
camera
=
pt3d
.
renderer
.
PerspectiveCameras
(
R
=
R
,
T
=
T
,
focal_length
=
focal_length
,
principal_point
=
principal_point
,
device
=
pts
.
device
,
)
feats_map
=
torch
.
arange
(
n_views
,
device
=
pts
.
device
,
dtype
=
pts
.
dtype
)
+
1
feats
=
{
"feats"
:
feats_map
[:,
None
,
None
,
None
].
repeat
(
1
,
fdim
,
H
,
W
)}
masks
=
(
torch
.
rand
(
n_views
,
1
,
H
,
W
,
device
=
pts
.
device
,
dtype
=
pts
.
dtype
)
>
0.5
).
type_as
(
R
)
if
not
random_masks
:
masks
[:]
=
1.0
return
pts
,
camera
,
feats
,
masks
,
seq_id_camera
,
seq_id_pts
def
test_compare_with_naive
(
self
):
"""
Compares the outputs of the efficient ViewSampler module with a
naive implementation.
"""
(
pts
,
camera
,
feats
,
masks
,
seq_id_camera
,
seq_id_pts
,
)
=
self
.
_init_view_sampler_problem
(
True
)
for
masked_sampling
in
(
True
,
False
):
feats_sampled_n
,
masks_sampled_n
=
_view_sample_naive
(
pts
,
seq_id_pts
,
camera
,
seq_id_camera
,
feats
,
masks
,
masked_sampling
,
)
# make sure we generate the constructor for ViewSampler
expand_args_fields
(
ViewSampler
)
view_sampler
=
ViewSampler
(
masked_sampling
=
masked_sampling
)
feats_sampled
,
masks_sampled
=
view_sampler
(
pts
=
pts
,
seq_id_pts
=
seq_id_pts
,
camera
=
camera
,
seq_id_camera
=
seq_id_camera
,
feats
=
feats
,
masks
=
masks
,
)
for
k
in
feats_sampled
.
keys
():
self
.
assertTrue
(
torch
.
allclose
(
feats_sampled
[
k
],
feats_sampled_n
[
k
]))
self
.
assertTrue
(
torch
.
allclose
(
masks_sampled
,
masks_sampled_n
))
def
test_viewsampling
(
self
):
"""
Generates a viewsampling problem with predictable outcome, and compares
the ViewSampler's output to the expected result.
"""
(
pts
,
camera
,
feats
,
masks
,
seq_id_camera
,
seq_id_pts
,
)
=
self
.
_init_view_sampler_problem
(
False
)
expand_args_fields
(
ViewSampler
)
for
masked_sampling
in
(
True
,
False
):
view_sampler
=
ViewSampler
(
masked_sampling
=
masked_sampling
)
feats_sampled
,
masks_sampled
=
view_sampler
(
pts
=
pts
,
seq_id_pts
=
seq_id_pts
,
camera
=
camera
,
seq_id_camera
=
seq_id_camera
,
feats
=
feats
,
masks
=
masks
,
)
n_views
=
camera
.
R
.
shape
[
0
]
n_pts
=
pts
.
shape
[
1
]
feat_dim
=
feats
[
"feats"
].
shape
[
1
]
pts_batch
=
pts
.
shape
[
0
]
n_pts_away
=
n_pts
//
2
for
pts_i
in
range
(
pts_batch
):
for
view_i
in
range
(
n_views
):
if
seq_id_pts
[
pts_i
]
!=
seq_id_camera
[
view_i
]:
# points / cameras come from different sequences
gt_masks
=
pts
.
new_zeros
(
n_pts
,
1
)
gt_feats
=
pts
.
new_zeros
(
n_pts
,
feat_dim
)
else
:
gt_masks
=
pts
.
new_ones
(
n_pts
,
1
)
gt_feats
=
pts
.
new_ones
(
n_pts
,
feat_dim
)
*
(
view_i
+
1
)
gt_feats
[
n_pts_away
:]
=
0.0
if
masked_sampling
:
gt_masks
[
n_pts_away
:]
=
0.0
for
k
in
feats_sampled
:
self
.
assertTrue
(
torch
.
allclose
(
feats_sampled
[
k
][
pts_i
,
view_i
],
gt_feats
,
)
)
self
.
assertTrue
(
torch
.
allclose
(
masks_sampled
[
pts_i
,
view_i
],
gt_masks
,
)
)
def
_view_sample_naive
(
pts
,
seq_id_pts
,
camera
,
seq_id_camera
,
feats
,
masks
,
masked_sampling
,
):
"""
A naive implementation of the forward pass of ViewSampler.
Refer to ViewSampler's docstring for description of the arguments.
"""
pts_batch
=
pts
.
shape
[
0
]
n_views
=
camera
.
R
.
shape
[
0
]
n_pts
=
pts
.
shape
[
1
]
feats_sampled
=
[[[]
for
_
in
range
(
n_views
)]
for
_
in
range
(
pts_batch
)]
masks_sampled
=
[[[]
for
_
in
range
(
n_views
)]
for
_
in
range
(
pts_batch
)]
for
pts_i
in
range
(
pts_batch
):
for
view_i
in
range
(
n_views
):
if
seq_id_pts
[
pts_i
]
!=
seq_id_camera
[
view_i
]:
# points/cameras come from different sequences
feats_sampled_
=
{
k
:
f
.
new_zeros
(
n_pts
,
f
.
shape
[
1
])
for
k
,
f
in
feats
.
items
()
}
masks_sampled_
=
masks
.
new_zeros
(
n_pts
,
1
)
else
:
# same sequence of pts and cameras -> sample
feats_sampled_
,
masks_sampled_
=
_sample_one_view_naive
(
camera
[
view_i
],
pts
[
pts_i
],
{
k
:
f
[
view_i
]
for
k
,
f
in
feats
.
items
()},
masks
[
view_i
],
masked_sampling
,
sampling_mode
=
"bilinear"
,
)
feats_sampled
[
pts_i
][
view_i
]
=
feats_sampled_
masks_sampled
[
pts_i
][
view_i
]
=
masks_sampled_
masks_sampled_cat
=
torch
.
stack
([
torch
.
stack
(
m
)
for
m
in
masks_sampled
])
feats_sampled_cat
=
{}
for
k
in
feats_sampled
[
0
][
0
].
keys
():
feats_sampled_cat
[
k
]
=
torch
.
stack
(
[
torch
.
stack
([
f_
[
k
]
for
f_
in
f
])
for
f
in
feats_sampled
]
)
return
feats_sampled_cat
,
masks_sampled_cat
def
_sample_one_view_naive
(
camera
,
pts
,
feats
,
masks
,
masked_sampling
,
sampling_mode
=
"bilinear"
,
):
"""
Sample a single source view.
"""
proj_ndc
=
camera
.
transform_points
(
pts
[
None
])[
None
,
...,
:
-
1
]
# 1 x 1 x n_pts x 2
feats_sampled
=
{
k
:
pt3d
.
renderer
.
ndc_grid_sample
(
f
[
None
],
proj_ndc
,
mode
=
sampling_mode
).
permute
(
0
,
3
,
1
,
2
)[
0
,
:,
:,
0
]
for
k
,
f
in
feats
.
items
()
}
# n_pts x dim
if
not
masked_sampling
:
n_pts
=
pts
.
shape
[
0
]
masks_sampled
=
proj_ndc
.
new_ones
(
n_pts
,
1
)
else
:
masks_sampled
=
pt3d
.
renderer
.
ndc_grid_sample
(
masks
[
None
],
proj_ndc
,
mode
=
sampling_mode
,
align_corners
=
False
,
)[
0
,
0
,
0
,
:][:,
None
]
return
feats_sampled
,
masks_sampled
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment