Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
330571b6
Unverified
Commit
330571b6
authored
Jul 06, 2023
by
peizhou001
Committed by
GitHub
Jul 06, 2023
Browse files
[Graphbolt] Fix asseration bug (#5952)
parent
8adb53bb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
7 deletions
+23
-7
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+3
-0
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
+8
-4
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
+12
-3
No files found.
graphbolt/src/csc_sampling_graph.cc
View file @
330571b6
...
@@ -346,6 +346,9 @@ torch::Tensor PickByEtype(
...
@@ -346,6 +346,9 @@ torch::Tensor PickByEtype(
const
auto
end
=
offset
+
num_neighbors
;
const
auto
end
=
offset
+
num_neighbors
;
while
(
etype_begin
<
end
)
{
while
(
etype_begin
<
end
)
{
scalar_t
etype
=
type_per_edge_data
[
etype_begin
];
scalar_t
etype
=
type_per_edge_data
[
etype_begin
];
TORCH_CHECK
(
etype
>=
0
&&
etype
<
fanouts
.
size
(),
"Etype values exceed the number of fanouts."
);
int64_t
fanout
=
fanouts
[
etype
];
int64_t
fanout
=
fanouts
[
etype
];
auto
etype_end_it
=
std
::
upper_bound
(
auto
etype_end_it
=
std
::
upper_bound
(
type_per_edge_data
+
etype_begin
,
type_per_edge_data
+
end
,
type_per_edge_data
+
etype_begin
,
type_per_edge_data
+
end
,
...
...
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
View file @
330571b6
...
@@ -270,6 +270,14 @@ class CSCSamplingGraph:
...
@@ -270,6 +270,14 @@ class CSCSamplingGraph:
# Ensure nodes is 1-D tensor.
# Ensure nodes is 1-D tensor.
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
assert
fanouts
.
dim
()
==
1
,
"Fanouts should be 1-D tensor."
assert
fanouts
.
dim
()
==
1
,
"Fanouts should be 1-D tensor."
expected_fanout_len
=
1
if
self
.
metadata
and
self
.
metadata
.
edge_type_to_id
:
expected_fanout_len
=
len
(
self
.
metadata
.
edge_type_to_id
)
assert
len
(
fanouts
)
in
[
expected_fanout_len
,
1
,
],
"Fanouts should have the same number of elements as etypes or
\
should have a length of 1."
if
fanouts
.
size
(
0
)
>
1
:
if
fanouts
.
size
(
0
)
>
1
:
assert
(
assert
(
self
.
type_per_edge
is
not
None
self
.
type_per_edge
is
not
None
...
@@ -279,10 +287,6 @@ class CSCSamplingGraph:
...
@@ -279,10 +287,6 @@ class CSCSamplingGraph:
(
fanouts
>=
0
)
|
(
fanouts
==
-
1
)
(
fanouts
>=
0
)
|
(
fanouts
==
-
1
)
),
"Fanouts should consist of values that are either -1 or
\
),
"Fanouts should consist of values that are either -1 or
\
greater than or equal to 0."
greater than or equal to 0."
if
self
.
metadata
and
self
.
metadata
.
edge_type_to_id
:
assert
len
(
self
.
metadata
.
edge_type_to_id
)
==
fanouts
.
size
(
0
),
"Fanouts should have the same number of elements as etypes."
if
probs_or_mask
is
not
None
:
if
probs_or_mask
is
not
None
:
assert
probs_or_mask
.
dim
()
==
1
,
"Probs should be 1-D tensor."
assert
probs_or_mask
.
dim
()
==
1
,
"Probs should be 1-D tensor."
assert
(
assert
(
...
...
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
View file @
330571b6
...
@@ -407,9 +407,13 @@ def test_sample_neighbors():
...
@@ -407,9 +407,13 @@ def test_sample_neighbors():
type_per_edge
=
torch
.
LongTensor
([
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
1
])
type_per_edge
=
torch
.
LongTensor
([
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
1
])
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
len
(
indices
)
assert
indptr
[
-
1
]
==
len
(
indices
)
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
,
"n3"
:
2
}
etypes
=
{(
"n1"
,
"e1"
,
"n2"
):
0
,
(
"n1"
,
"e2"
,
"n3"
):
1
}
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
# Construct CSCSamplingGraph.
# Construct CSCSamplingGraph.
graph
=
gb
.
from_csc
(
indptr
,
indices
,
type_per_edge
=
type_per_edge
)
graph
=
gb
.
from_csc
(
indptr
,
indices
,
type_per_edge
=
type_per_edge
,
metadata
=
metadata
)
# Generate subgraph via sample neighbors.
# Generate subgraph via sample neighbors.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
...
@@ -467,9 +471,14 @@ def test_sample_neighbors_fanouts(fanouts, expected_sampled_num):
...
@@ -467,9 +471,14 @@ def test_sample_neighbors_fanouts(fanouts, expected_sampled_num):
type_per_edge
=
torch
.
LongTensor
([
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
1
])
type_per_edge
=
torch
.
LongTensor
([
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
1
])
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
len
(
indices
)
assert
indptr
[
-
1
]
==
len
(
indices
)
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
,
"n3"
:
2
}
etypes
=
{(
"n1"
,
"e1"
,
"n2"
):
0
,
(
"n1"
,
"e2"
,
"n3"
):
1
}
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
# Construct CSCSamplingGraph.
# Construct CSCSamplingGraph.
graph
=
gb
.
from_csc
(
indptr
,
indices
,
type_per_edge
=
type_per_edge
)
graph
=
gb
.
from_csc
(
indptr
,
indices
,
type_per_edge
=
type_per_edge
,
metadata
=
metadata
)
# Generate subgraph via sample neighbors.
# Generate subgraph via sample neighbors.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment