Commit 68fb5f7e authored by Lingfan Yu's avatar Lingfan Yu Committed by Minjie Wang
Browse files

Reduce API (#15)

* add reduce_msg related api to dgl graph

* add reduce_sum, switch backend from numpy to pytorch

* update gat gcn to use reduce msg api

* remove reduce_sum

* add built-in reduce functions
parent b9073209
"""
Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT
"""
import networkx as nx import networkx as nx
from dgl.graph import DGLGraph from dgl.graph import DGLGraph
import torch import torch
...@@ -7,44 +13,33 @@ import argparse ...@@ -7,44 +13,33 @@ import argparse
from dataset import load_data, preprocess_features from dataset import load_data, preprocess_features
import numpy as np import numpy as np
class NodeUpdateModule(nn.Module): class NodeReduceModule(nn.Module):
def __init__(self, input_dim, num_hidden, aggregator, num_heads=3, act=None, def __init__(self, input_dim, num_hidden, num_heads=3, input_dropout=None,
attention_dropout=None, input_dropout=None, residual=False): attention_dropout=None):
super(NodeUpdateModule, self).__init__() super(NodeReduceModule, self).__init__()
self.num_hidden = num_hidden
self.num_heads = num_heads self.num_heads = num_heads
self.input_dropout = input_dropout
self.attention_dropout = attention_dropout
self.fc = nn.ModuleList( self.fc = nn.ModuleList(
[nn.Linear(input_dim, num_hidden, bias=False) [nn.Linear(input_dim, num_hidden, bias=False)
for _ in range(num_heads)]) for _ in range(num_heads)])
self.attention = nn.ModuleList( self.attention = nn.ModuleList(
[nn.Linear(num_hidden * 2, 1, bias=False) for _ in range(num_heads)]) [nn.Linear(num_hidden * 2, 1, bias=False) for _ in range(num_heads)])
self.act = act
self.attention_dropout = attention_dropout
self.input_dropout = input_dropout
self.aggregator = aggregator
self.residual = residual
def forward(self, node, msgs): def forward(self, msgs):
hv = node['h'] src, dst = zip(*msgs)
hu = torch.cat(msgs, dim=0) hu = torch.cat(src, dim=0) # neighbor repr
hv = torch.cat(dst, dim=0)
# number of neighbors, including itself msgs_repr = []
n = len(msgs) + 1
out = [] # iterate for each head
for i in range(self.num_heads): for i in range(self.num_heads):
hvv = hv
huu = hu
if self.input_dropout is not None:
hvv = F.dropout(hvv, self.input_dropout)
huu = F.dropout(huu, self.input_dropout)
# calc W*hself and W*hneigh # calc W*hself and W*hneigh
hvv = self.fc[i](hv) hvv = self.fc[i](hv)
huu = self.fc[i](hu) huu = self.fc[i](hu)
# concat itself with neighbors to make self-attention
huu = torch.cat((hvv, huu), dim=0)
# calculate W*hself||W*hneigh # calculate W*hself||W*hneigh
h = torch.cat((hvv.expand(n, -1), huu), dim=1) h = torch.cat((hvv, huu), dim=1)
a = F.leaky_relu(self.attention[i](h)) a = F.leaky_relu(self.attention[i](h))
a = F.softmax(a, dim=0) a = F.softmax(a, dim=0)
if self.attention_dropout is not None: if self.attention_dropout is not None:
...@@ -52,25 +47,41 @@ class NodeUpdateModule(nn.Module): ...@@ -52,25 +47,41 @@ class NodeUpdateModule(nn.Module):
if self.input_dropout is not None: if self.input_dropout is not None:
hvv = F.dropout(hvv, self.input_dropout) hvv = F.dropout(hvv, self.input_dropout)
h = torch.sum(a * hvv, 0, keepdim=True) h = torch.sum(a * hvv, 0, keepdim=True)
# add residual connection msgs_repr.append(h)
return msgs_repr
class NodeUpdateModule(nn.Module):
def __init__(self, residual, fc, act, aggregator):
super(NodeUpdateModule, self).__init__()
self.residual = residual
self.fc = fc
self.act = act
self.aggregator = aggregator
def forward(self, node, msgs_repr):
# apply residual connection and activation for each head
for i in range(len(msgs_repr)):
if self.residual: if self.residual:
h += hvv h = self.fc[i](node['h'])
msgs_repr[i] = msgs_repr[i] + h
if self.act is not None: if self.act is not None:
h = self.act(h) msgs_repr[i] = self.act(msgs_repr[i])
out.append(h)
# aggregate multi-head results # aggregate multi-head results
h = self.aggregator(out) h = self.aggregator(msgs_repr)
return {'h': h} return {'h': h}
class GAT(nn.Module): class GAT(nn.Module):
def __init__(self, num_layers, in_dim, num_hidden, num_classes, num_heads, def __init__(self, num_layers, in_dim, num_hidden, num_classes, num_heads,
activation, attention_dropout, input_dropout, use_residual=False): activation, input_dropout, attention_dropout, use_residual=False):
super(GAT, self).__init__() super(GAT, self).__init__()
self.layers = nn.ModuleList() self.input_dropout = input_dropout
# update layers self.reduce_layers = nn.ModuleList()
aggregator = lambda x: torch.cat(x, 1) self.update_layers = nn.ModuleList()
# hidden layers
for i in range(num_layers): for i in range(num_layers):
if i == 0: if i == 0:
last_dim = in_dim last_dim = in_dim
...@@ -78,19 +89,30 @@ class GAT(nn.Module): ...@@ -78,19 +89,30 @@ class GAT(nn.Module):
else: else:
last_dim = num_hidden * num_heads # because of concat heads last_dim = num_hidden * num_heads # because of concat heads
residual = use_residual residual = use_residual
self.layers.append( self.reduce_layers.append(
NodeUpdateModule(last_dim, num_hidden, aggregator, num_heads, NodeReduceModule(last_dim, num_hidden, num_heads, input_dropout,
activation, attention_dropout, input_dropout, residual)) attention_dropout))
# projection layer self.update_layers.append(
# FIXME: does pytorch has something similar to tf.add_n which sum over a list? NodeUpdateModule(residual, self.reduce_layers[-1].fc, activation,
aggregator = lambda x: reduce(lambda a, b: a+b, x) lambda x: torch.cat(x, 1)))
self.layers.append(NodeUpdateModule(num_hidden * 3, num_classes, aggregator, # projection
1, None, attention_dropout, input_dropout, False)) self.reduce_layers.append(
NodeReduceModule(num_hidden * num_heads, num_classes, 1, input_dropout,
attention_dropout))
self.update_layers.append(
NodeUpdateModule(False, self.reduce_layers[-1].fc, None, sum))
def forward(self, g): def forward(self, g):
g.register_message_func(lambda src, dst, edge: src['h']) g.register_message_func(lambda src, dst, edge: (src['h'], dst['h']))
for layer in self.layers: for reduce_func, update_func in zip(self.reduce_layers, self.update_layers):
g.register_update_func(layer) # apply dropout
if self.input_dropout is not None:
# TODO (lingfan): use batched dropout once we have better api
# for global manipulation
for n in g.nodes():
g.node[n]['h'] = F.dropout(g.node[n]['h'], p=self.input_dropout)
g.register_reduce_func(reduce_func)
g.register_update_func(update_func)
g.update_all() g.update_all()
logits = [g.node[n]['h'] for n in g.nodes()] logits = [g.node[n]['h'] for n in g.nodes()]
logits = torch.cat(logits, dim=0) logits = torch.cat(logits, dim=0)
...@@ -116,8 +138,8 @@ def main(args): ...@@ -116,8 +138,8 @@ def main(args):
y_train.shape[1], y_train.shape[1],
args.num_heads, args.num_heads,
F.elu, F.elu,
attention_dropout,
input_dropout, input_dropout,
attention_dropout,
args.residual) args.residual)
# use optimizer # use optimizer
......
"""
Semi-Supervised Classification with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1609.02907
Code: https://github.com/tkipf/gcn
"""
import networkx as nx import networkx as nx
from dgl.graph import DGLGraph from dgl.graph import DGLGraph
import torch import torch
...@@ -14,24 +20,20 @@ class NodeUpdateModule(nn.Module): ...@@ -14,24 +20,20 @@ class NodeUpdateModule(nn.Module):
self.act = act self.act = act
self.p = p self.p = p
def forward(self, node, msgs): def forward(self, node, msgs_repr):
h = node['h'] h = node['h']
# (lingfan): how to write dropout, is the following correct?
if self.p is not None:
h = F.dropout(h, p=self.p)
# aggregate messages # aggregate messages
for msg in msgs: h = h + msgs_repr
h += msg
h = self.linear(h) h = self.linear(h)
if self.act is not None: if self.act is not None:
h = self.act(h) h = self.act(h)
# (lingfan): Can user directly update node instead of using return statement?
return {'h': h} return {'h': h}
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, input_dim, num_hidden, num_classes, num_layers, activation, dropout=None, output_projection=True): def __init__(self, input_dim, num_hidden, num_classes, num_layers, activation, dropout=None, output_projection=True):
super(GCN, self).__init__() super(GCN, self).__init__()
self.dropout = dropout
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# hidden layers # hidden layers
last_dim = input_dim last_dim = input_dim
...@@ -43,9 +45,17 @@ class GCN(nn.Module): ...@@ -43,9 +45,17 @@ class GCN(nn.Module):
if output_projection: if output_projection:
self.layers.append(NodeUpdateModule(num_hidden, num_classes, p=dropout)) self.layers.append(NodeUpdateModule(num_hidden, num_classes, p=dropout))
def forward(self, g): def forward(self, g):
g.register_message_func(lambda src, dst, edge: src['h']) g.register_message_func(lambda src, dst, edge: src['h'])
g.register_reduce_func('sum')
for layer in self.layers: for layer in self.layers:
# apply dropout
if self.dropout is not None:
# TODO (lingfan): use batched dropout once we have better api
# for global manipulation
for n in g.nodes():
g.node[n]['h'] = F.dropout(g.node[n]['h'], p=self.dropout)
g.register_update_func(layer) g.register_update_func(layer)
g.update_all() g.update_all()
logits = [g.node[n]['h'] for n in g.nodes()] logits = [g.node[n]['h'] for n in g.nodes()]
......
__backend__ = 'numpy' import os
from dgl.backend.numpy import * __backend__ = os.environ.get('DGLBACKEND', 'pytorch').lower()
if __backend__ == 'numpy':
from dgl.backend.numpy import *
elif __backend__ == 'pytorch':
from dgl.backend.pytorch import *
else:
raise Exception("Unsupported backend %s" % __backend__)
from __future__ import absolute_import
import torch
import scipy.sparse
Tensor = torch.Tensor
SparseTensor = scipy.sparse.spmatrix
def asnumpy(a):
return a.cpu().numpy()
def reduce_sum(a):
return sum(a)
def reduce_max(a):
a = torch.cat(a, 0)
a, _ = torch.max(a, 0, keepdim=True)
return a
...@@ -15,6 +15,7 @@ __N_REPR__ = "__n_repr__" ...@@ -15,6 +15,7 @@ __N_REPR__ = "__n_repr__"
__MFUNC__ = "__mfunc__" __MFUNC__ = "__mfunc__"
__EFUNC__ = "__efunc__" __EFUNC__ = "__efunc__"
__UFUNC__ = "__ufunc__" __UFUNC__ = "__ufunc__"
__RFUNC__ = "__rfunc__"
class DGLGraph(DiGraph): class DGLGraph(DiGraph):
"""Base graph class specialized for neural networks on graphs. """Base graph class specialized for neural networks on graphs.
...@@ -138,12 +139,65 @@ class DGLGraph(DiGraph): ...@@ -138,12 +139,65 @@ class DGLGraph(DiGraph):
for e in edges: for e in edges:
self.edges[e][__EFUNC__] = edge_func self.edges[e][__EFUNC__] = edge_func
def register_reduce_func(self, reduce_func, nodes='all', batchable=False):
"""Register message reduce function on incoming edges.
The reduce function should be compatible with following signature:
edge_reprs -> reduced_edge_repr
It computes the reduced edge representations using the representations
of the in-coming edges (the same concept as messages).
The reduce function can be any of the pre-defined functions ('sum',
'max'). If built-in function is used, computation will be performed
efficiently (using generic-SPMV kernels).
Parameters
----------
reduce_func : str or callable
Reduce function on incoming edges.
nodes : str, node, container or tensor
The nodes for which the reduce function is registered. Default is
registering for all the nodes. Registering for multiple nodes is
supported.
batchable : bool
Whether the provided reduce function allows batch computing.
Examples
--------
Register for all nodes.
>>> g.register_reduce_func(rfunc)
Register for a specific node.
>>> g.register_reduce_func(rfunc, u) # TODO Not implemented
Register for multiple nodes.
>>> u = [u1, u2, u3, ...]
>>> g.register_reduce_func(rfunc, u)
"""
if isinstance(reduce_func, str):
# built-in reduce func
if reduce_func == 'sum':
reduce_func = F.reduce_sum
elif reduce_func == 'max':
reduce_func = F.reduce_max
else:
raise NotImplementedError(
"Built-in function %s not implemented" % reduce_func)
if nodes == 'all':
self.r_func = reduce_func
else:
for n in nodes:
self.nodes[n][__RFUNC__] = reduce_func
def register_update_func(self, update_func, nodes='all', batchable=False): def register_update_func(self, update_func, nodes='all', batchable=False):
"""Register computation on nodes. """Register computation on nodes.
The update function should be compatible with following signature: The update function should be compatible with following signature:
(node_reprs, edge_reprs) -> node_reprs (node_reprs, reduced_edge_repr) -> node_reprs
It computes the new node representations using the representations It computes the new node representations using the representations
of the in-coming edges (the same concept as messages) and the node of the in-coming edges (the same concept as messages) and the node
...@@ -289,10 +343,14 @@ class DGLGraph(DiGraph): ...@@ -289,10 +343,14 @@ class DGLGraph(DiGraph):
v = preds v = preds
# TODO(minjie): tensorize the message batching # TODO(minjie): tensorize the message batching
m = [self.edges[vv, uu][__MSG__] for vv in v] m = [self.edges[vv, uu][__MSG__] for vv in v]
f_reduce = self.nodes[uu].get(__RFUNC__, self.r_func)
assert f_reduce is not None, \
"Reduce function not registered for node %s" % uu
msgs_reduced_repr = f_reduce(m)
f_update = self.nodes[uu].get(__UFUNC__, self.u_func) f_update = self.nodes[uu].get(__UFUNC__, self.u_func)
assert f_update is not None, \ assert f_update is not None, \
"Update function not registered for node %s" % uu "Update function not registered for node %s" % uu
self.node[uu].update(f_update(self.nodes[uu], m)) self.node[uu].update(f_update(self.nodes[uu], msgs_reduced_repr))
def update_by_edge(self, u, v): def update_by_edge(self, u, v):
"""Trigger the message function on u->v and update v. """Trigger the message function on u->v and update v.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment