Commit 879495d3 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

omegaconf 2.2.2 compatibility

Summary: OmegaConf 2.2.2 doesn't like heterogenous tuples or Sequence or Set members. Workaround this.

Reviewed By: shapovalov

Differential Revision: D37278736

fbshipit-source-id: 123e6657947f5b27514910e4074c92086a457a2a
parent 5c1ca757
......@@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Sequence
from typing import Optional, Tuple
import torch
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
......@@ -86,7 +86,7 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
num_workers: int = 0
dataset_len: int = 1000
dataset_len_val: int = 1
images_per_seq_options: Sequence[int] = (2,)
images_per_seq_options: Tuple[int, ...] = (2,)
sample_consecutive_frames: bool = False
consecutive_frames_max_gap: int = 0
consecutive_frames_max_gap_seconds: float = 0.1
......
......@@ -122,9 +122,9 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
subsets: Optional[List[str]] = None
limit_to: int = 0
limit_sequences_to: int = 0
pick_sequence: Sequence[str] = ()
exclude_sequence: Sequence[str] = ()
limit_category_to: Sequence[int] = ()
pick_sequence: Tuple[str, ...] = ()
exclude_sequence: Tuple[str, ...] = ()
limit_category_to: Tuple[int, ...] = ()
dataset_root: str = ""
load_images: bool = True
load_depths: bool = True
......
......@@ -7,7 +7,7 @@
import json
import os
from typing import Dict, List, Sequence, Tuple, Type
from typing import Dict, List, Tuple, Type
from omegaconf import DictConfig, open_dict
from pytorch3d.implicitron.tools.config import (
......@@ -98,7 +98,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
dataset_root: str = _CO3D_DATASET_ROOT
n_frames_per_sequence: int = -1
test_on_train: bool = False
restrict_sequence_name: Sequence[str] = ()
restrict_sequence_name: Tuple[str, ...] = ()
test_restrict_sequence_id: int = -1
assert_single_seq: bool = False
only_test_set: bool = False
......
......@@ -3,7 +3,7 @@
# implicit_differentiable_renderer.py
# Copyright (c) 2020 Lior Yariv
import math
from typing import Sequence
from typing import Tuple
import torch
from pytorch3d.implicitron.tools.config import registry
......@@ -53,10 +53,10 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
feature_vector_size: int = 3
d_in: int = 3
d_out: int = 1
dims: Sequence[int] = (512, 512, 512, 512, 512, 512, 512, 512)
dims: Tuple[int, ...] = (512, 512, 512, 512, 512, 512, 512, 512)
geometric_init: bool = True
bias: float = 1.0
skip_in: Sequence[int] = ()
skip_in: Tuple[int, ...] = ()
weight_norm: bool = True
n_harmonic_functions_xyz: int = 0
pooled_feature_dim: int = 0
......
......@@ -6,7 +6,7 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Optional, Sequence, Union
from typing import Dict, Optional, Sequence, Tuple, Union
import torch
import torch.nn.functional as F
......@@ -176,7 +176,7 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
the stack of source-view-specific features to a single feature.
"""
reduction_functions: Sequence[ReductionFunction] = (
reduction_functions: Tuple[ReductionFunction, ...] = (
ReductionFunction.AVG,
ReductionFunction.STD,
)
......@@ -269,7 +269,7 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator
used when calculating the angle-based aggregation weights.
"""
reduction_functions: Sequence[ReductionFunction] = (
reduction_functions: Tuple[ReductionFunction, ...] = (
ReductionFunction.AVG,
ReductionFunction.STD,
)
......
......@@ -9,7 +9,7 @@ import textwrap
import unittest
from dataclasses import dataclass, field, is_dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Tuple
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
from pytorch3d.implicitron.tools.config import (
......@@ -484,16 +484,14 @@ class TestConfig(unittest.TestCase):
none_field: Optional[int] = None
float_field: float = 9.3
bool_field: bool = True
tuple_field: tuple = (3, True, "j")
tuple_field: Tuple[int, ...] = (3,)
class SimpleClass:
def __init__(
self,
tuple_member_: Tuple[int, int] = (3, 4),
set_member_: Set[int] = {2}, # noqa
):
self.tuple_member = tuple_member_
self.set_member = set_member_
def get_tuple(self):
return self.tuple_member
......@@ -524,11 +522,9 @@ class TestConfig(unittest.TestCase):
self.assertEqual(simple.get_tuple(), [3, 4])
self.assertTrue(isinstance(simple.get_tuple(), ListConfig))
# get_default_args converts sets to ListConfigs (which act like lists).
self.assertEqual(simple.set_member, [2])
self.assertTrue(isinstance(simple.set_member, ListConfig))
self.assertEqual(c.a_tuple, [4.0, 3.0])
self.assertTrue(isinstance(c.a_tuple, ListConfig))
self.assertEqual(mydata.tuple_field, (3, True, "j"))
self.assertEqual(mydata.tuple_field, (3,))
self.assertTrue(isinstance(mydata.tuple_field, ListConfig))
f(**c.f_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment