"vscode:/vscode.git/clone" did not exist on "fc1aa1940d03f193dcd150595e58f3e4dbb720d4"
Commit 4eda0048 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Export symbols of hyperparams module together.

PiperOrigin-RevId: 309457250
parent 55ec8194
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hyperparams package definition."""
from official.modeling.hyperparams.base_config import *
from official.modeling.hyperparams.params_dict import *
......@@ -15,12 +15,7 @@
# ==============================================================================
"""Runs an Image Classification model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pprint
from typing import Any, Tuple, Text, Optional, Mapping
......@@ -29,8 +24,8 @@ from absl import flags
from absl import logging
import tensorflow as tf
from official.modeling import hyperparams
from official.modeling import performance
from official.modeling.hyperparams import params_dict
from official.utils import hyperparams_flags
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
......@@ -186,7 +181,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
for param in overriding_configs:
logging.info('Overriding params: %s', param)
params = params_dict.override_params_dict(params, param, is_strict=True)
params = hyperparams.override_params_dict(params, param, is_strict=True)
params.validate()
params.lock()
......@@ -290,7 +285,7 @@ def serialize_config(params: base_configs.ExperimentConfig,
params_save_path = os.path.join(model_dir, 'params.yaml')
logging.info('Saving experiment configuration to %s', params_save_path)
tf.io.gfile.makedirs(model_dir)
params_dict.save_params_dict_to_yaml(params, params_save_path)
hyperparams.save_params_dict_to_yaml(params, params_save_path)
def train_and_eval(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment