"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "db3757aa06eee0a4ae4c8dc13d65d1760ac26b54"
Unverified Commit 83667741 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Support canonical edge types in HeteroGraphConv (#4440)



* fix

* Update hetero.py

* why did i remove this
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent 5ebd3bf0
...@@ -122,7 +122,13 @@ class HeteroGraphConv(nn.Module): ...@@ -122,7 +122,13 @@ class HeteroGraphConv(nn.Module):
""" """
def __init__(self, mods, aggregate='sum'): def __init__(self, mods, aggregate='sum'):
super(HeteroGraphConv, self).__init__() super(HeteroGraphConv, self).__init__()
self.mod_dict = mods
mods = {str(k): v for k, v in mods.items()}
# Register as child modules
self.mods = nn.ModuleDict(mods) self.mods = nn.ModuleDict(mods)
# PyTorch ModuleDict doesn't have get() method, so I have to store two
# dictionaries so that I can index with both canonical edge type and
# edge type with the get() method.
# Do not break if graph has 0-in-degree nodes. # Do not break if graph has 0-in-degree nodes.
# Because there is no general rule to add self-loop for heterograph. # Because there is no general rule to add self-loop for heterograph.
for _, v in self.mods.items(): for _, v in self.mods.items():
...@@ -134,6 +140,16 @@ class HeteroGraphConv(nn.Module): ...@@ -134,6 +140,16 @@ class HeteroGraphConv(nn.Module):
else: else:
self.agg_fn = aggregate self.agg_fn = aggregate
def _get_module(self, etype):
mod = self.mod_dict.get(etype, None)
if mod is not None:
return mod
if isinstance(etype, tuple):
# etype is canonical
_, etype, _ = etype
return self.mod_dict[etype]
raise KeyError('Cannot find module with edge type %s' % etype)
def forward(self, g, inputs, mod_args=None, mod_kwargs=None): def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
"""Forward computation """Forward computation
...@@ -171,7 +187,7 @@ class HeteroGraphConv(nn.Module): ...@@ -171,7 +187,7 @@ class HeteroGraphConv(nn.Module):
rel_graph = g[stype, etype, dtype] rel_graph = g[stype, etype, dtype]
if stype not in src_inputs or dtype not in dst_inputs: if stype not in src_inputs or dtype not in dst_inputs:
continue continue
dstdata = self.mods[etype]( dstdata = self._get_module((stype, etype, dtype))(
rel_graph, rel_graph,
(src_inputs[stype], dst_inputs[dtype]), (src_inputs[stype], dst_inputs[dtype]),
*mod_args.get(etype, ()), *mod_args.get(etype, ()),
...@@ -182,7 +198,7 @@ class HeteroGraphConv(nn.Module): ...@@ -182,7 +198,7 @@ class HeteroGraphConv(nn.Module):
rel_graph = g[stype, etype, dtype] rel_graph = g[stype, etype, dtype]
if stype not in inputs: if stype not in inputs:
continue continue
dstdata = self.mods[etype]( dstdata = self._get_module((stype, etype, dtype))(
rel_graph, rel_graph,
(inputs[stype], inputs[dtype]), (inputs[stype], inputs[dtype]),
*mod_args.get(etype, ()), *mod_args.get(etype, ()),
......
...@@ -1146,17 +1146,26 @@ def myagg(alist, dsttype): ...@@ -1146,17 +1146,26 @@ def myagg(alist, dsttype):
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg]) @pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg])
def test_hetero_conv(agg, idtype): @pytest.mark.parametrize('canonical_keys', [False, True])
def test_hetero_conv(agg, idtype, canonical_keys):
g = dgl.heterograph({ g = dgl.heterograph({
('user', 'follows', 'user'): ([0, 0, 2, 1], [1, 2, 1, 3]), ('user', 'follows', 'user'): ([0, 0, 2, 1], [1, 2, 1, 3]),
('user', 'plays', 'game'): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]), ('user', 'plays', 'game'): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
('store', 'sells', 'game'): ([0, 0, 1, 1], [0, 3, 1, 2])}, ('store', 'sells', 'game'): ([0, 0, 1, 1], [0, 3, 1, 2])},
idtype=idtype, device=F.ctx()) idtype=idtype, device=F.ctx())
conv = nn.HeteroGraphConv({ if not canonical_keys:
'follows': nn.GraphConv(2, 3, allow_zero_in_degree=True), conv = nn.HeteroGraphConv({
'plays': nn.GraphConv(2, 4, allow_zero_in_degree=True), 'follows': nn.GraphConv(2, 3, allow_zero_in_degree=True),
'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)}, 'plays': nn.GraphConv(2, 4, allow_zero_in_degree=True),
agg) 'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)},
agg)
else:
conv = nn.HeteroGraphConv({
('user', 'follows', 'user'): nn.GraphConv(2, 3, allow_zero_in_degree=True),
('user', 'plays', 'game'): nn.GraphConv(2, 4, allow_zero_in_degree=True),
('store', 'sells', 'game'): nn.GraphConv(3, 4, allow_zero_in_degree=True)},
agg)
conv = conv.to(F.ctx()) conv = conv.to(F.ctx())
# test pickle # test pickle
...@@ -1621,4 +1630,4 @@ def test_dgn_conv(in_size, out_size, aggregators, scalers, delta, ...@@ -1621,4 +1630,4 @@ def test_dgn_conv(in_size, out_size, aggregators, scalers, delta,
aggregators_non_eig = [aggr for aggr in aggregators if not aggr.startswith('dir')] aggregators_non_eig = [aggr for aggr in aggregators if not aggr.startswith('dir')]
model = nn.DGNConv(in_size, out_size, aggregators_non_eig, scalers, delta, dropout, model = nn.DGNConv(in_size, out_size, aggregators_non_eig, scalers, delta, dropout,
num_towers, edge_feat_size, residual).to(dev) num_towers, edge_feat_size, residual).to(dev)
model(g, h, edge_feat=e) model(g, h, edge_feat=e)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment