struct_of_array.py 7.54 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class decorator to represent (nested) struct of arrays."""

import dataclasses

import jax


def get_item(instance, key):
  sliced = {}
  for field in get_array_fields(instance):
    num_trailing_dims = field.metadata.get('num_trailing_dims', 0)
    this_key = key
    if isinstance(key, tuple) and Ellipsis in this_key:
      this_key += (slice(None),) * num_trailing_dims
    sliced[field.name] = getattr(instance, field.name)[this_key]
  return dataclasses.replace(instance, **sliced)


@property
def get_shape(instance):
  """Returns Shape for given instance of dataclass."""
  first_field = dataclasses.fields(instance)[0]
  num_trailing_dims = first_field.metadata.get('num_trailing_dims', None)
  value = getattr(instance, first_field.name)
  if num_trailing_dims:
    return value.shape[:-num_trailing_dims]
  else:
    return value.shape


def get_len(instance):
  """Returns length for given instance of dataclass."""
  shape = instance.shape
  if shape:
    return shape[0]
  else:
    raise TypeError('len() of unsized object')  # Match jax.numpy behavior.


@property
def get_dtype(instance):
  """Returns Dtype for given instance of dataclass."""
  fields = dataclasses.fields(instance)
  sets_dtype = [
      field.name for field in fields if field.metadata.get('sets_dtype', False)
  ]
  if sets_dtype:
    assert len(sets_dtype) == 1, 'at most field can set dtype'
    field_value = getattr(instance, sets_dtype[0])
  elif instance.same_dtype:
    field_value = getattr(instance, fields[0].name)
  else:
    # Should this be Value Error?
    raise AttributeError('Trying to access Dtype on Struct of Array without'
                         'either "same_dtype" or field setting dtype')

  if hasattr(field_value, 'dtype'):
    return field_value.dtype
  else:
    # Should this be Value Error?
    raise AttributeError(f'field_value {field_value} does not have dtype')


def replace(instance, **kwargs):
  return dataclasses.replace(instance, **kwargs)


def post_init(instance):
  """Validate instance has same shapes & dtypes."""
  array_fields = get_array_fields(instance)
  arrays = list(get_array_fields(instance, return_values=True).values())
  first_field = array_fields[0]
  # These slightly weird constructions about checking whether the leaves are
  # actual arrays is since e.g. vmap internally relies on being able to
  # construct pytree's with object() as leaves, this would break the checking
  # as such we are only validating the object when the entries in the dataclass
  # Are arrays or other dataclasses of arrays.
  try:
    dtype = instance.dtype
  except AttributeError:
    dtype = None
  if dtype is not None:
    first_shape = instance.shape
    for array, field in zip(arrays, array_fields):
      field_shape = array.shape
      num_trailing_dims = field.metadata.get('num_trailing_dims', None)
      if num_trailing_dims:
        array_shape = array.shape
        field_shape = array_shape[:-num_trailing_dims]
        msg = (f'field {field} should have number of trailing dims'
               ' {num_trailing_dims}')
        assert len(array_shape) == len(first_shape) + num_trailing_dims, msg
      else:
        field_shape = array.shape

      shape_msg = (f"Stripped Shape {field_shape} of field {field} doesn't "
                   f"match shape {first_shape} of field {first_field}")
      assert field_shape == first_shape, shape_msg

      field_dtype = array.dtype

      allowed_metadata_dtypes = field.metadata.get('allowed_dtypes', [])
      if allowed_metadata_dtypes:
        msg = f'Dtype is {field_dtype} but must be in {allowed_metadata_dtypes}'
        assert field_dtype in allowed_metadata_dtypes, msg

      if 'dtype' in field.metadata:
        target_dtype = field.metadata['dtype']
      else:
        target_dtype = dtype

      msg = f'Dtype is {field_dtype} but must be {target_dtype}'
      assert field_dtype == target_dtype, msg


def flatten(instance):
  """Flatten Struct of Array instance."""
  array_likes = list(get_array_fields(instance, return_values=True).values())
  flat_array_likes = []
  inner_treedefs = []
  num_arrays = []
  for array_like in array_likes:
    flat_array_like, inner_treedef = jax.tree_flatten(array_like)
    inner_treedefs.append(inner_treedef)
    flat_array_likes += flat_array_like
    num_arrays.append(len(flat_array_like))
  metadata = get_metadata_fields(instance, return_values=True)
  metadata = type(instance).metadata_cls(**metadata)
  return flat_array_likes, (inner_treedefs, metadata, num_arrays)


def make_metadata_class(cls):
  metadata_fields = get_fields(cls,
                               lambda x: x.metadata.get('is_metadata', False))
  metadata_cls = dataclasses.make_dataclass(
      cls_name='Meta' + cls.__name__,
      fields=[(field.name, field.type, field) for field in metadata_fields],
      frozen=True,
      eq=True)
  return metadata_cls


def get_fields(cls_or_instance, filterfn, return_values=False):
  fields = dataclasses.fields(cls_or_instance)
  fields = [field for field in fields if filterfn(field)]
  if return_values:
    return {
        field.name: getattr(cls_or_instance, field.name) for field in fields
    }
  else:
    return fields


def get_array_fields(cls, return_values=False):
  return get_fields(
      cls,
      lambda x: not x.metadata.get('is_metadata', False),
      return_values=return_values)


def get_metadata_fields(cls, return_values=False):
  return get_fields(
      cls,
      lambda x: x.metadata.get('is_metadata', False),
      return_values=return_values)


class StructOfArray:
  """Class Decorator for Struct Of Arrays."""

  def __init__(self, same_dtype=True):
    self.same_dtype = same_dtype

  def __call__(self, cls):
    cls.__array_ufunc__ = None
    cls.replace = replace
    cls.same_dtype = self.same_dtype
    cls.dtype = get_dtype
    cls.shape = get_shape
    cls.__len__ = get_len
    cls.__getitem__ = get_item
    cls.__post_init__ = post_init
    new_cls = dataclasses.dataclass(cls, frozen=True, eq=False)  # pytype: disable=wrong-keyword-args
    # pytree claims to require metadata to be hashable, not sure why,
    # But making derived dataclass that can just hold metadata
    new_cls.metadata_cls = make_metadata_class(new_cls)

    def unflatten(aux, data):
      inner_treedefs, metadata, num_arrays = aux
      array_fields = [field.name for field in get_array_fields(new_cls)]
      value_dict = {}
      array_start = 0
      for num_array, inner_treedef, array_field in zip(num_arrays,
                                                       inner_treedefs,
                                                       array_fields):
        value_dict[array_field] = jax.tree_unflatten(
            inner_treedef, data[array_start:array_start + num_array])
        array_start += num_array
      metadata_fields = get_metadata_fields(new_cls)
      for field in metadata_fields:
        value_dict[field.name] = getattr(metadata, field.name)

      return new_cls(**value_dict)

    jax.tree_util.register_pytree_node(
        nodetype=new_cls, flatten_func=flatten, unflatten_func=unflatten)
    return new_cls