"git@developer.sourcefind.cn:change/sglang.git" did not exist on "9fdc6d6abc69b2d976f2a2b1d11eb456d33369be"
array.py 5.39 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from collections import MutableMapping
import dgl.backend as F

class DGLArray(MutableMapping):
    def __init__(self):
        pass

    def __delitem__(self, key, value):
        raise NotImplementedError()

    def __getitem__(self, key):
        """
        If the key is an DGLArray of identical length, this function performs a
        logical filter: i.e. it subselects all the elements in this array
        where the corresponding value in the other array evaluates to true.
        If the key is an integer this returns a single row of
        the DGLArray. If the key is a slice, this returns an DGLArray with the
        sliced rows. See the Turi Create User Guide for usage examples.
        """
        raise NotImplementedError()

    def __iter__(self):
        raise NotImplementedError()

    def __len__(self):
        raise NotImplementedError()

    def __setitem__(self, key, value):
        raise NotImplementedError()

class DGLDenseArray(DGLArray):
    def __init__(self, data, applicable=None):
        """
        Parameters
        ----------
        data : list or tensor
        """
        if type(data) is list:
            raise NotImplementedError()
        elif isinstance(data, F.Tensor):
            self._data = data
            if applicable is None:
                self._applicable = F.ones(F.shape(data)[0], dtype=F.bool) # TODO: device
            else:
                assert isinstance(applicable, F.Tensor)
                assert F.device(applicable) == F.device(data)
                assert F.isboolean(applicable)
                a_shape = F.shape(applicable)
                assert len(a_shape) == 1
                assert a_shape[0] == F.shape(data)[0]
                self._applicable = applicable

    def __getitem__(self, key):
        """
        If the key is an DGLDenseArray of identical length, this function performs a
        logical filter: i.e. it subselects all the elements in this array
        where the corresponding value in the other array evaluates to true.
        If the key is an integer this returns a single row of
        the DGLArray. If the key is a slice, this returns an DGLArray with the
        sliced rows. See the Turi Create User Guide for usage examples.
        """
        if type(key) is DGLDenseArray:
            if type(key._data) is list:
                raise NotImplementedError()
            elif type(key._data) is F.Tensor:
                if type(self._data) is F.Tensor:
                    shape = F.shape(key._data)
                    assert len(shape) == 1
                    assert shape[0] == F.shape(self._data)[0]
                    assert F.dtype(key._data) is F.bool
                    data = self._data[key._data]
                    return DGLDenseArray(data)
                else:
                    raise NotImplementedError()
            else:
                raise RuntimeError()
        elif type(key) is int:
            return self._data[key]
        elif type(key) is slice:
            raise NotImplementedError()
        else:
            raise RuntimeError()

    def __iter__(self):
        return iter(range(len(self)))

    def __len__(self):
        if type(self._data) is F.Tensor:
            return F.shape(self._data)[0]
        elif type(self._data) is list:
            return len(self._data)
        else:
            raise RuntimeError()

    def __setitem__(self, key, value):
        if type(key) is int:
            if type(self._data) is list:
                raise NotImplementedError()
            elif type(self._data) is F.Tensor:
                assert isinstance(value, F.Tensor)
                assert F.device(value) == F.device(self._data)
                assert F.dtype(value) == F.dtype(self._data)
                # TODO(gaiyu): shape
                x = []
                if key > 0:
                    x.append(self._data[:key])
                x.append(F.expand_dims(value, 0))
                if key < F.shape(self._data)[0] - 1:
                    x.append(self._data[key + 1:])
                self._data = F.concatenate(x)
            else:
                raise RuntimeError()
        elif type(key) is DGLDenseArray:
            shape = F.shape(key._data)
            assert len(shape) == 1
            assert shape[0] == F.shape(self._data)[0]
            assert F.isboolean(key._data)
            data = self._data[key._data]
        elif type(key) is DGLSparseArray:
            raise NotImplementedError()
        else:
            raise RuntimeError()

    def _listize(self):
        raise NotImplementedError()

    def _tensorize(self):
        raise NotImplementedError()

    def append(self, other):
        assert type(other, DGLDenseArray)
        if self.shape is None:
            return other
        elif other.shape is None:
            return self
        else:
            assert self.shape[1:] == other.shape[1:]
            data = F.concatenate([self.data, other.data])
            return DGLDenseArray(data)

    @property
    def applicable(self):
        return self._applicable

    @property
    def data(self):
        return self._data

    def dropna(self):
        if type(self._data) is list:
            raise NotImplementedError()
        elif isinstance(self._data, F.Tensor):
            data = F.index_by_bool(self._data, self._applicable)
            return DGLDenseArray(data)
        else:
            raise RuntimeError()

class DGLSparseArray(DGLArray):
    def __init__(self):
        raise NotImplementedError()