checkpoint.py 5.63 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Serialize and deserialize trees."""

import dataclasses
import io
import types
from typing import Any, BinaryIO, Optional, TypeVar

import numpy as np

_T = TypeVar("_T")


def dump(dest: BinaryIO, value: Any) -> None:
  """Dump a tree of dicts/dataclasses to a file object.

  Args:
    dest: a file object to write to.
    value: A tree of dicts, lists, tuples and dataclasses of numpy arrays and
      other basic types. Unions are not supported, other than Optional/None
      which is only supported in dataclasses, not in dicts, lists or tuples.
      All leaves must be coercible to a numpy array, and recoverable as a single
      arg to a type.
  """
  buffer = io.BytesIO()  # In case the destination doesn't support seeking.
  np.savez(buffer, **_flatten(value))
  dest.write(buffer.getvalue())


def load(source: BinaryIO, typ: type[_T]) -> _T:
  """Load from a file object and convert it to the specified type.

  Args:
    source: a file object to read from.
    typ: a type object that acts as a schema for deserialization. It must match
      what was serialized. If a type is Any, it will be returned however numpy
      serialized it, which is what you want for a tree of numpy arrays.

  Returns:
    the deserialized value as the specified type.
  """
  return _convert_types(typ, _unflatten(np.load(source)))


_SEP = ":"


def _flatten(tree: Any) -> dict[str, Any]:
  """Flatten a tree of dicts/dataclasses/lists/tuples to a single dict."""
  if dataclasses.is_dataclass(tree):
    # Don't use dataclasses.asdict as it is recursive so skips dropping None.
    tree = {f.name: v for f in dataclasses.fields(tree)
            if (v := getattr(tree, f.name)) is not None}
  elif isinstance(tree, (list, tuple)):
    tree = dict(enumerate(tree))

  assert isinstance(tree, dict)

  flat = {}
  for k, v in tree.items():
    k = str(k)
    assert _SEP not in k
    if dataclasses.is_dataclass(v) or isinstance(v, (dict, list, tuple)):
      for a, b in _flatten(v).items():
        flat[f"{k}{_SEP}{a}"] = b
    else:
      assert v is not None
      flat[k] = v
  return flat


def _unflatten(flat: dict[str, Any]) -> dict[str, Any]:
  """Unflatten a dict to a tree of dicts."""
  tree = {}
  for flat_key, v in flat.items():
    node = tree
    keys = flat_key.split(_SEP)
    for k in keys[:-1]:
      if k not in node:
        node[k] = {}
      node = node[k]
    node[keys[-1]] = v
  return tree


def _convert_types(typ: type[_T], value: Any) -> _T:
  """Convert some structure into the given type. The structures must match."""
  if typ in (Any, ...):
    return value

  if typ in (int, float, str, bool):
    return typ(value)

  if typ is np.ndarray:
    assert isinstance(value, np.ndarray)
    return value

  if dataclasses.is_dataclass(typ):
    kwargs = {}
    for f in dataclasses.fields(typ):
      # Only support Optional for dataclasses, as numpy can't serialize it
      # directly (without pickle), and dataclasses are the only case where we
      # can know the full set of values and types and therefore know the
      # non-existence must mean None.
      if isinstance(f.type, (types.UnionType, type(Optional[int]))):
        constructors = [t for t in f.type.__args__ if t is not types.NoneType]
        if len(constructors) != 1:
          raise TypeError(
              "Optional works, Union with anything except None doesn't")
        if f.name not in value:
          kwargs[f.name] = None
          continue
        constructor = constructors[0]
      else:
        constructor = f.type

      if f.name in value:
        kwargs[f.name] = _convert_types(constructor, value[f.name])
      else:
        raise ValueError(f"Missing value: {f.name}")
    return typ(**kwargs)

  base_type = getattr(typ, "__origin__", None)

  if base_type is dict:
    assert len(typ.__args__) == 2
    key_type, value_type = typ.__args__
    return {_convert_types(key_type, k): _convert_types(value_type, v)
            for k, v in value.items()}

  if base_type is list:
    assert len(typ.__args__) == 1
    value_type = typ.__args__[0]
    return [_convert_types(value_type, v)
            for _, v in sorted(value.items(), key=lambda x: int(x[0]))]

  if base_type is tuple:
    if len(typ.__args__) == 2 and typ.__args__[1] == ...:
      # An arbitrary length tuple of a single type, eg: tuple[int, ...]
      value_type = typ.__args__[0]
      return tuple(_convert_types(value_type, v)
                   for _, v in sorted(value.items(), key=lambda x: int(x[0])))
    else:
      # A fixed length tuple of arbitrary types, eg: tuple[int, str, float]
      assert len(typ.__args__) == len(value)
      return tuple(
          _convert_types(t, v)
          for t, (_, v) in zip(
              typ.__args__, sorted(value.items(), key=lambda x: int(x[0]))))

  # This is probably unreachable with reasonable serializable inputs.
  try:
    return typ(value)
  except TypeError as e:
    raise TypeError(
        "_convert_types expects the type argument to be a dataclass defined "
        "with types that are valid constructors (eg tuple is fine, Tuple "
        "isn't), and accept a numpy array as the sole argument.") from e