einops_exts.py 4.93 KB
Newer Older
yuguo960516's avatar
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# From https://github.com/arogozhnikov/einops/blob/master/einops/layers/oneflow.py

import re
from functools import wraps

import oneflow as flow
from einops import rearrange, reduce, repeat
from einops._backends import AbstractBackend
from einops.layers import RearrangeMixin
from oneflow import nn


class Rearrange(RearrangeMixin, flow.nn.Module):
    def forward(self, input):
        return self._apply_recipe(input)


class OneFlowBackend(AbstractBackend):
    framework_name = "oneflow"

    def __init__(self):
        import oneflow as flow

        self.flow = flow

    def is_appropriate_type(self, tensor):
        return isinstance(tensor, self.flow.Tensor)

    def from_numpy(self, x):
        variable = self.flow.from_numpy(x)
        if self.is_float_type(variable):
            # attach grad only to floating types
            variable.requires_grad = True
        return variable

    def to_numpy(self, x):
        return x.detach().cpu().numpy()

    def arange(self, start, stop):
        return self.flow.arange(start, stop, dtype=self.flow.int64)

    def reduce(self, x, operation, reduced_axes):
        for axis in sorted(reduced_axes, reverse=True):
            if operation == "min":
                x, _ = x.min(dim=axis)
            elif operation == "max":
                x, _ = x.max(dim=axis)
            elif operation in ["sum", "mean", "prod"]:
                x = getattr(x, operation)(dim=axis)
            else:
                raise NotImplementedError("Unknown reduction ", operation)
        return x

    def transpose(self, x, axes):
        return x.permute(axes)

    def stack_on_zeroth_dimension(self, tensors: list):
        return self.flow.stack(tensors)

    def add_axes(self, x, n_axes, pos2len):
        repeats = [-1] * n_axes
        for axis_position, axis_length in pos2len.items():
            x = self.add_axis(x, axis_position)
            repeats[axis_position] = axis_length
        return x.expand(*repeats)

    def tile(self, x, repeats):
        return x.repeat(repeats)

    def add_axis(self, x, new_position):
        return self.flow.unsqueeze(x, new_position)

    def is_float_type(self, x):
        return x.dtype in [self.flow.float16, self.flow.float32, self.flow.float64]

    def einsum(self, pattern, *x):
        return self.flow.einsum(pattern, *x)


# From https://github.com/lucidrains/einops-exts/tree/main/einops_exts


class EinopsToAndFrom(nn.Module):
    def __init__(self, from_einops, to_einops, fn):
        super().__init__()
        self.from_einops = from_einops
        self.to_einops = to_einops
        self.fn = fn

    def forward(self, x, **kwargs):
        shape = x.shape
        reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(" "), shape)))
        x = rearrange(x, f"{self.from_einops} -> {self.to_einops}")
        x = self.fn(x, **kwargs)
        x = rearrange(x, f"{self.to_einops} -> {self.from_einops}", **reconstitute_kwargs)
        return x


# checking shape
# @nils-werner
# https://github.com/arogozhnikov/einops/issues/168#issuecomment-1042933838


def check_shape(tensor, pattern, **kwargs):
    return rearrange(tensor, f"{pattern} -> {pattern}", **kwargs)


# do same einops operations on a list of tensors


def _many(fn):
    @wraps(fn)
    def inner(tensors, pattern, **kwargs):
        return (fn(tensor, pattern, **kwargs) for tensor in tensors)

    return inner


# do einops with unflattening of anonymously named dimensions
# (...flattened) ->  ...flattened


def _with_anon_dims(fn):
    @wraps(fn)
    def inner(tensor, pattern, **kwargs):
        regex = r"(\.\.\.[a-zA-Z]+)"
        matches = re.findall(regex, pattern)

        def get_anon_dim_name(t):
            return t.lstrip("...")

        dim_prefixes = tuple(map(get_anon_dim_name, set(matches)))

        update_kwargs_dict = dict()

        for prefix in dim_prefixes:
            assert prefix in kwargs, f'dimension list "{prefix}" was not passed in'
            dim_list = kwargs[prefix]
            assert isinstance(
                dim_list, (list, tuple)
            ), f'dimension list "{prefix}" needs to be a tuple of list of dimensions'
            dim_names = list(map(lambda ind: f"{prefix}{ind}", range(len(dim_list))))
            update_kwargs_dict[prefix] = dict(zip(dim_names, dim_list))

        def sub_with_anonymous_dims(t):
            dim_name_prefix = get_anon_dim_name(t.groups()[0])
            return " ".join(update_kwargs_dict[dim_name_prefix].keys())

        pattern_new = re.sub(regex, sub_with_anonymous_dims, pattern)

        for prefix, update_dict in update_kwargs_dict.items():
            del kwargs[prefix]
            kwargs.update(update_dict)

        return fn(tensor, pattern_new, **kwargs)

    return inner


# generate all helper functions

rearrange_many = _many(rearrange)
repeat_many = _many(repeat)
reduce_many = _many(reduce)

rearrange_with_anon_dims = _with_anon_dims(rearrange)
repeat_with_anon_dims = _with_anon_dims(repeat)
reduce_with_anon_dims = _with_anon_dims(reduce)