Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
6aba92e9
Unverified
Commit
6aba92e9
authored
Sep 11, 2023
by
Rhett Ying
Committed by
GitHub
Sep 11, 2023
Browse files
[GraphBolt] fix test cases about datapipe (#6305)
parent
1328baf7
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
13 deletions
+13
-13
tests/python/pytorch/graphbolt/test_base.py
tests/python/pytorch/graphbolt/test_base.py
+3
-3
tests/python/pytorch/graphbolt/test_feature_fetcher.py
tests/python/pytorch/graphbolt/test_feature_fetcher.py
+3
-3
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+7
-7
No files found.
tests/python/pytorch/graphbolt/test_base.py
View file @
6aba92e9
...
@@ -10,15 +10,15 @@ import torch
...
@@ -10,15 +10,15 @@ import torch
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"cpu"
,
"CopyTo needs GPU to test"
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"cpu"
,
"CopyTo needs GPU to test"
)
def
test_CopyTo
():
def
test_CopyTo
():
dp
=
gb
.
ItemSampler
(
gb
.
ItemSet
(
torch
.
randn
(
20
)),
4
)
item_sampler
=
gb
.
ItemSampler
(
gb
.
ItemSet
(
torch
.
randn
(
20
)),
4
)
# Invoke CopyTo via class constructor.
# Invoke CopyTo via class constructor.
dp
=
gb
.
CopyTo
(
dp
,
"cuda"
)
dp
=
gb
.
CopyTo
(
item_sampler
,
"cuda"
)
for
data
in
dp
:
for
data
in
dp
:
assert
data
.
device
.
type
==
"cuda"
assert
data
.
device
.
type
==
"cuda"
# Invoke CopyTo via functional form.
# Invoke CopyTo via functional form.
dp
=
dp
.
copy_to
(
"cuda"
)
dp
=
item_sampler
.
copy_to
(
"cuda"
)
for
data
in
dp
:
for
data
in
dp
:
assert
data
.
device
.
type
==
"cuda"
assert
data
.
device
.
type
==
"cuda"
...
...
tests/python/pytorch/graphbolt/test_feature_fetcher.py
View file @
6aba92e9
...
@@ -17,17 +17,17 @@ def test_FeatureFetcher_invoke():
...
@@ -17,17 +17,17 @@ def test_FeatureFetcher_invoke():
feature_store
=
gb
.
BasicFeatureStore
(
features
)
feature_store
=
gb
.
BasicFeatureStore
(
features
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
# Invoke FeatureFetcher via class constructor.
# Invoke FeatureFetcher via class constructor.
datapipe
=
gb
.
NeighborSampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
gb
.
NeighborSampler
(
item_sampler
,
graph
,
fanouts
)
datapipe
=
gb
.
FeatureFetcher
(
datapipe
,
feature_store
,
[
"a"
],
[
"b"
])
datapipe
=
gb
.
FeatureFetcher
(
datapipe
,
feature_store
,
[
"a"
],
[
"b"
])
assert
len
(
list
(
datapipe
))
==
5
assert
len
(
list
(
datapipe
))
==
5
# Invoke FeatureFetcher via functional form.
# Invoke FeatureFetcher via functional form.
datapipe
=
datapipe
.
sample_neighbor
(
graph
,
fanouts
).
fetch_feature
(
datapipe
=
item_sampler
.
sample_neighbor
(
graph
,
fanouts
).
fetch_feature
(
feature_store
,
[
"a"
],
[
"b"
]
feature_store
,
[
"a"
],
[
"b"
]
)
)
assert
len
(
list
(
datapipe
))
==
5
assert
len
(
list
(
datapipe
))
==
5
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
6aba92e9
...
@@ -7,15 +7,15 @@ from torchdata.datapipes.iter import Mapper
...
@@ -7,15 +7,15 @@ from torchdata.datapipes.iter import Mapper
def
test_SubgraphSampler_invoke
():
def
test_SubgraphSampler_invoke
():
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
# Invoke via class constructor.
# Invoke via class constructor.
datapipe
=
gb
.
SubgraphSampler
(
datapipe
)
datapipe
=
gb
.
SubgraphSampler
(
item_sampler
)
with
pytest
.
raises
(
NotImplementedError
):
with
pytest
.
raises
(
NotImplementedError
):
next
(
iter
(
datapipe
))
next
(
iter
(
datapipe
))
# Invokde via functional form.
# Invokde via functional form.
datapipe
=
datapipe
.
sample_subgraph
()
datapipe
=
item_sampler
.
sample_subgraph
()
with
pytest
.
raises
(
NotImplementedError
):
with
pytest
.
raises
(
NotImplementedError
):
next
(
iter
(
datapipe
))
next
(
iter
(
datapipe
))
...
@@ -24,20 +24,20 @@ def test_SubgraphSampler_invoke():
...
@@ -24,20 +24,20 @@ def test_SubgraphSampler_invoke():
def
test_NeighborSampler_invoke
(
labor
):
def
test_NeighborSampler_invoke
(
labor
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
# Invoke via class constructor.
# Invoke via class constructor.
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
datapipe
=
Sampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
Sampler
(
item_sampler
,
graph
,
fanouts
)
assert
len
(
list
(
datapipe
))
==
5
assert
len
(
list
(
datapipe
))
==
5
# Invokde via functional form.
# Invokde via functional form.
if
labor
:
if
labor
:
datapipe
=
datapipe
.
sample_layer_neighbor
(
graph
,
fanouts
)
datapipe
=
item_sampler
.
sample_layer_neighbor
(
graph
,
fanouts
)
else
:
else
:
datapipe
=
datapipe
.
sample_neighbor
(
graph
,
fanouts
)
datapipe
=
item_sampler
.
sample_neighbor
(
graph
,
fanouts
)
assert
len
(
list
(
datapipe
))
==
5
assert
len
(
list
(
datapipe
))
==
5
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment