Unverified Commit 3efb5d8e authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[NN] Add HeteroGraphConv module for cleaner module definition (#1385)

* Add HeteroGraphConv

* add custom aggregator; some docstring

* debugging

* rm print

* fix some acc bugs

* fix initialization problem in weight basis

* passed tests

* lint

* fix graphconv flag; add error message

* add mxnet heteroconv

* more fix for mx

* lint

* fix torch cuda test

* fix mx test_nn

* add exhaust test for graphconv

* add tf heteroconv

* fix comment
parent bbfff8ce
...@@ -8,232 +8,9 @@ import time ...@@ -8,232 +8,9 @@ import time
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial
import dgl.function as fn
from dgl.data.rdf import AIFB, MUTAG, BGS, AM from dgl.data.rdf import AIFB, MUTAG, BGS, AM
from model import EntityClassify
class RelGraphConvHetero(nn.Module):
r"""Relational graph convolution layer.
Parameters
----------
in_feat : int
Input feature size.
out_feat : int
Output feature size.
rel_names : int
Relation names.
regularizer : str
Which weight regularizer to use "basis" or "bdd"
num_bases : int, optional
Number of bases. If is none, use number of relations. Default: None.
bias : bool, optional
True if bias is added. Default: True
activation : callable, optional
Activation function. Default: None
self_loop : bool, optional
True to include self loop message. Default: False
use_weight : bool, optional
If True, multiply the input node feature with a learnable weight matrix
before message passing.
dropout : float, optional
Dropout rate. Default: 0.0
"""
def __init__(self,
in_feat,
out_feat,
rel_names,
regularizer="basis",
num_bases=None,
bias=True,
activation=None,
self_loop=False,
use_weight=True,
dropout=0.0):
super(RelGraphConvHetero, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.rel_names = rel_names
self.num_rels = len(rel_names)
self.regularizer = regularizer
self.num_bases = num_bases
if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases < 0:
self.num_bases = self.num_rels
self.bias = bias
self.activation = activation
self.self_loop = self_loop
self.use_weight = use_weight
if use_weight:
if regularizer == "basis":
# add basis weights
self.weight = nn.Parameter(th.Tensor(self.num_bases, self.in_feat, self.out_feat))
if self.num_bases < self.num_rels:
# linear combination coefficients
self.w_comp = nn.Parameter(th.Tensor(self.num_rels, self.num_bases))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
if self.num_bases < self.num_rels:
nn.init.xavier_uniform_(self.w_comp,
gain=nn.init.calculate_gain('relu'))
else:
raise ValueError("Only basis regularizer is supported.")
# bias
if self.bias:
self.h_bias = nn.Parameter(th.Tensor(out_feat))
nn.init.zeros_(self.h_bias)
# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight,
gain=nn.init.calculate_gain('relu'))
self.dropout = nn.Dropout(dropout)
def basis_weight(self):
"""Message function for basis regularizer"""
if self.num_bases < self.num_rels:
# generate all weights from bases
weight = self.weight.view(self.num_bases,
self.in_feat * self.out_feat)
weight = th.matmul(self.w_comp, weight).view(
self.num_rels, self.in_feat, self.out_feat)
else:
weight = self.weight
return {self.rel_names[i] : w.squeeze(0) for i, w in enumerate(th.split(weight, 1, dim=0))}
def forward(self, g, xs):
"""Forward computation
Parameters
----------
g : DGLHeteroGraph
Input graph.
xs : dict[str, torch.Tensor]
Node feature for each node type.
Returns
-------
list of torch.Tensor
New node features for each node type.
"""
g = g.local_var()
for ntype in g.ntypes:
g.nodes[ntype].data['x'] = xs[ntype]
if self.use_weight:
ws = self.basis_weight()
funcs = {}
for i, (srctype, etype, dsttype) in enumerate(g.canonical_etypes):
g.nodes[srctype].data['h%d' % i] = th.matmul(
g.nodes[srctype].data['x'], ws[etype])
funcs[(srctype, etype, dsttype)] = (fn.copy_u('h%d' % i, 'm'), fn.mean('m', 'h'))
else:
funcs = {}
for i, (srctype, etype, dsttype) in enumerate(g.canonical_etypes):
g.nodes[srctype].data['h%d' % i] = g.nodes[srctype].data['x']
funcs[(srctype, etype, dsttype)] = (fn.copy_u('h%d' % i, 'm'), fn.mean('m', 'h'))
# message passing
g.multi_update_all(funcs, 'sum')
hs = {ntype : g.nodes[ntype].data['h'] for ntype in g.ntypes}
new_hs = {}
for ntype, h in hs.items():
# apply bias and activation
if self.self_loop:
h = h + th.matmul(xs[ntype], self.loop_weight)
if self.bias:
h = h + self.h_bias
if self.activation:
h = self.activation(h)
h = self.dropout(h)
new_hs[ntype] = h
return new_hs
class RelGraphEmbed(nn.Module):
r"""Embedding layer for featureless heterograph."""
def __init__(self,
g,
embed_size,
embed_name='embed',
activation=None,
dropout=0.0):
super(RelGraphEmbed, self).__init__()
self.g = g
self.embed_size = embed_size
self.embed_name = embed_name
self.activation = activation
self.dropout = nn.Dropout(dropout)
# create weight embeddings for each node for each relation
self.embeds = nn.ParameterDict()
for ntype in g.ntypes:
embed = nn.Parameter(th.Tensor(g.number_of_nodes(ntype), self.embed_size))
nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain('relu'))
self.embeds[ntype] = embed
def forward(self, block=None):
"""Forward computation
Parameters
----------
block : DGLHeteroGraph, optional
If not specified, directly return the full graph with embeddings stored in
:attr:`embed_name`. Otherwise, extract and store the embeddings to the block
graph and return.
Returns
-------
DGLHeteroGraph
The block graph fed with embeddings.
"""
return self.embeds
class EntityClassify(nn.Module):
def __init__(self,
g,
h_dim, out_dim,
num_bases,
num_hidden_layers=1,
dropout=0,
use_self_loop=False):
super(EntityClassify, self).__init__()
self.g = g
self.h_dim = h_dim
self.out_dim = out_dim
self.rel_names = list(set(g.etypes))
self.rel_names.sort()
self.num_bases = None if num_bases < 0 else num_bases
self.num_hidden_layers = num_hidden_layers
self.dropout = dropout
self.use_self_loop = use_self_loop
self.embed_layer = RelGraphEmbed(g, self.h_dim)
self.layers = nn.ModuleList()
# i2h
self.layers.append(RelGraphConvHetero(
self.h_dim, self.h_dim, self.rel_names, "basis",
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout, use_weight=False))
# h2h
for i in range(self.num_hidden_layers):
self.layers.append(RelGraphConvHetero(
self.h_dim, self.h_dim, self.rel_names, "basis",
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout))
# h2o
self.layers.append(RelGraphConvHetero(
self.h_dim, self.out_dim, self.rel_names, "basis",
self.num_bases, activation=None,
self_loop=self.use_self_loop))
def forward(self):
h = self.embed_layer()
for layer in self.layers:
h = layer(self.g, h)
return h
def main(args): def main(args):
# load graph data # load graph data
......
...@@ -13,231 +13,8 @@ from torch.utils.data import DataLoader ...@@ -13,231 +13,8 @@ from torch.utils.data import DataLoader
from functools import partial from functools import partial
import dgl import dgl
import dgl.function as fn
from dgl.data.rdf import AIFB, MUTAG, BGS, AM from dgl.data.rdf import AIFB, MUTAG, BGS, AM
from model import EntityClassify, RelGraphEmbed
class RelGraphConvHetero(nn.Module):
r"""Relational graph convolution layer.
Parameters
----------
in_feat : int
Input feature size.
out_feat : int
Output feature size.
rel_names : int
Relation names.
regularizer : str
Which weight regularizer to use "basis" or "bdd"
num_bases : int, optional
Number of bases. If is none, use number of relations. Default: None.
bias : bool, optional
True if bias is added. Default: True
activation : callable, optional
Activation function. Default: None
self_loop : bool, optional
True to include self loop message. Default: False
use_weight : bool, optional
If True, multiply the input node feature with a learnable weight matrix
before message passing.
dropout : float, optional
Dropout rate. Default: 0.0
"""
def __init__(self,
in_feat,
out_feat,
rel_names,
regularizer="basis",
num_bases=None,
bias=True,
activation=None,
self_loop=False,
use_weight=True,
dropout=0.0):
super(RelGraphConvHetero, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.rel_names = rel_names
self.num_rels = len(rel_names)
self.regularizer = regularizer
self.num_bases = num_bases
if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases < 0:
self.num_bases = self.num_rels
self.bias = bias
self.activation = activation
self.self_loop = self_loop
self.use_weight = use_weight
if use_weight:
if regularizer == "basis":
# add basis weights
self.weight = nn.Parameter(th.Tensor(self.num_bases, self.in_feat, self.out_feat))
if self.num_bases < self.num_rels:
# linear combination coefficients
self.w_comp = nn.Parameter(th.Tensor(self.num_rels, self.num_bases))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
if self.num_bases < self.num_rels:
nn.init.xavier_uniform_(self.w_comp,
gain=nn.init.calculate_gain('relu'))
else:
raise ValueError("Only basis regularizer is supported.")
# bias
if self.bias:
self.h_bias = nn.Parameter(th.Tensor(out_feat))
nn.init.zeros_(self.h_bias)
# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight,
gain=nn.init.calculate_gain('relu'))
self.dropout = nn.Dropout(dropout)
def basis_weight(self):
"""Message function for basis regularizer"""
if self.num_bases < self.num_rels:
# generate all weights from bases
weight = self.weight.view(self.num_bases,
self.in_feat * self.out_feat)
weight = th.matmul(self.w_comp, weight).view(
self.num_rels, self.in_feat, self.out_feat)
else:
weight = self.weight
return {self.rel_names[i] : w.squeeze(0) for i, w in enumerate(th.split(weight, 1, dim=0))}
def forward(self, g, xs):
"""Forward computation
Parameters
----------
g : DGLHeteroGraph
Input block graph.
xs : dict[str, torch.Tensor]
Node feature for each node type.
Returns
-------
list of torch.Tensor
New node features for each node type.
"""
g = g.local_var()
for ntype, x in xs.items():
g.srcnodes[ntype].data['x'] = x
if self.use_weight:
ws = self.basis_weight()
funcs = {}
for i, (srctype, etype, dsttype) in enumerate(g.canonical_etypes):
if srctype not in xs:
continue
g.srcnodes[srctype].data['h%d' % i] = th.matmul(
g.srcnodes[srctype].data['x'], ws[etype])
funcs[(srctype, etype, dsttype)] = (fn.copy_u('h%d' % i, 'm'), fn.mean('m', 'h'))
else:
funcs = {}
for i, (srctype, etype, dsttype) in enumerate(g.canonical_etypes):
if srctype not in xs:
continue
g.srcnodes[srctype].data['h%d' % i] = g.srcnodes[srctype].data['x']
funcs[(srctype, etype, dsttype)] = (fn.copy_u('h%d' % i, 'm'), fn.mean('m', 'h'))
# message passing
g.multi_update_all(funcs, 'sum')
hs = {}
for ntype in g.dsttypes:
if 'h' in g.dstnodes[ntype].data:
hs[ntype] = g.dstnodes[ntype].data['h']
def _apply(ntype, h):
# apply bias and activation
if self.self_loop:
h = h + th.matmul(xs[ntype][:h.shape[0]], self.loop_weight)
if self.activation:
h = self.activation(h)
h = self.dropout(h)
return h
hs = {ntype : _apply(ntype, h) for ntype, h in hs.items()}
return hs
class RelGraphEmbed(nn.Module):
r"""Embedding layer for featureless heterograph."""
def __init__(self,
g,
embed_size,
activation=None,
dropout=0.0):
super(RelGraphEmbed, self).__init__()
self.g = g
self.embed_size = embed_size
self.activation = activation
self.dropout = nn.Dropout(dropout)
# create weight embeddings for each node for each relation
self.embeds = nn.ParameterDict()
for ntype in g.ntypes:
embed = nn.Parameter(th.Tensor(g.number_of_nodes(ntype), self.embed_size))
nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain('relu'))
self.embeds[ntype] = embed
def forward(self, block=None):
"""Forward computation
Parameters
----------
block : DGLHeteroGraph, optional
If not specified, directly return the full graph with embeddings stored in
:attr:`embed_name`. Otherwise, extract and store the embeddings to the block
graph and return.
Returns
-------
DGLHeteroGraph
The block graph fed with embeddings.
"""
return self.embeds
class EntityClassify(nn.Module):
def __init__(self,
g,
h_dim, out_dim,
num_bases,
num_hidden_layers=1,
dropout=0,
use_self_loop=False):
super(EntityClassify, self).__init__()
self.g = g
self.h_dim = h_dim
self.out_dim = out_dim
self.rel_names = list(set(g.etypes))
self.rel_names.sort()
self.num_bases = None if num_bases < 0 else num_bases
self.num_hidden_layers = num_hidden_layers
self.dropout = dropout
self.use_self_loop = use_self_loop
self.layers = nn.ModuleList()
# i2h
self.layers.append(RelGraphConvHetero(
self.h_dim, self.h_dim, self.rel_names, "basis",
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout, use_weight=False))
# h2h
for i in range(self.num_hidden_layers):
self.layers.append(RelGraphConvHetero(
self.h_dim, self.h_dim, self.rel_names, "basis",
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout))
# h2o
self.layers.append(RelGraphConvHetero(
self.h_dim, self.out_dim, self.rel_names, "basis",
self.num_bases, activation=None,
self_loop=self.use_self_loop))
def forward(self, h, blocks):
for layer, block in zip(self.layers, blocks):
h = layer(block, h)
return h
class HeteroNeighborSampler: class HeteroNeighborSampler:
"""Neighbor sampler on heterogeneous graphs """Neighbor sampler on heterogeneous graphs
......
"""RGCN layer implementation"""
from collections import defaultdict
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import dgl.nn.pytorch as dglnn
class RelGraphConvLayer(nn.Module):
r"""Relational graph convolution layer.
Parameters
----------
in_feat : int
Input feature size.
out_feat : int
Output feature size.
rel_names : list[str]
Relation names.
num_bases : int, optional
Number of bases. If is none, use number of relations. Default: None.
weight : bool, optional
True if a linear layer is applied after message passing. Default: True
bias : bool, optional
True if bias is added. Default: True
activation : callable, optional
Activation function. Default: None
self_loop : bool, optional
True to include self loop message. Default: False
dropout : float, optional
Dropout rate. Default: 0.0
"""
def __init__(self,
in_feat,
out_feat,
rel_names,
num_bases,
*,
weight=True,
bias=True,
activation=None,
self_loop=False,
dropout=0.0):
super(RelGraphConvLayer, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.rel_names = rel_names
self.num_bases = num_bases
self.bias = bias
self.activation = activation
self.self_loop = self_loop
self.conv = dglnn.HeteroGraphConv({
rel : dglnn.GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False)
for rel in rel_names
})
self.use_weight = weight
self.use_basis = num_bases < len(self.rel_names) and weight
if self.use_weight:
if self.use_basis:
self.basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, len(self.rel_names))
else:
self.weight = nn.Parameter(th.Tensor(len(self.rel_names), in_feat, out_feat))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
# bias
if bias:
self.h_bias = nn.Parameter(th.Tensor(out_feat))
nn.init.zeros_(self.h_bias)
# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight,
gain=nn.init.calculate_gain('relu'))
self.dropout = nn.Dropout(dropout)
def forward(self, g, inputs):
"""Forward computation
Parameters
----------
g : DGLHeteroGraph
Input graph.
inputs : dict[str, torch.Tensor]
Node feature for each node type.
Returns
-------
dict[str, torch.Tensor]
New node features for each node type.
"""
g = g.local_var()
if self.use_weight:
weight = self.basis() if self.use_basis else self.weight
wdict = {self.rel_names[i] : {'weight' : w.squeeze(0)}
for i, w in enumerate(th.split(weight, 1, dim=0))}
else:
wdict = {}
hs = self.conv(g, inputs, mod_kwargs=wdict)
def _apply(ntype, h):
if self.self_loop:
h = h + th.matmul(inputs[ntype], self.loop_weight)
if self.bias:
h = h + self.h_bias
if self.activation:
h = self.activation(h)
return self.dropout(h)
return {ntype : _apply(ntype, h) for ntype, h in hs.items()}
class RelGraphEmbed(nn.Module):
r"""Embedding layer for featureless heterograph."""
def __init__(self,
g,
embed_size,
embed_name='embed',
activation=None,
dropout=0.0):
super(RelGraphEmbed, self).__init__()
self.g = g
self.embed_size = embed_size
self.embed_name = embed_name
self.activation = activation
self.dropout = nn.Dropout(dropout)
# create weight embeddings for each node for each relation
self.embeds = nn.ParameterDict()
for ntype in g.ntypes:
embed = nn.Parameter(th.Tensor(g.number_of_nodes(ntype), self.embed_size))
nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain('relu'))
self.embeds[ntype] = embed
def forward(self, block=None):
"""Forward computation
Parameters
----------
block : DGLHeteroGraph, optional
If not specified, directly return the full graph with embeddings stored in
:attr:`embed_name`. Otherwise, extract and store the embeddings to the block
graph and return.
Returns
-------
DGLHeteroGraph
The block graph fed with embeddings.
"""
return self.embeds
class EntityClassify(nn.Module):
def __init__(self,
g,
h_dim, out_dim,
num_bases,
num_hidden_layers=1,
dropout=0,
use_self_loop=False):
super(EntityClassify, self).__init__()
self.g = g
self.h_dim = h_dim
self.out_dim = out_dim
self.rel_names = list(set(g.etypes))
self.rel_names.sort()
if num_bases < 0 or num_bases > len(self.rel_names):
self.num_bases = len(self.rel_names)
else:
self.num_bases = num_bases
self.num_hidden_layers = num_hidden_layers
self.dropout = dropout
self.use_self_loop = use_self_loop
self.embed_layer = RelGraphEmbed(g, self.h_dim)
self.layers = nn.ModuleList()
# i2h
self.layers.append(RelGraphConvLayer(
self.h_dim, self.h_dim, self.rel_names,
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout, weight=False))
# h2h
for i in range(self.num_hidden_layers):
self.layers.append(RelGraphConvLayer(
self.h_dim, self.h_dim, self.rel_names,
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout))
# h2o
self.layers.append(RelGraphConvLayer(
self.h_dim, self.out_dim, self.rel_names,
self.num_bases, activation=None,
self_loop=self.use_self_loop))
def forward(self, h=None, blocks=None):
if blocks is None:
# full graph training
blocks = [self.g] * len(self.layers)
if h is None:
# full graph training
h = self.embed_layer()
for layer, block in zip(self.layers, blocks):
h = layer(block, h)
return h
...@@ -20,6 +20,13 @@ from .utils import download, extract_archive, get_download_dir, _get_dgl_url ...@@ -20,6 +20,13 @@ from .utils import download, extract_archive, get_download_dir, _get_dgl_url
__all__ = ['AIFB', 'MUTAG', 'BGS', 'AM'] __all__ = ['AIFB', 'MUTAG', 'BGS', 'AM']
# Dictionary for renaming reserved node/edge type names to the ones
# that are allowed by nn.Module.
RENAME_DICT = {
'type' : 'rdftype',
'rev-type' : 'rev-rdftype',
}
class Entity: class Entity:
"""Class for entities """Class for entities
...@@ -215,6 +222,14 @@ class RDFGraphDataset: ...@@ -215,6 +222,14 @@ class RDFGraphDataset:
print('Total #nodes:', g.number_of_nodes()) print('Total #nodes:', g.number_of_nodes())
print('Total #edges:', g.number_of_edges()) print('Total #edges:', g.number_of_edges())
# rename names such as 'type' so that they an be used as keys
# to nn.ModuleDict
etypes = [RENAME_DICT.get(ty, ty) for ty in etypes]
mg_edges = mg.edges(keys=True)
mg = nx.MultiDiGraph()
for sty, dty, ety in mg_edges:
mg.add_edge(sty, dty, key=RENAME_DICT.get(ety, ety))
# convert to heterograph # convert to heterograph
print('Convert to heterograph ...') print('Convert to heterograph ...')
hg = dgl.to_hetero(g, hg = dgl.to_hetero(g,
......
...@@ -530,7 +530,7 @@ class DGLHeteroGraph(object): ...@@ -530,7 +530,7 @@ class DGLHeteroGraph(object):
------- -------
int int
""" """
if self.is_unibipartite: if self.is_unibipartite and ntype is not None:
# Only check 'SRC/' and 'DST/' prefix when is_unibipartite graph is True. # Only check 'SRC/' and 'DST/' prefix when is_unibipartite graph is True.
if ntype.startswith('SRC/'): if ntype.startswith('SRC/'):
return self.get_ntype_id_from_src(ntype[4:]) return self.get_ntype_id_from_src(ntype[4:])
......
"""Package for mxnet-specific NN modules.""" """Package for mxnet-specific NN modules."""
from .conv import * from .conv import *
from .glob import * from .glob import *
from .hetero import *
from .softmax import * from .softmax import *
from .utils import Sequential from .utils import Sequential
...@@ -6,7 +6,7 @@ import mxnet as mx ...@@ -6,7 +6,7 @@ import mxnet as mx
from mxnet import gluon from mxnet import gluon
from .... import function as fn from .... import function as fn
from ....base import DGLError
class GraphConv(gluon.Block): class GraphConv(gluon.Block):
r"""Apply graph convolution over an input signal. r"""Apply graph convolution over an input signal.
...@@ -43,8 +43,14 @@ class GraphConv(gluon.Block): ...@@ -43,8 +43,14 @@ class GraphConv(gluon.Block):
Number of input features. Number of input features.
out_feats : int out_feats : int
Number of output features. Number of output features.
norm : bool, optional norm : str, optional
If True, the normalizer :math:`c_{ij}` is applied. Default: ``True``. How to apply the normalizer. If is `'right'`, divide the aggregated messages
by each node's in-degrees, which is equivalent to averaging the received messages.
If is `'none'`, no normalization is applied. Default is `'both'`,
where the :math:`c_{ij}` in the paper is applied.
weight : bool, optional
If True, apply a linear layer. Otherwise, aggregating the messages
without a weight matrix.
bias : bool, optional bias : bool, optional
If True, adds a learnable bias to the output. Default: ``True``. If True, adds a learnable bias to the output. Default: ``True``.
activation: callable activation function/layer or None, optional activation: callable activation function/layer or None, optional
...@@ -61,17 +67,25 @@ class GraphConv(gluon.Block): ...@@ -61,17 +67,25 @@ class GraphConv(gluon.Block):
def __init__(self, def __init__(self,
in_feats, in_feats,
out_feats, out_feats,
norm=True, norm='both',
weight=True,
bias=True, bias=True,
activation=None): activation=None):
super(GraphConv, self).__init__() super(GraphConv, self).__init__()
if norm not in ('none', 'both', 'right'):
raise DGLError('Invalid norm value. Must be either "none", "both" or "right".'
' But got "{}".'.format(norm))
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
self._norm = norm self._norm = norm
with self.name_scope(): with self.name_scope():
self.weight = self.params.get('weight', shape=(in_feats, out_feats), if weight:
init=mx.init.Xavier(magnitude=math.sqrt(2.0))) self.weight = self.params.get('weight', shape=(in_feats, out_feats),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
else:
self.weight = None
if bias: if bias:
self.bias = self.params.get('bias', shape=(out_feats,), self.bias = self.params.get('bias', shape=(out_feats,),
init=mx.init.Zero()) init=mx.init.Zero())
...@@ -80,15 +94,16 @@ class GraphConv(gluon.Block): ...@@ -80,15 +94,16 @@ class GraphConv(gluon.Block):
self._activation = activation self._activation = activation
def forward(self, graph, feat): def forward(self, graph, feat, weight=None):
r"""Compute graph convolution. r"""Compute graph convolution.
Notes Notes
----- -----
* Input shape: :math:`(N, *, \text{in_feats})` where * means any number of additional * Input shape: :math:`(N, *, \text{in_feats})` where * means any number of additional
dimensions, :math:`N` is the number of nodes. dimensions, :math:`N` is the number of nodes.
* Output shape: :math:`(N, *, \text{out_feats})` where all but the last dimension are * Output shape: :math:`(N, *, \text{out_feats})` where all but the last dimension are
the same shape as the input. the same shape as the input.
* Weight shape: "math:`(\text{in_feats}, \text{out_feats})`.
Parameters Parameters
---------- ----------
...@@ -96,6 +111,8 @@ class GraphConv(gluon.Block): ...@@ -96,6 +111,8 @@ class GraphConv(gluon.Block):
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray
The input feature The input feature
weight : torch.Tensor, optional
Optional external weight tensor.
Returns Returns
------- -------
...@@ -103,29 +120,49 @@ class GraphConv(gluon.Block): ...@@ -103,29 +120,49 @@ class GraphConv(gluon.Block):
The output feature The output feature
""" """
graph = graph.local_var() graph = graph.local_var()
if self._norm:
degs = graph.in_degrees().astype('float32') if self._norm == 'both':
norm = mx.nd.power(mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5) degs = graph.out_degrees().as_in_context(feat.context).astype('float32')
degs = mx.nd.clip(degs, a_min=1, a_max=float("inf"))
norm = mx.nd.power(degs, -0.5)
shp = norm.shape + (1,) * (feat.ndim - 1) shp = norm.shape + (1,) * (feat.ndim - 1)
norm = norm.reshape(shp).as_in_context(feat.context) norm = norm.reshape(shp)
feat = feat * norm feat = feat * norm
if weight is not None:
if self.weight is not None:
raise DGLError('External weight is provided while at the same time the'
' module has defined its own weight parameter. Please'
' create the module with flag weight=False.')
else:
weight = self.weight.data(feat.context)
if self._in_feats > self._out_feats: if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation. # mult W first to reduce the feature size for aggregation.
feat = mx.nd.dot(feat, self.weight.data(feat.context)) if weight is not None:
graph.ndata['h'] = feat feat = mx.nd.dot(feat, weight)
graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.ndata.pop('h') rst = graph.dstdata.pop('h')
else: else:
# aggregate first then mult W # aggregate first then mult W
graph.ndata['h'] = feat graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.ndata.pop('h') rst = graph.dstdata.pop('h')
rst = mx.nd.dot(rst, self.weight.data(feat.context)) if weight is not None:
rst = mx.nd.dot(rst, weight)
if self._norm:
if self._norm != 'none':
degs = graph.in_degrees().as_in_context(feat.context).astype('float32')
degs = mx.nd.clip(degs, a_min=1, a_max=float("inf"))
if self._norm == 'both':
norm = mx.nd.power(degs, -0.5)
else:
norm = 1.0 / degs
shp = norm.shape + (1,) * (feat.ndim - 1)
norm = norm.reshape(shp)
rst = rst * norm rst = rst * norm
if self.bias is not None: if self.bias is not None:
...@@ -141,5 +178,5 @@ class GraphConv(gluon.Block): ...@@ -141,5 +178,5 @@ class GraphConv(gluon.Block):
summary += 'in={:d}, out={:d}, normalization={}, activation={}'.format( summary += 'in={:d}, out={:d}, normalization={}, activation={}'.format(
self._in_feats, self._out_feats, self._in_feats, self._out_feats,
self._norm, self._activation) self._norm, self._activation)
summary += '\n)' summary += ')'
return summary return summary
"""Heterograph NN modules"""
from mxnet import nd
from mxnet.gluon import nn
__all__ = ['HeteroGraphConv']
class HeteroGraphConv(nn.Block):
r"""A generic module for computing convolution on heterogeneous graphs.
The heterograph convolution applies sub-modules on their associating
relation graphs, which reads the features from source nodes and writes the
updated ones to destination nodes. If multiple relations have the same
destination node types, their results are aggregated by the specified method.
If the relation graph has no edge, the corresponding module will not be called.
Examples
--------
Create a heterograph with three types of relations and nodes.
>>> import dgl
>>> g = dgl.heterograph({
... ('user', 'follows', 'user') : edges1,
... ('user', 'plays', 'game') : edges2,
... ('store', 'sells', 'game') : edges3})
Create a ``HeteroGraphConv`` that applies different convolution modules to
different relations. Note that the modules for ``'follows'`` and ``'plays'``
do not share weights.
>>> import dgl.nn.pytorch as dglnn
>>> conv = dglnn.HeteroGraphConv({
... 'follows' : dglnn.GraphConv(...),
... 'plays' : dglnn.GraphConv(...),
... 'sells' : dglnn.SAGEConv(...)},
... aggregate='sum')
Call forward with some ``'user'`` features. This computes new features for both
``'user'`` and ``'game'`` nodes.
>>> import mxnet.ndarray as nd
>>> h1 = {'user' : nd.randomrandn(g.number_of_nodes('user'), 5)}
>>> h2 = conv(g, h1)
>>> print(h2.keys())
dict_keys(['user', 'game'])
Call forward with both ``'user'`` and ``'store'`` features. Because both the
``'plays'`` and ``'sells'`` relations will update the ``'game'`` features,
their results are aggregated by the specified method (i.e., summation here).
>>> f1 = {'user' : ..., 'store' : ...}
>>> f2 = conv(g, f1)
>>> print(f2.keys())
dict_keys(['user', 'game'])
Call forward with some ``'store'`` features. This only computes new features
for ``'game'`` nodes.
>>> g1 = {'store' : ...}
>>> g2 = conv(g, g1)
>>> print(g2.keys())
dict_keys(['game'])
Call forward with a pair of inputs is allowed and each submodule will also
be invoked with a pair of inputs.
>>> x_src = {'user' : ..., 'store' : ...}
>>> x_dst = {'user' : ..., 'game' : ...}
>>> y_dst = conv(g, (x_src, x_dst))
>>> print(y_dst.keys())
dict_keys(['user', 'game'])
Parameters
----------
mods : dict[str, nn.Module]
Modules associated with every edge types. The forward function of each
module must have a `DGLHeteroGraph` object as the first argument, and
its second argument is either a tensor object representing the node
features or a pair of tensor object representing the source and destination
node features.
aggregate : str, callable, optional
Method for aggregating node features generated by different relations.
Allowed string values are 'sum', 'max', 'min', 'mean', 'stack'.
The 'stack' aggregation is performed along the second dimension, whose order
is deterministic.
User can also customize the aggregator by providing a callable instance.
For example, aggregation by summation is equivalent to the follows:
.. code::
def my_agg_func(tensors, dsttype):
# tensors: is a list of tensors to aggregate
# dsttype: string name of the destination node type for which the
# aggregation is performed
stacked = mx.nd.stack(*tensors, axis=0)
return mx.nd.sum(stacked, axis=0)
Attributes
----------
mods : dict[str, nn.Module]
Modules associated with every edge types.
"""
def __init__(self, mods, aggregate='sum'):
super(HeteroGraphConv, self).__init__()
with self.name_scope():
for name, mod in mods.items():
self.register_child(mod, name)
self.mods = mods
if isinstance(aggregate, str):
self.agg_fn = get_aggregate_fn(aggregate)
else:
self.agg_fn = aggregate
def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
"""Forward computation
Invoke the forward function with each module and aggregate their results.
Parameters
----------
g : DGLHeteroGraph
Graph data.
inputs : dict[str, Tensor] or pair of dict[str, Tensor]
Input node features.
mod_args : dict[str, tuple[any]], optional
Extra positional arguments for the sub-modules.
mod_kwargs : dict[str, dict[str, any]], optional
Extra key-word arguments for the sub-modules.
Returns
-------
dict[str, Tensor]
Output representations for every types of nodes.
"""
if mod_args is None:
mod_args = {}
if mod_kwargs is None:
mod_kwargs = {}
outputs = {nty : [] for nty in g.dsttypes}
if isinstance(inputs, tuple):
src_inputs, dst_inputs = inputs
for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype]
if rel_graph.number_of_edges() == 0:
continue
if stype not in src_inputs or dtype not in dst_inputs:
continue
dstdata = self.mods[etype](
rel_graph,
(src_inputs[stype], dst_inputs[dtype]),
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
outputs[dtype].append(dstdata)
else:
for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype]
if rel_graph.number_of_edges() == 0:
continue
if stype not in inputs:
continue
dstdata = self.mods[etype](
rel_graph,
inputs[stype],
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
outputs[dtype].append(dstdata)
rsts = {}
for nty, alist in outputs.items():
if len(alist) != 0:
rsts[nty] = self.agg_fn(alist, nty)
return rsts
def __repr__(self):
summary = 'HeteroGraphConv({\n'
for name, mod in self.mods.items():
summary += ' {} : {},\n'.format(name, mod)
summary += '\n})'
return summary
def get_aggregate_fn(agg):
"""Internal function to get the aggregation function for node data
generated from different relations.
Parameters
----------
agg : str
Method for aggregating node features generated by different relations.
Allowed values are 'sum', 'max', 'min', 'mean', 'stack'.
Returns
-------
callable
Aggregator function that takes a list of tensors to aggregate
and returns one aggregated tensor.
"""
if agg == 'sum':
fn = nd.sum
elif agg == 'max':
fn = nd.max
elif agg == 'min':
fn = nd.min
elif agg == 'mean':
fn = nd.mean
elif agg == 'stack':
fn = None # will not be called
else:
raise DGLError('Invalid cross type aggregator. Must be one of '
'"sum", "max", "min", "mean" or "stack". But got "%s"' % agg)
if agg == 'stack':
def stack_agg(inputs, dsttype): # pylint: disable=unused-argument
if len(inputs) == 0:
return None
return nd.stack(*inputs, axis=1)
return stack_agg
else:
def aggfn(inputs, dsttype): # pylint: disable=unused-argument
if len(inputs) == 0:
return None
stacked = nd.stack(*inputs, axis=0)
return fn(stacked, axis=0)
return aggfn
...@@ -3,4 +3,5 @@ from .conv import * ...@@ -3,4 +3,5 @@ from .conv import *
from .glob import * from .glob import *
from .softmax import * from .softmax import *
from .factory import * from .factory import *
from .utils import Sequential from .hetero import *
from .utils import Sequential, WeightBasis
...@@ -5,6 +5,7 @@ from torch import nn ...@@ -5,6 +5,7 @@ from torch import nn
from torch.nn import init from torch.nn import init
from .... import function as fn from .... import function as fn
from ....base import DGLError
# pylint: disable=W0235 # pylint: disable=W0235
class GraphConv(nn.Module): class GraphConv(nn.Module):
...@@ -42,8 +43,14 @@ class GraphConv(nn.Module): ...@@ -42,8 +43,14 @@ class GraphConv(nn.Module):
Input feature size. Input feature size.
out_feats : int out_feats : int
Output feature size. Output feature size.
norm : bool, optional norm : str, optional
If True, the normalizer :math:`c_{ij}` is applied. Default: ``True``. How to apply the normalizer. If is `'right'`, divide the aggregated messages
by each node's in-degrees, which is equivalent to averaging the received messages.
If is `'none'`, no normalization is applied. Default is `'both'`,
where the :math:`c_{ij}` in the paper is applied.
weight : bool, optional
If True, apply a linear layer. Otherwise, aggregating the messages
without a weight matrix.
bias : bool, optional bias : bool, optional
If True, adds a learnable bias to the output. Default: ``True``. If True, adds a learnable bias to the output. Default: ``True``.
activation: callable activation function/layer or None, optional activation: callable activation function/layer or None, optional
...@@ -60,30 +67,40 @@ class GraphConv(nn.Module): ...@@ -60,30 +67,40 @@ class GraphConv(nn.Module):
def __init__(self, def __init__(self,
in_feats, in_feats,
out_feats, out_feats,
norm=True, norm='both',
weight=True,
bias=True, bias=True,
activation=None): activation=None):
super(GraphConv, self).__init__() super(GraphConv, self).__init__()
if norm not in ('none', 'both', 'right'):
raise DGLError('Invalid norm value. Must be either "none", "both" or "right".'
' But got "{}".'.format(norm))
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
self._norm = norm self._norm = norm
self.weight = nn.Parameter(th.Tensor(in_feats, out_feats)) if weight:
self.weight = nn.Parameter(th.Tensor(in_feats, out_feats))
else:
self.register_parameter('weight', None)
if bias: if bias:
self.bias = nn.Parameter(th.Tensor(out_feats)) self.bias = nn.Parameter(th.Tensor(out_feats))
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.reset_parameters() self.reset_parameters()
self._activation = activation self._activation = activation
def reset_parameters(self): def reset_parameters(self):
"""Reinitialize learnable parameters.""" """Reinitialize learnable parameters."""
init.xavier_uniform_(self.weight) if self.weight is not None:
init.xavier_uniform_(self.weight)
if self.bias is not None: if self.bias is not None:
init.zeros_(self.bias) init.zeros_(self.bias)
def forward(self, graph, feat): def forward(self, graph, feat, weight=None):
r"""Compute graph convolution. r"""Compute graph convolution.
Notes Notes
...@@ -92,6 +109,7 @@ class GraphConv(nn.Module): ...@@ -92,6 +109,7 @@ class GraphConv(nn.Module):
dimensions, :math:`N` is the number of nodes. dimensions, :math:`N` is the number of nodes.
* Output shape: :math:`(N, *, \text{out_feats})` where all but the last dimension are * Output shape: :math:`(N, *, \text{out_feats})` where all but the last dimension are
the same shape as the input. the same shape as the input.
* Weight shape: "math:`(\text{in_feats}, \text{out_feats})`.
Parameters Parameters
---------- ----------
...@@ -99,6 +117,8 @@ class GraphConv(nn.Module): ...@@ -99,6 +117,8 @@ class GraphConv(nn.Module):
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature The input feature
weight : torch.Tensor, optional
Optional external weight tensor.
Returns Returns
------- -------
...@@ -106,28 +126,47 @@ class GraphConv(nn.Module): ...@@ -106,28 +126,47 @@ class GraphConv(nn.Module):
The output feature The output feature
""" """
graph = graph.local_var() graph = graph.local_var()
if self._norm:
norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5) if self._norm == 'both':
degs = graph.out_degrees().to(feat.device).float().clamp(min=1)
norm = th.pow(degs, -0.5)
shp = norm.shape + (1,) * (feat.dim() - 1) shp = norm.shape + (1,) * (feat.dim() - 1)
norm = th.reshape(norm, shp).to(feat.device) norm = th.reshape(norm, shp)
feat = feat * norm feat = feat * norm
if weight is not None:
if self.weight is not None:
raise DGLError('External weight is provided while at the same time the'
' module has defined its own weight parameter. Please'
' create the module with flag weight=False.')
else:
weight = self.weight
if self._in_feats > self._out_feats: if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation. # mult W first to reduce the feature size for aggregation.
feat = th.matmul(feat, self.weight) if weight is not None:
graph.ndata['h'] = feat feat = th.matmul(feat, weight)
graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.ndata['h'] rst = graph.dstdata['h']
else: else:
# aggregate first then mult W # aggregate first then mult W
graph.ndata['h'] = feat graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.ndata['h'] rst = graph.dstdata['h']
rst = th.matmul(rst, self.weight) if weight is not None:
rst = th.matmul(rst, weight)
if self._norm:
if self._norm != 'none':
degs = graph.in_degrees().to(feat.device).float().clamp(min=1)
if self._norm == 'both':
norm = th.pow(degs, -0.5)
else:
norm = 1.0 / degs
shp = norm.shape + (1,) * (feat.dim() - 1)
norm = th.reshape(norm, shp)
rst = rst * norm rst = rst * norm
if self.bias is not None: if self.bias is not None:
......
"""Heterograph NN modules"""
import torch as th
import torch.nn as nn
__all__ = ['HeteroGraphConv']
class HeteroGraphConv(nn.Module):
r"""A generic module for computing convolution on heterogeneous graphs.
The heterograph convolution applies sub-modules on their associating
relation graphs, which reads the features from source nodes and writes the
updated ones to destination nodes. If multiple relations have the same
destination node types, their results are aggregated by the specified method.
If the relation graph has no edge, the corresponding module will not be called.
Examples
--------
Create a heterograph with three types of relations and nodes.
>>> import dgl
>>> g = dgl.heterograph({
... ('user', 'follows', 'user') : edges1,
... ('user', 'plays', 'game') : edges2,
... ('store', 'sells', 'game') : edges3})
Create a ``HeteroGraphConv`` that applies different convolution modules to
different relations. Note that the modules for ``'follows'`` and ``'plays'``
do not share weights.
>>> import dgl.nn.pytorch as dglnn
>>> conv = dglnn.HeteroGraphConv({
... 'follows' : dglnn.GraphConv(...),
... 'plays' : dglnn.GraphConv(...),
... 'sells' : dglnn.SAGEConv(...)},
... aggregate='sum')
Call forward with some ``'user'`` features. This computes new features for both
``'user'`` and ``'game'`` nodes.
>>> import torch as th
>>> h1 = {'user' : th.randn((g.number_of_nodes('user'), 5))}
>>> h2 = conv(g, h1)
>>> print(h2.keys())
dict_keys(['user', 'game'])
Call forward with both ``'user'`` and ``'store'`` features. Because both the
``'plays'`` and ``'sells'`` relations will update the ``'game'`` features,
their results are aggregated by the specified method (i.e., summation here).
>>> f1 = {'user' : ..., 'store' : ...}
>>> f2 = conv(g, f1)
>>> print(f2.keys())
dict_keys(['user', 'game'])
Call forward with some ``'store'`` features. This only computes new features
for ``'game'`` nodes.
>>> g1 = {'store' : ...}
>>> g2 = conv(g, g1)
>>> print(g2.keys())
dict_keys(['game'])
Call forward with a pair of inputs is allowed and each submodule will also
be invoked with a pair of inputs.
>>> x_src = {'user' : ..., 'store' : ...}
>>> x_dst = {'user' : ..., 'game' : ...}
>>> y_dst = conv(g, (x_src, x_dst))
>>> print(y_dst.keys())
dict_keys(['user', 'game'])
Parameters
----------
mods : dict[str, nn.Module]
Modules associated with every edge types. The forward function of each
module must have a `DGLHeteroGraph` object as the first argument, and
its second argument is either a tensor object representing the node
features or a pair of tensor object representing the source and destination
node features.
aggregate : str, callable, optional
Method for aggregating node features generated by different relations.
Allowed string values are 'sum', 'max', 'min', 'mean', 'stack'.
The 'stack' aggregation is performed along the second dimension, whose order
is deterministic.
User can also customize the aggregator by providing a callable instance.
For example, aggregation by summation is equivalent to the follows:
.. code::
def my_agg_func(tensors, dsttype):
# tensors: is a list of tensors to aggregate
# dsttype: string name of the destination node type for which the
# aggregation is performed
stacked = torch.stack(tensors, dim=0)
return torch.sum(stacked, dim=0)
Attributes
----------
mods : dict[str, nn.Module]
Modules associated with every edge types.
"""
def __init__(self, mods, aggregate='sum'):
super(HeteroGraphConv, self).__init__()
self.mods = nn.ModuleDict(mods)
if isinstance(aggregate, str):
self.agg_fn = get_aggregate_fn(aggregate)
else:
self.agg_fn = aggregate
def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
"""Forward computation
Invoke the forward function with each module and aggregate their results.
Parameters
----------
g : DGLHeteroGraph
Graph data.
inputs : dict[str, Tensor] or pair of dict[str, Tensor]
Input node features.
mod_args : dict[str, tuple[any]], optional
Extra positional arguments for the sub-modules.
mod_kwargs : dict[str, dict[str, any]], optional
Extra key-word arguments for the sub-modules.
Returns
-------
dict[str, Tensor]
Output representations for every types of nodes.
"""
if mod_args is None:
mod_args = {}
if mod_kwargs is None:
mod_kwargs = {}
outputs = {nty : [] for nty in g.dsttypes}
if isinstance(inputs, tuple):
src_inputs, dst_inputs = inputs
for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype]
if rel_graph.number_of_edges() == 0:
continue
if stype not in src_inputs or dtype not in dst_inputs:
continue
dstdata = self.mods[etype](
rel_graph,
(src_inputs[stype], dst_inputs[dtype]),
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
outputs[dtype].append(dstdata)
else:
for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype]
if rel_graph.number_of_edges() == 0:
continue
if stype not in inputs:
continue
dstdata = self.mods[etype](
rel_graph,
inputs[stype],
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
outputs[dtype].append(dstdata)
rsts = {}
for nty, alist in outputs.items():
if len(alist) != 0:
rsts[nty] = self.agg_fn(alist, nty)
return rsts
def get_aggregate_fn(agg):
"""Internal function to get the aggregation function for node data
generated from different relations.
Parameters
----------
agg : str
Method for aggregating node features generated by different relations.
Allowed values are 'sum', 'max', 'min', 'mean', 'stack'.
Returns
-------
callable
Aggregator function that takes a list of tensors to aggregate
and returns one aggregated tensor.
"""
if agg == 'sum':
fn = th.sum
elif agg == 'max':
fn = lambda inputs, dim: th.max(inputs, dim=dim)[0]
elif agg == 'min':
fn = lambda inputs, dim: th.min(inputs, dim=dim)[0]
elif agg == 'mean':
fn = th.mean
elif agg == 'stack':
fn = None # will not be called
else:
raise DGLError('Invalid cross type aggregator. Must be one of '
'"sum", "max", "min", "mean" or "stack". But got "%s"' % agg)
if agg == 'stack':
def stack_agg(inputs, dsttype): # pylint: disable=unused-argument
if len(inputs) == 0:
return None
return th.stack(inputs, dim=1)
return stack_agg
else:
def aggfn(inputs, dsttype): # pylint: disable=unused-argument
if len(inputs) == 0:
return None
stacked = th.stack(inputs, dim=0)
return fn(stacked, dim=0)
return aggfn
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import torch as th import torch as th
from torch import nn from torch import nn
from ... import DGLGraph from ... import DGLGraph
from ...base import dgl_warning
def matmul_maybe_select(A, B): def matmul_maybe_select(A, B):
"""Perform Matrix multiplication C = A * B but A could be an integer id vector. """Perform Matrix multiplication C = A * B but A could be an integer id vector.
...@@ -216,3 +216,60 @@ class Sequential(nn.Sequential): ...@@ -216,3 +216,60 @@ class Sequential(nn.Sequential):
raise TypeError('The first argument of forward must be a DGLGraph' raise TypeError('The first argument of forward must be a DGLGraph'
' or a list of DGLGraph s') ' or a list of DGLGraph s')
return feats return feats
class WeightBasis(nn.Module):
r"""Basis decomposition module.
Basis decomposition is introduced in "`Modeling Relational Data with Graph
Convolutional Networks <https://arxiv.org/abs/1703.06103>`__"
and can be described as below:
.. math::
W_o = \sum_{b=1}^B a_{ob} V_b
Each weight output :math:`W_o` is essentially a linear combination of basis
transformations :math:`V_b` with coefficients :math:`a_{ob}`.
If is useful as a form of regularization on a large parameter matrix. Thus,
the number of weight outputs is usually larger than the number of bases.
Parameters
----------
shape : tuple[int]
Shape of the basis parameter.
num_bases : int
Number of bases.
num_outputs : int
Number of outputs.
"""
def __init__(self,
shape,
num_bases,
num_outputs):
super(WeightBasis, self).__init__()
self.shape = shape
self.num_bases = num_bases
self.num_outputs = num_outputs
if num_outputs <= num_bases:
dgl_warning('The number of weight outputs should be larger than the number'
' of bases.')
self.weight = nn.Parameter(th.Tensor(self.num_bases, *shape))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
# linear combination coefficients
self.w_comp = nn.Parameter(th.Tensor(self.num_outputs, self.num_bases))
nn.init.xavier_uniform_(self.w_comp, gain=nn.init.calculate_gain('relu'))
def forward(self):
r"""Forward computation
Returns
-------
weight : torch.Tensor
Composed weight tensor of shape ``(num_outputs,) + shape``
"""
# generate all weights from bases
weight = th.matmul(self.w_comp, self.weight.view(self.num_bases, -1))
return weight.view(self.num_outputs, *self.shape)
...@@ -3,3 +3,4 @@ from .conv import * ...@@ -3,3 +3,4 @@ from .conv import *
from .softmax import * from .softmax import *
from .utils import * from .utils import *
from .glob import * from .glob import *
from .hetero import *
...@@ -44,8 +44,14 @@ class GraphConv(layers.Layer): ...@@ -44,8 +44,14 @@ class GraphConv(layers.Layer):
Input feature size. Input feature size.
out_feats : int out_feats : int
Output feature size. Output feature size.
norm : bool, optional norm : str, optional
If True, the normalizer :math:`c_{ij}` is applied. Default: ``True``. How to apply the normalizer. If is `'right'`, divide the aggregated messages
by each node's in-degrees, which is equivalent to averaging the received messages.
If is `'none'`, no normalization is applied. Default is `'both'`,
where the :math:`c_{ij}` in the paper is applied.
weight : bool, optional
If True, apply a linear layer. Otherwise, aggregating the messages
without a weight matrix.
bias : bool, optional bias : bool, optional
If True, adds a learnable bias to the output. Default: ``True``. If True, adds a learnable bias to the output. Default: ``True``.
activation: callable activation function/layer or None, optional activation: callable activation function/layer or None, optional
...@@ -63,26 +69,35 @@ class GraphConv(layers.Layer): ...@@ -63,26 +69,35 @@ class GraphConv(layers.Layer):
def __init__(self, def __init__(self,
in_feats, in_feats,
out_feats, out_feats,
norm=True, norm='both',
weight=True,
bias=True, bias=True,
activation=None): activation=None):
super(GraphConv, self).__init__() super(GraphConv, self).__init__()
if norm not in ('none', 'both', 'right'):
raise DGLError('Invalid norm value. Must be either "none", "both" or "right".'
' But got "{}".'.format(norm))
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
self._norm = norm self._norm = norm
xinit = tf.keras.initializers.glorot_uniform() if weight:
self.weight = tf.Variable(initial_value=xinit( xinit = tf.keras.initializers.glorot_uniform()
shape=(in_feats, out_feats), dtype='float32'), trainable=True) self.weight = tf.Variable(initial_value=xinit(
shape=(in_feats, out_feats), dtype='float32'), trainable=True)
else:
self.weight = None
if bias: if bias:
zeroinit = tf.keras.initializers.zeros() zeroinit = tf.keras.initializers.zeros()
self.bias = tf.Variable(initial_value=zeroinit( self.bias = tf.Variable(initial_value=zeroinit(
shape=(out_feats), dtype='float32'), trainable=True) shape=(out_feats), dtype='float32'), trainable=True)
else:
self.bias = None
self._activation = activation self._activation = activation
def call(self, graph, feat): def call(self, graph, feat, weight=None):
r"""Compute graph convolution. r"""Compute graph convolution.
Notes Notes
...@@ -91,6 +106,7 @@ class GraphConv(layers.Layer): ...@@ -91,6 +106,7 @@ class GraphConv(layers.Layer):
dimensions, :math:`N` is the number of nodes. dimensions, :math:`N` is the number of nodes.
* Output shape: :math:`(N, *, \text{out_feats})` where all but the last dimension are * Output shape: :math:`(N, *, \text{out_feats})` where all but the last dimension are
the same shape as the input. the same shape as the input.
* Weight shape: "math:`(\text{in_feats}, \text{out_feats})`.
Parameters Parameters
---------- ----------
...@@ -98,6 +114,8 @@ class GraphConv(layers.Layer): ...@@ -98,6 +114,8 @@ class GraphConv(layers.Layer):
The graph. The graph.
feat : tf.Tensor feat : tf.Tensor
The input feature The input feature
weight : torch.Tensor, optional
Optional external weight tensor.
Returns Returns
------- -------
...@@ -105,30 +123,51 @@ class GraphConv(layers.Layer): ...@@ -105,30 +123,51 @@ class GraphConv(layers.Layer):
The output feature The output feature
""" """
graph = graph.local_var() graph = graph.local_var()
if self._norm:
in_degree = tf.clip_by_value(tf.cast(graph.in_degrees(), tf.float32), clip_value_min=1, if self._norm == 'both':
clip_value_max=np.inf) degs = tf.clip_by_value(tf.cast(graph.out_degrees(), tf.float32),
norm = tf.pow(in_degree, -0.5) clip_value_min=1,
clip_value_max=np.inf)
norm = tf.pow(degs, -0.5)
shp = norm.shape + (1,) * (feat.ndim - 1) shp = norm.shape + (1,) * (feat.ndim - 1)
norm = tf.reshape(norm, shp) norm = tf.reshape(norm, shp)
feat = feat * norm feat = feat * norm
if weight is not None:
if self.weight is not None:
raise DGLError('External weight is provided while at the same time the'
' module has defined its own weight parameter. Please'
' create the module with flag weight=False.')
else:
weight = self.weight
if self._in_feats > self._out_feats: if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation. # mult W first to reduce the feature size for aggregation.
feat = tf.matmul(feat, self.weight) if weight is not None:
graph.ndata['h'] = feat feat = tf.matmul(feat, weight)
graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.ndata['h'] rst = graph.dstdata['h']
else: else:
# aggregate first then mult W # aggregate first then mult W
graph.ndata['h'] = feat graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.ndata['h'] rst = graph.dstdata['h']
rst = tf.matmul(rst, self.weight) if weight is not None:
rst = tf.matmul(rst, weight)
if self._norm:
if self._norm != 'none':
degs = tf.clip_by_value(tf.cast(graph.in_degrees(), tf.float32),
clip_value_min=1,
clip_value_max=np.inf)
if self._norm == 'both':
norm = tf.pow(degs, -0.5)
else:
norm = 1.0 / degs
shp = norm.shape + (1,) * (feat.ndim - 1)
norm = tf.reshape(norm, shp)
rst = rst * norm rst = rst * norm
if self.bias is not None: if self.bias is not None:
......
"""Heterograph NN modules"""
import tensorflow as tf
from tensorflow.keras import layers
__all__ = ['HeteroGraphConv']
class HeteroGraphConv(layers.Layer):
r"""A generic module for computing convolution on heterogeneous graphs.
The heterograph convolution applies sub-modules on their associating
relation graphs, which reads the features from source nodes and writes the
updated ones to destination nodes. If multiple relations have the same
destination node types, their results are aggregated by the specified method.
If the relation graph has no edge, the corresponding module will not be called.
Examples
--------
Create a heterograph with three types of relations and nodes.
>>> import dgl
>>> g = dgl.heterograph({
... ('user', 'follows', 'user') : edges1,
... ('user', 'plays', 'game') : edges2,
... ('store', 'sells', 'game') : edges3})
Create a ``HeteroGraphConv`` that applies different convolution modules to
different relations. Note that the modules for ``'follows'`` and ``'plays'``
do not share weights.
>>> import dgl.nn.pytorch as dglnn
>>> conv = dglnn.HeteroGraphConv({
... 'follows' : dglnn.GraphConv(...),
... 'plays' : dglnn.GraphConv(...),
... 'sells' : dglnn.SAGEConv(...)},
... aggregate='sum')
Call forward with some ``'user'`` features. This computes new features for both
``'user'`` and ``'game'`` nodes.
>>> import tensorflow as tf
>>> h1 = {'user' : tf.random.normal((g.number_of_nodes('user'), 5))}
>>> h2 = conv(g, h1)
>>> print(h2.keys())
dict_keys(['user', 'game'])
Call forward with both ``'user'`` and ``'store'`` features. Because both the
``'plays'`` and ``'sells'`` relations will update the ``'game'`` features,
their results are aggregated by the specified method (i.e., summation here).
>>> f1 = {'user' : ..., 'store' : ...}
>>> f2 = conv(g, f1)
>>> print(f2.keys())
dict_keys(['user', 'game'])
Call forward with some ``'store'`` features. This only computes new features
for ``'game'`` nodes.
>>> g1 = {'store' : ...}
>>> g2 = conv(g, g1)
>>> print(g2.keys())
dict_keys(['game'])
Call forward with a pair of inputs is allowed and each submodule will also
be invoked with a pair of inputs.
>>> x_src = {'user' : ..., 'store' : ...}
>>> x_dst = {'user' : ..., 'game' : ...}
>>> y_dst = conv(g, (x_src, x_dst))
>>> print(y_dst.keys())
dict_keys(['user', 'game'])
Parameters
----------
mods : dict[str, nn.Module]
Modules associated with every edge types. The forward function of each
module must have a `DGLHeteroGraph` object as the first argument, and
its second argument is either a tensor object representing the node
features or a pair of tensor object representing the source and destination
node features.
aggregate : str, callable, optional
Method for aggregating node features generated by different relations.
Allowed string values are 'sum', 'max', 'min', 'mean', 'stack'.
The 'stack' aggregation is performed along the second dimension, whose order
is deterministic.
User can also customize the aggregator by providing a callable instance.
For example, aggregation by summation is equivalent to the follows:
.. code::
def my_agg_func(tensors, dsttype):
# tensors: is a list of tensors to aggregate
# dsttype: string name of the destination node type for which the
# aggregation is performed
stacked = tf.stack(tensors, axis=0)
return tf.reduce_sum(stacked, axis=0)
Attributes
----------
mods : dict[str, nn.Module]
Modules associated with every edge types.
"""
def __init__(self, mods, aggregate='sum'):
super(HeteroGraphConv, self).__init__()
self.mods = mods
if isinstance(aggregate, str):
self.agg_fn = get_aggregate_fn(aggregate)
else:
self.agg_fn = aggregate
def call(self, g, inputs, mod_args=None, mod_kwargs=None):
"""Forward computation
Invoke the forward function with each module and aggregate their results.
Parameters
----------
g : DGLHeteroGraph
Graph data.
inputs : dict[str, Tensor] or pair of dict[str, Tensor]
Input node features.
mod_args : dict[str, tuple[any]], optional
Extra positional arguments for the sub-modules.
mod_kwargs : dict[str, dict[str, any]], optional
Extra key-word arguments for the sub-modules.
Returns
-------
dict[str, Tensor]
Output representations for every types of nodes.
"""
if mod_args is None:
mod_args = {}
if mod_kwargs is None:
mod_kwargs = {}
outputs = {nty : [] for nty in g.dsttypes}
if isinstance(inputs, tuple):
src_inputs, dst_inputs = inputs
for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype]
if rel_graph.number_of_edges() == 0:
continue
if stype not in src_inputs or dtype not in dst_inputs:
continue
dstdata = self.mods[etype](
rel_graph,
(src_inputs[stype], dst_inputs[dtype]),
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
outputs[dtype].append(dstdata)
else:
for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype]
if rel_graph.number_of_edges() == 0:
continue
if stype not in inputs:
continue
dstdata = self.mods[etype](
rel_graph,
inputs[stype],
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
outputs[dtype].append(dstdata)
rsts = {}
for nty, alist in outputs.items():
if len(alist) != 0:
rsts[nty] = self.agg_fn(alist, nty)
return rsts
def get_aggregate_fn(agg):
"""Internal function to get the aggregation function for node data
generated from different relations.
Parameters
----------
agg : str
Method for aggregating node features generated by different relations.
Allowed values are 'sum', 'max', 'min', 'mean', 'stack'.
Returns
-------
callable
Aggregator function that takes a list of tensors to aggregate
and returns one aggregated tensor.
"""
if agg == 'sum':
fn = tf.reduce_sum
elif agg == 'max':
fn = tf.reduce_max
elif agg == 'min':
fn = tf.reduce_min
elif agg == 'mean':
fn = tf.reduce_mean
elif agg == 'stack':
fn = None # will not be called
else:
raise DGLError('Invalid cross type aggregator. Must be one of '
'"sum", "max", "min", "mean" or "stack". But got "%s"' % agg)
if agg == 'stack':
def stack_agg(inputs, dsttype): # pylint: disable=unused-argument
if len(inputs) == 0:
return None
return tf.stack(inputs, axis=1)
return stack_agg
else:
def aggfn(inputs, dsttype): # pylint: disable=unused-argument
if len(inputs) == 0:
return None
stacked = tf.stack(inputs, axis=0)
return fn(stacked, axis=0)
return aggfn
...@@ -2,10 +2,12 @@ import mxnet as mx ...@@ -2,10 +2,12 @@ import mxnet as mx
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import scipy as sp import scipy as sp
import pytest
import dgl import dgl
import dgl.nn.mxnet as nn import dgl.nn.mxnet as nn
import dgl.function as fn import dgl.function as fn
import backend as F import backend as F
from test_utils.graph_cases import get_cases
from mxnet import autograd, gluon, nd from mxnet import autograd, gluon, nd
def check_close(a, b): def check_close(a, b):
...@@ -21,7 +23,7 @@ def test_graph_conv(): ...@@ -21,7 +23,7 @@ def test_graph_conv():
ctx = F.ctx() ctx = F.ctx()
adj = g.adjacency_matrix(ctx=ctx) adj = g.adjacency_matrix(ctx=ctx)
conv = nn.GraphConv(5, 2, norm=False, bias=True) conv = nn.GraphConv(5, 2, norm='none', bias=True)
conv.initialize(ctx=ctx) conv.initialize(ctx=ctx)
# test#1: basic # test#1: basic
h0 = F.ones((3, 5)) h0 = F.ones((3, 5))
...@@ -73,6 +75,23 @@ def test_graph_conv(): ...@@ -73,6 +75,23 @@ def test_graph_conv():
assert "h" in g.ndata assert "h" in g.ndata
check_close(g.ndata['h'], 2 * F.ones((3, 1))) check_close(g.ndata['h'], 2 * F.ones((3, 1)))
@pytest.mark.parametrize('g', get_cases(['path', 'bipartite', 'small'], exclude=['zero-degree']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [False])
def test_graph_conv2(g, norm, weight, bias):
conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias)
conv.initialize(ctx=F.ctx())
ext_w = F.randn((5, 2)).as_in_context(F.ctx())
nsrc = g.number_of_nodes() if isinstance(g, dgl.DGLGraph) else g.number_of_src_nodes()
ndst = g.number_of_nodes() if isinstance(g, dgl.DGLGraph) else g.number_of_dst_nodes()
h = F.randn((nsrc, 5)).as_in_context(F.ctx())
if weight:
h = conv(g, h)
else:
h = conv(g, h, ext_w)
assert h.shape == (ndst, 2)
def _S2AXWb(A, N, X, W, b): def _S2AXWb(A, N, X, W, b):
X1 = X * N X1 = X * N
X1 = mx.nd.dot(A, X1.reshape(X1.shape[0], -1)) X1 = mx.nd.dot(A, X1.reshape(X1.shape[0], -1))
...@@ -231,7 +250,7 @@ def test_dense_graph_conv(): ...@@ -231,7 +250,7 @@ def test_dense_graph_conv():
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.3), readonly=True) g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.3), readonly=True)
adj = g.adjacency_matrix(ctx=ctx).tostype('default') adj = g.adjacency_matrix(ctx=ctx).tostype('default')
conv = nn.GraphConv(5, 2, norm=False, bias=True) conv = nn.GraphConv(5, 2, norm='none', bias=True)
dense_conv = nn.DenseGraphConv(5, 2, norm=False, bias=True) dense_conv = nn.DenseGraphConv(5, 2, norm=False, bias=True)
conv.initialize(ctx=ctx) conv.initialize(ctx=ctx)
dense_conv.initialize(ctx=ctx) dense_conv.initialize(ctx=ctx)
...@@ -576,6 +595,107 @@ def test_sequential(): ...@@ -576,6 +595,107 @@ def test_sequential():
n_feat = net([g1, g2, g3], n_feat) n_feat = net([g1, g2, g3], n_feat)
assert n_feat.shape == (4, 4) assert n_feat.shape == (4, 4)
def myagg(alist, dsttype):
rst = alist[0]
for i in range(1, len(alist)):
rst = rst + (i + 1) * alist[i]
return rst
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
def test_hetero_conv(agg):
g = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (0, 2), (2, 1), (1, 3)],
('user', 'plays', 'game'): [(0, 0), (0, 2), (0, 3), (1, 0), (2, 2)],
('store', 'sells', 'game'): [(0, 0), (0, 3), (1, 1), (1, 2)]})
conv = nn.HeteroGraphConv({
'follows': nn.GraphConv(2, 3),
'plays': nn.GraphConv(2, 4),
'sells': nn.GraphConv(3, 4)},
agg)
conv.initialize(ctx=F.ctx())
print(conv)
uf = F.randn((4, 2))
gf = F.randn((4, 4))
sf = F.randn((2, 3))
uf_dst = F.randn((4, 3))
gf_dst = F.randn((4, 4))
h = conv(g, {'user': uf})
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4)
h = conv(g, {'user': uf, 'store': sf})
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 2, 4)
h = conv(g, {'store': sf})
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
# test with pair input
conv = nn.HeteroGraphConv({
'follows': nn.SAGEConv(2, 3, 'mean'),
'plays': nn.SAGEConv((2, 4), 4, 'mean'),
'sells': nn.SAGEConv(3, 4, 'mean')},
agg)
conv.initialize(ctx=F.ctx())
h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4)
# pair input requires both src and dst type features to be provided
h = conv(g, ({'user': uf}, {'game' : gf}))
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
# test with mod args
class MyMod(mx.gluon.nn.Block):
def __init__(self, s1, s2):
super(MyMod, self).__init__()
self.carg1 = 0
self.s1 = s1
self.s2 = s2
def forward(self, g, h, arg1=None): # mxnet does not support kwargs
if arg1 is not None:
self.carg1 += 1
return F.zeros((g.number_of_dst_nodes(), self.s2))
mod1 = MyMod(2, 3)
mod2 = MyMod(2, 4)
mod3 = MyMod(3, 4)
conv = nn.HeteroGraphConv({
'follows': mod1,
'plays': mod2,
'sells': mod3},
agg)
conv.initialize(ctx=F.ctx())
mod_args = {'follows' : (1,), 'plays' : (1,)}
h = conv(g, {'user' : uf, 'store' : sf}, mod_args)
assert mod1.carg1 == 1
assert mod2.carg1 == 1
assert mod3.carg1 == 0
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
test_gat_conv() test_gat_conv()
......
...@@ -4,6 +4,8 @@ import dgl ...@@ -4,6 +4,8 @@ import dgl
import dgl.nn.pytorch as nn import dgl.nn.pytorch as nn
import dgl.function as fn import dgl.function as fn
import backend as F import backend as F
import pytest
from test_utils.graph_cases import get_cases
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
...@@ -19,7 +21,7 @@ def test_graph_conv(): ...@@ -19,7 +21,7 @@ def test_graph_conv():
ctx = F.ctx() ctx = F.ctx()
adj = g.adjacency_matrix(ctx=ctx) adj = g.adjacency_matrix(ctx=ctx)
conv = nn.GraphConv(5, 2, norm=False, bias=True) conv = nn.GraphConv(5, 2, norm='none', bias=True)
conv = conv.to(ctx) conv = conv.to(ctx)
print(conv) print(conv)
# test#1: basic # test#1: basic
...@@ -67,6 +69,22 @@ def test_graph_conv(): ...@@ -67,6 +69,22 @@ def test_graph_conv():
new_weight = conv.weight.data new_weight = conv.weight.data
assert not F.allclose(old_weight, new_weight) assert not F.allclose(old_weight, new_weight)
@pytest.mark.parametrize('g', get_cases(['path', 'bipartite', 'small'], exclude=['zero-degree']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
def test_graph_conv2(g, norm, weight, bias):
conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
ext_w = F.randn((5, 2)).to(F.ctx())
nsrc = g.number_of_nodes() if isinstance(g, dgl.DGLGraph) else g.number_of_src_nodes()
ndst = g.number_of_nodes() if isinstance(g, dgl.DGLGraph) else g.number_of_dst_nodes()
h = F.randn((nsrc, 5)).to(F.ctx())
if weight:
h = conv(g, h)
else:
h = conv(g, h, weight=ext_w)
assert h.shape == (ndst, 2)
def _S2AXWb(A, N, X, W, b): def _S2AXWb(A, N, X, W, b):
X1 = X * N X1 = X * N
X1 = th.matmul(A, X1.view(X1.shape[0], -1)) X1 = th.matmul(A, X1.view(X1.shape[0], -1))
...@@ -514,7 +532,7 @@ def test_dense_graph_conv(): ...@@ -514,7 +532,7 @@ def test_dense_graph_conv():
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
adj = g.adjacency_matrix(ctx=ctx).to_dense() adj = g.adjacency_matrix(ctx=ctx).to_dense()
conv = nn.GraphConv(5, 2, norm=False, bias=True) conv = nn.GraphConv(5, 2, norm='none', bias=True)
dense_conv = nn.DenseGraphConv(5, 2, norm=False, bias=True) dense_conv = nn.DenseGraphConv(5, 2, norm=False, bias=True)
dense_conv.weight.data = conv.weight.data dense_conv.weight.data = conv.weight.data
dense_conv.bias.data = conv.bias.data dense_conv.bias.data = conv.bias.data
...@@ -641,6 +659,116 @@ def test_cf_conv(): ...@@ -641,6 +659,116 @@ def test_cf_conv():
# current we only do shape check # current we only do shape check
assert h.shape[-1] == 3 assert h.shape[-1] == 3
def myagg(alist, dsttype):
rst = alist[0]
for i in range(1, len(alist)):
rst = rst + (i + 1) * alist[i]
return rst
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
def test_hetero_conv(agg):
g = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (0, 2), (2, 1), (1, 3)],
('user', 'plays', 'game'): [(0, 0), (0, 2), (0, 3), (1, 0), (2, 2)],
('store', 'sells', 'game'): [(0, 0), (0, 3), (1, 1), (1, 2)]})
conv = nn.HeteroGraphConv({
'follows': nn.GraphConv(2, 3),
'plays': nn.GraphConv(2, 4),
'sells': nn.GraphConv(3, 4)},
agg)
if F.gpu_ctx():
conv = conv.to(F.ctx())
uf = F.randn((4, 2))
gf = F.randn((4, 4))
sf = F.randn((2, 3))
uf_dst = F.randn((4, 3))
gf_dst = F.randn((4, 4))
h = conv(g, {'user': uf})
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4)
h = conv(g, {'user': uf, 'store': sf})
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 2, 4)
h = conv(g, {'store': sf})
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
# test with pair input
conv = nn.HeteroGraphConv({
'follows': nn.SAGEConv(2, 3, 'mean'),
'plays': nn.SAGEConv((2, 4), 4, 'mean'),
'sells': nn.SAGEConv(3, 4, 'mean')},
agg)
if F.gpu_ctx():
conv = conv.to(F.ctx())
h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4)
# pair input requires both src and dst type features to be provided
h = conv(g, ({'user': uf}, {'game' : gf}))
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
# test with mod args
class MyMod(th.nn.Module):
def __init__(self, s1, s2):
super(MyMod, self).__init__()
self.carg1 = 0
self.carg2 = 0
self.s1 = s1
self.s2 = s2
def forward(self, g, h, arg1=None, *, arg2=None):
if arg1 is not None:
self.carg1 += 1
if arg2 is not None:
self.carg2 += 1
return th.zeros((g.number_of_dst_nodes(), self.s2))
mod1 = MyMod(2, 3)
mod2 = MyMod(2, 4)
mod3 = MyMod(3, 4)
conv = nn.HeteroGraphConv({
'follows': mod1,
'plays': mod2,
'sells': mod3},
agg)
if F.gpu_ctx():
conv = conv.to(F.ctx())
mod_args = {'follows' : (1,), 'plays' : (1,)}
mod_kwargs = {'sells' : {'arg2' : 'abc'}}
h = conv(g, {'user' : uf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
assert mod1.carg1 == 1
assert mod1.carg2 == 0
assert mod2.carg1 == 1
assert mod2.carg2 == 0
assert mod3.carg1 == 0
assert mod3.carg2 == 1
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
test_edge_softmax() test_edge_softmax()
......
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import layers from tensorflow.keras import layers
import networkx as nx import networkx as nx
import pytest
import dgl import dgl
import dgl.nn.tensorflow as nn import dgl.nn.tensorflow as nn
import dgl.function as fn import dgl.function as fn
import backend as F import backend as F
from test_utils.graph_cases import get_cases
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
...@@ -20,7 +22,7 @@ def test_graph_conv(): ...@@ -20,7 +22,7 @@ def test_graph_conv():
ctx = F.ctx() ctx = F.ctx()
adj = tf.sparse.to_dense(tf.sparse.reorder(g.adjacency_matrix(ctx=ctx))) adj = tf.sparse.to_dense(tf.sparse.reorder(g.adjacency_matrix(ctx=ctx)))
conv = nn.GraphConv(5, 2, norm=False, bias=True) conv = nn.GraphConv(5, 2, norm='none', bias=True)
# conv = conv # conv = conv
print(conv) print(conv)
# test#1: basic # test#1: basic
...@@ -68,17 +70,21 @@ def test_graph_conv(): ...@@ -68,17 +70,21 @@ def test_graph_conv():
# new_weight = conv.weight.data # new_weight = conv.weight.data
# assert not F.allclose(old_weight, new_weight) # assert not F.allclose(old_weight, new_weight)
def _S2AXWb(A, N, X, W, b): @pytest.mark.parametrize('g', get_cases(['path', 'bipartite', 'small'], exclude=['zero-degree']))
X1 = X * N @pytest.mark.parametrize('norm', ['none', 'both', 'right'])
X1 = th.matmul(A, X1.view(X1.shape[0], -1)) @pytest.mark.parametrize('weight', [True, False])
X1 = X1 * N @pytest.mark.parametrize('bias', [True, False])
X2 = X1 * N def test_graph_conv2(g, norm, weight, bias):
X2 = th.matmul(A, X2.view(X2.shape[0], -1)) conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias)
X2 = X2 * N ext_w = F.randn((5, 2))
X = th.cat([X, X1, X2], dim=-1) nsrc = g.number_of_nodes() if isinstance(g, dgl.DGLGraph) else g.number_of_src_nodes()
Y = th.matmul(X, W.rot90()) ndst = g.number_of_nodes() if isinstance(g, dgl.DGLGraph) else g.number_of_dst_nodes()
h = F.randn((nsrc, 5))
return Y + b if weight:
h = conv(g, h)
else:
h = conv(g, h, weight=ext_w)
assert h.shape == (ndst, 2)
def test_simple_pool(): def test_simple_pool():
ctx = F.ctx() ctx = F.ctx()
...@@ -367,6 +373,110 @@ def test_gin_conv(): ...@@ -367,6 +373,110 @@ def test_gin_conv():
h = gin(g, feat) h = gin(g, feat)
assert h.shape[-1] == 12 assert h.shape[-1] == 12
def myagg(alist, dsttype):
rst = alist[0]
for i in range(1, len(alist)):
rst = rst + (i + 1) * alist[i]
return rst
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
def test_hetero_conv(agg):
g = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (0, 2), (2, 1), (1, 3)],
('user', 'plays', 'game'): [(0, 0), (0, 2), (0, 3), (1, 0), (2, 2)],
('store', 'sells', 'game'): [(0, 0), (0, 3), (1, 1), (1, 2)]})
conv = nn.HeteroGraphConv({
'follows': nn.GraphConv(2, 3),
'plays': nn.GraphConv(2, 4),
'sells': nn.GraphConv(3, 4)},
agg)
uf = F.randn((4, 2))
gf = F.randn((4, 4))
sf = F.randn((2, 3))
uf_dst = F.randn((4, 3))
gf_dst = F.randn((4, 4))
h = conv(g, {'user': uf})
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4)
h = conv(g, {'user': uf, 'store': sf})
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 2, 4)
h = conv(g, {'store': sf})
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
# test with pair input
conv = nn.HeteroGraphConv({
'follows': nn.SAGEConv(2, 3, 'mean'),
'plays': nn.SAGEConv((2, 4), 4, 'mean'),
'sells': nn.SAGEConv(3, 4, 'mean')},
agg)
h = conv(g, ({'user': uf}, {'user' : uf, 'game' : gf}))
assert set(h.keys()) == {'user', 'game'}
if agg != 'stack':
assert h['user'].shape == (4, 3)
assert h['game'].shape == (4, 4)
else:
assert h['user'].shape == (4, 1, 3)
assert h['game'].shape == (4, 1, 4)
# pair input requires both src and dst type features to be provided
h = conv(g, ({'user': uf}, {'game' : gf}))
assert set(h.keys()) == {'game'}
if agg != 'stack':
assert h['game'].shape == (4, 4)
else:
assert h['game'].shape == (4, 1, 4)
# test with mod args
class MyMod(tf.keras.layers.Layer):
def __init__(self, s1, s2):
super(MyMod, self).__init__()
self.carg1 = 0
self.carg2 = 0
self.s1 = s1
self.s2 = s2
def call(self, g, h, arg1=None, *, arg2=None):
if arg1 is not None:
self.carg1 += 1
if arg2 is not None:
self.carg2 += 1
return tf.zeros((g.number_of_dst_nodes(), self.s2))
mod1 = MyMod(2, 3)
mod2 = MyMod(2, 4)
mod3 = MyMod(3, 4)
conv = nn.HeteroGraphConv({
'follows': mod1,
'plays': mod2,
'sells': mod3},
agg)
mod_args = {'follows' : (1,), 'plays' : (1,)}
mod_kwargs = {'sells' : {'arg2' : 'abc'}}
h = conv(g, {'user' : uf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs)
assert mod1.carg1 == 1
assert mod1.carg2 == 0
assert mod2.carg1 == 1
assert mod2.carg2 == 0
assert mod3.carg1 == 0
assert mod3.carg2 == 1
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
......
from collections import defaultdict
import dgl
import networkx as nx
case_registry = defaultdict(list)
def register_case(labels):
def wrapper(fn):
for lbl in labels:
case_registry[lbl].append(fn)
return fn
return wrapper
def get_cases(labels=None, exclude=None):
cases = set()
if labels is None:
# get all the cases
labels = case_registry.keys()
for lbl in labels:
if exclude is not None and lbl in exclude:
continue
cases.update(case_registry[lbl])
return [fn() for fn in cases]
@register_case(['dglgraph', 'path', 'small'])
def dglgraph_path():
return dgl.DGLGraph(nx.path_graph(5))
@register_case(['bipartite', 'small', 'hetero', 'zero-degree'])
def bipartite1():
return dgl.bipartite([(0, 0), (0, 1), (0, 4), (2, 1), (2, 4), (3, 3)])
@register_case(['bipartite', 'small', 'hetero'])
def bipartite_full():
return dgl.bipartite([(0, 0), (0, 1), (0, 2), (0, 3), (1, 0), (1, 1), (1, 2), (1, 3)])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment