Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
0a51dc54
Unverified
Commit
0a51dc54
authored
Mar 18, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Mar 18, 2020
Browse files
[Bug] Fix dsttype in GraphSAGE minibatch model (#1371)
* fix for new ntype API for blocks * adding two new interfaces
parent
635dfb4a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
70 additions
and
6 deletions
+70
-6
examples/pytorch/graphsage/train_sampling.py
examples/pytorch/graphsage/train_sampling.py
+2
-2
examples/pytorch/graphsage/train_sampling_multi_gpu.py
examples/pytorch/graphsage/train_sampling_multi_gpu.py
+2
-2
python/dgl/heterograph.py
python/dgl/heterograph.py
+58
-2
tests/compute/test_heterograph.py
tests/compute/test_heterograph.py
+8
-0
No files found.
examples/pytorch/graphsage/train_sampling.py
View file @
0a51dc54
...
@@ -64,7 +64,7 @@ class SAGE(nn.Module):
...
@@ -64,7 +64,7 @@ class SAGE(nn.Module):
# appropriate nodes on the LHS.
# appropriate nodes on the LHS.
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# would be (num_nodes_RHS, D)
# would be (num_nodes_RHS, D)
h_dst
=
h
[:
block
.
number_of_nodes
(
block
.
dsttype
)]
h_dst
=
h
[:
block
.
number_of_
dst_
nodes
()]
# Then we compute the updated representation on the RHS.
# Then we compute the updated representation on the RHS.
# The shape of h now becomes (num_nodes_RHS, D)
# The shape of h now becomes (num_nodes_RHS, D)
h
=
layer
(
block
,
(
h
,
h_dst
))
h
=
layer
(
block
,
(
h
,
h_dst
))
...
@@ -98,7 +98,7 @@ class SAGE(nn.Module):
...
@@ -98,7 +98,7 @@ class SAGE(nn.Module):
input_nodes
=
block
.
srcdata
[
dgl
.
NID
]
input_nodes
=
block
.
srcdata
[
dgl
.
NID
]
h
=
x
[
input_nodes
].
to
(
device
)
h
=
x
[
input_nodes
].
to
(
device
)
h_dst
=
h
[:
block
.
number_of_nodes
(
block
.
dsttype
)]
h_dst
=
h
[:
block
.
number_of_
dst_
nodes
()]
h
=
layer
(
block
,
(
h
,
h_dst
))
h
=
layer
(
block
,
(
h
,
h_dst
))
if
l
!=
len
(
self
.
layers
)
-
1
:
if
l
!=
len
(
self
.
layers
)
-
1
:
h
=
self
.
activation
(
h
)
h
=
self
.
activation
(
h
)
...
...
examples/pytorch/graphsage/train_sampling_multi_gpu.py
View file @
0a51dc54
...
@@ -65,7 +65,7 @@ class SAGE(nn.Module):
...
@@ -65,7 +65,7 @@ class SAGE(nn.Module):
# appropriate nodes on the LHS.
# appropriate nodes on the LHS.
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# would be (num_nodes_RHS, D)
# would be (num_nodes_RHS, D)
h_dst
=
h
[:
block
.
number_of_nodes
(
block
.
dsttype
)]
h_dst
=
h
[:
block
.
number_of_
dst_
nodes
()]
# Then we compute the updated representation on the RHS.
# Then we compute the updated representation on the RHS.
# The shape of h now becomes (num_nodes_RHS, D)
# The shape of h now becomes (num_nodes_RHS, D)
h
=
layer
(
block
,
(
h
,
h_dst
))
h
=
layer
(
block
,
(
h
,
h_dst
))
...
@@ -99,7 +99,7 @@ class SAGE(nn.Module):
...
@@ -99,7 +99,7 @@ class SAGE(nn.Module):
input_nodes
=
block
.
srcdata
[
dgl
.
NID
]
input_nodes
=
block
.
srcdata
[
dgl
.
NID
]
h
=
x
[
input_nodes
].
to
(
device
)
h
=
x
[
input_nodes
].
to
(
device
)
h_dst
=
h
[:
block
.
number_of_nodes
(
block
.
dsttype
)]
h_dst
=
h
[:
block
.
number_of_
dst_
nodes
()]
h
=
layer
(
block
,
(
h
,
h_dst
))
h
=
layer
(
block
,
(
h
,
h_dst
))
if
l
!=
len
(
self
.
layers
)
-
1
:
if
l
!=
len
(
self
.
layers
)
-
1
:
h
=
self
.
activation
(
h
)
h
=
self
.
activation
(
h
)
...
...
python/dgl/heterograph.py
View file @
0a51dc54
...
@@ -568,7 +568,7 @@ class DGLHeteroGraph(object):
...
@@ -568,7 +568,7 @@ class DGLHeteroGraph(object):
if
len
(
self
.
_srctypes_invmap
)
!=
1
:
if
len
(
self
.
_srctypes_invmap
)
!=
1
:
raise
DGLError
(
'SRC node type name must be specified if there are more than one '
raise
DGLError
(
'SRC node type name must be specified if there are more than one '
'SRC node types.'
)
'SRC node types.'
)
return
0
return
next
(
iter
(
self
.
_srctypes_invmap
.
values
()))
ntid
=
self
.
_srctypes_invmap
.
get
(
ntype
,
None
)
ntid
=
self
.
_srctypes_invmap
.
get
(
ntype
,
None
)
if
ntid
is
None
:
if
ntid
is
None
:
raise
DGLError
(
'SRC node type "{}" does not exist.'
.
format
(
ntype
))
raise
DGLError
(
'SRC node type "{}" does not exist.'
.
format
(
ntype
))
...
@@ -593,7 +593,7 @@ class DGLHeteroGraph(object):
...
@@ -593,7 +593,7 @@ class DGLHeteroGraph(object):
if
len
(
self
.
_dsttypes_invmap
)
!=
1
:
if
len
(
self
.
_dsttypes_invmap
)
!=
1
:
raise
DGLError
(
'DST node type name must be specified if there are more than one '
raise
DGLError
(
'DST node type name must be specified if there are more than one '
'DST node types.'
)
'DST node types.'
)
return
0
return
next
(
iter
(
self
.
_dsttypes_invmap
.
values
()))
ntid
=
self
.
_dsttypes_invmap
.
get
(
ntype
,
None
)
ntid
=
self
.
_dsttypes_invmap
.
get
(
ntype
,
None
)
if
ntid
is
None
:
if
ntid
is
None
:
raise
DGLError
(
'DST node type "{}" does not exist.'
.
format
(
ntype
))
raise
DGLError
(
'DST node type "{}" does not exist.'
.
format
(
ntype
))
...
@@ -972,6 +972,62 @@ class DGLHeteroGraph(object):
...
@@ -972,6 +972,62 @@ class DGLHeteroGraph(object):
"""
"""
return
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id
(
ntype
))
return
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id
(
ntype
))
def
number_of_src_nodes
(
self
,
ntype
=
None
):
"""Return the number of nodes of the given SRC node type in the heterograph.
The heterograph is usually a unidirectional bipartite graph.
Parameters
----------
ntype : str, optional
Node type.
If omitted, there should be only one node type in the SRC category.
Returns
-------
int
The number of nodes
Examples
--------
>>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
>>> g.number_of_src_nodes('user')
2
>>> g.number_of_src_nodes()
2
>>> g.number_of_nodes('user')
2
"""
return
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id_from_src
(
ntype
))
def
number_of_dst_nodes
(
self
,
ntype
=
None
):
"""Return the number of nodes of the given DST node type in the heterograph.
The heterograph is usually a unidirectional bipartite graph.
Parameters
----------
ntype : str, optional
Node type.
If omitted, there should be only one node type in the DST category.
Returns
-------
int
The number of nodes
Examples
--------
>>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
>>> g.number_of_dst_nodes('game')
3
>>> g.number_of_dst_nodes()
3
>>> g.number_of_nodes('game')
3
"""
return
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id_from_dst
(
ntype
))
def
number_of_edges
(
self
,
etype
=
None
):
def
number_of_edges
(
self
,
etype
=
None
):
"""Return the number of edges of the given type in the heterograph.
"""Return the number of edges of the given type in the heterograph.
...
...
tests/compute/test_heterograph.py
View file @
0a51dc54
...
@@ -1482,6 +1482,10 @@ def test_bipartite():
...
@@ -1482,6 +1482,10 @@ def test_bipartite():
assert
g1
.
dsttypes
==
[
'B'
]
assert
g1
.
dsttypes
==
[
'B'
]
assert
g1
.
number_of_nodes
(
'A'
)
==
2
assert
g1
.
number_of_nodes
(
'A'
)
==
2
assert
g1
.
number_of_nodes
(
'B'
)
==
6
assert
g1
.
number_of_nodes
(
'B'
)
==
6
assert
g1
.
number_of_src_nodes
(
'A'
)
==
2
assert
g1
.
number_of_src_nodes
()
==
2
assert
g1
.
number_of_dst_nodes
(
'B'
)
==
6
assert
g1
.
number_of_dst_nodes
()
==
6
assert
g1
.
number_of_edges
()
==
3
assert
g1
.
number_of_edges
()
==
3
g1
.
srcdata
[
'h'
]
=
F
.
randn
((
2
,
5
))
g1
.
srcdata
[
'h'
]
=
F
.
randn
((
2
,
5
))
assert
F
.
array_equal
(
g1
.
srcnodes
[
'A'
].
data
[
'h'
],
g1
.
srcdata
[
'h'
])
assert
F
.
array_equal
(
g1
.
srcnodes
[
'A'
].
data
[
'h'
],
g1
.
srcdata
[
'h'
])
...
@@ -1501,6 +1505,10 @@ def test_bipartite():
...
@@ -1501,6 +1505,10 @@ def test_bipartite():
assert
g3
.
number_of_nodes
(
'A'
)
==
2
assert
g3
.
number_of_nodes
(
'A'
)
==
2
assert
g3
.
number_of_nodes
(
'B'
)
==
6
assert
g3
.
number_of_nodes
(
'B'
)
==
6
assert
g3
.
number_of_nodes
(
'C'
)
==
1
assert
g3
.
number_of_nodes
(
'C'
)
==
1
assert
g3
.
number_of_src_nodes
(
'A'
)
==
2
assert
g3
.
number_of_src_nodes
()
==
2
assert
g3
.
number_of_dst_nodes
(
'B'
)
==
6
assert
g3
.
number_of_dst_nodes
(
'C'
)
==
1
g3
.
srcdata
[
'h'
]
=
F
.
randn
((
2
,
5
))
g3
.
srcdata
[
'h'
]
=
F
.
randn
((
2
,
5
))
assert
F
.
array_equal
(
g3
.
srcnodes
[
'A'
].
data
[
'h'
],
g3
.
srcdata
[
'h'
])
assert
F
.
array_equal
(
g3
.
srcnodes
[
'A'
].
data
[
'h'
],
g3
.
srcdata
[
'h'
])
assert
F
.
array_equal
(
g3
.
nodes
[
'A'
].
data
[
'h'
],
g3
.
srcdata
[
'h'
])
assert
F
.
array_equal
(
g3
.
nodes
[
'A'
].
data
[
'h'
],
g3
.
srcdata
[
'h'
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment