Commit ab1ffea0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[Cleanup] Remove hyperparams/config_definitions.py. The classes are only for...

[Cleanup] Remove hyperparams/config_definitions.py. The classes are only for deprecated models. Move the usages to individual legacy models and we will remove them finally.

PiperOrigin-RevId: 408997960
parent 79b6de8e
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common configuration settings."""
# pylint:disable=wildcard-import
import dataclasses
from official.core.config_definitions import *
from official.modeling.hyperparams import base_config
# TODO(hongkuny): These configs are used in models that are going to deprecate.
# Once those models are removed, we should delete this file to avoid confusion.
# Users should not use this file anymore.
@dataclasses.dataclass
class TensorboardConfig(base_config.Config):
"""Configuration for Tensorboard.
Attributes:
track_lr: Whether or not to track the learning rate in Tensorboard. Defaults
to True.
write_model_weights: Whether or not to write the model weights as images in
Tensorboard. Defaults to False.
"""
track_lr: bool = True
write_model_weights: bool = False
@dataclasses.dataclass
class CallbacksConfig(base_config.Config):
"""Configuration for Callbacks.
Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True.
enable_backup_and_restore: Whether or not to add BackupAndRestore
callback. Defaults to True.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True.
"""
enable_checkpoint_and_export: bool = True
enable_backup_and_restore: bool = False
enable_tensorboard: bool = True
enable_time_history: bool = True
...@@ -16,11 +16,10 @@ ...@@ -16,11 +16,10 @@
import dataclasses import dataclasses
import os import os
from typing import List, Optional, Union from typing import List, Optional, Union
from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import optimization from official.modeling import optimization
from official.modeling.hyperparams import config_definitions as cfg
from official.vision.beta.configs import common from official.vision.beta.configs import common
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling.hyperparams import config_definitions as cfg
from official.projects.basnet.configs import basnet as exp_cfg from official.projects.basnet.configs import basnet as exp_cfg
......
...@@ -21,11 +21,10 @@ deeplab v3 segmentation head. ...@@ -21,11 +21,10 @@ deeplab v3 segmentation head.
import dataclasses import dataclasses
import os import os
from typing import Optional from typing import Optional
from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import optimization from official.modeling import optimization
from official.modeling.hyperparams import config_definitions as cfg
from official.vision.beta.configs import backbones from official.vision.beta.configs import backbones
from official.vision.beta.configs import common from official.vision.beta.configs import common
from official.vision.beta.configs import decoders from official.vision.beta.configs import decoders
......
...@@ -19,8 +19,8 @@ from absl import logging ...@@ -19,8 +19,8 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.common import dataset_fn from official.common import dataset_fn
from official.core import config_definitions as cfg
from official.core import task_factory from official.core import task_factory
from official.modeling.hyperparams import config_definitions as cfg
from official.projects.edgetpu.vision.configs import semantic_segmentation_config as exp_cfg from official.projects.edgetpu.vision.configs import semantic_segmentation_config as exp_cfg
from official.projects.edgetpu.vision.configs import semantic_segmentation_searched_config as searched_cfg from official.projects.edgetpu.vision.configs import semantic_segmentation_searched_config as searched_cfg
from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v1_model from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v1_model
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
import dataclasses import dataclasses
from typing import List, Optional, Union from typing import List, Optional, Union
from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import optimization from official.modeling import optimization
from official.modeling.hyperparams import config_definitions as cfg
from official.projects.volumetric_models.configs import backbones from official.projects.volumetric_models.configs import backbones
from official.projects.volumetric_models.configs import decoders from official.projects.volumetric_models.configs import decoders
from official.vision.beta.configs import common from official.vision.beta.configs import common
......
...@@ -19,8 +19,8 @@ ...@@ -19,8 +19,8 @@
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling.hyperparams import config_definitions as cfg
from official.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg from official.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg
......
...@@ -13,12 +13,31 @@ ...@@ -13,12 +13,31 @@
# limitations under the License. # limitations under the License.
"""Ranking Model configuration definition.""" """Ranking Model configuration definition."""
from typing import Optional, List, Union
import dataclasses import dataclasses
from typing import List, Optional, Union
from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling.hyperparams import config_definitions as cfg
@dataclasses.dataclass
class CallbacksConfig(hyperparams.Config):
"""Configuration for Callbacks.
Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True.
enable_backup_and_restore: Whether or not to add BackupAndRestore
callback. Defaults to True.
enable_tensorboard: Whether or not to enable TensorBoard as a Callback.
Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True.
"""
enable_checkpoint_and_export: bool = True
enable_backup_and_restore: bool = False
enable_tensorboard: bool = True
enable_time_history: bool = True
@dataclasses.dataclass @dataclasses.dataclass
...@@ -126,7 +145,6 @@ class TrainerConfig(cfg.TrainerConfig): ...@@ -126,7 +145,6 @@ class TrainerConfig(cfg.TrainerConfig):
use_orbit: Whether to use orbit library with custom training loop or use_orbit: Whether to use orbit library with custom training loop or
compile/fit API. compile/fit API.
enable_metrics_in_training: Whether to enable metrics during training. enable_metrics_in_training: Whether to enable metrics during training.
tensorboard: An instance of TensorboardConfig.
time_history: Config of TimeHistory callback. time_history: Config of TimeHistory callback.
optimizer_config: An `OptimizerConfig` instance for embedding optimizer. optimizer_config: An `OptimizerConfig` instance for embedding optimizer.
Defaults to None. Defaults to None.
...@@ -135,10 +153,9 @@ class TrainerConfig(cfg.TrainerConfig): ...@@ -135,10 +153,9 @@ class TrainerConfig(cfg.TrainerConfig):
# Sets validation steps to be -1 to evaluate the entire dataset. # Sets validation steps to be -1 to evaluate the entire dataset.
validation_steps: int = -1 validation_steps: int = -1
validation_interval: int = 70000 validation_interval: int = 70000
callbacks: cfg.CallbacksConfig = cfg.CallbacksConfig() callbacks: CallbacksConfig = CallbacksConfig()
use_orbit: bool = False use_orbit: bool = False
enable_metrics_in_training: bool = True enable_metrics_in_training: bool = True
tensorboard: cfg.TensorboardConfig = cfg.TensorboardConfig()
time_history: TimeHistoryConfig = TimeHistoryConfig(log_steps=5000) time_history: TimeHistoryConfig = TimeHistoryConfig(log_steps=5000)
optimizer_config: OptimizationConfig = OptimizationConfig() optimizer_config: OptimizationConfig = OptimizationConfig()
......
...@@ -19,11 +19,10 @@ import os ...@@ -19,11 +19,10 @@ import os
from typing import List, Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import optimization from official.modeling import optimization
from official.modeling.hyperparams import config_definitions as cfg
from official.vision.beta.configs import common from official.vision.beta.configs import common
from official.vision.beta.configs import decoders from official.vision.beta.configs import decoders
from official.vision.beta.configs import backbones from official.vision.beta.configs import backbones
......
...@@ -19,8 +19,8 @@ ...@@ -19,8 +19,8 @@
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling.hyperparams import config_definitions as cfg
from official.vision import beta from official.vision import beta
from official.vision.beta.configs import semantic_segmentation as exp_cfg from official.vision.beta.configs import semantic_segmentation as exp_cfg
......
...@@ -17,11 +17,10 @@ ...@@ -17,11 +17,10 @@
import dataclasses import dataclasses
import os import os
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import optimization from official.modeling import optimization
from official.modeling.hyperparams import config_definitions as cfg
from official.vision.beta.configs import common from official.vision.beta.configs import common
from official.vision.beta.projects.centernet.configs import backbones from official.vision.beta.projects.centernet.configs import backbones
......
...@@ -18,7 +18,7 @@ from typing import List ...@@ -18,7 +18,7 @@ from typing import List
import tensorflow as tf import tensorflow as tf
from official.modeling.hyperparams import config_definitions as cfg from official.core import config_definitions as cfg
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model
from official.vision.beta.serving import detection from official.vision.beta.serving import detection
......
...@@ -19,9 +19,8 @@ import abc ...@@ -19,9 +19,8 @@ import abc
from typing import Dict, List, Mapping, Optional, Text from typing import Dict, List, Mapping, Optional, Text
import tensorflow as tf import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import export_base from official.core import export_base
from official.modeling.hyperparams import config_definitions as cfg
class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta): class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
......
...@@ -14,19 +14,49 @@ ...@@ -14,19 +14,49 @@
# Lint as: python3 # Lint as: python3
"""Definitions for high level configuration groups..""" """Definitions for high level configuration groups.."""
from typing import Any, List, Mapping, Optional
import dataclasses import dataclasses
from typing import Any, List, Mapping, Optional
from official.core import config_definitions from official.core import config_definitions
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling.hyperparams import config_definitions as legacy_cfg
CallbacksConfig = legacy_cfg.CallbacksConfig
TensorboardConfig = legacy_cfg.TensorboardConfig
RuntimeConfig = config_definitions.RuntimeConfig RuntimeConfig = config_definitions.RuntimeConfig
@dataclasses.dataclass
class TensorBoardConfig(hyperparams.Config):
"""Configuration for TensorBoard.
Attributes:
track_lr: Whether or not to track the learning rate in TensorBoard. Defaults
to True.
write_model_weights: Whether or not to write the model weights as images in
TensorBoard. Defaults to False.
"""
track_lr: bool = True
write_model_weights: bool = False
@dataclasses.dataclass
class CallbacksConfig(hyperparams.Config):
"""Configuration for Callbacks.
Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True.
enable_backup_and_restore: Whether or not to add BackupAndRestore
callback. Defaults to True.
enable_tensorboard: Whether or not to enable TensorBoard as a Callback.
Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True.
"""
enable_checkpoint_and_export: bool = True
enable_backup_and_restore: bool = False
enable_tensorboard: bool = True
enable_time_history: bool = True
@dataclasses.dataclass @dataclasses.dataclass
class ExportConfig(hyperparams.Config): class ExportConfig(hyperparams.Config):
"""Configuration for exports. """Configuration for exports.
...@@ -74,7 +104,7 @@ class TrainConfig(hyperparams.Config): ...@@ -74,7 +104,7 @@ class TrainConfig(hyperparams.Config):
inferred based on the number of images and batch size. Defaults to None. inferred based on the number of images and batch size. Defaults to None.
callbacks: An instance of CallbacksConfig. callbacks: An instance of CallbacksConfig.
metrics: An instance of MetricsConfig. metrics: An instance of MetricsConfig.
tensorboard: An instance of TensorboardConfig. tensorboard: An instance of TensorBoardConfig.
set_epoch_loop: Whether or not to set `steps_per_execution` to set_epoch_loop: Whether or not to set `steps_per_execution` to
equal the number of training steps in `model.compile`. This reduces the equal the number of training steps in `model.compile`. This reduces the
number of callbacks run per epoch which significantly improves end-to-end number of callbacks run per epoch which significantly improves end-to-end
...@@ -85,7 +115,7 @@ class TrainConfig(hyperparams.Config): ...@@ -85,7 +115,7 @@ class TrainConfig(hyperparams.Config):
steps: int = None steps: int = None
callbacks: CallbacksConfig = CallbacksConfig() callbacks: CallbacksConfig = CallbacksConfig()
metrics: MetricsConfig = None metrics: MetricsConfig = None
tensorboard: TensorboardConfig = TensorboardConfig() tensorboard: TensorBoardConfig = TensorBoardConfig()
time_history: TimeHistoryConfig = TimeHistoryConfig() time_history: TimeHistoryConfig = TimeHistoryConfig()
set_epoch_loop: bool = False set_epoch_loop: bool = False
......
...@@ -52,7 +52,7 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig): ...@@ -52,7 +52,7 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
enable_checkpoint_and_export=True, enable_tensorboard=True), enable_checkpoint_and_export=True, enable_tensorboard=True),
metrics=['accuracy', 'top_5'], metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100), time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorboardConfig( tensorboard=base_configs.TensorBoardConfig(
track_lr=True, write_model_weights=False), track_lr=True, write_model_weights=False),
set_epoch_loop=False) set_epoch_loop=False)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig( evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
...@@ -84,7 +84,7 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig): ...@@ -84,7 +84,7 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
enable_checkpoint_and_export=True, enable_tensorboard=True), enable_checkpoint_and_export=True, enable_tensorboard=True),
metrics=['accuracy', 'top_5'], metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100), time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorboardConfig( tensorboard=base_configs.TensorBoardConfig(
track_lr=True, write_model_weights=False), track_lr=True, write_model_weights=False),
set_epoch_loop=False) set_epoch_loop=False)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig( evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment