Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
c6abbb13
"examples/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "57d2f31f20f124a5fd93077060d2f189faed0eb8"
Unverified
Commit
c6abbb13
authored
Dec 19, 2023
by
Mingbang Wang
Committed by
GitHub
Dec 19, 2023
Browse files
[GraphBolt] Add a check assertion for data type of nodes (#6767)
parent
6451807b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
74 additions
and
9 deletions
+74
-9
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+4
-0
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
...n/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
+70
-9
No files found.
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
c6abbb13
...
@@ -629,6 +629,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -629,6 +629,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
def
_check_sampler_arguments
(
self
,
nodes
,
fanouts
,
probs_name
):
def
_check_sampler_arguments
(
self
,
nodes
,
fanouts
,
probs_name
):
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
assert
nodes
.
dtype
==
self
.
indices
.
dtype
,
(
f
"Data type of nodes must be consistent with "
f
"indices.dtype(
{
self
.
indices
.
dtype
}
), but got
{
nodes
.
dtype
}
."
)
assert
fanouts
.
dim
()
==
1
,
"Fanouts should be 1-D tensor."
assert
fanouts
.
dim
()
==
1
,
"Fanouts should be 1-D tensor."
expected_fanout_len
=
1
expected_fanout_len
=
1
if
self
.
edge_type_to_id
:
if
self
.
edge_type_to_id
:
...
...
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
View file @
c6abbb13
import
os
import
os
import
pickle
import
pickle
import
re
import
tempfile
import
tempfile
import
unittest
import
unittest
...
@@ -972,9 +973,25 @@ def test_sample_neighbors_homo(labor, indptr_dtype, indices_dtype):
...
@@ -972,9 +973,25 @@ def test_sample_neighbors_homo(labor, indptr_dtype, indices_dtype):
graph
=
gb
.
fused_csc_sampling_graph
(
indptr
,
indices
)
graph
=
gb
.
fused_csc_sampling_graph
(
indptr
,
indices
)
# Generate subgraph via sample neighbors.
# Generate subgraph via sample neighbors.
node
s
=
torch
.
t
ensor
([
1
,
3
,
4
],
dtype
=
indices_dtype
)
fanout
s
=
torch
.
LongT
ensor
([
2
]
)
sampler
=
graph
.
sample_layer_neighbors
if
labor
else
graph
.
sample_neighbors
sampler
=
graph
.
sample_layer_neighbors
if
labor
else
graph
.
sample_neighbors
subgraph
=
sampler
(
nodes
,
fanouts
=
torch
.
LongTensor
([
2
]))
# 1. Sample with nodes in mismatched dtype with graph's indices.
nodes
=
torch
.
tensor
(
[
1
,
3
,
4
],
dtype
=
(
torch
.
int64
if
indices_dtype
==
torch
.
int32
else
torch
.
int32
),
)
with
pytest
.
raises
(
AssertionError
,
match
=
re
.
escape
(
"Data type of nodes must be consistent with indices.dtype"
),
):
_
=
sampler
(
nodes
,
fanouts
)
# 2. Sample with nodes in matched dtype with graph's indices.
nodes
=
torch
.
tensor
([
1
,
3
,
4
],
dtype
=
indices_dtype
)
subgraph
=
sampler
(
nodes
,
fanouts
)
# Verify in subgraph.
# Verify in subgraph.
sampled_num
=
subgraph
.
node_pairs
[
0
].
size
(
0
)
sampled_num
=
subgraph
.
node_pairs
[
0
].
size
(
0
)
...
@@ -1023,12 +1040,37 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
...
@@ -1023,12 +1040,37 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
)
)
# Sample on both node types.
# Sample on both node types.
fanouts
=
torch
.
tensor
([
-
1
,
-
1
])
sampler
=
graph
.
sample_layer_neighbors
if
labor
else
graph
.
sample_neighbors
# 1. Sample with nodes in mismatched dtype with graph's indices.
nodes
=
{
"n1"
:
torch
.
tensor
(
[
0
],
dtype
=
(
torch
.
int64
if
indices_dtype
==
torch
.
int32
else
torch
.
int32
),
),
"n2"
:
torch
.
tensor
(
[
0
],
dtype
=
(
torch
.
int64
if
indices_dtype
==
torch
.
int32
else
torch
.
int32
),
),
}
with
pytest
.
raises
(
AssertionError
,
match
=
re
.
escape
(
"Data type of nodes must be consistent with indices.dtype"
),
):
_
=
sampler
(
nodes
,
fanouts
)
# 2. Sample with nodes in matched dtype with graph's indices.
nodes
=
{
nodes
=
{
"n1"
:
torch
.
tensor
([
0
],
dtype
=
indices_dtype
),
"n1"
:
torch
.
tensor
([
0
],
dtype
=
indices_dtype
),
"n2"
:
torch
.
tensor
([
0
],
dtype
=
indices_dtype
),
"n2"
:
torch
.
tensor
([
0
],
dtype
=
indices_dtype
),
}
}
fanouts
=
torch
.
tensor
([
-
1
,
-
1
])
sampler
=
graph
.
sample_layer_neighbors
if
labor
else
graph
.
sample_neighbors
subgraph
=
sampler
(
nodes
,
fanouts
)
subgraph
=
sampler
(
nodes
,
fanouts
)
# Verify in subgraph.
# Verify in subgraph.
...
@@ -1051,20 +1093,39 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
...
@@ -1051,20 +1093,39 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
assert
subgraph
.
original_edge_ids
is
None
assert
subgraph
.
original_edge_ids
is
None
# Sample on single node type.
# Sample on single node type.
nodes
=
{
"n1"
:
torch
.
LongTensor
([
0
])}
fanouts
=
torch
.
tensor
([
-
1
,
-
1
])
fanouts
=
torch
.
tensor
([
-
1
,
-
1
])
sampler
=
graph
.
sample_layer_neighbors
if
labor
else
graph
.
sample_neighbors
sampler
=
graph
.
sample_layer_neighbors
if
labor
else
graph
.
sample_neighbors
# 1. Sample with nodes in mismatched dtype with graph's indices.
nodes
=
{
"n1"
:
torch
.
tensor
(
[
0
],
dtype
=
(
torch
.
int64
if
indices_dtype
==
torch
.
int32
else
torch
.
int32
),
)
}
with
pytest
.
raises
(
AssertionError
,
match
=
re
.
escape
(
"Data type of nodes must be consistent with indices.dtype"
),
):
_
=
sampler
(
nodes
,
fanouts
)
# 2. Sample with nodes in matched dtype with graph's indices.
nodes
=
{
"n1"
:
torch
.
tensor
([
0
],
dtype
=
indices_dtype
)}
subgraph
=
sampler
(
nodes
,
fanouts
)
subgraph
=
sampler
(
nodes
,
fanouts
)
# Verify in subgraph.
# Verify in subgraph.
expected_node_pairs
=
{
expected_node_pairs
=
{
"n2:e2:n1"
:
(
"n2:e2:n1"
:
(
torch
.
LongT
ensor
([
0
,
2
]),
torch
.
t
ensor
([
0
,
2
]
,
dtype
=
indices_dtype
),
torch
.
LongT
ensor
([
0
,
0
]),
torch
.
t
ensor
([
0
,
0
]
,
dtype
=
indices_dtype
),
),
),
"n1:e1:n2"
:
(
"n1:e1:n2"
:
(
torch
.
LongT
ensor
([]),
torch
.
t
ensor
([]
,
dtype
=
indices_dtype
),
torch
.
LongT
ensor
([]),
torch
.
t
ensor
([]
,
dtype
=
indices_dtype
),
),
),
}
}
assert
len
(
subgraph
.
node_pairs
)
==
2
assert
len
(
subgraph
.
node_pairs
)
==
2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment