"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "97863ab85bcbcface04970252502a25122986fa3"
Unverified Commit 4672c022 authored by Tianjun Xiao's avatar Tianjun Xiao Committed by GitHub
Browse files

[hotfix] no in-degree check on sage (#2014)

* no degre check on sage

* fix pytest

* no in degree check for heteroconv
parent eb9c067b
...@@ -107,6 +107,11 @@ class HeteroGraphConv(nn.Block): ...@@ -107,6 +107,11 @@ class HeteroGraphConv(nn.Block):
for name, mod in mods.items(): for name, mod in mods.items():
self.register_child(mod, name) self.register_child(mod, name)
self.mods = mods self.mods = mods
# Do not break if graph has 0-in-degree nodes.
# Because there is no general rule to add self-loop for heterograph.
for _, v in self.mods.items():
if hasattr(v, '_allow_zero_in_degree'):
v._allow_zero_in_degree = True
if isinstance(aggregate, str): if isinstance(aggregate, str):
self.agg_fn = get_aggregate_fn(aggregate) self.agg_fn = get_aggregate_fn(aggregate)
else: else:
......
...@@ -5,7 +5,6 @@ from torch import nn ...@@ -5,7 +5,6 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from .... import function as fn from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair, check_eq_shape from ....utils import expand_as_pair, check_eq_shape
...@@ -53,28 +52,6 @@ class SAGEConv(nn.Module): ...@@ -53,28 +52,6 @@ class SAGEConv(nn.Module):
activation : callable activation function/layer or None, optional activation : callable activation function/layer or None, optional
If not None, applies an activation function to the updated node features. If not None, applies an activation function to the updated node features.
Default: ``None``. Default: ``None``.
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Default: ``False``.
Notes
-----
Zero in-degree nodes will lead to invalid output value. This is because no message
will be passed to those nodes, the aggregation function will be appied on empty input.
A common practice to avoid this is to add a self-loop for each node in the graph if
it is homogeneous, which can be achieved by:
>>> g = ... # a DGLGraph
>>> g = dgl.add_self_loop(g)
Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph
since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``
to ``True`` for those cases to unblock the code and handle zere-in-degree nodes manually.
A common practise to handle this is to filter out the nodes with zere-in-degree when use
after conv.
Examples Examples
-------- --------
...@@ -118,8 +95,7 @@ class SAGEConv(nn.Module): ...@@ -118,8 +95,7 @@ class SAGEConv(nn.Module):
feat_drop=0., feat_drop=0.,
bias=True, bias=True,
norm=None, norm=None,
activation=None, activation=None):
allow_zero_in_degree=False):
super(SAGEConv, self).__init__() super(SAGEConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
...@@ -128,7 +104,6 @@ class SAGEConv(nn.Module): ...@@ -128,7 +104,6 @@ class SAGEConv(nn.Module):
self.norm = norm self.norm = norm
self.feat_drop = nn.Dropout(feat_drop) self.feat_drop = nn.Dropout(feat_drop)
self.activation = activation self.activation = activation
self._allow_zero_in_degree = allow_zero_in_degree
# aggregator type: mean/pool/lstm/gcn # aggregator type: mean/pool/lstm/gcn
if aggregator_type == 'pool': if aggregator_type == 'pool':
self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats) self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
...@@ -197,18 +172,6 @@ class SAGEConv(nn.Module): ...@@ -197,18 +172,6 @@ class SAGEConv(nn.Module):
is size of output feature. is size of output feature.
""" """
with graph.local_scope(): with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
if isinstance(feat, tuple): if isinstance(feat, tuple):
feat_src = self.feat_drop(feat[0]) feat_src = self.feat_drop(feat[0])
feat_dst = self.feat_drop(feat[1]) feat_dst = self.feat_drop(feat[1])
......
...@@ -104,6 +104,11 @@ class HeteroGraphConv(nn.Module): ...@@ -104,6 +104,11 @@ class HeteroGraphConv(nn.Module):
def __init__(self, mods, aggregate='sum'): def __init__(self, mods, aggregate='sum'):
super(HeteroGraphConv, self).__init__() super(HeteroGraphConv, self).__init__()
self.mods = nn.ModuleDict(mods) self.mods = nn.ModuleDict(mods)
# Do not break if graph has 0-in-degree nodes.
# Because there is no general rule to add self-loop for heterograph.
for _, v in self.mods.items():
if hasattr(v, '_allow_zero_in_degree'):
v._allow_zero_in_degree = True
if isinstance(aggregate, str): if isinstance(aggregate, str):
self.agg_fn = get_aggregate_fn(aggregate) self.agg_fn = get_aggregate_fn(aggregate)
else: else:
......
...@@ -104,6 +104,11 @@ class HeteroGraphConv(layers.Layer): ...@@ -104,6 +104,11 @@ class HeteroGraphConv(layers.Layer):
def __init__(self, mods, aggregate='sum'): def __init__(self, mods, aggregate='sum'):
super(HeteroGraphConv, self).__init__() super(HeteroGraphConv, self).__init__()
self.mods = mods self.mods = mods
# Do not break if graph has 0-in-degree nodes.
# Because there is no general rule to add self-loop for heterograph.
for _, v in self.mods.items():
if hasattr(v, '_allow_zero_in_degree'):
v._allow_zero_in_degree = True
if isinstance(aggregate, str): if isinstance(aggregate, str):
self.agg_fn = get_aggregate_fn(aggregate) self.agg_fn = get_aggregate_fn(aggregate)
else: else:
......
...@@ -392,7 +392,7 @@ def test_gat_conv_bi(g, idtype): ...@@ -392,7 +392,7 @@ def test_gat_conv_bi(g, idtype):
assert h.shape == (g.number_of_dst_nodes(), 4, 2) assert h.shape == (g.number_of_dst_nodes(), 4, 2)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm']) @pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
def test_sage_conv(idtype, g, aggre_type): def test_sage_conv(idtype, g, aggre_type):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
...@@ -403,7 +403,7 @@ def test_sage_conv(idtype, g, aggre_type): ...@@ -403,7 +403,7 @@ def test_sage_conv(idtype, g, aggre_type):
assert h.shape[-1] == 10 assert h.shape[-1] == 10
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm']) @pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
def test_sage_conv_bi(idtype, g, aggre_type): def test_sage_conv_bi(idtype, g, aggre_type):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
...@@ -422,14 +422,14 @@ def test_sage_conv2(idtype): ...@@ -422,14 +422,14 @@ def test_sage_conv2(idtype):
g = dgl.bipartite([], num_nodes=(5, 3)) g = dgl.bipartite([], num_nodes=(5, 3))
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
sage = nn.SAGEConv((3, 3), 2, 'gcn', allow_zero_in_degree=True) sage = nn.SAGEConv((3, 3), 2, 'gcn')
feat = (F.randn((5, 3)), F.randn((3, 3))) feat = (F.randn((5, 3)), F.randn((3, 3)))
sage = sage.to(ctx) sage = sage.to(ctx)
h = sage(g, (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx()))) h = sage(g, (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx())))
assert h.shape[-1] == 2 assert h.shape[-1] == 2
assert h.shape[0] == 3 assert h.shape[0] == 3
for aggre_type in ['mean', 'pool', 'lstm']: for aggre_type in ['mean', 'pool', 'lstm']:
sage = nn.SAGEConv((3, 1), 2, aggre_type, allow_zero_in_degree=True) sage = nn.SAGEConv((3, 1), 2, aggre_type)
feat = (F.randn((5, 3)), F.randn((3, 1))) feat = (F.randn((5, 3)), F.randn((3, 1)))
sage = sage.to(ctx) sage = sage.to(ctx)
h = sage(g, feat) h = sage(g, feat)
...@@ -610,7 +610,7 @@ def test_dense_graph_conv(norm_type, g, idtype): ...@@ -610,7 +610,7 @@ def test_dense_graph_conv(norm_type, g, idtype):
assert F.allclose(out_conv, out_dense_conv) assert F.allclose(out_conv, out_dense_conv)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['homo', 'bipartite']))
def test_dense_sage_conv(g, idtype): def test_dense_sage_conv(g, idtype):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
...@@ -813,9 +813,9 @@ def test_hetero_conv(agg, idtype): ...@@ -813,9 +813,9 @@ def test_hetero_conv(agg, idtype):
# test with pair input # test with pair input
conv = nn.HeteroGraphConv({ conv = nn.HeteroGraphConv({
'follows': nn.SAGEConv(2, 3, 'mean', allow_zero_in_degree=True), 'follows': nn.SAGEConv(2, 3, 'mean'),
'plays': nn.SAGEConv((2, 4), 4, 'mean', allow_zero_in_degree=True), 'plays': nn.SAGEConv((2, 4), 4, 'mean'),
'sells': nn.SAGEConv(3, 4, 'mean', allow_zero_in_degree=True)}, 'sells': nn.SAGEConv(3, 4, 'mean')},
agg) agg)
conv = conv.to(F.ctx()) conv = conv.to(F.ctx())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment