Commit c0f88e04 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

make ExperimentConfig Configurable

Summary: Preparing for pluggables in experiment.py

Reviewed By: davnov134

Differential Revision: D36830674

fbshipit-source-id: eab499d1bc19c690798fbf7da547544df7e88fa5
parent 62752832
...@@ -53,7 +53,7 @@ import os ...@@ -53,7 +53,7 @@ import os
import random import random
import time import time
import warnings import warnings
from dataclasses import dataclass, field from dataclasses import field
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
import hydra import hydra
...@@ -73,7 +73,9 @@ from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as eval ...@@ -73,7 +73,9 @@ from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as eval
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
from pytorch3d.implicitron.tools import model_io, vis_utils from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
Configurable,
enable_get_default_args, enable_get_default_args,
expand_args_fields,
get_default_args_field, get_default_args_field,
remove_unused_components, remove_unused_components,
) )
...@@ -671,8 +673,7 @@ def _seed_all_random_engines(seed: int): ...@@ -671,8 +673,7 @@ def _seed_all_random_engines(seed: int):
random.seed(seed) random.seed(seed)
@dataclass(eq=False) class ExperimentConfig(Configurable):
class ExperimentConfig:
generic_model_args: DictConfig = get_default_args_field(GenericModel) generic_model_args: DictConfig = get_default_args_field(GenericModel)
solver_args: DictConfig = get_default_args_field(init_optimizer) solver_args: DictConfig = get_default_args_field(init_optimizer)
data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource) data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource)
...@@ -705,6 +706,8 @@ class ExperimentConfig: ...@@ -705,6 +706,8 @@ class ExperimentConfig:
) )
expand_args_fields(ExperimentConfig)
if __name__ == "__main__": if __name__ == "__main__":
cs = hydra.core.config_store.ConfigStore.instance() cs = hydra.core.config_store.ConfigStore.instance()
cs.store(name="default_config", node=ExperimentConfig) cs.store(name="default_config", node=ExperimentConfig)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment