Unverified Commit 0a51dc54 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Fix dsttype in GraphSAGE minibatch model (#1371)

* fix for new ntype API for blocks

* adding two new interfaces
parent 635dfb4a
......@@ -64,7 +64,7 @@ class SAGE(nn.Module):
# appropriate nodes on the LHS.
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# would be (num_nodes_RHS, D)
h_dst = h[:block.number_of_nodes(block.dsttype)]
h_dst = h[:block.number_of_dst_nodes()]
# Then we compute the updated representation on the RHS.
# The shape of h now becomes (num_nodes_RHS, D)
h = layer(block, (h, h_dst))
......@@ -98,7 +98,7 @@ class SAGE(nn.Module):
input_nodes = block.srcdata[dgl.NID]
h = x[input_nodes].to(device)
h_dst = h[:block.number_of_nodes(block.dsttype)]
h_dst = h[:block.number_of_dst_nodes()]
h = layer(block, (h, h_dst))
if l != len(self.layers) - 1:
h = self.activation(h)
......
......@@ -65,7 +65,7 @@ class SAGE(nn.Module):
# appropriate nodes on the LHS.
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# would be (num_nodes_RHS, D)
h_dst = h[:block.number_of_nodes(block.dsttype)]
h_dst = h[:block.number_of_dst_nodes()]
# Then we compute the updated representation on the RHS.
# The shape of h now becomes (num_nodes_RHS, D)
h = layer(block, (h, h_dst))
......@@ -99,7 +99,7 @@ class SAGE(nn.Module):
input_nodes = block.srcdata[dgl.NID]
h = x[input_nodes].to(device)
h_dst = h[:block.number_of_nodes(block.dsttype)]
h_dst = h[:block.number_of_dst_nodes()]
h = layer(block, (h, h_dst))
if l != len(self.layers) - 1:
h = self.activation(h)
......
......@@ -568,7 +568,7 @@ class DGLHeteroGraph(object):
if len(self._srctypes_invmap) != 1:
raise DGLError('SRC node type name must be specified if there are more than one '
'SRC node types.')
return 0
return next(iter(self._srctypes_invmap.values()))
ntid = self._srctypes_invmap.get(ntype, None)
if ntid is None:
raise DGLError('SRC node type "{}" does not exist.'.format(ntype))
......@@ -593,7 +593,7 @@ class DGLHeteroGraph(object):
if len(self._dsttypes_invmap) != 1:
raise DGLError('DST node type name must be specified if there are more than one '
'DST node types.')
return 0
return next(iter(self._dsttypes_invmap.values()))
ntid = self._dsttypes_invmap.get(ntype, None)
if ntid is None:
raise DGLError('DST node type "{}" does not exist.'.format(ntype))
......@@ -972,6 +972,62 @@ class DGLHeteroGraph(object):
"""
return self._graph.number_of_nodes(self.get_ntype_id(ntype))
def number_of_src_nodes(self, ntype=None):
"""Return the number of nodes of the given SRC node type in the heterograph.
The heterograph is usually a unidirectional bipartite graph.
Parameters
----------
ntype : str, optional
Node type.
If omitted, there should be only one node type in the SRC category.
Returns
-------
int
The number of nodes
Examples
--------
>>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
>>> g.number_of_src_nodes('user')
2
>>> g.number_of_src_nodes()
2
>>> g.number_of_nodes('user')
2
"""
return self._graph.number_of_nodes(self.get_ntype_id_from_src(ntype))
def number_of_dst_nodes(self, ntype=None):
"""Return the number of nodes of the given DST node type in the heterograph.
The heterograph is usually a unidirectional bipartite graph.
Parameters
----------
ntype : str, optional
Node type.
If omitted, there should be only one node type in the DST category.
Returns
-------
int
The number of nodes
Examples
--------
>>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
>>> g.number_of_dst_nodes('game')
3
>>> g.number_of_dst_nodes()
3
>>> g.number_of_nodes('game')
3
"""
return self._graph.number_of_nodes(self.get_ntype_id_from_dst(ntype))
def number_of_edges(self, etype=None):
"""Return the number of edges of the given type in the heterograph.
......
......@@ -1482,6 +1482,10 @@ def test_bipartite():
assert g1.dsttypes == ['B']
assert g1.number_of_nodes('A') == 2
assert g1.number_of_nodes('B') == 6
assert g1.number_of_src_nodes('A') == 2
assert g1.number_of_src_nodes() == 2
assert g1.number_of_dst_nodes('B') == 6
assert g1.number_of_dst_nodes() == 6
assert g1.number_of_edges() == 3
g1.srcdata['h'] = F.randn((2, 5))
assert F.array_equal(g1.srcnodes['A'].data['h'], g1.srcdata['h'])
......@@ -1501,6 +1505,10 @@ def test_bipartite():
assert g3.number_of_nodes('A') == 2
assert g3.number_of_nodes('B') == 6
assert g3.number_of_nodes('C') == 1
assert g3.number_of_src_nodes('A') == 2
assert g3.number_of_src_nodes() == 2
assert g3.number_of_dst_nodes('B') == 6
assert g3.number_of_dst_nodes('C') == 1
g3.srcdata['h'] = F.randn((2, 5))
assert F.array_equal(g3.srcnodes['A'].data['h'], g3.srcdata['h'])
assert F.array_equal(g3.nodes['A'].data['h'], g3.srcdata['h'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment