base_config.py 11 KB
Newer Older
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Hongkun Yu's avatar
Hongkun Yu committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Hongkun Yu's avatar
Hongkun Yu committed
14

Hongkun Yu's avatar
Hongkun Yu committed
15
16
"""Base configurations to standardize experiments."""
import copy
Hongkun Yu's avatar
Hongkun Yu committed
17
import dataclasses
Yeqing Li's avatar
Yeqing Li committed
18
import functools
Hongkun Yu's avatar
Hongkun Yu committed
19
import inspect
Yeqing Li's avatar
Yeqing Li committed
20
from typing import Any, List, Mapping, Optional, Type
Hongkun Yu's avatar
Hongkun Yu committed
21

Hongkun Yu's avatar
Hongkun Yu committed
22
from absl import logging
Hongkun Yu's avatar
Hongkun Yu committed
23
24
25
26
27
import tensorflow as tf
import yaml

from official.modeling.hyperparams import params_dict

Hongkun Yu's avatar
Hongkun Yu committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
_BOUND = set()


def bind(config_cls):
  """Bind a class to config cls."""
  if not inspect.isclass(config_cls):
    raise ValueError('The bind decorator is supposed to apply on the class '
                     f'attribute. Received {config_cls}, not a class.')

  def decorator(builder):
    if config_cls in _BOUND:
      raise ValueError('Inside a program, we should not bind the config with a'
                       ' class twice.')
    if inspect.isclass(builder):
      config_cls._BUILDER = builder  # pylint: disable=protected-access
    elif inspect.isfunction(builder):

      def _wrapper(self, *args, **kwargs):  # pylint: disable=unused-argument
        return builder(*args, **kwargs)

      config_cls._BUILDER = _wrapper  # pylint: disable=protected-access
    else:
      raise ValueError(f'The `BUILDER` type is not supported: {builder}')
    _BOUND.add(config_cls)
    return builder

  return decorator

Hongkun Yu's avatar
Hongkun Yu committed
56
57
58

@dataclasses.dataclass
class Config(params_dict.ParamsDict):
Yeqing Li's avatar
Yeqing Li committed
59
60
  """The base configuration class that supports YAML/JSON based overrides.

61
62
  Because of YAML/JSON serialization limitations, some semantics of dataclass
  are not supported:
Hongkun Yu's avatar
Hongkun Yu committed
63
  * It recursively enforces a allowlist of basic types and container types, so
Yeqing Li's avatar
Yeqing Li committed
64
    it avoids surprises with copy and reuse caused by unanticipated types.
65
  * Warning: it converts Dict to `Config` even within sequences,
Yeqing Li's avatar
Yeqing Li committed
66
67
    e.g. for config = Config({'key': [([{'a': 42}],)]),
         type(config.key[0][0][0]) is Config rather than dict.
68
69
    If you define/annotate some field as Dict, the field will convert to a
    `Config` instance and lose the dictionary type.
Yeqing Li's avatar
Yeqing Li committed
70
  """
Hongkun Yu's avatar
Hongkun Yu committed
71
72
  # The class or method to bind with the params class.
  _BUILDER = None
Yeqing Li's avatar
Yeqing Li committed
73
74
75
76
77
78
79
80
  # It's safe to add bytes and other immutable types here.
  IMMUTABLE_TYPES = (str, int, float, bool, type(None))
  # It's safe to add set, frozenset and other collections here.
  SEQUENCE_TYPES = (list, tuple)

  default_params: dataclasses.InitVar[Optional[Mapping[str, Any]]] = None
  restrictions: dataclasses.InitVar[Optional[List[str]]] = None

Hongkun Yu's avatar
Hongkun Yu committed
81
82
83
84
85
  def __post_init__(self, default_params, restrictions):
    super().__init__(
        default_params=default_params,
        restrictions=restrictions)

Hongkun Yu's avatar
Hongkun Yu committed
86
87
88
89
  @property
  def BUILDER(self):
    return self._BUILDER

Yeqing Li's avatar
Yeqing Li committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  @classmethod
  def _isvalidsequence(cls, v):
    """Check if the input values are valid sequences.

    Args:
      v: Input sequence.

    Returns:
      True if the sequence is valid. Valid sequence includes the sequence
      type in cls.SEQUENCE_TYPES and element type is in cls.IMMUTABLE_TYPES or
      is dict or ParamsDict.
    """
    if not isinstance(v, cls.SEQUENCE_TYPES):
      return False
    return (all(isinstance(e, cls.IMMUTABLE_TYPES) for e in v) or
            all(isinstance(e, dict) for e in v) or
            all(isinstance(e, params_dict.ParamsDict) for e in v))

  @classmethod
  def _import_config(cls, v, subconfig_type):
    """Returns v with dicts converted to Configs, recursively."""
    if not issubclass(subconfig_type, params_dict.ParamsDict):
      raise TypeError(
Yeqing Li's avatar
Yeqing Li committed
113
114
          'Subconfig_type should be subclass of ParamsDict, found {!r}'.format(
              subconfig_type))
Yeqing Li's avatar
Yeqing Li committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    if isinstance(v, cls.IMMUTABLE_TYPES):
      return v
    elif isinstance(v, cls.SEQUENCE_TYPES):
      # Only support one layer of sequence.
      if not cls._isvalidsequence(v):
        raise TypeError(
            'Invalid sequence: only supports single level {!r} of {!r} or '
            'dict or ParamsDict found: {!r}'.format(cls.SEQUENCE_TYPES,
                                                    cls.IMMUTABLE_TYPES, v))
      import_fn = functools.partial(
          cls._import_config, subconfig_type=subconfig_type)
      return type(v)(map(import_fn, v))
    elif isinstance(v, params_dict.ParamsDict):
      # Deepcopy here is a temporary solution for preserving type in nested
      # Config object.
      return copy.deepcopy(v)
    elif isinstance(v, dict):
      return subconfig_type(v)
    else:
Yeqing Li's avatar
Yeqing Li committed
134
      raise TypeError('Unknown type: {!r}'.format(type(v)))
Yeqing Li's avatar
Yeqing Li committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165

  @classmethod
  def _export_config(cls, v):
    """Returns v with Configs converted to dicts, recursively."""
    if isinstance(v, cls.IMMUTABLE_TYPES):
      return v
    elif isinstance(v, cls.SEQUENCE_TYPES):
      return type(v)(map(cls._export_config, v))
    elif isinstance(v, params_dict.ParamsDict):
      return v.as_dict()
    elif isinstance(v, dict):
      raise TypeError('dict value not supported in converting.')
    else:
      raise TypeError('Unknown type: {!r}'.format(type(v)))

  @classmethod
  def _get_subconfig_type(cls, k) -> Type[params_dict.ParamsDict]:
    """Get element type by the field name.

    Args:
      k: the key/name of the field.

    Returns:
      Config as default. If a type annotation is found for `k`,
      1) returns the type of the annotation if it is subtype of ParamsDict;
      2) returns the element type if the annotation of `k` is List[SubType]
         or Tuple[SubType].
    """
    subconfig_type = Config
    if k in cls.__annotations__:
      # Directly Config subtype.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
166
      type_annotation = cls.__annotations__[k]  # pytype: disable=invalid-annotation
Yeqing Li's avatar
Yeqing Li committed
167
168
      if (isinstance(type_annotation, type) and
          issubclass(type_annotation, Config)):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
169
        subconfig_type = cls.__annotations__[k]  # pytype: disable=invalid-annotation
Yeqing Li's avatar
Yeqing Li committed
170
171
172
173
174
175
176
177
178
179
      else:
        # Check if the field is a sequence of subtypes.
        field_type = getattr(type_annotation, '__origin__', type(None))
        if (isinstance(field_type, type) and
            issubclass(field_type, cls.SEQUENCE_TYPES)):
          element_type = getattr(type_annotation, '__args__', [type(None)])[0]
          subconfig_type = (
              element_type if issubclass(element_type, params_dict.ParamsDict)
              else subconfig_type)
    return subconfig_type
Hongkun Yu's avatar
Hongkun Yu committed
180
181

  def _set(self, k, v):
Yeqing Li's avatar
Yeqing Li committed
182
183
184
185
186
187
188
189
190
191
192
193
    """Overrides same method in ParamsDict.

    Also called by ParamsDict methods.

    Args:
      k: key to set.
      v: value.

    Raises:
      RuntimeError
    """
    subconfig_type = self._get_subconfig_type(k)
Hongkun Yu's avatar
Hongkun Yu committed
194
195

    def is_null(k):
Yeqing Li's avatar
Yeqing Li committed
196
      if k not in self.__dict__ or not self.__dict__[k]:
Hongkun Yu's avatar
Hongkun Yu committed
197
198
199
200
201
        return True
      return False

    if isinstance(v, dict):
      if is_null(k):
Yeqing Li's avatar
Yeqing Li committed
202
203
        # If the key not exist or the value is None, a new Config-family object
        # sould be created for the key.
Yeqing Li's avatar
Yeqing Li committed
204
        self.__dict__[k] = subconfig_type(v)
Hongkun Yu's avatar
Hongkun Yu committed
205
206
      else:
        self.__dict__[k].override(v)
