Unverified Commit 83667741 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Support canonical edge types in HeteroGraphConv (#4440)



* fix

* Update hetero.py

* why did i remove this
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent 5ebd3bf0
......@@ -122,7 +122,13 @@ class HeteroGraphConv(nn.Module):
"""
def __init__(self, mods, aggregate='sum'):
super(HeteroGraphConv, self).__init__()
self.mod_dict = mods
mods = {str(k): v for k, v in mods.items()}
# Register as child modules
self.mods = nn.ModuleDict(mods)
# PyTorch ModuleDict doesn't have get() method, so I have to store two
# dictionaries so that I can index with both canonical edge type and
# edge type with the get() method.
# Do not break if graph has 0-in-degree nodes.
# Because there is no general rule to add self-loop for heterograph.
for _, v in self.mods.items():
......@@ -134,6 +140,16 @@ class HeteroGraphConv(nn.Module):
else:
self.agg_fn = aggregate
def _get_module(self, etype):
mod = self.mod_dict.get(etype, None)
if mod is not None:
return mod
if isinstance(etype, tuple):
# etype is canonical
_, etype, _ = etype
return self.mod_dict[etype]
raise KeyError('Cannot find module with edge type %s' % etype)
def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
"""Forward computation
......@@ -171,7 +187,7 @@ class HeteroGraphConv(nn.Module):
rel_graph = g[stype, etype, dtype]
if stype not in src_inputs or dtype not in dst_inputs:
continue
dstdata = self.mods[etype](
dstdata = self._get_module((stype, etype, dtype))(
rel_graph,
(src_inputs[stype], dst_inputs[dtype]),
*mod_args.get(etype, ()),
......@@ -182,7 +198,7 @@ class HeteroGraphConv(nn.Module):
rel_graph = g[stype, etype, dtype]
if stype not in inputs:
continue
dstdata = self.mods[etype](
dstdata = self._get_module((stype, etype, dtype))(
rel_graph,
(inputs[stype], inputs[dtype]),
*mod_args.get(etype, ()),
......
......@@ -1146,17 +1146,26 @@ def myagg(alist, dsttype):
@parametrize_idtype
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
def test_hetero_conv(agg, idtype):
@pytest.mark.parametrize('canonical_keys', [False, True])
def test_hetero_conv(agg, idtype, canonical_keys):
g = dgl.heterograph({
('user', 'follows', 'user'): ([0, 0, 2, 1], [1, 2, 1, 3]),
('user', 'plays', 'game'): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
('store', 'sells', 'game'): ([0, 0, 1, 1], [0, 3, 1, 2])},
idtype=idtype, device=F.ctx())
conv = nn.HeteroGraphConv({
'follows': nn.GraphConv(2, 3, allow_zero_in_degree=True),
'plays': nn.GraphConv(2, 4, allow_zero_in_degree=True),
'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)},
agg)
if not canonical_keys:
conv = nn.HeteroGraphConv({
'follows': nn.GraphConv(2, 3, allow_zero_in_degree=True),
'plays': nn.GraphConv(2, 4, allow_zero_in_degree=True),
'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)},
agg)
else:
conv = nn.HeteroGraphConv({
('user', 'follows', 'user'): nn.GraphConv(2, 3, allow_zero_in_degree=True),
('user', 'plays', 'game'): nn.GraphConv(2, 4, allow_zero_in_degree=True),
('store', 'sells', 'game'): nn.GraphConv(3, 4, allow_zero_in_degree=True)},
agg)
conv = conv.to(F.ctx())
# test pickle
......@@ -1621,4 +1630,4 @@ def test_dgn_conv(in_size, out_size, aggregators, scalers, delta,
aggregators_non_eig = [aggr for aggr in aggregators if not aggr.startswith('dir')]
model = nn.DGNConv(in_size, out_size, aggregators_non_eig, scalers, delta, dropout,
num_towers, edge_feat_size, residual).to(dev)
model(g, h, edge_feat=e)
\ No newline at end of file
model(g, h, edge_feat=e)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment