structures.py 7.14 KB
Newer Older
yuguo960516's avatar
bloom  
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, List

import oneflow as flow

from libai.utils import distributed as dist


@dataclass
class DistTensorData:
    tensor: flow.Tensor
    sbp_list: list = field(default_factory=lambda: ["split_0", "broadcast"])
    placement_idx: int = 0

    # Tensor-like methods
    def to_global(self, sbp=None, placement=None, device_type="cuda"):
        if sbp is not None:
            self.sbp = sbp
        else:
            sbp_list = []
            for sbp in self.sbp_list:
                sbp = sbp.split("_")
                if len(sbp) > 1:
                    # split dim
                    assert sbp[0] == "split"
                    split_dim = int(sbp[1])
                    sbp_list.append(flow.sbp.split(split_dim))
                else:
                    sbp_sign = sbp[0]
                    sbp_list.append(getattr(flow.sbp, sbp_sign))
            self.sbp = dist.get_nd_sbp(sbp_list)

        if placement is not None:
            self.tensor = self.tensor.to_global(sbp=self.sbp, placement=placement)
        else:
            # Convert local tensor to global tensor with default setting,
            # if the placement parameter is not provided.
            # When enable pipeline parallel training,
            # all the devices will be grouped into several device groups
            # and the model will be split into several stages.
            # Each stage will be placed on the corresponding device group.
            # For those tensors to be used in the last stage,
            # we first convert them to global tensor by only retain those on the device group 0,
            # then transfer the result to the last stage.
            # We do that to make sure that all the tensors used by the model are all generated
            # by the fist device group, in case that each device group containg
            # some random augmentations to the tensors without setting the same global seed.
            main_placement = dist.get_layer_placement(0, device_type)
            self.tensor = self.tensor.to_global(sbp=self.sbp, placement=main_placement)
            if self.placement_idx != 0:
                self.tensor = self.tensor.to_global(
                    placement=dist.get_layer_placement(self.placement_idx, device_type)
                )

    @staticmethod
    def stack(distTensor_lists: List["DistTensorData"]) -> "DistTensorData":
        if not isinstance(distTensor_lists[0].tensor, flow.Tensor):
            raise TypeError(
                "DistTensorData.tensor must be a flow.Tensor, but got {}. "
                "Please check the return values of `__getitem__` in dataset.".format(
                    type(distTensor_lists[0].tensor)
                )
            )

        assert len(distTensor_lists) > 0
        if len(distTensor_lists) == 1:
            # TODO(l1aoxingyu): add inplace unsqueeze
            # distTensor_lists[0].tensor.unsqueeze_(0)  # add batch dim
            distTensor_lists[0].tensor = distTensor_lists[0].tensor.unsqueeze(0)  # add batch dim
            return distTensor_lists[0]

        tensor_size = distTensor_lists[0].tensor.size()
        sbp_list = distTensor_lists[0].sbp_list
        placement_idx = distTensor_lists[0].placement_idx
        tensors = []
        for data in distTensor_lists:
            assert (
                data.tensor.size() == tensor_size
            ), f"tensor shape is not equal, {data.tensor.size()} != {tensor_size}"
            assert (
                data.sbp_list == sbp_list
            ), f"sbp_list is not equal, {data.sbp_list} != {sbp_list}!"
            assert (
                data.placement_idx == placement_idx
            ), f"placement_idx is not equal, {data.placement_idx} != {placement_idx}"
            tensors.append(data.tensor)
        tensors = flow.stack(tensors, dim=0)
        ret = DistTensorData(tensors, sbp_list=sbp_list, placement_idx=placement_idx)
        return ret


class Instance:
    """
    This class represents a instance with metadata as attributes.
    It stores the attributes of an instance (e.g., image, tokens) as "fields".

    all other (non-filed) attributes of this class are considered private:
    they must start with '_' and are not modifiable by a user.

    Some basic usage:

    1. Set/get/check a field:

        .. code-block:: python

            instance.tokens = Metadata(...)
            instance.mask = Metadata(...)
            print(instance.tokens)
            print(instance.has("mask")) # True

    2. ``len(instance)`` returns the number of instance
    """

    def __init__(self, **kwargs):

        self._fields = OrderedDict()
        for k, v in kwargs.items():
            self.set(k, v)

    def __setattr__(self, name: str, val: Any) -> None:
        if name.startswith("_"):
            super().__setattr__(name, val)
        else:
            self.set(name, val)

    def __getattr__(self, name: str):
        if name == "_fields" or name not in self._fields:
            raise AttributeError(f"Cannot find field '{name}' in the given Instance!")
        return self._fields[name]

    def set(self, name: str, value: Any):
        """
        Set the field named `name` to `value`.
        """
        self._fields[name] = value

    def has(self, name: str):
        return name in self._fields

    def remove(self, name: str):
        del self._fields[name]

    def get(self, name: str):
        return self._fields[name]

    def get_fields(self):
        return self._fields

    def __len__(self):
        return len(self._fields.keys())

    def __iter__(self):
        raise NotImplementedError("`Instances` object is not iterable!")

    @staticmethod
    def stack(instance_lists: List["Instance"]) -> "Instance":
        assert all(isinstance(i, Instance) for i in instance_lists)
        assert len(instance_lists) > 0

        ret = Instance()
        for k in instance_lists[0]._fields.keys():
            values = [i.get(k) for i in instance_lists]
            v0 = values[0]
            if isinstance(v0, flow.Tensor):
                values = flow.stack(values, dim=0)
            elif isinstance(v0, list):
                pass
            elif hasattr(type(v0), "stack"):
                values = type(v0).stack(values)
            else:
                raise ValueError("Unsupported type {} for stack.".format(type(v0)))
            ret.set(k, values)
        return ret

    def __str__(self):
        s = self.__class__.__name__ + "("
        s += "fields=[{}]".format(", ".join((f"{k}: {v}" for k, v in self._fields.items())))
        return s

    __repr__ = __str__