Commit f45893b8 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

clean raysampler args

Summary: Don't copy from one part of config to another, rather do the copy within GenericModel.

Reviewed By: davnov134

Differential Revision: D38248828

fbshipit-source-id: ff8af985c37ea1f7df9e0aa0a45a58df34c3f893
parent 5f069dbb
...@@ -191,10 +191,6 @@ model_factory_ImplicitronModelFactory_args: ...@@ -191,10 +191,6 @@ model_factory_ImplicitronModelFactory_args:
init_scale: 1.0 init_scale: 1.0
ignore_input: false ignore_input: false
raysampler_AdaptiveRaySampler_args: raysampler_AdaptiveRaySampler_args:
image_width: 400
image_height: 400
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
n_pts_per_ray_training: 64 n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64 n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024 n_rays_per_image_sampled_from_mask: 1024
...@@ -206,10 +202,6 @@ model_factory_ImplicitronModelFactory_args: ...@@ -206,10 +202,6 @@ model_factory_ImplicitronModelFactory_args:
- 0.0 - 0.0
- 0.0 - 0.0
raysampler_NearFarRaySampler_args: raysampler_NearFarRaySampler_args:
image_width: 400
image_height: 400
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
n_pts_per_ray_training: 64 n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64 n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024 n_rays_per_image_sampled_from_mask: 1024
......
...@@ -9,7 +9,7 @@ import json ...@@ -9,7 +9,7 @@ import json
import os import os
from typing import Dict, List, Optional, Tuple, Type from typing import Dict, List, Optional, Tuple, Type
from omegaconf import DictConfig, open_dict from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
expand_args_fields, expand_args_fields,
registry, registry,
......
...@@ -11,7 +11,7 @@ import os ...@@ -11,7 +11,7 @@ import os
import warnings import warnings
from typing import Dict, List, Optional, Tuple, Type from typing import Dict, List, Optional, Tuple, Type
from omegaconf import DictConfig, open_dict from omegaconf import DictConfig
from pytorch3d.implicitron.dataset.dataset_map_provider import ( from pytorch3d.implicitron.dataset.dataset_map_provider import (
DatasetMap, DatasetMap,
DatasetMapProviderBase, DatasetMapProviderBase,
......
...@@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import tqdm import tqdm
from omegaconf import DictConfig
from pytorch3d.implicitron.models.metrics import ( from pytorch3d.implicitron.models.metrics import (
RegularizationMetricsBase, RegularizationMetricsBase,
ViewMetricsBase, ViewMetricsBase,
...@@ -27,7 +28,7 @@ from pytorch3d.implicitron.tools.config import ( ...@@ -27,7 +28,7 @@ from pytorch3d.implicitron.tools.config import (
run_auto_creation, run_auto_creation,
) )
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
from pytorch3d.implicitron.tools.utils import cat_dataclass, setattr_if_hasattr from pytorch3d.implicitron.tools.utils import cat_dataclass
from pytorch3d.renderer import RayBundle, utils as rend_utils from pytorch3d.renderer import RayBundle, utils as rend_utils
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from visdom import Visdom from visdom import Visdom
...@@ -615,20 +616,29 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 ...@@ -615,20 +616,29 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
self.image_feature_extractor.get_feat_dims() self.image_feature_extractor.get_feat_dims()
) )
@classmethod
def raysampler_tweak_args(cls, type, args: DictConfig) -> None:
"""
We don't expose certain fields of the raysampler because we want to set
them from our own members.
"""
del args["sampling_mode_training"]
del args["sampling_mode_evaluation"]
del args["image_width"]
del args["image_height"]
def create_raysampler(self): def create_raysampler(self):
extra_args = {
"sampling_mode_training": self.sampling_mode_training,
"sampling_mode_evaluation": self.sampling_mode_evaluation,
"image_width": self.render_image_width,
"image_height": self.render_image_height,
}
raysampler_args = getattr( raysampler_args = getattr(
self, "raysampler_" + self.raysampler_class_type + "_args" self, "raysampler_" + self.raysampler_class_type + "_args"
) )
setattr_if_hasattr(
raysampler_args, "sampling_mode_training", self.sampling_mode_training
)
setattr_if_hasattr(
raysampler_args, "sampling_mode_evaluation", self.sampling_mode_evaluation
)
setattr_if_hasattr(raysampler_args, "image_width", self.render_image_width)
setattr_if_hasattr(raysampler_args, "image_height", self.render_image_height)
self.raysampler = registry.get(RaySamplerBase, self.raysampler_class_type)( self.raysampler = registry.get(RaySamplerBase, self.raysampler_class_type)(
**raysampler_args **raysampler_args, **extra_args
) )
def create_renderer(self): def create_renderer(self):
......
...@@ -157,15 +157,6 @@ def cat_dataclass(batch, tensor_collator: Callable): ...@@ -157,15 +157,6 @@ def cat_dataclass(batch, tensor_collator: Callable):
return type(elem)(**collated) return type(elem)(**collated)
def setattr_if_hasattr(obj, name, value):
"""
Same as setattr(obj, name, value), but does nothing in case `name` is
not an attribe of `obj`.
"""
if hasattr(obj, name):
setattr(obj, name, value)
class Timer: class Timer:
""" """
A simple class for timing execution. A simple class for timing execution.
......
...@@ -56,10 +56,6 @@ global_encoder_SequenceAutodecoder_args: ...@@ -56,10 +56,6 @@ global_encoder_SequenceAutodecoder_args:
init_scale: 1.0 init_scale: 1.0
ignore_input: false ignore_input: false
raysampler_AdaptiveRaySampler_args: raysampler_AdaptiveRaySampler_args:
image_width: 400
image_height: 400
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
n_pts_per_ray_training: 64 n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64 n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024 n_rays_per_image_sampled_from_mask: 1024
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment