Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
13e7c2fa
Unverified
Commit
13e7c2fa
authored
Feb 14, 2024
by
Andrei Ivanov
Committed by
GitHub
Feb 14, 2024
Browse files
[GraphBolt] Improving `subgraph_sampler` tests. (#7047)
parent
8204fe19
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
138 additions
and
143 deletions
+138
-143
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+138
-143
No files found.
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
13e7c2fa
import
unittest
import
warnings
from
enum
import
Enum
from
functools
import
partial
...
...
@@ -9,7 +10,6 @@ import dgl
import
dgl.graphbolt
as
gb
import
pytest
import
torch
from
torchdata.datapipes.iter
import
Mapper
from
.
import
gb_test_utils
...
...
@@ -22,6 +22,12 @@ def _check_sampler_type(sampler_type):
)
def
_check_sampler_len
(
sampler
,
lenExp
):
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
,
category
=
UserWarning
)
assert
len
(
list
(
sampler
))
==
lenExp
class
SamplerType
(
Enum
):
Normal
=
0
Layer
=
1
...
...
@@ -128,7 +134,7 @@ def test_SubgraphSampler_Node_seed_nodes(sampler_type):
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
sampler
=
_get_sampler
(
sampler_type
)
sampler_dp
=
sampler
(
item_sampler
,
graph
,
fanouts
)
assert
len
(
list
(
sampler_dp
))
==
5
_check_sampler_len
(
sampler_dp
,
5
)
def
to_link_batch
(
data
):
...
...
@@ -161,7 +167,7 @@ def test_SubgraphSampler_Link_node_pairs(sampler_type):
sampler
=
_get_sampler
(
sampler_type
)
datapipe
=
sampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
assert
len
(
list
(
datapipe
))
==
5
_check_sampler_len
(
datapipe
,
5
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -190,7 +196,7 @@ def test_SubgraphSampler_Link_With_Negative_node_pairs(sampler_type):
sampler
=
_get_sampler
(
sampler_type
)
datapipe
=
sampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
assert
len
(
list
(
datapipe
))
==
5
_check_sampler_len
(
datapipe
,
5
)
def
get_hetero_graph
():
...
...
@@ -239,7 +245,9 @@ def test_SubgraphSampler_Node_seed_nodes_Hetero(sampler_type):
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
sampler
=
_get_sampler
(
sampler_type
)
sampler_dp
=
sampler
(
item_sampler
,
graph
,
fanouts
)
assert
len
(
list
(
sampler_dp
))
==
2
_check_sampler_len
(
sampler_dp
,
2
)
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
,
category
=
UserWarning
)
for
minibatch
in
sampler_dp
:
assert
len
(
minibatch
.
sampled_subgraphs
)
==
num_layer
...
...
@@ -285,7 +293,7 @@ def test_SubgraphSampler_Link_Hetero_node_pairs(sampler_type):
sampler
=
_get_sampler
(
sampler_type
)
datapipe
=
sampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
assert
len
(
list
(
datapipe
))
==
5
_check_sampler_len
(
datapipe
,
5
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -330,7 +338,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_node_pairs(sampler_type):
sampler
=
_get_sampler
(
sampler_type
)
datapipe
=
sampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
assert
len
(
list
(
datapipe
))
==
5
_check_sampler_len
(
datapipe
,
5
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -375,7 +383,7 @@ def test_SubgraphSampler_Link_Hetero_Unknown_Etype_node_pairs(sampler_type):
sampler
=
_get_sampler
(
sampler_type
)
datapipe
=
sampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
assert
len
(
list
(
datapipe
))
==
5
_check_sampler_len
(
datapipe
,
5
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -423,7 +431,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype_node_pairs(
sampler
=
_get_sampler
(
sampler_type
)
datapipe
=
sampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
assert
len
(
list
(
datapipe
))
==
5
_check_sampler_len
(
datapipe
,
5
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -493,30 +501,26 @@ def test_SubgraphSampler_Random_Hetero_Graph_seed_ndoes(sampler_type, replace):
sampler_dp
=
sampler
(
item_sampler
,
graph
,
fanouts
,
replace
=
replace
)
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
,
category
=
UserWarning
)
for
data
in
sampler_dp
:
for
sampledsubgraph
in
data
.
sampled_subgraphs
:
for
_
,
value
in
sampledsubgraph
.
sampled_csc
.
items
():
for
idx
in
[
value
.
indices
,
value
.
indptr
]:
assert
torch
.
equal
(
torch
.
ge
(
value
.
indices
,
torch
.
zeros
(
len
(
value
.
indices
)).
to
(
F
.
ctx
()),
),
torch
.
ones
(
len
(
value
.
indices
)).
to
(
F
.
ctx
()),
torch
.
ge
(
idx
,
torch
.
zeros
(
len
(
idx
)).
to
(
F
.
ctx
())),
torch
.
ones
(
len
(
idx
)).
to
(
F
.
ctx
()),
)
node_ids
=
[
sampledsubgraph
.
original_column_node_ids
,
sampledsubgraph
.
original_row_node_ids
,
]
for
ids
in
node_ids
:
for
_
,
value
in
ids
.
items
():
assert
torch
.
equal
(
torch
.
ge
(
value
.
indptr
,
torch
.
zeros
(
len
(
value
.
indptr
)).
to
(
F
.
ctx
())
value
,
torch
.
zeros
(
len
(
value
)).
to
(
F
.
ctx
())
),
torch
.
ones
(
len
(
value
.
indptr
)).
to
(
F
.
ctx
()),
)
for
_
,
value
in
sampledsubgraph
.
original_column_node_ids
.
items
():
assert
torch
.
equal
(
torch
.
ge
(
value
,
torch
.
zeros
(
len
(
value
)).
to
(
F
.
ctx
())),
torch
.
ones
(
len
(
value
)).
to
(
F
.
ctx
()),
)
for
_
,
value
in
sampledsubgraph
.
original_row_node_ids
.
items
():
assert
torch
.
equal
(
torch
.
ge
(
value
,
torch
.
zeros
(
len
(
value
)).
to
(
F
.
ctx
())),
torch
.
ones
(
len
(
value
)).
to
(
F
.
ctx
()),
)
...
...
@@ -570,11 +574,16 @@ def test_SubgraphSampler_without_dedpulication_Homo_seed_nodes(sampler_type):
torch
.
tensor
([
0
,
2
,
2
,
3
,
4
,
4
,
5
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
]).
to
(
F
.
ctx
()),
]
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
,
category
=
UserWarning
)
for
data
in
datapipe
:
for
step
,
sampled_subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
assert
len
(
sampled_subgraph
.
original_row_node_ids
)
==
length
[
step
]
assert
(
len
(
sampled_subgraph
.
original_row_node_ids
)
==
length
[
step
]
)
assert
torch
.
equal
(
sampled_subgraph
.
sampled_csc
.
indices
,
compacted_indices
[
step
]
sampled_subgraph
.
sampled_csc
.
indices
,
compacted_indices
[
step
],
)
assert
torch
.
equal
(
sampled_subgraph
.
sampled_csc
.
indptr
,
indptr
[
step
]
...
...
@@ -585,6 +594,51 @@ def test_SubgraphSampler_without_dedpulication_Homo_seed_nodes(sampler_type):
)
def
_assert_hetero_values
(
datapipe
,
original_row_node_ids
,
original_column_node_ids
,
csc_formats
):
for
data
in
datapipe
:
for
step
,
sampled_subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
for
ntype
in
[
"n1"
,
"n2"
]:
assert
torch
.
equal
(
sampled_subgraph
.
original_row_node_ids
[
ntype
],
original_row_node_ids
[
step
][
ntype
].
to
(
F
.
ctx
()),
)
assert
torch
.
equal
(
sampled_subgraph
.
original_column_node_ids
[
ntype
],
original_column_node_ids
[
step
][
ntype
].
to
(
F
.
ctx
()),
)
for
etype
in
[
"n1:e1:n2"
,
"n2:e2:n1"
]:
assert
torch
.
equal
(
sampled_subgraph
.
sampled_csc
[
etype
].
indices
,
csc_formats
[
step
][
etype
].
indices
.
to
(
F
.
ctx
()),
)
assert
torch
.
equal
(
sampled_subgraph
.
sampled_csc
[
etype
].
indptr
,
csc_formats
[
step
][
etype
].
indptr
.
to
(
F
.
ctx
()),
)
def
_assert_homo_values
(
datapipe
,
original_row_node_ids
,
compacted_indices
,
indptr
,
seeds
):
for
data
in
datapipe
:
for
step
,
sampled_subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
assert
torch
.
equal
(
sampled_subgraph
.
original_row_node_ids
,
original_row_node_ids
[
step
],
)
assert
torch
.
equal
(
sampled_subgraph
.
sampled_csc
.
indices
,
compacted_indices
[
step
]
)
assert
torch
.
equal
(
sampled_subgraph
.
sampled_csc
.
indptr
,
indptr
[
step
]
)
assert
torch
.
equal
(
sampled_subgraph
.
original_column_node_ids
,
seeds
[
step
]
)
@
pytest
.
mark
.
parametrize
(
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
...
...
@@ -655,25 +709,13 @@ def test_SubgraphSampler_without_dedpulication_Hetero_seed_nodes(sampler_type):
},
]
for
data
in
datapipe
:
for
step
,
sampled_subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
for
ntype
in
[
"n1"
,
"n2"
]:
assert
torch
.
equal
(
sampled_subgraph
.
original_row_node_ids
[
ntype
],
original_row_node_ids
[
step
][
ntype
].
to
(
F
.
ctx
()),
)
assert
torch
.
equal
(
sampled_subgraph
.
original_column_node_ids
[
ntype
],
original_column_node_ids
[
step
][
ntype
].
to
(
F
.
ctx
()),
)
for
etype
in
[
"n1:e1:n2"
,
"n2:e2:n1"
]:
assert
torch
.
equal
(
sampled_subgraph
.
sampled_csc
[
etype
].
indices
,
csc_formats
[
step
][
etype
].
indices
.
to
(
F
.
ctx
()),
)
assert
torch
.
equal
(
sampled_subgraph
.
sampled_csc
[
etype
].
indptr
,
csc_formats
[
step
][
etype
].
indptr
.
to
(
F
.
ctx
()),
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
,
category
=
UserWarning
)
_assert_hetero_values
(
datapipe
,
original_row_node_ids
,
original_column_node_ids
,
csc_formats
,
)
...
...
@@ -719,20 +761,8 @@ def test_SubgraphSampler_unique_csc_format_Homo_cpu_seed_nodes(labor):
torch
.
tensor
([
0
,
3
,
4
,
5
,
2
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
]).
to
(
F
.
ctx
()),
]
for
data
in
datapipe
:
for
step
,
sampled_subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
assert
torch
.
equal
(
sampled_subgraph
.
original_row_node_ids
,
original_row_node_ids
[
step
],
)
assert
torch
.
equal
(
sampled_subgraph
.
sampled_csc
.
indices
,
compacted_indices
[
step
]
)
assert
torch
.
equal
(
sampled_subgraph
.
sampled_csc
.
indptr
,
indptr
[
step
]
)
assert
torch
.
equal
(
sampled_subgraph
.
original_column_node_ids
,
seeds
[
step
]
_assert_homo_values
(
datapipe
,
original_row_node_ids
,
compacted_indices
,
indptr
,
seeds
)
...
...
@@ -778,20 +808,8 @@ def test_SubgraphSampler_unique_csc_format_Homo_gpu_seed_nodes(labor):
torch
.
tensor
([
0
,
3
,
4
,
2
,
5
]).
to
(
F
.
ctx
()),
torch
.
tensor
([
0
,
3
,
4
]).
to
(
F
.
ctx
()),
]
for
data
in
datapipe
:
for
step
,
sampled_subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
assert
torch
.
equal
(
sampled_subgraph
.
original_row_node_ids
,
original_row_node_ids
[
step
],
)
assert
torch
.
equal
(
sampled_subgraph
.
sampled_csc
.
indices
,
compacted_indices
[
step
]
)
assert
torch
.
equal
(
sampled_subgraph
.
sampled_csc
.
indptr
,
indptr
[
step
]
)
assert
torch
.
equal
(
sampled_subgraph
.
original_column_node_ids
,
seeds
[
step
]
_assert_homo_values
(
datapipe
,
original_row_node_ids
,
compacted_indices
,
indptr
,
seeds
)
...
...
@@ -853,26 +871,8 @@ def test_SubgraphSampler_unique_csc_format_Hetero_seed_nodes(labor):
"n2"
:
torch
.
tensor
([
0
,
1
]),
},
]
for
data
in
datapipe
:
for
step
,
sampled_subgraph
in
enumerate
(
data
.
sampled_subgraphs
):
for
ntype
in
[
"n1"
,
"n2"
]:
assert
torch
.
equal
(
sampled_subgraph
.
original_row_node_ids
[
ntype
],
original_row_node_ids
[
step
][
ntype
].
to
(
F
.
ctx
()),
)
assert
torch
.
equal
(
sampled_subgraph
.
original_column_node_ids
[
ntype
],
original_column_node_ids
[
step
][
ntype
].
to
(
F
.
ctx
()),
)
for
etype
in
[
"n1:e1:n2"
,
"n2:e2:n1"
]:
assert
torch
.
equal
(
sampled_subgraph
.
sampled_csc
[
etype
].
indices
,
csc_formats
[
step
][
etype
].
indices
.
to
(
F
.
ctx
()),
)
assert
torch
.
equal
(
sampled_subgraph
.
sampled_csc
[
etype
].
indptr
,
csc_formats
[
step
][
etype
].
indptr
.
to
(
F
.
ctx
()),
_assert_hetero_values
(
datapipe
,
original_row_node_ids
,
original_column_node_ids
,
csc_formats
)
...
...
@@ -886,7 +886,9 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer_seed_nodes(sampler_type):
items_n1
=
torch
.
tensor
([
0
])
items_n2
=
torch
.
tensor
([
1
])
names
=
"seed_nodes"
item_length
=
2
if
sampler_type
==
SamplerType
.
Temporal
:
item_length
=
3
graph
.
node_attributes
=
{
"timestamp"
:
torch
.
arange
(
graph
.
csc_indptr
.
numel
()
-
1
).
to
(
F
.
ctx
())
}
...
...
@@ -909,30 +911,23 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer_seed_nodes(sampler_type):
fanouts
=
[
torch
.
LongTensor
([
2
,
1
])
for
_
in
range
(
num_layer
)]
sampler
=
_get_sampler
(
sampler_type
)
sampler_dp
=
sampler
(
item_sampler
,
graph
,
fanouts
)
if
sampler_type
==
SamplerType
.
Temporal
:
indices_len
=
[
{
"n1:e1:n2"
:
4
,
"n2:e2:n1"
:
3
,
},
{
"n1:e1:n2"
:
2
,
"n2:e2:n1"
:
1
,
},
]
else
:
indices_len
=
[
{
"n1:e1:n2"
:
4
,
"n2:e2:n1"
:
2
,
"n2:e2:n1"
:
item_length
,
},
{
"n1:e1:n2"
:
2
,
"n2:e2:n1"
:
1
,
},
]
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
,
category
=
UserWarning
)
for
minibatch
in
sampler_dp
:
for
step
,
sampled_subgraph
in
enumerate
(
minibatch
.
sampled_subgraphs
):
for
step
,
sampled_subgraph
in
enumerate
(
minibatch
.
sampled_subgraphs
):
assert
(
len
(
sampled_subgraph
.
sampled_csc
[
"n1:e1:n2"
].
indices
)
==
indices_len
[
step
][
"n1:e1:n2"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment