base_config.py 9.46 KB
Newer Older
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Hongkun Yu's avatar
Hongkun Yu committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Hongkun Yu's avatar
Hongkun Yu committed
14

Hongkun Yu's avatar
Hongkun Yu committed
15
16
17
"""Base configurations to standardize experiments."""

import copy
Yeqing Li's avatar
Yeqing Li committed
18
19
import functools
from typing import Any, List, Mapping, Optional, Type
Hongkun Yu's avatar
Hongkun Yu committed
20
from absl import logging
Hongkun Yu's avatar
Hongkun Yu committed
21
22
23
24
25
26
27
28
29
30

import dataclasses
import tensorflow as tf
import yaml

from official.modeling.hyperparams import params_dict


@dataclasses.dataclass
class Config(params_dict.ParamsDict):
Yeqing Li's avatar
Yeqing Li committed
31
32
  """The base configuration class that supports YAML/JSON based overrides.

Hongkun Yu's avatar
Hongkun Yu committed
33
  * It recursively enforces a allowlist of basic types and container types, so
Yeqing Li's avatar
Yeqing Li committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    it avoids surprises with copy and reuse caused by unanticipated types.
  * It converts dict to Config even within sequences,
    e.g. for config = Config({'key': [([{'a': 42}],)]),
         type(config.key[0][0][0]) is Config rather than dict.
  """

  # It's safe to add bytes and other immutable types here.
  IMMUTABLE_TYPES = (str, int, float, bool, type(None))
  # It's safe to add set, frozenset and other collections here.
  SEQUENCE_TYPES = (list, tuple)

  default_params: dataclasses.InitVar[Optional[Mapping[str, Any]]] = None
  restrictions: dataclasses.InitVar[Optional[List[str]]] = None

  @classmethod
  def _isvalidsequence(cls, v):
    """Check if the input values are valid sequences.

    Args:
      v: Input sequence.

    Returns:
      True if the sequence is valid. Valid sequence includes the sequence
      type in cls.SEQUENCE_TYPES and element type is in cls.IMMUTABLE_TYPES or
      is dict or ParamsDict.
    """
    if not isinstance(v, cls.SEQUENCE_TYPES):
      return False
    return (all(isinstance(e, cls.IMMUTABLE_TYPES) for e in v) or
            all(isinstance(e, dict) for e in v) or
            all(isinstance(e, params_dict.ParamsDict) for e in v))

  @classmethod
  def _import_config(cls, v, subconfig_type):
    """Returns v with dicts converted to Configs, recursively."""
    if not issubclass(subconfig_type, params_dict.ParamsDict):
      raise TypeError(
Yeqing Li's avatar
Yeqing Li committed
71
72
          'Subconfig_type should be subclass of ParamsDict, found {!r}'.format(
              subconfig_type))
Yeqing Li's avatar
Yeqing Li committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    if isinstance(v, cls.IMMUTABLE_TYPES):
      return v
    elif isinstance(v, cls.SEQUENCE_TYPES):
      # Only support one layer of sequence.
      if not cls._isvalidsequence(v):
        raise TypeError(
            'Invalid sequence: only supports single level {!r} of {!r} or '
            'dict or ParamsDict found: {!r}'.format(cls.SEQUENCE_TYPES,
                                                    cls.IMMUTABLE_TYPES, v))
      import_fn = functools.partial(
          cls._import_config, subconfig_type=subconfig_type)
      return type(v)(map(import_fn, v))
    elif isinstance(v, params_dict.ParamsDict):
      # Deepcopy here is a temporary solution for preserving type in nested
      # Config object.
      return copy.deepcopy(v)
    elif isinstance(v, dict):
      return subconfig_type(v)
    else:
Yeqing Li's avatar
Yeqing Li committed
92
      raise TypeError('Unknown type: {!r}'.format(type(v)))
Yeqing Li's avatar
Yeqing Li committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

  @classmethod
  def _export_config(cls, v):
    """Returns v with Configs converted to dicts, recursively."""
    if isinstance(v, cls.IMMUTABLE_TYPES):
      return v
    elif isinstance(v, cls.SEQUENCE_TYPES):
      return type(v)(map(cls._export_config, v))
    elif isinstance(v, params_dict.ParamsDict):
      return v.as_dict()
    elif isinstance(v, dict):
      raise TypeError('dict value not supported in converting.')
    else:
      raise TypeError('Unknown type: {!r}'.format(type(v)))

  @classmethod
  def _get_subconfig_type(cls, k) -> Type[params_dict.ParamsDict]:
    """Get element type by the field name.

    Args:
      k: the key/name of the field.

    Returns:
      Config as default. If a type annotation is found for `k`,
      1) returns the type of the annotation if it is subtype of ParamsDict;
      2) returns the element type if the annotation of `k` is List[SubType]
         or Tuple[SubType].
    """
    subconfig_type = Config
    if k in cls.__annotations__:
      # Directly Config subtype.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
124
      type_annotation = cls.__annotations__[k]  # pytype: disable=invalid-annotation
Yeqing Li's avatar
Yeqing Li committed
125
126
      if (isinstance(type_annotation, type) and
          issubclass(type_annotation, Config)):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
127
        subconfig_type = cls.__annotations__[k]  # pytype: disable=invalid-annotation
Yeqing Li's avatar
Yeqing Li committed
128
129
130
131
132
133
134
135
136
137
      else:
        # Check if the field is a sequence of subtypes.
        field_type = getattr(type_annotation, '__origin__', type(None))
        if (isinstance(field_type, type) and
            issubclass(field_type, cls.SEQUENCE_TYPES)):
          element_type = getattr(type_annotation, '__args__', [type(None)])[0]
          subconfig_type = (
              element_type if issubclass(element_type, params_dict.ParamsDict)
              else subconfig_type)
    return subconfig_type
Hongkun Yu's avatar
Hongkun Yu committed
138
139

  def __post_init__(self, default_params, restrictions, *args, **kwargs):
Hongkun Yu's avatar
Hongkun Yu committed
140
141
142
143
144
    super().__init__(
        default_params=default_params,
        restrictions=restrictions,
        *args,
        **kwargs)
Hongkun Yu's avatar
Hongkun Yu committed
145
146

  def _set(self, k, v):
Yeqing Li's avatar
Yeqing Li committed
147
148
149
150
151
152
153
154
155
156
157
158
    """Overrides same method in ParamsDict.

    Also called by ParamsDict methods.

    Args:
      k: key to set.
      v: value.

    Raises:
      RuntimeError
    """
    subconfig_type = self._get_subconfig_type(k)
Hongkun Yu's avatar
Hongkun Yu committed
159
160

    def is_null(k):
Yeqing Li's avatar
Yeqing Li committed
161
      if k not in self.__dict__ or not self.__dict__[k]:
Hongkun Yu's avatar
Hongkun Yu committed
162
163
164
165
166
        return True
      return False

    if isinstance(v, dict):
      if is_null(k):
Yeqing Li's avatar
Yeqing Li committed
167
168
        # If the key not exist or the value is None, a new Config-family object
        # sould be created for the key.
Yeqing Li's avatar
Yeqing Li committed
169
        self.__dict__[k] = subconfig_type(v)
Hongkun Yu's avatar
Hongkun Yu committed
170
171
      else:
        self.__dict__[k].override(v)
Hongkun Yu's avatar
Hongkun Yu committed
172
173
174
175
176
177
178
179
180
181
182
    elif not is_null(k) and isinstance(v, self.SEQUENCE_TYPES) and all(
        [not isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
      if len(self.__dict__[k]) == len(v):
        for i in range(len(v)):
          self.__dict__[k][i].override(v[i])
      elif not all([isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
        logging.warning(
            "The list/tuple don't match the value dictionaries provided. Thus, "
            'the list/tuple is determined by the type annotation and '
            'values provided. This is error-prone.')
        self.__dict__[k] = self._import_config(v, subconfig_type)
183
184
      else:
        self.__dict__[k] = self._import_config(v, subconfig_type)
Hongkun Yu's avatar
Hongkun Yu committed
185
    else:
Yeqing Li's avatar
Yeqing Li committed
186
      self.__dict__[k] = self._import_config(v, subconfig_type)
Hongkun Yu's avatar
Hongkun Yu committed
187
188

  def __setattr__(self, k, v):
Yeqing Li's avatar
Yeqing Li committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    if k not in self.RESERVED_ATTR:
      if getattr(self, '_locked', False):
        raise ValueError('The Config has been locked. ' 'No change is allowed.')
    self._set(k, v)

  def _override(self, override_dict, is_strict=True):
    """Overrides same method in ParamsDict.

    Also called by ParamsDict methods.

    Args:
      override_dict: dictionary to write to .
      is_strict: If True, not allows to add new keys.

    Raises:
      KeyError: overriding reserved keys or keys not exist (is_strict=True).
    """
    for k, v in sorted(override_dict.items()):
      if k in self.RESERVED_ATTR:
        raise KeyError('The key {!r} is internally reserved. '
                       'Can not be overridden.'.format(k))
      if k not in self.__dict__:
        if is_strict:
Yeqing Li's avatar
Yeqing Li committed
212
          raise KeyError('The key {!r} does not exist in {!r}. '
Yeqing Li's avatar
Yeqing Li committed
213
                         'To extend the existing keys, use '
Yeqing Li's avatar
Yeqing Li committed
214
215
                         '`override` with `is_strict` = False.'.format(
                             k, type(self)))
Yeqing Li's avatar
Yeqing Li committed
216
217
218
        else:
          self._set(k, v)
      else:
Yeqing Li's avatar
Yeqing Li committed
219
        if isinstance(v, dict) and self.__dict__[k]:
Yeqing Li's avatar
Yeqing Li committed
220
          self.__dict__[k]._override(v, is_strict)  # pylint: disable=protected-access
Yeqing Li's avatar
Yeqing Li committed
221
        elif isinstance(v, params_dict.ParamsDict) and self.__dict__[k]:
Yeqing Li's avatar
Yeqing Li committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
          self.__dict__[k]._override(v.as_dict(), is_strict)  # pylint: disable=protected-access
        else:
          self._set(k, v)

  def as_dict(self):
    """Returns a dict representation of params_dict.ParamsDict.

    For the nested params_dict.ParamsDict, a nested dict will be returned.
    """
    return {
        k: self._export_config(v)
        for k, v in self.__dict__.items()
        if k not in self.RESERVED_ATTR
    }
Hongkun Yu's avatar
Hongkun Yu committed
236
237

  def replace(self, **kwargs):
Hongkun Yu's avatar
Hongkun Yu committed
238
239
240
241
242
243
    """Overrides/returns a unlocked copy with the current config unchanged."""
    # pylint: disable=protected-access
    params = copy.deepcopy(self)
    params._locked = False
    params._override(kwargs, is_strict=True)
    # pylint: enable=protected-access
Hongkun Yu's avatar
Hongkun Yu committed
244
245
246
247
248
249
    return params

  @classmethod
  def from_yaml(cls, file_path: str):
    # Note: This only works if the Config has all default values.
    with tf.io.gfile.GFile(file_path, 'r') as f:
Hongkun Yu's avatar
Hongkun Yu committed
250
      loaded = yaml.load(f, Loader=yaml.FullLoader)
Hongkun Yu's avatar
Hongkun Yu committed
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
      config = cls()
      config.override(loaded)
      return config

  @classmethod
  def from_json(cls, file_path: str):
    """Wrapper for `from_yaml`."""
    return cls.from_yaml(file_path)

  @classmethod
  def from_args(cls, *args, **kwargs):
    """Builds a config from the given list of arguments."""
    attributes = list(cls.__annotations__.keys())
    default_params = {a: p for a, p in zip(attributes, args)}
    default_params.update(kwargs)
    return cls(default_params)