"vscode:/vscode.git/clone" did not exist on "9a86f1fa4c47e2289c97c68ef6b231af7e79017d"
Unverified Commit af61e2fb authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Feature] Support nn modules for bipartite graphs. (#1392)



* init gat

* fix

* gin

* 7 nn modules

* rename & lint

* upd

* upd

* fix lint

* upd test

* upd

* lint

* shape check

* upd

* lint

* address comments

* update tensorflow
Co-authored-by: default avatarQuan Gan <coin2028@hotmail.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 67cb7a43
...@@ -77,6 +77,8 @@ class GatedGraphConv(nn.Module): ...@@ -77,6 +77,8 @@ class GatedGraphConv(nn.Module):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is the output feature size. is the output feature size.
""" """
assert graph.is_homograph(), \
"not a homograph; convert it with to_homo and pass in the edge type as argument"
graph = graph.local_var() graph = graph.local_var()
zero_pad = feat.new_zeros((feat.shape[0], self._out_feats - feat.shape[1])) zero_pad = feat.new_zeros((feat.shape[0], self._out_feats - feat.shape[1]))
feat = th.cat([feat, zero_pad], -1) feat = th.cat([feat, zero_pad], -1)
......
...@@ -4,6 +4,7 @@ import torch as th ...@@ -4,6 +4,7 @@ import torch as th
from torch import nn from torch import nn
from .... import function as fn from .... import function as fn
from ....utils import expand_as_pair
class GINConv(nn.Module): class GINConv(nn.Module):
...@@ -55,10 +56,12 @@ class GINConv(nn.Module): ...@@ -55,10 +56,12 @@ class GINConv(nn.Module):
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor or pair of torch.Tensor
The input feature of shape :math:`(N, D)` where :math:`D` If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
could be any positive integer, :math:`N` is the number :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
of nodes. If ``apply_func`` is not None, :math:`D` should If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`.
If ``apply_func`` is not None, :math:`D_{in}` should
fit the input dimensionality requirement of ``apply_func``. fit the input dimensionality requirement of ``apply_func``.
Returns Returns
...@@ -70,9 +73,10 @@ class GINConv(nn.Module): ...@@ -70,9 +73,10 @@ class GINConv(nn.Module):
as input dimensionality. as input dimensionality.
""" """
graph = graph.local_var() graph = graph.local_var()
graph.ndata['h'] = feat feat_src, feat_dst = expand_as_pair(feat)
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh')) graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh'))
rst = (1 + self.eps) * feat + graph.ndata['neigh'] rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
if self.apply_func is not None: if self.apply_func is not None:
rst = self.apply_func(rst) rst = self.apply_func(rst)
return rst return rst
...@@ -6,6 +6,7 @@ from torch.nn import init ...@@ -6,6 +6,7 @@ from torch.nn import init
from .... import function as fn from .... import function as fn
from ..utils import Identity from ..utils import Identity
from ....utils import expand_as_pair
class GMMConv(nn.Module): class GMMConv(nn.Module):
...@@ -45,7 +46,7 @@ class GMMConv(nn.Module): ...@@ -45,7 +46,7 @@ class GMMConv(nn.Module):
residual=False, residual=False,
bias=True): bias=True):
super(GMMConv, self).__init__() super(GMMConv, self).__init__()
self._in_feats = in_feats self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats self._out_feats = out_feats
self._dim = dim self._dim = dim
self._n_kernels = n_kernels self._n_kernels = n_kernels
...@@ -60,10 +61,10 @@ class GMMConv(nn.Module): ...@@ -60,10 +61,10 @@ class GMMConv(nn.Module):
self.mu = nn.Parameter(th.Tensor(n_kernels, dim)) self.mu = nn.Parameter(th.Tensor(n_kernels, dim))
self.inv_sigma = nn.Parameter(th.Tensor(n_kernels, dim)) self.inv_sigma = nn.Parameter(th.Tensor(n_kernels, dim))
self.fc = nn.Linear(in_feats, n_kernels * out_feats, bias=False) self.fc = nn.Linear(self._in_src_feats, n_kernels * out_feats, bias=False)
if residual: if residual:
if in_feats != out_feats: if self._in_dst_feats != out_feats:
self.res_fc = nn.Linear(in_feats, out_feats, bias=False) self.res_fc = nn.Linear(self._in_dst_feats, out_feats, bias=False)
else: else:
self.res_fc = Identity() self.res_fc = Identity()
else: else:
...@@ -94,9 +95,10 @@ class GMMConv(nn.Module): ...@@ -94,9 +95,10 @@ class GMMConv(nn.Module):
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`N` If a single tensor is given, the input feature of shape :math:`(N, D_{in})` where
is the number of nodes of the graph and :math:`D_{in}` is the :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
input feature size. If a pair of tensors are given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
pseudo : torch.Tensor pseudo : torch.Tensor
The pseudo coordinate tensor of shape :math:`(E, D_{u})` where The pseudo coordinate tensor of shape :math:`(E, D_{u})` where
:math:`E` is the number of edges of the graph and :math:`D_{u}` :math:`E` is the number of edges of the graph and :math:`D_{u}`
...@@ -108,8 +110,9 @@ class GMMConv(nn.Module): ...@@ -108,8 +110,9 @@ class GMMConv(nn.Module):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is the output feature size. is the output feature size.
""" """
graph = graph.local_var() with graph.local_scope():
graph.ndata['h'] = self.fc(feat).view(-1, self._n_kernels, self._out_feats) feat_src, feat_dst = expand_as_pair(feat)
graph.srcdata['h'] = self.fc(feat_src).view(-1, self._n_kernels, self._out_feats)
E = graph.number_of_edges() E = graph.number_of_edges()
# compute gaussian weight # compute gaussian weight
gaussian = -0.5 * ((pseudo.view(E, 1, self._dim) - gaussian = -0.5 * ((pseudo.view(E, 1, self._dim) -
...@@ -118,10 +121,10 @@ class GMMConv(nn.Module): ...@@ -118,10 +121,10 @@ class GMMConv(nn.Module):
gaussian = th.exp(gaussian.sum(dim=-1, keepdim=True)) # (E, K, 1) gaussian = th.exp(gaussian.sum(dim=-1, keepdim=True)) # (E, K, 1)
graph.edata['w'] = gaussian graph.edata['w'] = gaussian
graph.update_all(fn.u_mul_e('h', 'w', 'm'), self._reducer('m', 'h')) graph.update_all(fn.u_mul_e('h', 'w', 'm'), self._reducer('m', 'h'))
rst = graph.ndata['h'].sum(1) rst = graph.dstdata['h'].sum(1)
# residual connection # residual connection
if self.res_fc is not None: if self.res_fc is not None:
rst = rst + self.res_fc(feat) rst = rst + self.res_fc(feat_dst)
# bias # bias
if self.bias is not None: if self.bias is not None:
rst = rst + self.bias rst = rst + self.bias
......
...@@ -6,6 +6,7 @@ from torch.nn import init ...@@ -6,6 +6,7 @@ from torch.nn import init
from .... import function as fn from .... import function as fn
from ..utils import Identity from ..utils import Identity
from ....utils import expand_as_pair
class NNConv(nn.Module): class NNConv(nn.Module):
...@@ -20,6 +21,11 @@ class NNConv(nn.Module): ...@@ -20,6 +21,11 @@ class NNConv(nn.Module):
---------- ----------
in_feats : int in_feats : int
Input feature size. Input feature size.
If the layer is to be applied on a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
out_feats : int out_feats : int
Output feature size. Output feature size.
edge_func : callable activation function/layer edge_func : callable activation function/layer
...@@ -42,7 +48,7 @@ class NNConv(nn.Module): ...@@ -42,7 +48,7 @@ class NNConv(nn.Module):
residual=False, residual=False,
bias=True): bias=True):
super(NNConv, self).__init__() super(NNConv, self).__init__()
self._in_feats = in_feats self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats self._out_feats = out_feats
self.edge_nn = edge_func self.edge_nn = edge_func
if aggregator_type == 'sum': if aggregator_type == 'sum':
...@@ -55,8 +61,8 @@ class NNConv(nn.Module): ...@@ -55,8 +61,8 @@ class NNConv(nn.Module):
raise KeyError('Aggregator type {} not recognized: '.format(aggregator_type)) raise KeyError('Aggregator type {} not recognized: '.format(aggregator_type))
self._aggre_type = aggregator_type self._aggre_type = aggregator_type
if residual: if residual:
if in_feats != out_feats: if self._in_dst_feats != out_feats:
self.res_fc = nn.Linear(in_feats, out_feats, bias=False) self.res_fc = nn.Linear(self._in_dst_feats, out_feats, bias=False)
else: else:
self.res_fc = Identity() self.res_fc = Identity()
else: else:
...@@ -82,7 +88,7 @@ class NNConv(nn.Module): ...@@ -82,7 +88,7 @@ class NNConv(nn.Module):
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor or pair of torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`N` The input feature of shape :math:`(N, D_{in})` where :math:`N`
is the number of nodes of the graph and :math:`D_{in}` is the is the number of nodes of the graph and :math:`D_{in}` is the
input feature size. input feature size.
...@@ -96,17 +102,19 @@ class NNConv(nn.Module): ...@@ -96,17 +102,19 @@ class NNConv(nn.Module):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is the output feature size. is the output feature size.
""" """
graph = graph.local_var() with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat)
# (n, d_in, 1) # (n, d_in, 1)
graph.ndata['h'] = feat.unsqueeze(-1) graph.srcdata['h'] = feat_src.unsqueeze(-1)
# (n, d_in, d_out) # (n, d_in, d_out)
graph.edata['w'] = self.edge_nn(efeat).view(-1, self._in_feats, self._out_feats) graph.edata['w'] = self.edge_nn(efeat).view(-1, self._in_src_feats, self._out_feats)
# (n, d_in, d_out) # (n, d_in, d_out)
graph.update_all(fn.u_mul_e('h', 'w', 'm'), self.reducer('m', 'neigh')) graph.update_all(fn.u_mul_e('h', 'w', 'm'), self.reducer('m', 'neigh'))
rst = graph.ndata.pop('neigh').sum(dim=1) # (n, d_out) rst = graph.dstdata['neigh'].sum(dim=1) # (n, d_out)
# residual connection # residual connection
if self.res_fc is not None: if self.res_fc is not None:
rst = rst + self.res_fc(feat) rst = rst + self.res_fc(feat_dst)
# bias # bias
if self.bias is not None: if self.bias is not None:
rst = rst + self.bias rst = rst + self.bias
......
...@@ -172,7 +172,9 @@ class RelGraphConv(nn.Module): ...@@ -172,7 +172,9 @@ class RelGraphConv(nn.Module):
torch.Tensor torch.Tensor
New node features. New node features.
""" """
g = g.local_var() assert g.is_homograph(), \
"not a homograph; convert it with to_homo and pass in the edge type as argument"
with g.local_scope():
g.ndata['h'] = x g.ndata['h'] = x
g.edata['type'] = etypes g.edata['type'] = etypes
if norm is not None: if norm is not None:
......
"""Torch Module for GraphSAGE layer""" """Torch Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
from numbers import Integral
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from .... import function as fn from .... import function as fn
from ....utils import expand_as_pair, check_eq_shape
class SAGEConv(nn.Module): class SAGEConv(nn.Module):
...@@ -56,14 +56,7 @@ class SAGEConv(nn.Module): ...@@ -56,14 +56,7 @@ class SAGEConv(nn.Module):
activation=None): activation=None):
super(SAGEConv, self).__init__() super(SAGEConv, self).__init__()
if isinstance(in_feats, tuple): self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._in_src_feats = in_feats[0]
self._in_dst_feats = in_feats[1]
elif isinstance(in_feats, Integral):
self._in_src_feats = self._in_dst_feats = in_feats
else:
raise TypeError('in_feats must be either int or pair of ints')
self._out_feats = out_feats self._out_feats = out_feats
self._aggre_type = aggregator_type self._aggre_type = aggregator_type
self.norm = norm self.norm = norm
...@@ -136,6 +129,7 @@ class SAGEConv(nn.Module): ...@@ -136,6 +129,7 @@ class SAGEConv(nn.Module):
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh')) graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh'] h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn': elif self._aggre_type == 'gcn':
check_eq_shape(feat)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh')) graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
......
...@@ -73,6 +73,7 @@ class TAGConv(nn.Module): ...@@ -73,6 +73,7 @@ class TAGConv(nn.Module):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature. is size of output feature.
""" """
assert graph.is_homograph(), 'Graph is not homogeneous'
graph = graph.local_var() graph = graph.local_var()
norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5) norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
......
...@@ -28,8 +28,13 @@ class GATConv(layers.Layer): ...@@ -28,8 +28,13 @@ class GATConv(layers.Layer):
Parameters Parameters
---------- ----------
in_feats : int in_feats : int, or a pair of ints
Input feature size. Input feature size.
If the layer is to be applied to a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
out_feats : int out_feats : int
Output feature size. Output feature size.
num_heads : int num_heads : int
...@@ -62,11 +67,16 @@ class GATConv(layers.Layer): ...@@ -62,11 +67,16 @@ class GATConv(layers.Layer):
self._out_feats = out_feats self._out_feats = out_feats
xinit = tf.keras.initializers.VarianceScaling(scale=np.sqrt( xinit = tf.keras.initializers.VarianceScaling(scale=np.sqrt(
2), mode="fan_avg", distribution="untruncated_normal") 2), mode="fan_avg", distribution="untruncated_normal")
if isinstance(in_feats, tuple):
self.fc_src = layers.Dense(
out_feats * num_heads, use_bias=False, kernel_initializer=xinit)
self.fc_dst = layers.Dense(
out_feats * num_heads, use_bias=False, kernel_initializer=xinit)
else:
self.fc = layers.Dense( self.fc = layers.Dense(
out_feats * num_heads, use_bias=False, kernel_initializer=xinit) out_feats * num_heads, use_bias=False, kernel_initializer=xinit)
self.attn_l = tf.Variable(initial_value=xinit( self.attn_l = tf.Variable(initial_value=xinit(
shape=(1, num_heads, out_feats), dtype='float32'), trainable=True) shape=(1, num_heads, out_feats), dtype='float32'), trainable=True)
self.attn_r = tf.Variable(initial_value=xinit( self.attn_r = tf.Variable(initial_value=xinit(
shape=(1, num_heads, out_feats), dtype='float32'), trainable=True) shape=(1, num_heads, out_feats), dtype='float32'), trainable=True)
self.feat_drop = layers.Dropout(rate=feat_drop) self.feat_drop = layers.Dropout(rate=feat_drop)
...@@ -90,9 +100,11 @@ class GATConv(layers.Layer): ...@@ -90,9 +100,11 @@ class GATConv(layers.Layer):
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : tf.Tensor feat : tf.Tensor or pair of tf.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` If a tf.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
is size of input feature, :math:`N` is the number of nodes. :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of tf.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
Returns Returns
------- -------
...@@ -101,8 +113,15 @@ class GATConv(layers.Layer): ...@@ -101,8 +113,15 @@ class GATConv(layers.Layer):
is the number of heads, and :math:`D_{out}` is size of output feature. is the number of heads, and :math:`D_{out}` is size of output feature.
""" """
graph = graph.local_var() graph = graph.local_var()
h = self.feat_drop(feat) if isinstance(feat, tuple):
feat = tf.reshape(self.fc(h), (-1, self._num_heads, self._out_feats)) h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
feat_src = tf.reshape(self.fc_src(h_src), (-1, self._num_heads, self._out_feats))
feat_dst = tf.reshape(self.fc_dst(h_dst), (-1, self._num_heads, self._out_feats))
else:
h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = tf.reshape(
self.fc(h_src), (-1, self._num_heads, self._out_feats))
# NOTE: GAT paper uses "first concatenation then linear projection" # NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then # to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent: # addition", the two approaches are mathematically equivalent:
...@@ -113,9 +132,10 @@ class GATConv(layers.Layer): ...@@ -113,9 +132,10 @@ class GATConv(layers.Layer):
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus, # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v, # addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint. # which further speeds up computation and saves memory footprint.
el = tf.reduce_sum(feat * self.attn_l, axis=-1, keepdims=True) el = tf.reduce_sum(feat_src * self.attn_l, axis=-1, keepdims=True)
er = tf.reduce_sum(feat * self.attn_r, axis=-1, keepdims=True) er = tf.reduce_sum(feat_dst * self.attn_r, axis=-1, keepdims=True)
graph.ndata.update({'ft': feat, 'el': el, 'er': er}) graph.srcdata.update({'ft': feat_src, 'el': el})
graph.dstdata.update({'er': er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v('el', 'er', 'e')) graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e')) e = self.leaky_relu(graph.edata.pop('e'))
...@@ -124,11 +144,11 @@ class GATConv(layers.Layer): ...@@ -124,11 +144,11 @@ class GATConv(layers.Layer):
# message passing # message passing
graph.update_all(fn.u_mul_e('ft', 'a', 'm'), graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft')) fn.sum('m', 'ft'))
rst = graph.ndata['ft'] rst = graph.dstdata['ft']
# residual # residual
if self.res_fc is not None: if self.res_fc is not None:
resval = tf.reshape(self.res_fc( resval = tf.reshape(self.res_fc(
h), (h.shape[0], -1, self._out_feats)) h_dst), (h_dst.shape[0], -1, self._out_feats))
rst = rst + resval rst = rst + resval
# activation # activation
if self.activation: if self.activation:
......
...@@ -4,6 +4,7 @@ import tensorflow as tf ...@@ -4,6 +4,7 @@ import tensorflow as tf
from tensorflow.keras import layers from tensorflow.keras import layers
from .... import function as fn from .... import function as fn
from ....utils import expand_as_pair
class GINConv(layers.Layer): class GINConv(layers.Layer):
...@@ -52,10 +53,13 @@ class GINConv(layers.Layer): ...@@ -52,10 +53,13 @@ class GINConv(layers.Layer):
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : tf.Tensor
The input feature of shape :math:`(N, D)` where :math:`D` feat : tf.Tensor or pair of tf.Tensor
could be any positive integer, :math:`N` is the number If a tf.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
of nodes. If ``apply_func`` is not None, :math:`D` should :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of tf.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`.
If ``apply_func`` is not None, :math:`D_{in}` should
fit the input dimensionality requirement of ``apply_func``. fit the input dimensionality requirement of ``apply_func``.
Returns Returns
...@@ -67,9 +71,10 @@ class GINConv(layers.Layer): ...@@ -67,9 +71,10 @@ class GINConv(layers.Layer):
as input dimensionality. as input dimensionality.
""" """
graph = graph.local_var() graph = graph.local_var()
graph.ndata['h'] = feat feat_src, feat_dst = expand_as_pair(feat)
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh')) graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh'))
rst = (1 + self.eps) * feat + graph.ndata['neigh'] rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
if self.apply_func is not None: if self.apply_func is not None:
rst = self.apply_func(rst) rst = self.apply_func(rst)
return rst return rst
...@@ -176,7 +176,9 @@ class RelGraphConv(layers.Layer): ...@@ -176,7 +176,9 @@ class RelGraphConv(layers.Layer):
tf.Tensor tf.Tensor
New node features. New node features.
""" """
g = g.local_var() assert g.is_homograph(), \
"not a homograph; convert it with to_homo and pass in the edge type as argument"
with g.local_scope():
g.ndata['h'] = x g.ndata['h'] = x
g.edata['type'] = tf.cast(etypes, tf.int64) g.edata['type'] = tf.cast(etypes, tf.int64)
if norm is not None: if norm is not None:
......
"""Tensorflow Module for GraphSAGE layer""" """Tensorflow Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
from numbers import Integral
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import layers from tensorflow.keras import layers
from .... import function as fn from .... import function as fn
from ....utils import expand_as_pair, check_eq_shape
class SAGEConv(layers.Layer): class SAGEConv(layers.Layer):
...@@ -57,14 +57,7 @@ class SAGEConv(layers.Layer): ...@@ -57,14 +57,7 @@ class SAGEConv(layers.Layer):
activation=None): activation=None):
super(SAGEConv, self).__init__() super(SAGEConv, self).__init__()
if isinstance(in_feats, tuple): self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._in_src_feats = in_feats[0]
self._in_dst_feats = in_feats[1]
elif isinstance(in_feats, Integral):
self._in_src_feats = self._in_dst_feats = in_feats
else:
raise TypeError('in_feats must be either int or pair of ints')
self._out_feats = out_feats self._out_feats = out_feats
self._aggre_type = aggregator_type self._aggre_type = aggregator_type
self.norm = norm self.norm = norm
...@@ -95,9 +88,11 @@ class SAGEConv(layers.Layer): ...@@ -95,9 +88,11 @@ class SAGEConv(layers.Layer):
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : tf.Tensor feat : tf.Tensor or pair of tf.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` If a single tensor is given, the input feature of shape :math:`(N, D_{in})` where
is size of input feature, :math:`N` is the number of nodes. :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of tensors are given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
Returns Returns
------- -------
...@@ -120,6 +115,7 @@ class SAGEConv(layers.Layer): ...@@ -120,6 +115,7 @@ class SAGEConv(layers.Layer):
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh')) graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh'] h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn': elif self._aggre_type == 'gcn':
check_eq_shape(feat)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh')) graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
......
...@@ -519,3 +519,23 @@ def make_invmap(array, use_numpy=True): ...@@ -519,3 +519,23 @@ def make_invmap(array, use_numpy=True):
invmap = {x: i for i, x in enumerate(uniques)} invmap = {x: i for i, x in enumerate(uniques)}
remapped = np.asarray([invmap[x] for x in array]) remapped = np.asarray([invmap[x] for x in array])
return uniques, invmap, remapped return uniques, invmap, remapped
def expand_as_pair(input_):
"""Return a pair of same element if the input is not a pair.
"""
if isinstance(input_, tuple):
return input_
else:
return input_, input_
def check_eq_shape(input_):
"""If input_ is a pair of features, check if the feature shape of source
nodes is equal to the feature shape of destination nodes.
"""
srcdata, dstdata = expand_as_pair(input_)
src_feat_shape = tuple(F.shape(srcdata))[1:]
dst_feat_shape = tuple(F.shape(dstdata))[1:]
if src_feat_shape != dst_feat_shape:
raise DGLError("The feature shape of source nodes: {} \
should be equal to the feature shape of destination \
nodes: {}.".format(src_feat_shape, dst_feat_shape))
...@@ -7,7 +7,7 @@ import dgl ...@@ -7,7 +7,7 @@ import dgl
import dgl.nn.mxnet as nn import dgl.nn.mxnet as nn
import dgl.function as fn import dgl.function as fn
import backend as F import backend as F
from test_utils.graph_cases import get_cases from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from mxnet import autograd, gluon, nd from mxnet import autograd, gluon, nd
def check_close(a, b): def check_close(a, b):
...@@ -133,20 +133,29 @@ def test_tagconv(): ...@@ -133,20 +133,29 @@ def test_tagconv():
assert h1.shape[-1] == 2 assert h1.shape[-1] == 2
def test_gat_conv(): def test_gat_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
gat = nn.GATConv(10, 20, 5) # n_heads = 5 gat = nn.GATConv(10, 20, 5) # n_heads = 5
gat.initialize(ctx=ctx) gat.initialize(ctx=ctx)
print(gat) print(gat)
# test#1: basic # test#1: basic
h0 = F.randn((20, 10)) feat = F.randn((20, 10))
h1 = gat(g, h0) h = gat(g, feat)
assert h1.shape == (20, 5, 20) assert h.shape == (20, 5, 20)
# test#2: bipartite
g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
gat = nn.GATConv((5, 10), 2, 4)
gat.initialize(ctx=ctx)
feat = (F.randn((100, 5)), F.randn((200, 10)))
h = gat(g, feat)
assert h.shape == (200, 4, 2)
def test_sage_conv(): @pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn'])
for aggre_type in ['mean', 'pool', 'gcn']: def test_sage_conv(aggre_type):
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
sage = nn.SAGEConv(5, 10, aggre_type) sage = nn.SAGEConv(5, 10, aggre_type)
...@@ -207,9 +216,14 @@ def test_agnn_conv(): ...@@ -207,9 +216,14 @@ def test_agnn_conv():
print(agnn_conv) print(agnn_conv)
# test#1: basic # test#1: basic
h0 = F.randn((20, 10)) feat = F.randn((20, 10))
h1 = agnn_conv(g, h0) h = agnn_conv(g, feat)
assert h1.shape == (20, 10) assert h.shape == (20, 10)
g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
feat = (F.randn((100, 5)), F.randn((200, 5)))
h = agnn_conv(g, feat)
assert h.shape == (200, 5)
def test_appnp_conv(): def test_appnp_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
...@@ -246,27 +260,27 @@ def test_dense_cheb_conv(): ...@@ -246,27 +260,27 @@ def test_dense_cheb_conv():
out_dense_cheb = dense_cheb(adj, feat, 2.0) out_dense_cheb = dense_cheb(adj, feat, 2.0)
assert F.allclose(out_cheb, out_dense_cheb) assert F.allclose(out_cheb, out_dense_cheb)
def test_dense_graph_conv(): @pytest.mark.parametrize('norm_type', ['both', 'right', 'none'])
@pytest.mark.parametrize('g', [random_graph(100), random_bipartite(100, 200)])
def test_dense_graph_conv(g, norm_type):
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.3), readonly=True)
adj = g.adjacency_matrix(ctx=ctx).tostype('default') adj = g.adjacency_matrix(ctx=ctx).tostype('default')
conv = nn.GraphConv(5, 2, norm='none', bias=True) conv = nn.GraphConv(5, 2, norm=norm_type, bias=True)
dense_conv = nn.DenseGraphConv(5, 2, norm=False, bias=True) dense_conv = nn.DenseGraphConv(5, 2, norm=norm_type, bias=True)
conv.initialize(ctx=ctx) conv.initialize(ctx=ctx)
dense_conv.initialize(ctx=ctx) dense_conv.initialize(ctx=ctx)
dense_conv.weight.set_data( dense_conv.weight.set_data(
conv.weight.data()) conv.weight.data())
dense_conv.bias.set_data( dense_conv.bias.set_data(
conv.bias.data()) conv.bias.data())
feat = F.randn((100, 5)) feat = F.randn((g.number_of_src_nodes(), 5))
out_conv = conv(g, feat) out_conv = conv(g, feat)
out_dense_conv = dense_conv(adj, feat) out_dense_conv = dense_conv(adj, feat)
assert F.allclose(out_conv, out_dense_conv) assert F.allclose(out_conv, out_dense_conv)
def test_dense_sage_conv(): @pytest.mark.parametrize('g', [random_graph(100), random_bipartite(100, 200)])
def test_dense_sage_conv(g):
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
adj = g.adjacency_matrix(ctx=ctx).tostype('default') adj = g.adjacency_matrix(ctx=ctx).tostype('default')
sage = nn.SAGEConv(5, 2, 'gcn') sage = nn.SAGEConv(5, 2, 'gcn')
dense_sage = nn.DenseSAGEConv(5, 2) dense_sage = nn.DenseSAGEConv(5, 2)
...@@ -276,14 +290,20 @@ def test_dense_sage_conv(): ...@@ -276,14 +290,20 @@ def test_dense_sage_conv():
sage.fc_neigh.weight.data()) sage.fc_neigh.weight.data())
dense_sage.fc.bias.set_data( dense_sage.fc.bias.set_data(
sage.fc_neigh.bias.data()) sage.fc_neigh.bias.data())
feat = F.randn((100, 5)) if len(g.ntypes) == 2:
feat = (
F.randn((g.number_of_src_nodes(), 5)),
F.randn((g.number_of_dst_nodes(), 5))
)
else:
feat = F.randn((g.number_of_nodes(), 5))
out_sage = sage(g, feat) out_sage = sage(g, feat)
out_dense_sage = dense_sage(adj, feat) out_dense_sage = dense_sage(adj, feat)
assert F.allclose(out_sage, out_dense_sage) assert F.allclose(out_sage, out_dense_sage)
def test_edge_conv(): @pytest.mark.parametrize('g', [random_dglgraph(20), random_graph(20), random_bipartite(20, 10)])
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) def test_edge_conv(g):
ctx = F.ctx() ctx = F.ctx()
edge_conv = nn.EdgeConv(5, 2) edge_conv = nn.EdgeConv(5, 2)
...@@ -291,9 +311,13 @@ def test_edge_conv(): ...@@ -291,9 +311,13 @@ def test_edge_conv():
print(edge_conv) print(edge_conv)
# test #1: basic # test #1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_src_nodes(), 5))
if not g.is_homograph():
# bipartite
h1 = edge_conv(g, (h0, h0[:10]))
else:
h1 = edge_conv(g, h0) h1 = edge_conv(g, h0)
assert h1.shape == (g.number_of_nodes(), 2) assert h1.shape == (g.number_of_dst_nodes(), 2)
def test_gin_conv(): def test_gin_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
...@@ -304,38 +328,79 @@ def test_gin_conv(): ...@@ -304,38 +328,79 @@ def test_gin_conv():
print(gin_conv) print(gin_conv)
# test #1: basic # test #1: basic
h0 = F.randn((g.number_of_nodes(), 5)) feat = F.randn((g.number_of_nodes(), 5))
h1 = gin_conv(g, h0) h = gin_conv(g, feat)
assert h1.shape == (g.number_of_nodes(), 5) assert h.shape == (20, 5)
# test #2: bipartite
g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
feat = (F.randn((100, 5)), F.randn((200, 5)))
h = gin_conv(g, feat)
return h.shape == (20, 5)
def test_gmm_conv(): def test_gmm_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max') gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max')
gmm_conv.initialize(ctx=ctx) gmm_conv.initialize(ctx=ctx)
print(gmm_conv) # test #1: basic
h0 = F.randn((g.number_of_nodes(), 5))
pseudo = F.randn((g.number_of_edges(), 5))
h1 = gmm_conv(g, h0, pseudo)
assert h1.shape == (g.number_of_nodes(), 2)
g = dgl.graph(nx.erdos_renyi_graph(20, 0.3))
gmm_conv = nn.GMMConv(5, 2, 5, 3, 'max')
gmm_conv.initialize(ctx=ctx)
# test #1: basic # test #1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
pseudo = F.randn((g.number_of_edges(), 5)) pseudo = F.randn((g.number_of_edges(), 5))
h1 = gmm_conv(g, h0, pseudo) h1 = gmm_conv(g, h0, pseudo)
assert h1.shape == (g.number_of_nodes(), 2) assert h1.shape == (g.number_of_nodes(), 2)
g = dgl.bipartite(sp.sparse.random(20, 10, 0.1))
gmm_conv = nn.GMMConv((5, 4), 2, 5, 3, 'max')
gmm_conv.initialize(ctx=ctx)
# test #1: basic
h0 = F.randn((g.number_of_src_nodes(), 5))
hd = F.randn((g.number_of_dst_nodes(), 4))
pseudo = F.randn((g.number_of_edges(), 5))
h1 = gmm_conv(g, (h0, hd), pseudo)
assert h1.shape == (g.number_of_dst_nodes(), 2)
def test_nn_conv(): def test_nn_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
nn_conv = nn.NNConv(5, 2, gluon.nn.Embedding(3, 5 * 2), 'max') nn_conv = nn.NNConv(5, 2, gluon.nn.Embedding(3, 5 * 2), 'max')
nn_conv.initialize(ctx=ctx) nn_conv.initialize(ctx=ctx)
print(nn_conv) # test #1: basic
h0 = F.randn((g.number_of_nodes(), 5))
etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx)
h1 = nn_conv(g, h0, etypes)
assert h1.shape == (g.number_of_nodes(), 2)
g = dgl.graph(nx.erdos_renyi_graph(20, 0.3))
nn_conv = nn.NNConv(5, 2, gluon.nn.Embedding(3, 5 * 2), 'max')
nn_conv.initialize(ctx=ctx)
# test #1: basic # test #1: basic
h0 = F.randn((g.number_of_nodes(), 5)) h0 = F.randn((g.number_of_nodes(), 5))
etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx) etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx)
h1 = nn_conv(g, h0, etypes) h1 = nn_conv(g, h0, etypes)
assert h1.shape == (g.number_of_nodes(), 2) assert h1.shape == (g.number_of_nodes(), 2)
g = dgl.bipartite(sp.sparse.random(20, 10, 0.3))
nn_conv = nn.NNConv((5, 4), 2, gluon.nn.Embedding(3, 5 * 2), 'max')
nn_conv.initialize(ctx=ctx)
# test #1: basic
h0 = F.randn((g.number_of_src_nodes(), 5))
hd = F.randn((g.number_of_dst_nodes(), 4))
etypes = nd.random.randint(0, 4, g.number_of_edges()).as_in_context(ctx)
h1 = nn_conv(g, (h0, hd), etypes)
assert h1.shape == (g.number_of_dst_nodes(), 2)
def test_sg_conv(): def test_sg_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
ctx = F.ctx() ctx = F.ctx()
......
...@@ -5,7 +5,7 @@ import dgl.nn.pytorch as nn ...@@ -5,7 +5,7 @@ import dgl.nn.pytorch as nn
import dgl.function as fn import dgl.function as fn
import backend as F import backend as F
import pytest import pytest
from test_utils.graph_cases import get_cases from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
...@@ -413,10 +413,17 @@ def test_gat_conv(): ...@@ -413,10 +413,17 @@ def test_gat_conv():
feat = F.randn((100, 5)) feat = F.randn((100, 5))
gat = gat.to(ctx) gat = gat.to(ctx)
h = gat(g, feat) h = gat(g, feat)
assert h.shape[-1] == 2 and h.shape[-2] == 4 assert h.shape == (100, 4, 2)
def test_sage_conv(): g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
for aggre_type in ['mean', 'pool', 'gcn', 'lstm']: gat = nn.GATConv((5, 10), 2, 4)
feat = (F.randn((100, 5)), F.randn((200, 10)))
gat = gat.to(ctx)
h = gat(g, feat)
assert h.shape == (200, 4, 2)
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
def test_sage_conv(aggre_type):
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
sage = nn.SAGEConv(5, 10, aggre_type) sage = nn.SAGEConv(5, 10, aggre_type)
...@@ -470,10 +477,10 @@ def test_appnp_conv(): ...@@ -470,10 +477,10 @@ def test_appnp_conv():
h = appnp(g, feat) h = appnp(g, feat)
assert h.shape[-1] == 5 assert h.shape[-1] == 5
def test_gin_conv(): @pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
for aggregator_type in ['mean', 'max', 'sum']: def test_gin_conv(aggregator_type):
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
gin = nn.GINConv( gin = nn.GINConv(
th.nn.Linear(5, 12), th.nn.Linear(5, 12),
aggregator_type aggregator_type
...@@ -481,16 +488,33 @@ def test_gin_conv(): ...@@ -481,16 +488,33 @@ def test_gin_conv():
feat = F.randn((100, 5)) feat = F.randn((100, 5))
gin = gin.to(ctx) gin = gin.to(ctx)
h = gin(g, feat) h = gin(g, feat)
assert h.shape[-1] == 12 assert h.shape == (100, 12)
g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
gin = nn.GINConv(
th.nn.Linear(5, 12),
aggregator_type
)
feat = (F.randn((100, 5)), F.randn((200, 5)))
gin = gin.to(ctx)
h = gin(g, feat)
assert h.shape == (200, 12)
def test_agnn_conv(): def test_agnn_conv():
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
agnn = nn.AGNNConv(1) agnn = nn.AGNNConv(1)
feat = F.randn((100, 5)) feat = F.randn((100, 5))
agnn = agnn.to(ctx) agnn = agnn.to(ctx)
h = agnn(g, feat) h = agnn(g, feat)
assert h.shape[-1] == 5 assert h.shape == (100, 5)
g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
agnn = nn.AGNNConv(1)
feat = (F.randn((100, 5)), F.randn((200, 5)))
agnn = agnn.to(ctx)
h = agnn(g, feat)
assert h.shape == (200, 5)
def test_gated_graph_conv(): def test_gated_graph_conv():
ctx = F.ctx() ctx = F.ctx()
...@@ -517,6 +541,27 @@ def test_nn_conv(): ...@@ -517,6 +541,27 @@ def test_nn_conv():
# currently we only do shape check # currently we only do shape check
assert h.shape[-1] == 10 assert h.shape[-1] == 10
g = dgl.graph(sp.sparse.random(100, 100, density=0.1))
edge_func = th.nn.Linear(4, 5 * 10)
nnconv = nn.NNConv(5, 10, edge_func, 'mean')
feat = F.randn((100, 5))
efeat = F.randn((g.number_of_edges(), 4))
nnconv = nnconv.to(ctx)
h = nnconv(g, feat, efeat)
# currently we only do shape check
assert h.shape[-1] == 10
g = dgl.bipartite(sp.sparse.random(50, 100, density=0.1))
edge_func = th.nn.Linear(4, 5 * 10)
nnconv = nn.NNConv((5, 2), 10, edge_func, 'mean')
feat = F.randn((50, 5))
feat_dst = F.randn((100, 2))
efeat = F.randn((g.number_of_edges(), 4))
nnconv = nnconv.to(ctx)
h = nnconv(g, (feat, feat_dst), efeat)
# currently we only do shape check
assert h.shape[-1] == 10
def test_gmm_conv(): def test_gmm_conv():
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
...@@ -528,35 +573,78 @@ def test_gmm_conv(): ...@@ -528,35 +573,78 @@ def test_gmm_conv():
# currently we only do shape check # currently we only do shape check
assert h.shape[-1] == 10 assert h.shape[-1] == 10
def test_dense_graph_conv(): g = dgl.graph(sp.sparse.random(100, 100, density=0.1), readonly=True)
gmmconv = nn.GMMConv(5, 10, 3, 4, 'mean')
feat = F.randn((100, 5))
pseudo = F.randn((g.number_of_edges(), 3))
gmmconv = gmmconv.to(ctx)
h = gmmconv(g, feat, pseudo)
# currently we only do shape check
assert h.shape[-1] == 10
g = dgl.bipartite(sp.sparse.random(100, 50, density=0.1), readonly=True)
gmmconv = nn.GMMConv((5, 2), 10, 3, 4, 'mean')
feat = F.randn((100, 5))
feat_dst = F.randn((50, 2))
pseudo = F.randn((g.number_of_edges(), 3))
gmmconv = gmmconv.to(ctx)
h = gmmconv(g, (feat, feat_dst), pseudo)
# currently we only do shape check
assert h.shape[-1] == 10
@pytest.mark.parametrize('norm_type', ['both', 'right', 'none'])
@pytest.mark.parametrize('g', [random_graph(100), random_bipartite(100, 200)])
def test_dense_graph_conv(norm_type, g):
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) # TODO(minjie): enable the following option after #1385
adj = g.adjacency_matrix(ctx=ctx).to_dense() adj = g.adjacency_matrix(ctx=ctx).to_dense()
conv = nn.GraphConv(5, 2, norm='none', bias=True) conv = nn.GraphConv(5, 2, norm=norm_type, bias=True)
dense_conv = nn.DenseGraphConv(5, 2, norm=False, bias=True) dense_conv = nn.DenseGraphConv(5, 2, norm=norm_type, bias=True)
dense_conv.weight.data = conv.weight.data dense_conv.weight.data = conv.weight.data
dense_conv.bias.data = conv.bias.data dense_conv.bias.data = conv.bias.data
feat = F.randn((100, 5)) feat = F.randn((g.number_of_src_nodes(), 5))
conv = conv.to(ctx) conv = conv.to(ctx)
dense_conv = dense_conv.to(ctx) dense_conv = dense_conv.to(ctx)
out_conv = conv(g, feat) out_conv = conv(g, feat)
out_dense_conv = dense_conv(adj, feat) out_dense_conv = dense_conv(adj, feat)
assert F.allclose(out_conv, out_dense_conv) assert F.allclose(out_conv, out_dense_conv)
def test_dense_sage_conv(): @pytest.mark.parametrize('g', [random_graph(100), random_bipartite(100, 200)])
def test_dense_sage_conv(g):
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
adj = g.adjacency_matrix(ctx=ctx).to_dense() adj = g.adjacency_matrix(ctx=ctx).to_dense()
sage = nn.SAGEConv(5, 2, 'gcn') sage = nn.SAGEConv(5, 2, 'gcn')
dense_sage = nn.DenseSAGEConv(5, 2) dense_sage = nn.DenseSAGEConv(5, 2)
dense_sage.fc.weight.data = sage.fc_neigh.weight.data dense_sage.fc.weight.data = sage.fc_neigh.weight.data
dense_sage.fc.bias.data = sage.fc_neigh.bias.data dense_sage.fc.bias.data = sage.fc_neigh.bias.data
feat = F.randn((100, 5)) if len(g.ntypes) == 2:
feat = (
F.randn((g.number_of_src_nodes(), 5)),
F.randn((g.number_of_dst_nodes(), 5))
)
else:
feat = F.randn((g.number_of_nodes(), 5))
sage = sage.to(ctx) sage = sage.to(ctx)
dense_sage = dense_sage.to(ctx) dense_sage = dense_sage.to(ctx)
out_sage = sage(g, feat) out_sage = sage(g, feat)
out_dense_sage = dense_sage(adj, feat) out_dense_sage = dense_sage(adj, feat)
assert F.allclose(out_sage, out_dense_sage) assert F.allclose(out_sage, out_dense_sage), g
@pytest.mark.parametrize('g', [random_dglgraph(20), random_graph(20), random_bipartite(20, 10)])
def test_edge_conv(g):
ctx = F.ctx()
edge_conv = nn.EdgeConv(5, 2).to(ctx)
print(edge_conv)
# test #1: basic
h0 = F.randn((g.number_of_src_nodes(), 5))
if not g.is_homograph():
# bipartite
h1 = edge_conv(g, (h0, h0[:10]))
else:
h1 = edge_conv(g, h0)
assert h1.shape == (g.number_of_dst_nodes(), 2)
def test_dense_cheb_conv(): def test_dense_cheb_conv():
for k in range(1, 4): for k in range(1, 4):
......
...@@ -6,7 +6,7 @@ import dgl ...@@ -6,7 +6,7 @@ import dgl
import dgl.nn.tensorflow as nn import dgl.nn.tensorflow as nn
import dgl.function as fn import dgl.function as fn
import backend as F import backend as F
from test_utils.graph_cases import get_cases from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
...@@ -167,7 +167,6 @@ def test_edge_softmax(): ...@@ -167,7 +167,6 @@ def test_edge_softmax():
for j in range(30): for j in range(30):
g.add_edge(i, j) g.add_edge(i, j)
score = F.randn((900, 1)) score = F.randn((900, 1))
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
tape.watch(score) tape.watch(score)
...@@ -311,10 +310,15 @@ def test_gat_conv(): ...@@ -311,10 +310,15 @@ def test_gat_conv():
gat = nn.GATConv(5, 2, 4) gat = nn.GATConv(5, 2, 4)
feat = F.randn((100, 5)) feat = F.randn((100, 5))
h = gat(g, feat) h = gat(g, feat)
assert h.shape[-1] == 2 and h.shape[-2] == 4 assert h.shape == (100, 4, 2)
g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
gat = nn.GATConv((5, 10), 2, 4)
feat = (F.randn((100, 5)), F.randn((200, 10)))
h = gat(g, feat)
def test_sage_conv(): @pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
for aggre_type in ['mean', 'pool', 'gcn', 'lstm']: def test_sage_conv(aggre_type):
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
sage = nn.SAGEConv(5, 10, aggre_type) sage = nn.SAGEConv(5, 10, aggre_type)
...@@ -361,8 +365,8 @@ def test_appnp_conv(): ...@@ -361,8 +365,8 @@ def test_appnp_conv():
h = appnp(g, feat) h = appnp(g, feat)
assert h.shape[-1] == 5 assert h.shape[-1] == 5
def test_gin_conv(): @pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
for aggregator_type in ['mean', 'max', 'sum']: def test_gin_conv(aggregator_type):
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
gin = nn.GINConv( gin = nn.GINConv(
tf.keras.layers.Dense(12), tf.keras.layers.Dense(12),
...@@ -371,7 +375,16 @@ def test_gin_conv(): ...@@ -371,7 +375,16 @@ def test_gin_conv():
feat = F.randn((100, 5)) feat = F.randn((100, 5))
gin = gin gin = gin
h = gin(g, feat) h = gin(g, feat)
assert h.shape[-1] == 12 assert h.shape == (100, 12)
g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
gin = nn.GINConv(
tf.keras.layers.Dense(12),
aggregator_type
)
feat = (F.randn((100, 5)), F.randn((200, 5)))
h = gin(g, feat)
assert h.shape == (200, 12)
def myagg(alist, dsttype): def myagg(alist, dsttype):
rst = alist[0] rst = alist[0]
...@@ -477,7 +490,6 @@ def test_hetero_conv(agg): ...@@ -477,7 +490,6 @@ def test_hetero_conv(agg):
assert mod3.carg1 == 0 assert mod3.carg1 == 0
assert mod3.carg2 == 1 assert mod3.carg2 == 1
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
test_edge_softmax() test_edge_softmax()
...@@ -501,4 +513,3 @@ if __name__ == '__main__': ...@@ -501,4 +513,3 @@ if __name__ == '__main__':
# test_dense_sage_conv() # test_dense_sage_conv()
# test_dense_cheb_conv() # test_dense_cheb_conv()
# test_sequential() # test_sequential()
from collections import defaultdict from collections import defaultdict
import dgl import dgl
import networkx as nx import networkx as nx
import scipy.sparse as ssp
case_registry = defaultdict(list) case_registry = defaultdict(list)
...@@ -33,3 +34,12 @@ def bipartite1(): ...@@ -33,3 +34,12 @@ def bipartite1():
@register_case(['bipartite', 'small', 'hetero']) @register_case(['bipartite', 'small', 'hetero'])
def bipartite_full(): def bipartite_full():
return dgl.bipartite([(0, 0), (0, 1), (0, 2), (0, 3), (1, 0), (1, 1), (1, 2), (1, 3)]) return dgl.bipartite([(0, 0), (0, 1), (0, 2), (0, 3), (1, 0), (1, 1), (1, 2), (1, 3)])
def random_dglgraph(size):
return dgl.DGLGraph(nx.erdos_renyi_graph(size, 0.3))
def random_graph(size):
return dgl.graph(nx.erdos_renyi_graph(size, 0.3))
def random_bipartite(size_src, size_dst):
return dgl.bipartite(ssp.random(size_src, size_dst, 0.1))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment