Commit 4eda0048 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Export symbols of hyperparams module together.

PiperOrigin-RevId: 309457250
parent 55ec8194
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hyperparams package definition."""
from official.modeling.hyperparams.base_config import *
from official.modeling.hyperparams.params_dict import *
...@@ -15,12 +15,7 @@ ...@@ -15,12 +15,7 @@
# ============================================================================== # ==============================================================================
"""Runs an Image Classification model.""" """Runs an Image Classification model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os import os
import pprint import pprint
from typing import Any, Tuple, Text, Optional, Mapping from typing import Any, Tuple, Text, Optional, Mapping
...@@ -29,8 +24,8 @@ from absl import flags ...@@ -29,8 +24,8 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams
from official.modeling import performance from official.modeling import performance
from official.modeling.hyperparams import params_dict
from official.utils import hyperparams_flags from official.utils import hyperparams_flags
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
...@@ -186,7 +181,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues): ...@@ -186,7 +181,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
for param in overriding_configs: for param in overriding_configs:
logging.info('Overriding params: %s', param) logging.info('Overriding params: %s', param)
params = params_dict.override_params_dict(params, param, is_strict=True) params = hyperparams.override_params_dict(params, param, is_strict=True)
params.validate() params.validate()
params.lock() params.lock()
...@@ -290,7 +285,7 @@ def serialize_config(params: base_configs.ExperimentConfig, ...@@ -290,7 +285,7 @@ def serialize_config(params: base_configs.ExperimentConfig,
params_save_path = os.path.join(model_dir, 'params.yaml') params_save_path = os.path.join(model_dir, 'params.yaml')
logging.info('Saving experiment configuration to %s', params_save_path) logging.info('Saving experiment configuration to %s', params_save_path)
tf.io.gfile.makedirs(model_dir) tf.io.gfile.makedirs(model_dir)
params_dict.save_params_dict_to_yaml(params, params_save_path) hyperparams.save_params_dict_to_yaml(params, params_save_path)
def train_and_eval( def train_and_eval(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment