toy.py 2.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
###############################################################################
# A toy example
# -------------
#
# Let’s begin with the simplest graph possible with two nodes, and set
# the node representations:

import torch as th
import dgl

g = dgl.DGLGraph()
g.add_nodes(2)
g.add_edge(1, 0)

x = th.tensor([[0.0, 0.0], [1.0, 2.0]])
g.nodes[:].data['x'] = x

###############################################################################
# A syntax sugar for accessing feature data of all nodes
print(g.ndata['x'])

###############################################################################
# What we want to do is simply to copy representation from node#1 to
# node#0, but with a message passing interface. We do this like what we
# will do over a pair of sockets, with a send and a recv interface. The
# two user defined function (UDF) specifies the actions: deposit the
# value into an internal key-value store with the key msg, and retrive
# it. Note that there may be multiple incoming edges to a node, and the
# receiving end aggregates them.

def send_source(edges):  # type is dgl.EdgeBatch
    return {'msg': edges.src['x']}

def simple_reduce(nodes):  # type is dgl.NodeBatch
    msgs = nodes.mailbox['msg']
    return {'x' : th.sum(msgs, dim=1)}

g.send((1, 0), message_func=send_source)
g.recv(0, reduce_func=simple_reduce)
print(g.ndata)

###############################################################################
# Some times the computation may involve representations on the edges.
# Let’s say we want to “amplify” the message:

w = th.tensor([2.0])
g.edata['w'] = w

def send_source_with_edge_weight(edges):
    return {'msg': edges.src['x'] * edges.data['w']}

g.send((1, 0), message_func=send_source_with_edge_weight)
g.recv(0, reduce_func=simple_reduce)
print(g.ndata)

###############################################################################
# Or we may need to involve the desination’s representation, and here
# is one version:

def simple_reduce_addup(nodes):
    msgs = nodes.mailbox['msg']
    return {'x' : nodes.data['x'] + th.sum(msgs, dim=1)}

g.send((1, 0), message_func=send_source_with_edge_weight)
g.recv(0, reduce_func=simple_reduce_addup)
print(g.ndata)

del g.ndata['x']
del g.edata['w']