Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
ec495239
Unverified
Commit
ec495239
authored
Sep 07, 2023
by
Rhett Ying
Committed by
GitHub
Sep 07, 2023
Browse files
[GraphBolt] enable to invoke gb samplers in functional form (#6297)
parent
79a95477
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
158 additions
and
1 deletion
+158
-1
python/dgl/graphbolt/base.py
python/dgl/graphbolt/base.py
+2
-0
python/dgl/graphbolt/feature_fetcher.py
python/dgl/graphbolt/feature_fetcher.py
+3
-0
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+4
-0
python/dgl/graphbolt/impl/uniform_negative_sampler.py
python/dgl/graphbolt/impl/uniform_negative_sampler.py
+3
-0
python/dgl/graphbolt/item_sampler.py
python/dgl/graphbolt/item_sampler.py
+2
-1
python/dgl/graphbolt/negative_sampler.py
python/dgl/graphbolt/negative_sampler.py
+2
-0
python/dgl/graphbolt/subgraph_sampler.py
python/dgl/graphbolt/subgraph_sampler.py
+2
-0
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
+69
-0
tests/python/pytorch/graphbolt/test_base.py
tests/python/pytorch/graphbolt/test_base.py
+6
-0
tests/python/pytorch/graphbolt/test_feature_fetcher.py
tests/python/pytorch/graphbolt/test_feature_fetcher.py
+29
-0
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+36
-0
No files found.
python/dgl/graphbolt/base.py
View file @
ec495239
"""Base types and utilities for Graph Bolt."""
from
torch.utils.data
import
functional_datapipe
from
torchdata.datapipes.iter
import
IterDataPipe
from
..utils
import
recursive_apply
...
...
@@ -53,6 +54,7 @@ def _to(x, device):
return
x
.
to
(
device
)
if
hasattr
(
x
,
"to"
)
else
x
@
functional_datapipe
(
"copy_to"
)
class
CopyTo
(
IterDataPipe
):
"""DataPipe that transfers each element yielded from the previous DataPipe
to the given device.
...
...
python/dgl/graphbolt/feature_fetcher.py
View file @
ec495239
...
...
@@ -2,9 +2,12 @@
from
typing
import
Dict
from
torch.utils.data
import
functional_datapipe
from
torchdata.datapipes.iter
import
Mapper
@
functional_datapipe
(
"fetch_feature"
)
class
FeatureFetcher
(
Mapper
):
"""A feature fetcher used to fetch features for node/edge in graphbolt."""
...
...
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
ec495239
"""Neighbor subgraph samplers for GraphBolt."""
from
torch.utils.data
import
functional_datapipe
from
..subgraph_sampler
import
SubgraphSampler
from
..utils
import
unique_and_compact_node_pairs
from
.sampled_subgraph_impl
import
SampledSubgraphImpl
@
functional_datapipe
(
"sample_neighbor"
)
class
NeighborSampler
(
SubgraphSampler
):
"""
Neighbor sampler is responsible for sampling a subgraph from given data. It
...
...
@@ -106,6 +109,7 @@ class NeighborSampler(SubgraphSampler):
return
seeds
,
subgraphs
@
functional_datapipe
(
"sample_layer_neighbor"
)
class
LayerNeighborSampler
(
NeighborSampler
):
"""
Layer-Neighbor sampler is responsible for sampling a subgraph from given
...
...
python/dgl/graphbolt/impl/uniform_negative_sampler.py
View file @
ec495239
"""Uniform negative sampler for GraphBolt."""
from
torch.utils.data
import
functional_datapipe
from
..negative_sampler
import
NegativeSampler
@
functional_datapipe
(
"sample_uniform_negative"
)
class
UniformNegativeSampler
(
NegativeSampler
):
"""
Negative samplers randomly select negative destination nodes for each
...
...
python/dgl/graphbolt/item_sampler.py
View file @
ec495239
...
...
@@ -4,7 +4,7 @@ from collections.abc import Mapping
from
functools
import
partial
from
typing
import
Callable
,
Iterator
,
Optional
from
torch.utils.data
import
default_collate
from
torch.utils.data
import
default_collate
,
functional_datapipe
from
torchdata.datapipes.iter
import
IterableWrapper
,
IterDataPipe
from
..base
import
dgl_warning
...
...
@@ -78,6 +78,7 @@ def minibatcher_default(batch, names):
return
minibatch
@
functional_datapipe
(
"sample_item"
)
class
ItemSampler
(
IterDataPipe
):
"""Item Sampler.
...
...
python/dgl/graphbolt/negative_sampler.py
View file @
ec495239
...
...
@@ -3,11 +3,13 @@
from
_collections_abc
import
Mapping
import
torch
from
torch.utils.data
import
functional_datapipe
from
torchdata.datapipes.iter
import
Mapper
from
.data_format
import
LinkPredictionEdgeFormat
@
functional_datapipe
(
"sample_negative"
)
class
NegativeSampler
(
Mapper
):
"""
A negative sampler used to generate negative samples and return
...
...
python/dgl/graphbolt/subgraph_sampler.py
View file @
ec495239
...
...
@@ -3,12 +3,14 @@
from
collections
import
defaultdict
from
typing
import
Dict
from
torch.utils.data
import
functional_datapipe
from
torchdata.datapipes.iter
import
Mapper
from
.base
import
etype_str_to_tuple
from
.utils
import
unique_and_compact
@
functional_datapipe
(
"sample_subgraph"
)
class
SubgraphSampler
(
Mapper
):
"""A subgraph sampler used to sample a subgraph from a given set of nodes
from a larger graph."""
...
...
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
View file @
ec495239
...
...
@@ -5,6 +5,75 @@ import torch
from
torchdata.datapipes.iter
import
Mapper
def
test_NegativeSampler_invoke
():
# Instantiate graph and required datapipes.
num_seeds
=
30
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
2
*
num_seeds
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
batch_size
=
10
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
)
negative_ratio
=
2
# Invoke NegativeSampler via class constructor.
negative_sampler
=
gb
.
NegativeSampler
(
item_sampler
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
)
with
pytest
.
raises
(
NotImplementedError
):
next
(
iter
(
negative_sampler
))
# Invoke NegativeSampler via functional form.
negative_sampler
=
item_sampler
.
sample_negative
(
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
)
with
pytest
.
raises
(
NotImplementedError
):
next
(
iter
(
negative_sampler
))
def
test_UniformNegativeSampler_invoke
():
# Instantiate graph and required datapipes.
graph
=
gb_test_utils
.
rand_csc_graph
(
100
,
0.05
)
num_seeds
=
30
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
2
*
num_seeds
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
batch_size
=
10
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
)
negative_ratio
=
2
# Verify iteration over UniformNegativeSampler.
def
_verify
(
negative_sampler
):
for
data
in
negative_sampler
:
src
,
dst
=
data
.
node_pairs
labels
=
data
.
labels
# Assertation
assert
len
(
src
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
dst
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
labels
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
torch
.
all
(
torch
.
eq
(
labels
[:
batch_size
],
1
))
assert
torch
.
all
(
torch
.
eq
(
labels
[
batch_size
:],
0
))
# Invoke UniformNegativeSampler via class constructor.
negative_sampler
=
gb
.
UniformNegativeSampler
(
item_sampler
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
graph
,
)
_verify
(
negative_sampler
)
# Invoke UniformNegativeSampler via functional form.
negative_sampler
=
item_sampler
.
sample_uniform_negative
(
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
graph
,
)
_verify
(
negative_sampler
)
@
pytest
.
mark
.
parametrize
(
"negative_ratio"
,
[
1
,
5
,
10
,
20
])
def
test_NegativeSampler_Independent_Format
(
negative_ratio
):
# Construct CSCSamplingGraph.
...
...
tests/python/pytorch/graphbolt/test_base.py
View file @
ec495239
...
...
@@ -11,8 +11,14 @@ import torch
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"cpu"
,
"CopyTo needs GPU to test"
)
def
test_CopyTo
():
dp
=
gb
.
ItemSampler
(
gb
.
ItemSet
(
torch
.
randn
(
20
)),
4
)
# Invoke CopyTo via class constructor.
dp
=
gb
.
CopyTo
(
dp
,
"cuda"
)
for
data
in
dp
:
assert
data
.
device
.
type
==
"cuda"
# Invoke CopyTo via functional form.
dp
=
dp
.
copy_to
(
"cuda"
)
for
data
in
dp
:
assert
data
.
device
.
type
==
"cuda"
...
...
tests/python/pytorch/graphbolt/test_feature_fetcher.py
View file @
ec495239
...
...
@@ -4,6 +4,35 @@ import torch
from
torchdata.datapipes.iter
import
Mapper
def
test_FeatureFetcher_invoke
():
# Prepare graph and required datapipes.
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
a
=
torch
.
randint
(
0
,
10
,
(
graph
.
num_nodes
,))
b
=
torch
.
randint
(
0
,
10
,
(
graph
.
num_edges
,))
features
=
{}
keys
=
[(
"node"
,
None
,
"a"
),
(
"edge"
,
None
,
"b"
)]
features
[
keys
[
0
]]
=
gb
.
TorchBasedFeature
(
a
)
features
[
keys
[
1
]]
=
gb
.
TorchBasedFeature
(
b
)
feature_store
=
gb
.
BasicFeatureStore
(
features
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
# Invoke FeatureFetcher via class constructor.
datapipe
=
gb
.
NeighborSampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
gb
.
FeatureFetcher
(
datapipe
,
feature_store
,
[
"a"
],
[
"b"
])
assert
len
(
list
(
datapipe
))
==
5
# Invoke FeatureFetcher via functional form.
datapipe
=
datapipe
.
sample_neighbor
(
graph
,
fanouts
).
fetch_feature
(
feature_store
,
[
"a"
],
[
"b"
]
)
assert
len
(
list
(
datapipe
))
==
5
def
test_FeatureFetcher_homo
():
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
a
=
torch
.
randint
(
0
,
10
,
(
graph
.
num_nodes
,))
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
ec495239
...
...
@@ -6,6 +6,42 @@ import torchdata.datapipes as dp
from
torchdata.datapipes.iter
import
Mapper
def
test_SubgraphSampler_invoke
():
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
# Invoke via class constructor.
datapipe
=
gb
.
SubgraphSampler
(
datapipe
)
with
pytest
.
raises
(
NotImplementedError
):
next
(
iter
(
datapipe
))
# Invokde via functional form.
datapipe
=
datapipe
.
sample_subgraph
()
with
pytest
.
raises
(
NotImplementedError
):
next
(
iter
(
datapipe
))
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_NeighborSampler_invoke
(
labor
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
# Invoke via class constructor.
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
datapipe
=
Sampler
(
datapipe
,
graph
,
fanouts
)
assert
len
(
list
(
datapipe
))
==
5
# Invokde via functional form.
if
labor
:
datapipe
=
datapipe
.
sample_layer_neighbor
(
graph
,
fanouts
)
else
:
datapipe
=
datapipe
.
sample_neighbor
(
graph
,
fanouts
)
assert
len
(
list
(
datapipe
))
==
5
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Node
(
labor
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment