Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
24dc71fc
Unverified
Commit
24dc71fc
authored
Apr 03, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Apr 03, 2020
Browse files
[BUG] Fix #1409 (#1411)
* [BUG] Fix #1409 * fix test
parent
d3560b71
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
4 deletions
+52
-4
src/graph/unit_graph.cc
src/graph/unit_graph.cc
+40
-4
tests/compute/test_heterograph.py
tests/compute/test_heterograph.py
+12
-0
No files found.
src/graph/unit_graph.cc
View file @
24dc71fc
...
@@ -999,9 +999,27 @@ HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const
...
@@ -999,9 +999,27 @@ HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const
// We prefer to generate a subgraph from out-csr.
// We prefer to generate a subgraph from out-csr.
SparseFormat
fmt
=
SelectFormat
(
SparseFormat
::
kCSR
);
SparseFormat
fmt
=
SelectFormat
(
SparseFormat
::
kCSR
);
HeteroSubgraph
sg
=
GetFormat
(
fmt
)
->
VertexSubgraph
(
vids
);
HeteroSubgraph
sg
=
GetFormat
(
fmt
)
->
VertexSubgraph
(
vids
);
CSRPtr
subcsr
=
std
::
dynamic_pointer_cast
<
CSR
>
(
sg
.
graph
);
HeteroSubgraph
ret
;
HeteroSubgraph
ret
;
ret
.
graph
=
HeteroGraphPtr
(
new
UnitGraph
(
meta_graph
(),
nullptr
,
subcsr
,
nullptr
));
CSRPtr
subcsr
=
nullptr
;
CSRPtr
subcsc
=
nullptr
;
COOPtr
subcoo
=
nullptr
;
switch
(
fmt
)
{
case
SparseFormat
::
kCSR
:
subcsr
=
std
::
dynamic_pointer_cast
<
CSR
>
(
sg
.
graph
);
break
;
case
SparseFormat
::
kCSC
:
subcsc
=
std
::
dynamic_pointer_cast
<
CSR
>
(
sg
.
graph
);
break
;
case
SparseFormat
::
kCOO
:
subcoo
=
std
::
dynamic_pointer_cast
<
COO
>
(
sg
.
graph
);
break
;
default:
LOG
(
FATAL
)
<<
"[BUG] unsupported format "
<<
static_cast
<
int
>
(
fmt
);
return
ret
;
}
ret
.
graph
=
HeteroGraphPtr
(
new
UnitGraph
(
meta_graph
(),
subcsc
,
subcsr
,
subcoo
));
ret
.
induced_vertices
=
std
::
move
(
sg
.
induced_vertices
);
ret
.
induced_vertices
=
std
::
move
(
sg
.
induced_vertices
);
ret
.
induced_edges
=
std
::
move
(
sg
.
induced_edges
);
ret
.
induced_edges
=
std
::
move
(
sg
.
induced_edges
);
return
ret
;
return
ret
;
...
@@ -1011,9 +1029,27 @@ HeteroSubgraph UnitGraph::EdgeSubgraph(
...
@@ -1011,9 +1029,27 @@ HeteroSubgraph UnitGraph::EdgeSubgraph(
const
std
::
vector
<
IdArray
>&
eids
,
bool
preserve_nodes
)
const
{
const
std
::
vector
<
IdArray
>&
eids
,
bool
preserve_nodes
)
const
{
SparseFormat
fmt
=
SelectFormat
(
SparseFormat
::
kCOO
);
SparseFormat
fmt
=
SelectFormat
(
SparseFormat
::
kCOO
);
auto
sg
=
GetFormat
(
fmt
)
->
EdgeSubgraph
(
eids
,
preserve_nodes
);
auto
sg
=
GetFormat
(
fmt
)
->
EdgeSubgraph
(
eids
,
preserve_nodes
);
COOPtr
subcoo
=
std
::
dynamic_pointer_cast
<
COO
>
(
sg
.
graph
);
HeteroSubgraph
ret
;
HeteroSubgraph
ret
;
ret
.
graph
=
HeteroGraphPtr
(
new
UnitGraph
(
meta_graph
(),
nullptr
,
nullptr
,
subcoo
));
CSRPtr
subcsr
=
nullptr
;
CSRPtr
subcsc
=
nullptr
;
COOPtr
subcoo
=
nullptr
;
switch
(
fmt
)
{
case
SparseFormat
::
kCSR
:
subcsr
=
std
::
dynamic_pointer_cast
<
CSR
>
(
sg
.
graph
);
break
;
case
SparseFormat
::
kCSC
:
subcsc
=
std
::
dynamic_pointer_cast
<
CSR
>
(
sg
.
graph
);
break
;
case
SparseFormat
::
kCOO
:
subcoo
=
std
::
dynamic_pointer_cast
<
COO
>
(
sg
.
graph
);
break
;
default:
LOG
(
FATAL
)
<<
"[BUG] unsupported format "
<<
static_cast
<
int
>
(
fmt
);
return
ret
;
}
ret
.
graph
=
HeteroGraphPtr
(
new
UnitGraph
(
meta_graph
(),
subcsc
,
subcsr
,
subcoo
));
ret
.
induced_vertices
=
std
::
move
(
sg
.
induced_vertices
);
ret
.
induced_vertices
=
std
::
move
(
sg
.
induced_vertices
);
ret
.
induced_edges
=
std
::
move
(
sg
.
induced_edges
);
ret
.
induced_edges
=
std
::
move
(
sg
.
induced_edges
);
return
ret
;
return
ret
;
...
...
tests/compute/test_heterograph.py
View file @
24dc71fc
...
@@ -940,6 +940,18 @@ def test_subgraph():
...
@@ -940,6 +940,18 @@ def test_subgraph():
sg5
=
g
.
edge_type_subgraph
([
'follows'
,
'plays'
,
'wishes'
])
sg5
=
g
.
edge_type_subgraph
([
'follows'
,
'plays'
,
'wishes'
])
_check_typed_subgraph1
(
g
,
sg5
)
_check_typed_subgraph1
(
g
,
sg5
)
# Test for restricted format
for
fmt
in
[
'csr'
,
'csc'
,
'coo'
]:
g
=
dgl
.
graph
([(
0
,
1
),
(
1
,
2
)],
restrict_format
=
fmt
)
sg
=
g
.
subgraph
({
g
.
ntypes
[
0
]:
[
1
,
0
]})
nids
=
F
.
asnumpy
(
sg
.
ndata
[
dgl
.
NID
])
assert
np
.
array_equal
(
nids
,
np
.
array
([
1
,
0
]))
src
,
dst
=
sg
.
all_edges
(
order
=
'eid'
)
src
=
F
.
asnumpy
(
src
)
dst
=
F
.
asnumpy
(
dst
)
assert
np
.
array_equal
(
src
,
np
.
array
([
1
]))
assert
np
.
array_equal
(
dst
,
np
.
array
([
0
]))
def
test_apply
():
def
test_apply
():
def
node_udf
(
nodes
):
def
node_udf
(
nodes
):
return
{
'h'
:
nodes
.
data
[
'h'
]
*
2
}
return
{
'h'
:
nodes
.
data
[
'h'
]
*
2
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment