".github/workflows/test-linux-cpu.yml" did not exist on "0ed5d81196868c0f87b7e3e89d28c998b759fc60"
dense_aggregate.py 2.21 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from dgl.array import DGLArray, DGLDenseArray, DGLSparseArray
import dgl.backend as F

def _gridize(frame, key_column_names, src_column_name):
    if type(key_column_names) is str:
        key_column = frame[key_column_names]
        assert F.prod(key_column.applicable)
        if type(key_column) is DGLDenseArray:
            row = key_column.data
            if type(row) is F.Tensor:
                assert F.isinteger(row) and len(F.shape(row)) == 1
                col = F.unique(row)
                xy = (F.expand_dims(row, 1) == F.expand_dims(col, 0))
                if src_column_name:
                    src_column = frame[src_column_name]
                    if type(src_column) is DGLDenseArray:
                        z = src_column.data
                        if type(z) is F.Tensor:
                            z = F.expand_dims(z, 1)
                            for i in range(2, len(F.shape(z))):
                                xy = F.expand_dims(xy, i)
                            xy = F.astype(xy, F.dtype(z))
                            return col, xy * z
                        elif type(z) is list:
                            raise NotImplementedError()
                        else:
                            raise RuntimeError()
                else:
                    return col, xy
            elif type(row) is list:
                raise NotImplementedError()
            else:
                raise RuntimeError()
        else:
            raise NotImplementedError()
    elif type(key_column_names) is list:
        raise NotImplementedError()
    else:
        raise RuntimeError()

def aggregator(src_column_name=''):
    def decorator(a):
        def decorated(frame, key_column_names):
            col, xy = _gridize(frame, key_column_names, src_column_name)
            trg_column_name = src_column_name + a.__name__
            key = DGLDenseArray(col)
            trg = DGLDenseArray(a(xy))
            return {key_column_names : key, trg_column_name : trg}
        return decorated
    return decorator

def COUNT():
    @aggregator()
    def count(x):
        return F.sum(x, 0)
    return count

def SUM(src_column_name):
    @aggregator(src_column_name)
    def sum(x):
        return F.sum(x, 0)
    return sum