base_config.py 8.72 KB
Newer Older
Hongkun Yu's avatar
Hongkun Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base configurations to standardize experiments."""

import copy
Yeqing Li's avatar
Yeqing Li committed
18
19
import functools
from typing import Any, List, Mapping, Optional, Type
Hongkun Yu's avatar
Hongkun Yu committed
20
21
22
23
24
25
26
27
28
29

import dataclasses
import tensorflow as tf
import yaml

from official.modeling.hyperparams import params_dict


@dataclasses.dataclass
class Config(params_dict.ParamsDict):
Yeqing Li's avatar
Yeqing Li committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
  """The base configuration class that supports YAML/JSON based overrides.

  * It recursively enforces a whitelist of basic types and container types, so
    it avoids surprises with copy and reuse caused by unanticipated types.
  * It converts dict to Config even within sequences,
    e.g. for config = Config({'key': [([{'a': 42}],)]),
         type(config.key[0][0][0]) is Config rather than dict.
  """

  # It's safe to add bytes and other immutable types here.
  IMMUTABLE_TYPES = (str, int, float, bool, type(None))
  # It's safe to add set, frozenset and other collections here.
  SEQUENCE_TYPES = (list, tuple)

  default_params: dataclasses.InitVar[Optional[Mapping[str, Any]]] = None
  restrictions: dataclasses.InitVar[Optional[List[str]]] = None

  @classmethod
  def _isvalidsequence(cls, v):
    """Check if the input values are valid sequences.

    Args:
      v: Input sequence.

    Returns:
      True if the sequence is valid. Valid sequence includes the sequence
      type in cls.SEQUENCE_TYPES and element type is in cls.IMMUTABLE_TYPES or
      is dict or ParamsDict.
    """
    if not isinstance(v, cls.SEQUENCE_TYPES):
      return False
    return (all(isinstance(e, cls.IMMUTABLE_TYPES) for e in v) or
            all(isinstance(e, dict) for e in v) or
            all(isinstance(e, params_dict.ParamsDict) for e in v))

  @classmethod
  def _import_config(cls, v, subconfig_type):
    """Returns v with dicts converted to Configs, recursively."""
    if not issubclass(subconfig_type, params_dict.ParamsDict):
      raise TypeError(
Yeqing Li's avatar
Yeqing Li committed
70
71
          'Subconfig_type should be subclass of ParamsDict, found {!r}'.format(
              subconfig_type))
Yeqing Li's avatar
Yeqing Li committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    if isinstance(v, cls.IMMUTABLE_TYPES):
      return v
    elif isinstance(v, cls.SEQUENCE_TYPES):
      # Only support one layer of sequence.
      if not cls._isvalidsequence(v):
        raise TypeError(
            'Invalid sequence: only supports single level {!r} of {!r} or '
            'dict or ParamsDict found: {!r}'.format(cls.SEQUENCE_TYPES,
                                                    cls.IMMUTABLE_TYPES, v))
      import_fn = functools.partial(
          cls._import_config, subconfig_type=subconfig_type)
      return type(v)(map(import_fn, v))
    elif isinstance(v, params_dict.ParamsDict):
      # Deepcopy here is a temporary solution for preserving type in nested
      # Config object.
      return copy.deepcopy(v)
    elif isinstance(v, dict):
      return subconfig_type(v)
    else:
Yeqing Li's avatar
Yeqing Li committed
91
      raise TypeError('Unknown type: {!r}'.format(type(v)))
Yeqing Li's avatar
Yeqing Li committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

  @classmethod
  def _export_config(cls, v):
    """Returns v with Configs converted to dicts, recursively."""
    if isinstance(v, cls.IMMUTABLE_TYPES):
      return v
    elif isinstance(v, cls.SEQUENCE_TYPES):
      return type(v)(map(cls._export_config, v))
    elif isinstance(v, params_dict.ParamsDict):
      return v.as_dict()
    elif isinstance(v, dict):
      raise TypeError('dict value not supported in converting.')
    else:
      raise TypeError('Unknown type: {!r}'.format(type(v)))

  @classmethod
  def _get_subconfig_type(cls, k) -> Type[params_dict.ParamsDict]:
    """Get element type by the field name.

    Args:
      k: the key/name of the field.

    Returns:
      Config as default. If a type annotation is found for `k`,
      1) returns the type of the annotation if it is subtype of ParamsDict;
      2) returns the element type if the annotation of `k` is List[SubType]
         or Tuple[SubType].
    """
    subconfig_type = Config
    if k in cls.__annotations__:
      # Directly Config subtype.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
123
      type_annotation = cls.__annotations__[k]  # pytype: disable=invalid-annotation
Yeqing Li's avatar
Yeqing Li committed
124
125
      if (isinstance(type_annotation, type) and
          issubclass(type_annotation, Config)):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
126
        subconfig_type = cls.__annotations__[k]  # pytype: disable=invalid-annotation
Yeqing Li's avatar
Yeqing Li committed
127
128
129
130
131
132
133
134
135
136
      else:
        # Check if the field is a sequence of subtypes.
        field_type = getattr(type_annotation, '__origin__', type(None))
        if (isinstance(field_type, type) and
            issubclass(field_type, cls.SEQUENCE_TYPES)):
          element_type = getattr(type_annotation, '__args__', [type(None)])[0]
          subconfig_type = (
              element_type if issubclass(element_type, params_dict.ParamsDict)
              else subconfig_type)
    return subconfig_type
Hongkun Yu's avatar
Hongkun Yu committed
137
138

  def __post_init__(self, default_params, restrictions, *args, **kwargs):
Hongkun Yu's avatar
Hongkun Yu committed
139
140
141
142
143
    super().__init__(
        default_params=default_params,
        restrictions=restrictions,
        *args,
        **kwargs)
Hongkun Yu's avatar
Hongkun Yu committed
144
145

  def _set(self, k, v):
Yeqing Li's avatar
Yeqing Li committed
146
147
148
149
150
151
152
153
154
155
156
157
    """Overrides same method in ParamsDict.

    Also called by ParamsDict methods.

    Args:
      k: key to set.
      v: value.

    Raises:
      RuntimeError
    """
    subconfig_type = self._get_subconfig_type(k)
Hongkun Yu's avatar
Hongkun Yu committed
158
    if isinstance(v, dict):
Yeqing Li's avatar
Yeqing Li committed
159
160
161
      if k not in self.__dict__ or not self.__dict__[k]:
        # If the key not exist or the value is None, a new Config-family object
        # sould be created for the key.
Yeqing Li's avatar
Yeqing Li committed
162
        self.__dict__[k] = subconfig_type(v)
Hongkun Yu's avatar
Hongkun Yu committed
163
164
165
      else:
        self.__dict__[k].override(v)
    else:
Yeqing Li's avatar
Yeqing Li committed
166
      self.__dict__[k] = self._import_config(v, subconfig_type)
Hongkun Yu's avatar
Hongkun Yu committed
167
168

  def __setattr__(self, k, v):
Yeqing Li's avatar
Yeqing Li committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    if k not in self.RESERVED_ATTR:
      if getattr(self, '_locked', False):
        raise ValueError('The Config has been locked. ' 'No change is allowed.')
    self._set(k, v)

  def _override(self, override_dict, is_strict=True):
    """Overrides same method in ParamsDict.

    Also called by ParamsDict methods.

    Args:
      override_dict: dictionary to write to .
      is_strict: If True, not allows to add new keys.

    Raises:
      KeyError: overriding reserved keys or keys not exist (is_strict=True).
    """
    for k, v in sorted(override_dict.items()):
      if k in self.RESERVED_ATTR:
        raise KeyError('The key {!r} is internally reserved. '
                       'Can not be overridden.'.format(k))
      if k not in self.__dict__:
        if is_strict:
Yeqing Li's avatar
Yeqing Li committed
192
          raise KeyError('The key {!r} does not exist in {!r}. '
Yeqing Li's avatar
Yeqing Li committed
193
                         'To extend the existing keys, use '
Yeqing Li's avatar
Yeqing Li committed
194
195
                         '`override` with `is_strict` = False.'.format(
                             k, type(self)))
Yeqing Li's avatar
Yeqing Li committed
196
197
198
        else:
          self._set(k, v)
      else:
Yeqing Li's avatar
Yeqing Li committed
199
        if isinstance(v, dict) and self.__dict__[k]:
Yeqing Li's avatar
Yeqing Li committed
200
          self.__dict__[k]._override(v, is_strict)  # pylint: disable=protected-access
Yeqing Li's avatar
Yeqing Li committed
201
        elif isinstance(v, params_dict.ParamsDict) and self.__dict__[k]:
Yeqing Li's avatar
Yeqing Li committed
202
203
204
205
206
207
208
209
210
211
212
213
214
215
          self.__dict__[k]._override(v.as_dict(), is_strict)  # pylint: disable=protected-access
        else:
          self._set(k, v)

  def as_dict(self):
    """Returns a dict representation of params_dict.ParamsDict.

    For the nested params_dict.ParamsDict, a nested dict will be returned.
    """
    return {
        k: self._export_config(v)
        for k, v in self.__dict__.items()
        if k not in self.RESERVED_ATTR
    }
Hongkun Yu's avatar
Hongkun Yu committed
216
217

  def replace(self, **kwargs):
Hongkun Yu's avatar
Hongkun Yu committed
218
219
220
221
222
223
    """Overrides/returns a unlocked copy with the current config unchanged."""
    # pylint: disable=protected-access
    params = copy.deepcopy(self)
    params._locked = False
    params._override(kwargs, is_strict=True)
    # pylint: enable=protected-access
Hongkun Yu's avatar
Hongkun Yu committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    return params

  @classmethod
  def from_yaml(cls, file_path: str):
    # Note: This only works if the Config has all default values.
    with tf.io.gfile.GFile(file_path, 'r') as f:
      loaded = yaml.load(f)
      config = cls()
      config.override(loaded)
      return config

  @classmethod
  def from_json(cls, file_path: str):
    """Wrapper for `from_yaml`."""
    return cls.from_yaml(file_path)

  @classmethod
  def from_args(cls, *args, **kwargs):
    """Builds a config from the given list of arguments."""
    attributes = list(cls.__annotations__.keys())
    default_params = {a: p for a, p in zip(attributes, args)}
    default_params.update(kwargs)
    return cls(default_params)