Unverified Commit 4af02022 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Multiple fixes (#1374)

* multiple fixes

* lint

* lint x2
parent 0a51dc54
......@@ -51,6 +51,30 @@ class DGLBaseGraph(object):
"""
return self._graph.number_of_nodes()
def number_of_src_nodes(self):
"""Return the number of nodes in the graph.
For compatibility with heterographs.
Returns
-------
int
The number of nodes
"""
return self._graph.number_of_nodes()
def number_of_dst_nodes(self):
"""Return the number of nodes in the graph.
For compatibility with heterographs.
Returns
-------
int
The number of nodes
"""
return self._graph.number_of_nodes()
def __len__(self):
"""Return the number of nodes in the graph."""
return self.number_of_nodes()
......
......@@ -716,8 +716,12 @@ class DGLHeteroGraph(object):
def srcdata(self):
"""Return the data view of all nodes in the SRC category.
**Only works if the graph is uni-bipartite and has one node type in the
SRC category.**
Only works if the graph is either
* Uni-bipartite and has one node type in the SRC category.
* Non-uni-bipartite and has only one node type (in this case identical to
:any:`DGLHeteroGraph.ndata`)
Examples
--------
......@@ -750,8 +754,10 @@ class DGLHeteroGraph(object):
--------
nodes
"""
assert self.is_unibipartite, 'srcdata is only allowed for uni-bipartite graph.'
assert len(self.srctypes) == 1, 'srcdata is only allowed when there is only one SRC type.'
err_msg = (
'srcdata is only allowed when there is only one %s type.' %
('SRC' if self.is_unibipartite else 'node'))
assert len(self.srctypes) == 1, err_msg
ntype = self.srctypes[0]
ntid = self.get_ntype_id_from_src(ntype)
return HeteroNodeDataView(self, ntype, ntid, ALL)
......@@ -760,8 +766,12 @@ class DGLHeteroGraph(object):
def dstdata(self):
"""Return the data view of all destination nodes.
**Only works if the graph is uni-bipartite and has one node type in the
DST category.**
Only works if the graph is either
* Uni-bipartite and has one node type in the SRC category.
* Non-uni-bipartite and has only one node type (in this case identical to
:any:`DGLHeteroGraph.ndata`)
Examples
--------
......@@ -794,8 +804,10 @@ class DGLHeteroGraph(object):
--------
nodes
"""
assert self.is_unibipartite, 'dstdata is only allowed for uni-bipartite graph.'
assert len(self.dsttypes) == 1, 'dstdata is only allowed when there is only one DST type.'
err_msg = (
'dstdata is only allowed when there is only one %s type.' %
('DST' if self.is_unibipartite else 'node'))
assert len(self.dsttypes) == 1, err_msg
ntype = self.dsttypes[0]
ntid = self.get_ntype_id_from_dst(ntype)
return HeteroNodeDataView(self, ntype, ntid, ALL)
......
"""MXNet Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import math
from numbers import Integral
import mxnet as mx
from mxnet import nd
from mxnet.gluon import nn
......@@ -24,6 +25,14 @@ class SAGEConv(nn.Block):
----------
in_feats : int
Input feature size.
If the layer is to be applied on a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
If aggregator type is ``gcn``, the feature size of source and destination nodes
are required to be the same.
out_feats : int
Output feature size.
feat_drop : float
......@@ -47,7 +56,15 @@ class SAGEConv(nn.Block):
norm=None,
activation=None):
super(SAGEConv, self).__init__()
self._in_feats = in_feats
if isinstance(in_feats, tuple):
self._in_src_feats = in_feats[0]
self._in_dst_feats = in_feats[1]
elif isinstance(in_feats, Integral):
self._in_src_feats = self._in_dst_feats = in_feats
else:
raise TypeError('in_feats must be either int or pair of ints')
self._out_feats = out_feats
self._aggre_type = aggregator_type
with self.name_scope():
......@@ -55,18 +72,18 @@ class SAGEConv(nn.Block):
self.feat_drop = nn.Dropout(feat_drop)
self.activation = activation
if aggregator_type == 'pool':
self.fc_pool = nn.Dense(in_feats, use_bias=bias,
self.fc_pool = nn.Dense(self._in_src_feats, use_bias=bias,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=in_feats)
in_units=self._in_src_feats)
if aggregator_type == 'lstm':
raise NotImplementedError
if aggregator_type != 'gcn':
self.fc_self = nn.Dense(out_feats, use_bias=bias,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=in_feats)
in_units=self._in_dst_feats)
self.fc_neigh = nn.Dense(out_feats, use_bias=bias,
weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
in_units=in_feats)
in_units=self._in_src_feats)
def forward(self, graph, feat):
r"""Compute GraphSAGE layer.
......@@ -86,23 +103,31 @@ class SAGEConv(nn.Block):
is size of output feature.
"""
graph = graph.local_var()
feat = self.feat_drop(feat)
h_self = feat
if isinstance(feat, tuple):
feat_src = self.feat_drop(feat[0])
feat_dst = self.feat_drop(feat[1])
else:
feat_src = feat_dst = self.feat_drop(feat)
h_self = feat_dst
if self._aggre_type == 'mean':
graph.ndata['h'] = feat
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.ndata['neigh']
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
graph.ndata['h'] = feat
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # saame as above if homogeneous
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
# divide in degrees
degs = graph.in_degrees().astype(feat.dtype)
degs = degs.as_in_context(feat.context)
h_neigh = (graph.ndata['neigh'] + graph.ndata['h']) / (degs.expand_dims(-1) + 1)
degs = graph.in_degrees().astype(feat_dst.dtype)
degs = degs.as_in_context(feat_dst.context)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.expand_dims(-1) + 1)
elif self._aggre_type == 'pool':
graph.ndata['h'] = nd.relu(self.fc_pool(feat))
graph.srcdata['h'] = nd.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.ndata['neigh']
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'lstm':
raise NotImplementedError
else:
......
"""Torch Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
from numbers import Integral
import torch
from torch import nn
from torch.nn import functional as F
......@@ -124,11 +123,11 @@ class SAGEConv(nn.Module):
"""
graph = graph.local_var()
if torch.is_tensor(feat):
feat_src = feat_dst = self.feat_drop(feat)
else:
if isinstance(feat, tuple):
feat_src = self.feat_drop(feat[0])
feat_dst = self.feat_drop(feat[1])
else:
feat_src = feat_dst = self.feat_drop(feat)
h_self = feat_dst
......@@ -141,8 +140,7 @@ class SAGEConv(nn.Module):
graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
# divide in_degrees
degs = graph.in_degrees().float()
degs = degs.to(feat_dst.device)
degs = graph.in_degrees().to(feat_dst)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
elif self._aggre_type == 'pool':
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
......
......@@ -49,7 +49,7 @@ class EdgeSoftmax(th.autograd.Function):
if not is_all(eids):
g = g.edge_subgraph(eids.long())
n_nodes = g.number_of_nodes()
n_nodes = g.number_of_dst_nodes()
n_edges = g.number_of_edges()
# TODO(BarclayII): this is a temporary fix of memory leakage in PyTorch
......
"""Tensorflow Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
from numbers import Integral
import tensorflow as tf
from tensorflow.keras import layers
......@@ -21,8 +22,16 @@ class SAGEConv(layers.Layer):
Parameters
----------
in_feats : int
in_feats : int, or pair of ints
Input feature size.
If the layer is to be applied on a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
If aggregator type is ``gcn``, the feature size of source and destination nodes
are required to be the same.
out_feats : int
Output feature size.
feat_drop : float
......@@ -47,7 +56,15 @@ class SAGEConv(layers.Layer):
norm=None,
activation=None):
super(SAGEConv, self).__init__()
self._in_feats = in_feats
if isinstance(in_feats, tuple):
self._in_src_feats = in_feats[0]
self._in_dst_feats = in_feats[1]
elif isinstance(in_feats, Integral):
self._in_src_feats = self._in_dst_feats = in_feats
else:
raise TypeError('in_feats must be either int or pair of ints')
self._out_feats = out_feats
self._aggre_type = aggregator_type
self.norm = norm
......@@ -55,9 +72,9 @@ class SAGEConv(layers.Layer):
self.activation = activation
# aggregator type: mean/pool/lstm/gcn
if aggregator_type == 'pool':
self.fc_pool = layers.Dense(in_feats)
self.fc_pool = layers.Dense(self._in_src_feats)
if aggregator_type == 'lstm':
self.lstm = layers.LSTM(units=in_feats)
self.lstm = layers.LSTM(units=self._in_src_feats)
if aggregator_type != 'gcn':
self.fc_self = layers.Dense(out_feats, use_bias=bias)
self.fc_neigh = layers.Dense(out_feats, use_bias=bias)
......@@ -89,27 +106,35 @@ class SAGEConv(layers.Layer):
is size of output feature.
"""
graph = graph.local_var()
feat = self.feat_drop(feat)
h_self = feat
if isinstance(feat, tuple):
feat_src = self.feat_drop(feat[0])
feat_dst = self.feat_drop(feat[1])
else:
feat_src = feat_dst = self.feat_drop(feat)
h_self = feat_dst
if self._aggre_type == 'mean':
graph.ndata['h'] = feat
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.ndata['neigh']
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
graph.ndata['h'] = feat
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
# divide in_degrees
degs = tf.cast(graph.in_degrees(), tf.float32)
h_neigh = (graph.ndata['neigh'] + graph.ndata['h']
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']
) / (tf.expand_dims(degs, -1) + 1)
elif self._aggre_type == 'pool':
graph.ndata['h'] = tf.nn.relu(self.fc_pool(feat))
graph.srcdata['h'] = tf.nn.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.ndata['neigh']
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'lstm':
graph.ndata['h'] = feat
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer)
h_neigh = graph.ndata['neigh']
h_neigh = graph.dstdata['neigh']
else:
raise KeyError(
'Aggregator type {} not recognized.'.format(self._aggre_type))
......
......@@ -127,17 +127,30 @@ def test_gat_conv():
assert h1.shape == (20, 5, 20)
def test_sage_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
for aggre_type in ['mean', 'pool', 'gcn']:
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
sage = nn.SAGEConv(5, 10, aggre_type)
feat = F.randn((100, 5))
sage.initialize(ctx=ctx)
h = sage(g, feat)
assert h.shape[-1] == 10
graphsage = nn.SAGEConv(10, 20)
graphsage.initialize(ctx=ctx)
print(graphsage)
g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
sage = nn.SAGEConv(5, 10, aggre_type)
feat = F.randn((100, 5))
sage.initialize(ctx=ctx)
h = sage(g, feat)
assert h.shape[-1] == 10
# test#1: basic
h0 = F.randn((20, 10))
h1 = graphsage(g, h0)
assert h1.shape == (20, 20)
g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
dst_dim = 5 if aggre_type != 'gcn' else 10
sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
feat = (F.randn((100, 10)), F.randn((200, dst_dim)))
sage.initialize(ctx=ctx)
h = sage(g, feat)
assert h.shape[-1] == 2
assert h.shape[0] == 200
def test_gg_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
......
......@@ -290,11 +290,15 @@ def test_edge_softmax():
print(score.grad[:10], grad_score[:10])
# Test 2
def generate_rand_graph(n):
arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64)
return dgl.DGLGraph(arr, readonly=True)
g = generate_rand_graph(50)
def generate_rand_graph(n, m=None, ctor=dgl.DGLGraph):
if m is None:
m = n
arr = (sp.sparse.random(m, n, density=0.1, format='coo') != 0).astype(np.int64)
return ctor(arr, readonly=True)
for g in [generate_rand_graph(50),
generate_rand_graph(50, ctor=dgl.graph),
generate_rand_graph(100, 50, ctor=dgl.bipartite)]:
a1 = F.randn((g.number_of_edges(), 1)).requires_grad_()
a2 = a1.clone().detach().requires_grad_()
g.edata['s'] = a1
......@@ -304,7 +308,8 @@ def test_edge_softmax():
builtin_sm = nn.edge_softmax(g, a2)
builtin_sm.sum().backward()
print(a1.grad - a2.grad)
assert len(g.ndata) == 0
assert len(g.srcdata) == 0
assert len(g.dstdata) == 0
assert len(g.edata) == 2
assert F.allclose(a1.grad, a2.grad, rtol=1e-4, atol=1e-4) # Follow tolerance in unittest backend
......@@ -402,6 +407,13 @@ def test_sage_conv():
h = sage(g, feat)
assert h.shape[-1] == 10
g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
sage = nn.SAGEConv(5, 10, aggre_type)
feat = F.randn((100, 5))
sage = sage.to(ctx)
h = sage(g, feat)
assert h.shape[-1] == 10
g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
dst_dim = 5 if aggre_type != 'gcn' else 10
sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
......
......@@ -309,12 +309,27 @@ def test_gat_conv():
def test_sage_conv():
for aggre_type in ['mean', 'pool', 'gcn', 'lstm']:
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
sage = nn.SAGEConv(5, 10, aggre_type)
feat = F.randn((100, 5))
h = sage(g, feat)
assert h.shape[-1] == 10
g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
sage = nn.SAGEConv(5, 10, aggre_type)
feat = F.randn((100, 5))
h = sage(g, feat)
assert h.shape[-1] == 10
g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
dst_dim = 5 if aggre_type != 'gcn' else 10
sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
feat = (F.randn((100, 10)), F.randn((200, dst_dim)))
h = sage(g, feat)
assert h.shape[-1] == 2
assert h.shape[0] == 200
def test_sgc_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment