Unverified Commit 4672c022 authored by Tianjun Xiao's avatar Tianjun Xiao Committed by GitHub
Browse files

[hotfix] no in-degree check on sage (#2014)

* no degre check on sage

* fix pytest

* no in degree check for heteroconv
parent eb9c067b
......@@ -107,6 +107,11 @@ class HeteroGraphConv(nn.Block):
for name, mod in mods.items():
self.register_child(mod, name)
self.mods = mods
# Do not break if graph has 0-in-degree nodes.
# Because there is no general rule to add self-loop for heterograph.
for _, v in self.mods.items():
if hasattr(v, '_allow_zero_in_degree'):
v._allow_zero_in_degree = True
if isinstance(aggregate, str):
self.agg_fn = get_aggregate_fn(aggregate)
else:
......
......@@ -5,7 +5,6 @@ from torch import nn
from torch.nn import functional as F
from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair, check_eq_shape
......@@ -53,28 +52,6 @@ class SAGEConv(nn.Module):
activation : callable activation function/layer or None, optional
If not None, applies an activation function to the updated node features.
Default: ``None``.
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Default: ``False``.
Notes
-----
Zero in-degree nodes will lead to invalid output value. This is because no message
will be passed to those nodes, the aggregation function will be appied on empty input.
A common practice to avoid this is to add a self-loop for each node in the graph if
it is homogeneous, which can be achieved by:
>>> g = ... # a DGLGraph
>>> g = dgl.add_self_loop(g)
Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph
since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``
to ``True`` for those cases to unblock the code and handle zere-in-degree nodes manually.
A common practise to handle this is to filter out the nodes with zere-in-degree when use
after conv.
Examples
--------
......@@ -118,8 +95,7 @@ class SAGEConv(nn.Module):
feat_drop=0.,
bias=True,
norm=None,
activation=None,
allow_zero_in_degree=False):
activation=None):
super(SAGEConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
......@@ -128,7 +104,6 @@ class SAGEConv(nn.Module):
self.norm = norm
self.feat_drop = nn.Dropout(feat_drop)
self.activation = activation
self._allow_zero_in_degree = allow_zero_in_degree
# aggregator type: mean/pool/lstm/gcn
if aggregator_type == 'pool':
self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
......@@ -197,18 +172,6 @@ class SAGEConv(nn.Module):
is size of output feature.
"""
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
if isinstance(feat, tuple):
feat_src = self.feat_drop(feat[0])
feat_dst = self.feat_drop(feat[1])
......
......@@ -104,6 +104,11 @@ class HeteroGraphConv(nn.Module):
def __init__(self, mods, aggregate='sum'):
super(HeteroGraphConv, self).__init__()
self.mods = nn.ModuleDict(mods)
# Do not break if graph has 0-in-degree nodes.
# Because there is no general rule to add self-loop for heterograph.
for _, v in self.mods.items():
if hasattr(v, '_allow_zero_in_degree'):
v._allow_zero_in_degree = True
if isinstance(aggregate, str):
self.agg_fn = get_aggregate_fn(aggregate)
else:
......
......@@ -104,6 +104,11 @@ class HeteroGraphConv(layers.Layer):
def __init__(self, mods, aggregate='sum'):
super(HeteroGraphConv, self).__init__()
self.mods = mods
# Do not break if graph has 0-in-degree nodes.
# Because there is no general rule to add self-loop for heterograph.
for _, v in self.mods.items():
if hasattr(v, '_allow_zero_in_degree'):
v._allow_zero_in_degree = True
if isinstance(aggregate, str):
self.agg_fn = get_aggregate_fn(aggregate)
else:
......
......@@ -392,7 +392,7 @@ def test_gat_conv_bi(g, idtype):
assert h.shape == (g.number_of_dst_nodes(), 4, 2)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
def test_sage_conv(idtype, g, aggre_type):
g = g.astype(idtype).to(F.ctx())
......@@ -403,7 +403,7 @@ def test_sage_conv(idtype, g, aggre_type):
assert h.shape[-1] == 10
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
def test_sage_conv_bi(idtype, g, aggre_type):
g = g.astype(idtype).to(F.ctx())
......@@ -422,14 +422,14 @@ def test_sage_conv2(idtype):
g = dgl.bipartite([], num_nodes=(5, 3))
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
sage = nn.SAGEConv((3, 3), 2, 'gcn', allow_zero_in_degree=True)
sage = nn.SAGEConv((3, 3), 2, 'gcn')
feat = (F.randn((5, 3)), F.randn((3, 3)))
sage = sage.to(ctx)
h = sage(g, (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx())))
assert h.shape[-1] == 2
assert h.shape[0] == 3
for aggre_type in ['mean', 'pool', 'lstm']:
sage = nn.SAGEConv((3, 1), 2, aggre_type, allow_zero_in_degree=True)
sage = nn.SAGEConv((3, 1), 2, aggre_type)
feat = (F.randn((5, 3)), F.randn((3, 1)))
sage = sage.to(ctx)
h = sage(g, feat)
......@@ -610,7 +610,7 @@ def test_dense_graph_conv(norm_type, g, idtype):
assert F.allclose(out_conv, out_dense_conv)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite']))
def test_dense_sage_conv(g, idtype):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
......@@ -813,9 +813,9 @@ def test_hetero_conv(agg, idtype):
# test with pair input
conv = nn.HeteroGraphConv({
'follows': nn.SAGEConv(2, 3, 'mean', allow_zero_in_degree=True),
'plays': nn.SAGEConv((2, 4), 4, 'mean', allow_zero_in_degree=True),
'sells': nn.SAGEConv(3, 4, 'mean', allow_zero_in_degree=True)},
'follows': nn.SAGEConv(2, 3, 'mean'),
'plays': nn.SAGEConv((2, 4), 4, 'mean'),
'sells': nn.SAGEConv(3, 4, 'mean')},
agg)
conv = conv.to(F.ctx())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment