Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
e117adac
Unverified
Commit
e117adac
authored
Feb 01, 2024
by
Rhett Ying
Committed by
GitHub
Feb 01, 2024
Browse files
[GraphBolt] fix testcases on warning messages (#7054)
parent
571340da
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
11 deletions
+22
-11
tests/python/pytorch/graphbolt/test_base.py
tests/python/pytorch/graphbolt/test_base.py
+5
-3
tests/python/pytorch/graphbolt/test_feature_fetcher.py
tests/python/pytorch/graphbolt/test_feature_fetcher.py
+6
-4
tests/python/pytorch/graphbolt/test_item_sampler.py
tests/python/pytorch/graphbolt/test_item_sampler.py
+11
-4
No files found.
tests/python/pytorch/graphbolt/test_base.py
View file @
e117adac
...
...
@@ -13,17 +13,19 @@ from . import gb_test_utils
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"cpu"
,
"CopyTo needs GPU to test"
)
def
test_CopyTo
():
item_sampler
=
gb
.
ItemSampler
(
gb
.
ItemSet
(
torch
.
randn
(
20
)),
4
)
item_sampler
=
gb
.
ItemSampler
(
gb
.
ItemSet
(
torch
.
arange
(
20
),
names
=
"seed_nodes"
),
4
)
# Invoke CopyTo via class constructor.
dp
=
gb
.
CopyTo
(
item_sampler
,
"cuda"
)
for
data
in
dp
:
assert
data
.
device
.
type
==
"cuda"
assert
data
.
seed_nodes
.
device
.
type
==
"cuda"
# Invoke CopyTo via functional form.
dp
=
item_sampler
.
copy_to
(
"cuda"
)
for
data
in
dp
:
assert
data
.
device
.
type
==
"cuda"
assert
data
.
seed_nodes
.
device
.
type
==
"cuda"
@
pytest
.
mark
.
parametrize
(
...
...
tests/python/pytorch/graphbolt/test_feature_fetcher.py
View file @
e117adac
...
...
@@ -77,7 +77,8 @@ def test_FeatureFetcher_with_edges_homo():
[[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
graph
.
total_num_edges
)]
)
def
add_node_and_edge_ids
(
seeds
):
def
add_node_and_edge_ids
(
minibatch
):
seeds
=
minibatch
.
seed_nodes
subgraphs
=
[]
for
_
in
range
(
3
):
sampled_csc
=
gb
.
CSCFormatBase
(
...
...
@@ -103,7 +104,7 @@ def test_FeatureFetcher_with_edges_homo():
features
[
keys
[
1
]]
=
gb
.
TorchBasedFeature
(
b
)
feature_store
=
gb
.
BasicFeatureStore
(
features
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
))
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
)
,
names
=
"seed_nodes"
)
item_sampler_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
converter_dp
=
Mapper
(
item_sampler_dp
,
add_node_and_edge_ids
)
fetcher_dp
=
gb
.
FeatureFetcher
(
converter_dp
,
feature_store
,
[
"a"
],
[
"b"
])
...
...
@@ -170,7 +171,8 @@ def test_FeatureFetcher_with_edges_hetero():
a
=
torch
.
tensor
([[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
20
)])
b
=
torch
.
tensor
([[
random
.
randint
(
0
,
10
)]
for
_
in
range
(
50
)])
def
add_node_and_edge_ids
(
seeds
):
def
add_node_and_edge_ids
(
minibatch
):
seeds
=
minibatch
.
seed_nodes
subgraphs
=
[]
original_edge_ids
=
{
"n1:e1:n2"
:
torch
.
randint
(
0
,
50
,
(
10
,)),
...
...
@@ -213,7 +215,7 @@ def test_FeatureFetcher_with_edges_hetero():
itemset
=
gb
.
ItemSetDict
(
{
"n1"
:
gb
.
ItemSet
(
torch
.
randint
(
0
,
20
,
(
10
,))),
"n1"
:
gb
.
ItemSet
(
torch
.
randint
(
0
,
20
,
(
10
,))
,
names
=
"seed_nodes"
),
}
)
item_sampler_dp
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
...
...
tests/python/pytorch/graphbolt/test_item_sampler.py
View file @
e117adac
...
...
@@ -204,9 +204,16 @@ def test_ItemSet_graphs(batch_size, shuffle, drop_last):
dgl
.
rand_graph
(
num_nodes
*
(
i
+
1
),
num_edges
*
(
i
+
1
))
for
i
in
range
(
num_graphs
)
]
item_set
=
gb
.
ItemSet
(
graphs
)
item_set
=
gb
.
ItemSet
(
graphs
,
names
=
"graphs"
)
# DGLGraph is not supported in gb.MiniBatch yet. Let's use a customized
# minibatcher to return the original graphs.
customized_minibatcher
=
lambda
batch
,
names
:
batch
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
,
minibatcher
=
customized_minibatcher
,
)
minibatch_num_nodes
=
[]
minibatch_num_edges
=
[]
...
...
@@ -459,13 +466,13 @@ def test_ItemSet_seeds_labels(batch_size, shuffle, drop_last):
def
test_append_with_other_datapipes
():
num_ids
=
100
batch_size
=
4
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
num_ids
))
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
num_ids
)
,
names
=
"seed_nodes"
)
data_pipe
=
gb
.
ItemSampler
(
item_set
,
batch_size
)
# torchdata.datapipes.iter.Enumerator
data_pipe
=
data_pipe
.
enumerate
()
for
i
,
(
idx
,
data
)
in
enumerate
(
data_pipe
):
assert
i
==
idx
assert
len
(
data
)
==
batch_size
assert
len
(
data
.
seed_nodes
)
==
batch_size
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment