array_converter.py 13.8 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import functools
3
from inspect import getfullargspec
4
from typing import Callable, Optional, Tuple, Type, Union
5

6
7
8
import numpy as np
import torch

9
TemplateArrayType = Union[np.ndarray, torch.Tensor, list, tuple, int, float]
10

11
12
13
14
15

def array_converter(to_torch: bool = True,
                    apply_to: Tuple[str, ...] = tuple(),
                    template_arg_name_: Optional[str] = None,
                    recover: bool = True) -> Callable:
16
17
    """Wrapper function for data-type agnostic processing.

18
19
    First converts input arrays to PyTorch tensors or NumPy arrays for middle
    calculation, then convert output to original data-type if `recover=True`.
20
21

    Args:
22
23
24
25
26
27
28
29
30
31
        to_torch (bool): Whether to convert to PyTorch tensors for middle
            calculation. Defaults to True.
        apply_to (Tuple[str]): The arguments to which we apply data-type
            conversion. Defaults to an empty tuple.
        template_arg_name_ (str, optional): Argument serving as the template
            (return arrays should have the same dtype and device as the
            template). Defaults to None. If None, we will use the first
            argument in `apply_to` as the template argument.
        recover (bool): Whether or not to recover the wrapped function outputs
            to the `template_arg_name_` type. Defaults to True.
32
33

    Raises:
34
35
36
37
38
39
40
41
42
        ValueError: When template_arg_name_ is not among all args, or when
            apply_to contains an arg which is not among all args, a ValueError
            will be raised. When the template argument or an argument to
            convert is a list or tuple, and cannot be converted to a NumPy
            array, a ValueError will be raised.
        TypeError: When the type of the template argument or an argument to
            convert does not belong to the above range, or the contents of such
            an list-or-tuple-type argument do not share the same data type, a
            TypeError will be raised.
43
44

    Returns:
45
        Callable: Wrapped function.
46

47
    Examples:
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        >>> import torch
        >>> import numpy as np
        >>>
        >>> # Use torch addition for a + b,
        >>> # and convert return values to the type of a
        >>> @array_converter(apply_to=('a', 'b'))
        >>> def simple_add(a, b):
        >>>     return a + b
        >>>
        >>> a = np.array([1.1])
        >>> b = np.array([2.2])
        >>> simple_add(a, b)
        >>>
        >>> # Use numpy addition for a + b,
        >>> # and convert return values to the type of b
        >>> @array_converter(to_torch=False, apply_to=('a', 'b'),
        >>>                  template_arg_name_='b')
        >>> def simple_add(a, b):
        >>>     return a + b
        >>>
68
        >>> simple_add(a, b)
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        >>>
        >>> # Use torch funcs for floor(a) if flag=True else ceil(a),
        >>> # and return the torch tensor
        >>> @array_converter(apply_to=('a',), recover=False)
        >>> def floor_or_ceil(a, flag=True):
        >>>     return torch.floor(a) if flag else torch.ceil(a)
        >>>
        >>> floor_or_ceil(a, flag=False)
    """

    def array_converter_wrapper(func):
        """Outer wrapper for the function."""

        @functools.wraps(func)
        def new_func(*args, **kwargs):
            """Inner wrapper for the arguments."""
            if len(apply_to) == 0:
                return func(*args, **kwargs)

            func_name = func.__name__

            arg_spec = getfullargspec(func)

            arg_names = arg_spec.args
            arg_num = len(arg_names)
            default_arg_values = arg_spec.defaults
            if default_arg_values is None:
                default_arg_values = []
            no_default_arg_num = len(arg_names) - len(default_arg_values)

            kwonly_arg_names = arg_spec.kwonlyargs
            kwonly_default_arg_values = arg_spec.kwonlydefaults
            if kwonly_default_arg_values is None:
                kwonly_default_arg_values = {}

            all_arg_names = arg_names + kwonly_arg_names

            # in case there are args in the form of *args
            if len(args) > arg_num:
                named_args = args[:arg_num]
                nameless_args = args[arg_num:]
            else:
                named_args = args
                nameless_args = []

            # template argument data type is used for all array-like arguments
            if template_arg_name_ is None:
                template_arg_name = apply_to[0]
            else:
                template_arg_name = template_arg_name_

            if template_arg_name not in all_arg_names:
                raise ValueError(f'{template_arg_name} is not among the '
                                 f'argument list of function {func_name}')

            # inspect apply_to
            for arg_to_apply in apply_to:
                if arg_to_apply not in all_arg_names:
