config_util.py 3.71 KB
Newer Older
Christopher Shallue's avatar
Christopher Shallue committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility functions for configurations."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import os.path

import tensorflow as tf


def parse_json(json_string_or_file):
  """Parses values from a JSON string or JSON file.

  This function is useful for command line flags containing configuration
  overrides. Using this function, the flag can be passed either as a JSON string
  (e.g. '{"learning_rate": 1.0}') or the path to a JSON configuration file.

  Args:
    json_string_or_file: A JSON serialized string OR the path to a JSON file.

  Returns:
    A dictionary; the parsed JSON.

  Raises:
    ValueError: If the JSON could not be parsed.
  """
  # First, attempt to parse the string as a JSON dict.
  try:
    json_dict = json.loads(json_string_or_file)
  except ValueError as literal_json_parsing_error:
    try:
      # Otherwise, try to use it as a path to a JSON file.
      with tf.gfile.Open(json_string_or_file) as f:
        json_dict = json.load(f)
    except ValueError as json_file_parsing_error:
52
53
54
55
      raise ValueError("Unable to parse the content of the json file {}. "
                       "Parsing error: {}.".format(
                           json_string_or_file,
                           json_file_parsing_error.message))
Christopher Shallue's avatar
Christopher Shallue committed
56
57
58
    except tf.gfile.FileError:
      message = ("Unable to parse the input parameter neither as literal "
                 "JSON nor as the name of a file that exists.\n"
59
60
                 "JSON parsing error: {}\n\n Input parameter:\n{}.".format(
                     literal_json_parsing_error.message, json_string_or_file))
Christopher Shallue's avatar
Christopher Shallue committed
61
62
63
64
65
      raise ValueError(message)

  return json_dict


66
67
68
69
70
71
72
73
def to_json(config):
  """Converts a JSON-serializable configuration object to a JSON string."""
  if hasattr(config, "to_json") and callable(config.to_json):
    return config.to_json(indent=2)
  else:
    return json.dumps(config, indent=2)


Christopher Shallue's avatar
Christopher Shallue committed
74
75
76
77
78
79
80
def log_and_save_config(config, output_dir):
  """Logs and writes a JSON-serializable configuration object.

  Args:
    config: A JSON-serializable object.
    output_dir: Destination directory.
  """
81
  config_json = to_json(config)
Christopher Shallue's avatar
Christopher Shallue committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  tf.logging.info("config: %s", config_json)

  tf.gfile.MakeDirs(output_dir)
  with tf.gfile.Open(os.path.join(output_dir, "config.json"), "w") as f:
    f.write(config_json)


def unflatten(flat_config):
  """Transforms a flat configuration dictionary into a nested dictionary.

  Example:
    {
      "a": 1,
      "b.c": 2,
      "b.d.e": 3,
      "b.d.f": 4,
    }
  would be transformed to:
    {
      "a": 1,
      "b": {
        "c": 2,
        "d": {
          "e": 3,
          "f": 4,
        }
      }
    }

  Args:
    flat_config: A dictionary with strings as keys where nested configuration
Chris Shallue's avatar
Chris Shallue committed
113
      parameters are represented with period-separated names.
Christopher Shallue's avatar
Christopher Shallue committed
114
115
116
117
118

  Returns:
    A dictionary nested according to the keys of the input dictionary.
  """
  config = {}
119
  for path, value in flat_config.items():
Christopher Shallue's avatar
Christopher Shallue committed
120
121
122
123
124
125
126
    path = path.split(".")
    final_key = path.pop()
    nested_config = config
    for key in path:
      nested_config = nested_config.setdefault(key, {})
    nested_config[final_key] = value
  return config