Unverified Commit fb9d2138 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

Fix SAGEConv (#1792)

parent 7ea777e1
...@@ -107,6 +107,12 @@ class SAGEConv(nn.Block): ...@@ -107,6 +107,12 @@ class SAGEConv(nn.Block):
h_self = feat_dst h_self = feat_dst
# Handle the case of graphs without edges
if graph.number_of_edges() == 0:
dst_neigh = mx.nd.zeros((graph.number_of_dst_nodes(), self._in_src_feats))
dst_neigh = dst_neigh.as_in_context(feat_dst.context)
graph.dstdata['neigh'] = dst_neigh
if self._aggre_type == 'mean': if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh')) graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
......
"""Torch Module for GraphSAGE layer""" """Torch Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
...@@ -124,6 +125,11 @@ class SAGEConv(nn.Module): ...@@ -124,6 +125,11 @@ class SAGEConv(nn.Module):
h_self = feat_dst h_self = feat_dst
# Handle the case of graphs without edges
if graph.number_of_edges() == 0:
graph.dstdata['neigh'] = torch.zeros(
feat_dst.shape[0], self._in_src_feats).to(feat_dst)
if self._aggre_type == 'mean': if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh')) graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
......
...@@ -110,6 +110,11 @@ class SAGEConv(layers.Layer): ...@@ -110,6 +110,11 @@ class SAGEConv(layers.Layer):
h_self = feat_dst h_self = feat_dst
# Handle the case of graphs without edges
if graph.number_of_edges() == 0:
graph.dstdata['neigh'] = tf.cast(tf.zeros(
(graph.number_of_dst_nodes(), self._in_src_feats)), tf.float32)
if self._aggre_type == 'mean': if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh')) graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
......
...@@ -190,6 +190,22 @@ def test_sage_conv(aggre_type): ...@@ -190,6 +190,22 @@ def test_sage_conv(aggre_type):
assert h.shape[-1] == 2 assert h.shape[-1] == 2
assert h.shape[0] == 200 assert h.shape[0] == 200
# Test the case for graphs without edges
g = dgl.bipartite([], num_nodes=(5, 3))
sage = nn.SAGEConv((3, 3), 2, 'gcn')
feat = (F.randn((5, 3)), F.randn((3, 3)))
sage.initialize(ctx=ctx)
h = sage(g, feat)
assert h.shape[-1] == 2
assert h.shape[0] == 3
for aggre_type in ['mean', 'pool']:
sage = nn.SAGEConv((3, 1), 2, aggre_type)
feat = (F.randn((5, 3)), F.randn((3, 1)))
sage.initialize(ctx=ctx)
h = sage(g, feat)
assert h.shape[-1] == 2
assert h.shape[0] == 3
def test_gg_conv(): def test_gg_conv():
g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3)) g = dgl.DGLGraph(nx.erdos_renyi_graph(20, 0.3))
ctx = F.ctx() ctx = F.ctx()
......
...@@ -486,6 +486,22 @@ def test_sage_conv(aggre_type): ...@@ -486,6 +486,22 @@ def test_sage_conv(aggre_type):
assert h.shape[-1] == 2 assert h.shape[-1] == 2
assert h.shape[0] == 200 assert h.shape[0] == 200
# Test the case for graphs without edges
g = dgl.bipartite([], num_nodes=(5, 3))
sage = nn.SAGEConv((3, 3), 2, 'gcn')
feat = (F.randn((5, 3)), F.randn((3, 3)))
sage = sage.to(ctx)
h = sage(g, feat)
assert h.shape[-1] == 2
assert h.shape[0] == 3
for aggre_type in ['mean', 'pool', 'lstm']:
sage = nn.SAGEConv((3, 1), 2, aggre_type)
feat = (F.randn((5, 3)), F.randn((3, 1)))
sage = sage.to(ctx)
h = sage(g, feat)
assert h.shape[-1] == 2
assert h.shape[0] == 3
def test_sgc_conv(): def test_sgc_conv():
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
......
...@@ -378,6 +378,20 @@ def test_sage_conv(aggre_type): ...@@ -378,6 +378,20 @@ def test_sage_conv(aggre_type):
assert h.shape[-1] == 2 assert h.shape[-1] == 2
assert h.shape[0] == 200 assert h.shape[0] == 200
# Test the case for graphs without edges
g = dgl.bipartite([], num_nodes=(5, 3))
sage = nn.SAGEConv((3, 3), 2, 'gcn')
feat = (F.randn((5, 3)), F.randn((3, 3)))
h = sage(g, feat)
assert h.shape[-1] == 2
assert h.shape[0] == 3
for aggre_type in ['mean', 'pool', 'lstm']:
sage = nn.SAGEConv((3, 1), 2, aggre_type)
feat = (F.randn((5, 3)), F.randn((3, 1)))
h = sage(g, feat)
assert h.shape[-1] == 2
assert h.shape[0] == 3
def test_sgc_conv(): def test_sgc_conv():
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment