reducer.py 2.65 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
"""Built-in reducer function."""
from __future__ import absolute_import

import dgl.backend as F

__all__ = ["ReduceFunction", "sum", "max"]

class ReduceFunction(object):
    def __call__(self, node, msgs):
        raise NotImplementedError

    def name(self):
        raise NotImplementedError

15
16
17
    def is_spmv_supported(self):
        raise NotImplementedError

18
19
class BundledReduceFunction(ReduceFunction):
    def __init__(self, fn_list):
20
21
22
23
24
25
26
        if not isinstance(fn_list, (list, tuple)):
            fn_list = [fn_list]
        else:
            # sanity check on out field
            for fn in fn_list:
                if isinstance(fn, ReduceFunction) and fn.out_field is None:
                    raise RuntimeError("Not specifying out field for multiple reduce is ambiguous")
27
28
        self.fn_list = fn_list

29
30
31
32
33
34
    def is_spmv_supported(self):
        for fn in self.fn_list:
            if not isinstance(fn, ReduceFunction) or not fn.is_spmv_supported():
                return False
        return True

35
36
37
38
39
40
41
42
    def __call__(self, node, msgs):
        ret = None
        for fn in self.fn_list:
            rpr = fn(node, msgs)
            if ret is None:
                ret = rpr
            else:
                try:
43
                    # ret and rpr must be dict
44
                    ret.update(rpr)
45
46
                except:
                    raise RuntimeError("Must specify out field for multiple reudce")
47
48
49
50
51
        return ret

    def name(self):
        return "bundled"

52
53
54
55
56
class ReducerFunctionTemplate(ReduceFunction):
    def __init__(self, name, batch_op, nonbatch_op, msg_field=None, out_field=None):
        self.name = name
        self.batch_op = batch_op
        self.nonbatch_op = nonbatch_op
57
58
59
        self.msg_field = msg_field
        self.out_field = out_field

60
61
62
63
    def is_spmv_supported(self):
        # TODO: support max
        return self.name == "sum"

64
65
66
    def __call__(self, node, msgs):
        if isinstance(msgs, list):
            if self.msg_field is None:
67
                ret = self.nonbatch_op(msgs)
68
            else:
69
                ret = self.nonbatch_op([msg[self.msg_field] for msg in msgs])
70
71
        else:
            if self.msg_field is None:
72
                ret = self.batch_op(msgs, 1)
73
            else:
74
                ret = self.batch_op(msgs[self.msg_field], 1)
75
76
77
78
79
80
        if self.out_field is None:
            return ret
        else:
            return {self.out_field : ret}

    def name(self):
81
        return self.name
82
83
84

_python_sum = sum
def sum(msgs=None, out=None):
85
    return ReducerFunctionTemplate("sum", F.sum, _python_sum, msgs, out)
86
87
88

_python_max = max
def max(msgs=None, out=None):
89
    return ReducerFunctionTemplate("max", F.max, _python_max, msgs, out)