Hongkun Yu's avatar
Hongkun Yu committed
207
208
209
210
211
212
213
214
215
216
217
    elif not is_null(k) and isinstance(v, self.SEQUENCE_TYPES) and all(
        [not isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
      if len(self.__dict__[k]) == len(v):
        for i in range(len(v)):
          self.__dict__[k][i].override(v[i])
      elif not all([isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
        logging.warning(
            "The list/tuple don't match the value dictionaries provided. Thus, "
            'the list/tuple is determined by the type annotation and '
            'values provided. This is error-prone.')
        self.__dict__[k] = self._import_config(v, subconfig_type)
218
219
      else:
        self.__dict__[k] = self._import_config(v, subconfig_type)
Hongkun Yu's avatar
Hongkun Yu committed
220
    else:
Yeqing Li's avatar
Yeqing Li committed
221
      self.__dict__[k] = self._import_config(v, subconfig_type)
Hongkun Yu's avatar
Hongkun Yu committed
222
223

  def __setattr__(self, k, v):
Hongkun Yu's avatar
Hongkun Yu committed
224
225
226
227
228
    if k == 'BUILDER' or k == '_BUILDER':
      raise AttributeError('`BUILDER` is a property and `_BUILDER` is the '
                           'reserved class attribute. We should only assign '
                           '`_BUILDER` at the class level.')

Yeqing Li's avatar
Yeqing Li committed
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    if k not in self.RESERVED_ATTR:
      if getattr(self, '_locked', False):
        raise ValueError('The Config has been locked. ' 'No change is allowed.')
    self._set(k, v)

  def _override(self, override_dict, is_strict=True):
    """Overrides same method in ParamsDict.

    Also called by ParamsDict methods.

    Args:
      override_dict: dictionary to write to .
      is_strict: If True, not allows to add new keys.

    Raises:
      KeyError: overriding reserved keys or keys not exist (is_strict=True).
    """
    for k, v in sorted(override_dict.items()):
      if k in self.RESERVED_ATTR:
        raise KeyError('The key {!r} is internally reserved. '
                       'Can not be overridden.'.format(k))
      if k not in self.__dict__:
        if is_strict:
Yeqing Li's avatar
Yeqing Li committed
252
          raise KeyError('The key {!r} does not exist in {!r}. '
Yeqing Li's avatar
Yeqing Li committed
253
                         'To extend the existing keys, use '
Yeqing Li's avatar
Yeqing Li committed
254
255
                         '`override` with `is_strict` = False.'.format(
                             k, type(self)))
Yeqing Li's avatar
Yeqing Li committed
256
257
258
        else:
          self._set(k, v)
      else:
Yeqing Li's avatar
Yeqing Li committed
259
        if isinstance(v, dict) and self.__dict__[k]:
Yeqing Li's avatar
Yeqing Li committed
260
          self.__dict__[k]._override(v, is_strict)  # pylint: disable=protected-access
Yeqing Li's avatar
Yeqing Li committed
261
        elif isinstance(v, params_dict.ParamsDict) and self.__dict__[k]:
Yeqing Li's avatar
Yeqing Li committed
262
263
264
265
266
267
268
269
270
271
272
273
274
275
          self.__dict__[k]._override(v.as_dict(), is_strict)  # pylint: disable=protected-access
        else:
          self._set(k, v)

  def as_dict(self):
    """Returns a dict representation of params_dict.ParamsDict.

    For the nested params_dict.ParamsDict, a nested dict will be returned.
    """
    return {
        k: self._export_config(v)
        for k, v in self.__dict__.items()
        if k not in self.RESERVED_ATTR
    }
Hongkun Yu's avatar
Hongkun Yu committed
276
277

  def replace(self, **kwargs):
Hongkun Yu's avatar
Hongkun Yu committed
278
279
280
281
282
283
    """Overrides/returns a unlocked copy with the current config unchanged."""
    # pylint: disable=protected-access
    params = copy.deepcopy(self)
    params._locked = False
    params._override(kwargs, is_strict=True)
    # pylint: enable=protected-access
Hongkun Yu's avatar
Hongkun Yu committed
284
285
286
287
288
289
    return params

  @classmethod
  def from_yaml(cls, file_path: str):
    # Note: This only works if the Config has all default values.
    with tf.io.gfile.GFile(file_path, 'r') as f:
Hongkun Yu's avatar
Hongkun Yu committed
290
      loaded = yaml.load(f, Loader=yaml.FullLoader)
Hongkun Yu's avatar
Hongkun Yu committed
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
      config = cls()
      config.override(loaded)
      return config

  @classmethod
  def from_json(cls, file_path: str):
    """Wrapper for `from_yaml`."""
    return cls.from_yaml(file_path)

  @classmethod
  def from_args(cls, *args, **kwargs):
    """Builds a config from the given list of arguments."""
    attributes = list(cls.__annotations__.keys())
    default_params = {a: p for a, p in zip(attributes, args)}
    default_params.update(kwargs)
Hongkun Yu's avatar
Hongkun Yu committed
306
    return cls(default_params=default_params)