127
128
                    raise ValueError(
                        f'{arg_to_apply} is not an argument of {func_name}')
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

            new_args = []
            new_kwargs = {}

            converter = ArrayConverter()
            target_type = torch.Tensor if to_torch else np.ndarray

            # non-keyword arguments
            for i, arg_value in enumerate(named_args):
                if arg_names[i] in apply_to:
                    new_args.append(
                        converter.convert(
                            input_array=arg_value, target_type=target_type))
                else:
                    new_args.append(arg_value)

                if arg_names[i] == template_arg_name:
                    template_arg_value = arg_value

            kwonly_default_arg_values.update(kwargs)
            kwargs = kwonly_default_arg_values

            # keyword arguments and non-keyword arguments using default value
            for i in range(len(named_args), len(all_arg_names)):
                arg_name = all_arg_names[i]
                if arg_name in kwargs:
                    if arg_name in apply_to:
                        new_kwargs[arg_name] = converter.convert(
                            input_array=kwargs[arg_name],
                            target_type=target_type)
                    else:
                        new_kwargs[arg_name] = kwargs[arg_name]
                else:
                    default_value = default_arg_values[i - no_default_arg_num]
                    if arg_name in apply_to:
                        new_kwargs[arg_name] = converter.convert(
                            input_array=default_value, target_type=target_type)
                    else:
                        new_kwargs[arg_name] = default_value
                if arg_name == template_arg_name:
                    template_arg_value = kwargs[arg_name]

            # add nameless args provided by *args (if exists)
            new_args += nameless_args

            return_values = func(*new_args, **new_kwargs)
            converter.set_template(template_arg_value)

            def recursive_recover(input_data):
                if isinstance(input_data, (tuple, list)):
                    new_data = []
                    for item in input_data:
                        new_data.append(recursive_recover(item))
                    return tuple(new_data) if isinstance(input_data,
                                                         tuple) else new_data
                elif isinstance(input_data, dict):
                    new_data = {}
                    for k, v in input_data.items():
                        new_data[k] = recursive_recover(v)
                    return new_data
                elif isinstance(input_data, (torch.Tensor, np.ndarray)):
                    return converter.recover(input_data)
                else:
                    return input_data

            if recover:
                return recursive_recover(return_values)
            else:
                return return_values

        return new_func

    return array_converter_wrapper


class ArrayConverter:
205
    """Utility class for data-type agnostic processing.
206

207
    Args:
208
209
        template_array (np.ndarray or torch.Tensor or list or tuple or int or
            float, optional): Template array. Defaults to None.
210
    """
211
212
213
214
    SUPPORTED_NON_ARRAY_TYPES = (int, float, np.int8, np.int16, np.int32,
                                 np.int64, np.uint8, np.uint16, np.uint32,
                                 np.uint64, np.float16, np.float32, np.float64)

215
216
    def __init__(self,
                 template_array: Optional[TemplateArrayType] = None) -> None:
217
218
219
        if template_array is not None:
            self.set_template(template_array)

220
    def set_template(self, array: TemplateArrayType) -> None:
221
222
223
        """Set template array.

        Args:
224
225
            array (np.ndarray or torch.Tensor or list or tuple or int or
                float): Template array.
226
227

        Raises:
228
229
230
231
232
            ValueError: If input is list or tuple and cannot be converted to a
                NumPy array, a ValueError is raised.
            TypeError: If input type does not belong to the above range, or the
                contents of a list or tuple do not share the same data type, a
                TypeError is raised.
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        """
        self.array_type = type(array)
        self.is_num = False
        self.device = 'cpu'

        if isinstance(array, np.ndarray):
            self.dtype = array.dtype
        elif isinstance(array, torch.Tensor):
            self.dtype = array.dtype
            self.device = array.device
        elif isinstance(array, (list, tuple)):
            try:
                array = np.array(array)
                if array.dtype not in self.SUPPORTED_NON_ARRAY_TYPES:
                    raise TypeError
                self.dtype = array.dtype
            except (ValueError, TypeError):
250
251
                print('The following list cannot be converted to a numpy '
                      f'array of supported dtype:\n{array}')
252
                raise
253
        elif isinstance(array, (int, float)):
254
255
256
257
            self.array_type = np.ndarray
            self.is_num = True
            self.dtype = np.dtype(type(array))
        else:
258
259
            raise TypeError(
                f'Template type {self.array_type} is not supported.')
260

261
    def convert(
262
263
264
265
        self,
        input_array: TemplateArrayType,
        target_type: Optional[Type] = None,
        target_array: Optional[Union[np.ndarray, torch.Tensor]] = None
266
    ) -> Union[np.ndarray, torch.Tensor]:
267
268
269
        """Convert input array to target data type.

        Args:
270
271
272
273
            input_array (np.ndarray or torch.Tensor or list or tuple or int or
                float): Input array.
            target_type (Type, optional): Type to which input array is
                converted. It should be `np.ndarray` or `torch.Tensor`.
274
                Defaults to None.
275
276
            target_array (np.ndarray or torch.Tensor, optional): Template array
                to which input array is converted. Defaults to None.
277
278

        Raises:
279
280
281
282
283
            ValueError: If input is list or tuple and cannot be converted to a
                NumPy array, a ValueError is raised.
            TypeError: If input type does not belong to the above range, or the
                contents of a list or tuple do not share the same data type, a
                TypeError is raised.
284
285
286

        Returns:
            np.ndarray or torch.Tensor: The converted array.
287
288
289
290
291
292
293
        """
        if isinstance(input_array, (list, tuple)):
            try:
                input_array = np.array(input_array)
                if input_array.dtype not in self.SUPPORTED_NON_ARRAY_TYPES:
                    raise TypeError
            except (ValueError, TypeError):
294
295
                print('The input cannot be converted to a single-type numpy '
                      f'array:\n{input_array}')
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
                raise
        elif isinstance(input_array, self.SUPPORTED_NON_ARRAY_TYPES):
            input_array = np.array(input_array)
        array_type = type(input_array)
        assert target_type is not None or target_array is not None, \
            'must specify a target'
        if target_type is not None:
            assert target_type in (np.ndarray, torch.Tensor), \
                'invalid target type'
            if target_type == array_type:
                return input_array
            elif target_type == np.ndarray:
                # default dtype is float32
                converted_array = input_array.cpu().numpy().astype(np.float32)
            else:
                # default dtype is float32, device is 'cpu'
                converted_array = torch.tensor(
                    input_array, dtype=torch.float32)
        else:
            assert isinstance(target_array, (np.ndarray, torch.Tensor)), \
                'invalid target array type'
            if isinstance(target_array, array_type):
                return input_array
            elif isinstance(target_array, np.ndarray):
                converted_array = input_array.cpu().numpy().astype(
                    target_array.dtype)
            else:
                converted_array = target_array.new_tensor(input_array)
        return converted_array

326
327
    def recover(
        self, input_array: Union[np.ndarray, torch.Tensor]
328
    ) -> Union[np.ndarray, torch.Tensor, int, float]:
329
330
331
        """Recover input type to original array type.

        Args:
332
            input_array (np.ndarray or torch.Tensor): Input array.
333
334

        Returns:
335
            np.ndarray or torch.Tensor or int or float: Converted array.
336
        """
337
338
339
340
341
342
343
344
345
346
347
348
        assert isinstance(input_array, (np.ndarray, torch.Tensor)), \
            'invalid input array type'
        if isinstance(input_array, self.array_type):
            return input_array
        elif isinstance(input_array, torch.Tensor):
            converted_array = input_array.cpu().numpy().astype(self.dtype)
        else:
            converted_array = torch.tensor(
                input_array, dtype=self.dtype, device=self.device)
        if self.is_num:
            converted_array = converted_array.item()
        return converted_array