Unverified Commit 0a51dc54 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Fix dsttype in GraphSAGE minibatch model (#1371)

* fix for new ntype API for blocks

* adding two new interfaces
parent 635dfb4a
...@@ -64,7 +64,7 @@ class SAGE(nn.Module): ...@@ -64,7 +64,7 @@ class SAGE(nn.Module):
# appropriate nodes on the LHS. # appropriate nodes on the LHS.
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst # Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# would be (num_nodes_RHS, D) # would be (num_nodes_RHS, D)
h_dst = h[:block.number_of_nodes(block.dsttype)] h_dst = h[:block.number_of_dst_nodes()]
# Then we compute the updated representation on the RHS. # Then we compute the updated representation on the RHS.
# The shape of h now becomes (num_nodes_RHS, D) # The shape of h now becomes (num_nodes_RHS, D)
h = layer(block, (h, h_dst)) h = layer(block, (h, h_dst))
...@@ -98,7 +98,7 @@ class SAGE(nn.Module): ...@@ -98,7 +98,7 @@ class SAGE(nn.Module):
input_nodes = block.srcdata[dgl.NID] input_nodes = block.srcdata[dgl.NID]
h = x[input_nodes].to(device) h = x[input_nodes].to(device)
h_dst = h[:block.number_of_nodes(block.dsttype)] h_dst = h[:block.number_of_dst_nodes()]
h = layer(block, (h, h_dst)) h = layer(block, (h, h_dst))
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
h = self.activation(h) h = self.activation(h)
......
...@@ -65,7 +65,7 @@ class SAGE(nn.Module): ...@@ -65,7 +65,7 @@ class SAGE(nn.Module):
# appropriate nodes on the LHS. # appropriate nodes on the LHS.
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst # Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# would be (num_nodes_RHS, D) # would be (num_nodes_RHS, D)
h_dst = h[:block.number_of_nodes(block.dsttype)] h_dst = h[:block.number_of_dst_nodes()]
# Then we compute the updated representation on the RHS. # Then we compute the updated representation on the RHS.
# The shape of h now becomes (num_nodes_RHS, D) # The shape of h now becomes (num_nodes_RHS, D)
h = layer(block, (h, h_dst)) h = layer(block, (h, h_dst))
...@@ -99,7 +99,7 @@ class SAGE(nn.Module): ...@@ -99,7 +99,7 @@ class SAGE(nn.Module):
input_nodes = block.srcdata[dgl.NID] input_nodes = block.srcdata[dgl.NID]
h = x[input_nodes].to(device) h = x[input_nodes].to(device)
h_dst = h[:block.number_of_nodes(block.dsttype)] h_dst = h[:block.number_of_dst_nodes()]
h = layer(block, (h, h_dst)) h = layer(block, (h, h_dst))
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
h = self.activation(h) h = self.activation(h)
......
...@@ -568,7 +568,7 @@ class DGLHeteroGraph(object): ...@@ -568,7 +568,7 @@ class DGLHeteroGraph(object):
if len(self._srctypes_invmap) != 1: if len(self._srctypes_invmap) != 1:
raise DGLError('SRC node type name must be specified if there are more than one ' raise DGLError('SRC node type name must be specified if there are more than one '
'SRC node types.') 'SRC node types.')
return 0 return next(iter(self._srctypes_invmap.values()))
ntid = self._srctypes_invmap.get(ntype, None) ntid = self._srctypes_invmap.get(ntype, None)
if ntid is None: if ntid is None:
raise DGLError('SRC node type "{}" does not exist.'.format(ntype)) raise DGLError('SRC node type "{}" does not exist.'.format(ntype))
...@@ -593,7 +593,7 @@ class DGLHeteroGraph(object): ...@@ -593,7 +593,7 @@ class DGLHeteroGraph(object):
if len(self._dsttypes_invmap) != 1: if len(self._dsttypes_invmap) != 1:
raise DGLError('DST node type name must be specified if there are more than one ' raise DGLError('DST node type name must be specified if there are more than one '
'DST node types.') 'DST node types.')
return 0 return next(iter(self._dsttypes_invmap.values()))
ntid = self._dsttypes_invmap.get(ntype, None) ntid = self._dsttypes_invmap.get(ntype, None)
if ntid is None: if ntid is None:
raise DGLError('DST node type "{}" does not exist.'.format(ntype)) raise DGLError('DST node type "{}" does not exist.'.format(ntype))
...@@ -972,6 +972,62 @@ class DGLHeteroGraph(object): ...@@ -972,6 +972,62 @@ class DGLHeteroGraph(object):
""" """
return self._graph.number_of_nodes(self.get_ntype_id(ntype)) return self._graph.number_of_nodes(self.get_ntype_id(ntype))
def number_of_src_nodes(self, ntype=None):
"""Return the number of nodes of the given SRC node type in the heterograph.
The heterograph is usually a unidirectional bipartite graph.
Parameters
----------
ntype : str, optional
Node type.
If omitted, there should be only one node type in the SRC category.
Returns
-------
int
The number of nodes
Examples
--------
>>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
>>> g.number_of_src_nodes('user')
2
>>> g.number_of_src_nodes()
2
>>> g.number_of_nodes('user')
2
"""
return self._graph.number_of_nodes(self.get_ntype_id_from_src(ntype))
def number_of_dst_nodes(self, ntype=None):
"""Return the number of nodes of the given DST node type in the heterograph.
The heterograph is usually a unidirectional bipartite graph.
Parameters
----------
ntype : str, optional
Node type.
If omitted, there should be only one node type in the DST category.
Returns
-------
int
The number of nodes
Examples
--------
>>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
>>> g.number_of_dst_nodes('game')
3
>>> g.number_of_dst_nodes()
3
>>> g.number_of_nodes('game')
3
"""
return self._graph.number_of_nodes(self.get_ntype_id_from_dst(ntype))
def number_of_edges(self, etype=None): def number_of_edges(self, etype=None):
"""Return the number of edges of the given type in the heterograph. """Return the number of edges of the given type in the heterograph.
......
...@@ -1482,6 +1482,10 @@ def test_bipartite(): ...@@ -1482,6 +1482,10 @@ def test_bipartite():
assert g1.dsttypes == ['B'] assert g1.dsttypes == ['B']
assert g1.number_of_nodes('A') == 2 assert g1.number_of_nodes('A') == 2
assert g1.number_of_nodes('B') == 6 assert g1.number_of_nodes('B') == 6
assert g1.number_of_src_nodes('A') == 2
assert g1.number_of_src_nodes() == 2
assert g1.number_of_dst_nodes('B') == 6
assert g1.number_of_dst_nodes() == 6
assert g1.number_of_edges() == 3 assert g1.number_of_edges() == 3
g1.srcdata['h'] = F.randn((2, 5)) g1.srcdata['h'] = F.randn((2, 5))
assert F.array_equal(g1.srcnodes['A'].data['h'], g1.srcdata['h']) assert F.array_equal(g1.srcnodes['A'].data['h'], g1.srcdata['h'])
...@@ -1501,6 +1505,10 @@ def test_bipartite(): ...@@ -1501,6 +1505,10 @@ def test_bipartite():
assert g3.number_of_nodes('A') == 2 assert g3.number_of_nodes('A') == 2
assert g3.number_of_nodes('B') == 6 assert g3.number_of_nodes('B') == 6
assert g3.number_of_nodes('C') == 1 assert g3.number_of_nodes('C') == 1
assert g3.number_of_src_nodes('A') == 2
assert g3.number_of_src_nodes() == 2
assert g3.number_of_dst_nodes('B') == 6
assert g3.number_of_dst_nodes('C') == 1
g3.srcdata['h'] = F.randn((2, 5)) g3.srcdata['h'] = F.randn((2, 5))
assert F.array_equal(g3.srcnodes['A'].data['h'], g3.srcdata['h']) assert F.array_equal(g3.srcnodes['A'].data['h'], g3.srcdata['h'])
assert F.array_equal(g3.nodes['A'].data['h'], g3.srcdata['h']) assert F.array_equal(g3.nodes['A'].data['h'], g3.srcdata['h'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment