Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
ec495239
Unverified
Commit
ec495239
authored
Sep 07, 2023
by
Rhett Ying
Committed by
GitHub
Sep 07, 2023
Browse files
[GraphBolt] enable to invoke gb samplers in functional form (#6297)
parent
79a95477
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
158 additions
and
1 deletion
+158
-1
python/dgl/graphbolt/base.py
python/dgl/graphbolt/base.py
+2
-0
python/dgl/graphbolt/feature_fetcher.py
python/dgl/graphbolt/feature_fetcher.py
+3
-0
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+4
-0
python/dgl/graphbolt/impl/uniform_negative_sampler.py
python/dgl/graphbolt/impl/uniform_negative_sampler.py
+3
-0
python/dgl/graphbolt/item_sampler.py
python/dgl/graphbolt/item_sampler.py
+2
-1
python/dgl/graphbolt/negative_sampler.py
python/dgl/graphbolt/negative_sampler.py
+2
-0
python/dgl/graphbolt/subgraph_sampler.py
python/dgl/graphbolt/subgraph_sampler.py
+2
-0
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
+69
-0
tests/python/pytorch/graphbolt/test_base.py
tests/python/pytorch/graphbolt/test_base.py
+6
-0
tests/python/pytorch/graphbolt/test_feature_fetcher.py
tests/python/pytorch/graphbolt/test_feature_fetcher.py
+29
-0
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+36
-0
No files found.
python/dgl/graphbolt/base.py
View file @
ec495239
"""Base types and utilities for Graph Bolt."""
"""Base types and utilities for Graph Bolt."""
from
torch.utils.data
import
functional_datapipe
from
torchdata.datapipes.iter
import
IterDataPipe
from
torchdata.datapipes.iter
import
IterDataPipe
from
..utils
import
recursive_apply
from
..utils
import
recursive_apply
...
@@ -53,6 +54,7 @@ def _to(x, device):
...
@@ -53,6 +54,7 @@ def _to(x, device):
return
x
.
to
(
device
)
if
hasattr
(
x
,
"to"
)
else
x
return
x
.
to
(
device
)
if
hasattr
(
x
,
"to"
)
else
x
@
functional_datapipe
(
"copy_to"
)
class
CopyTo
(
IterDataPipe
):
class
CopyTo
(
IterDataPipe
):
"""DataPipe that transfers each element yielded from the previous DataPipe
"""DataPipe that transfers each element yielded from the previous DataPipe
to the given device.
to the given device.
...
...
python/dgl/graphbolt/feature_fetcher.py
View file @
ec495239
...
@@ -2,9 +2,12 @@
...
@@ -2,9 +2,12 @@
from
typing
import
Dict
from
typing
import
Dict
from
torch.utils.data
import
functional_datapipe
from
torchdata.datapipes.iter
import
Mapper
from
torchdata.datapipes.iter
import
Mapper
@
functional_datapipe
(
"fetch_feature"
)
class
FeatureFetcher
(
Mapper
):
class
FeatureFetcher
(
Mapper
):
"""A feature fetcher used to fetch features for node/edge in graphbolt."""
"""A feature fetcher used to fetch features for node/edge in graphbolt."""
...
...
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
ec495239
"""Neighbor subgraph samplers for GraphBolt."""
"""Neighbor subgraph samplers for GraphBolt."""
from
torch.utils.data
import
functional_datapipe
from
..subgraph_sampler
import
SubgraphSampler
from
..subgraph_sampler
import
SubgraphSampler
from
..utils
import
unique_and_compact_node_pairs
from
..utils
import
unique_and_compact_node_pairs
from
.sampled_subgraph_impl
import
SampledSubgraphImpl
from
.sampled_subgraph_impl
import
SampledSubgraphImpl
@
functional_datapipe
(
"sample_neighbor"
)
class
NeighborSampler
(
SubgraphSampler
):
class
NeighborSampler
(
SubgraphSampler
):
"""
"""
Neighbor sampler is responsible for sampling a subgraph from given data. It
Neighbor sampler is responsible for sampling a subgraph from given data. It
...
@@ -106,6 +109,7 @@ class NeighborSampler(SubgraphSampler):
...
@@ -106,6 +109,7 @@ class NeighborSampler(SubgraphSampler):
return
seeds
,
subgraphs
return
seeds
,
subgraphs
@
functional_datapipe
(
"sample_layer_neighbor"
)
class
LayerNeighborSampler
(
NeighborSampler
):
class
LayerNeighborSampler
(
NeighborSampler
):
"""
"""
Layer-Neighbor sampler is responsible for sampling a subgraph from given
Layer-Neighbor sampler is responsible for sampling a subgraph from given
...
...
python/dgl/graphbolt/impl/uniform_negative_sampler.py
View file @
ec495239
"""Uniform negative sampler for GraphBolt."""
"""Uniform negative sampler for GraphBolt."""
from
torch.utils.data
import
functional_datapipe
from
..negative_sampler
import
NegativeSampler
from
..negative_sampler
import
NegativeSampler
@
functional_datapipe
(
"sample_uniform_negative"
)
class
UniformNegativeSampler
(
NegativeSampler
):
class
UniformNegativeSampler
(
NegativeSampler
):
"""
"""
Negative samplers randomly select negative destination nodes for each
Negative samplers randomly select negative destination nodes for each
...
...
python/dgl/graphbolt/item_sampler.py
View file @
ec495239
...
@@ -4,7 +4,7 @@ from collections.abc import Mapping
...
@@ -4,7 +4,7 @@ from collections.abc import Mapping
from
functools
import
partial
from
functools
import
partial
from
typing
import
Callable
,
Iterator
,
Optional
from
typing
import
Callable
,
Iterator
,
Optional
from
torch.utils.data
import
default_collate
from
torch.utils.data
import
default_collate
,
functional_datapipe
from
torchdata.datapipes.iter
import
IterableWrapper
,
IterDataPipe
from
torchdata.datapipes.iter
import
IterableWrapper
,
IterDataPipe
from
..base
import
dgl_warning
from
..base
import
dgl_warning
...
@@ -78,6 +78,7 @@ def minibatcher_default(batch, names):
...
@@ -78,6 +78,7 @@ def minibatcher_default(batch, names):
return
minibatch
return
minibatch
@
functional_datapipe
(
"sample_item"
)
class
ItemSampler
(
IterDataPipe
):
class
ItemSampler
(
IterDataPipe
):
"""Item Sampler.
"""Item Sampler.
...
...
python/dgl/graphbolt/negative_sampler.py
View file @
ec495239
...
@@ -3,11 +3,13 @@
...
@@ -3,11 +3,13 @@
from
_collections_abc
import
Mapping
from
_collections_abc
import
Mapping
import
torch
import
torch
from
torch.utils.data
import
functional_datapipe
from
torchdata.datapipes.iter
import
Mapper
from
torchdata.datapipes.iter
import
Mapper
from
.data_format
import
LinkPredictionEdgeFormat
from
.data_format
import
LinkPredictionEdgeFormat
@
functional_datapipe
(
"sample_negative"
)
class
NegativeSampler
(
Mapper
):
class
NegativeSampler
(
Mapper
):
"""
"""
A negative sampler used to generate negative samples and return
A negative sampler used to generate negative samples and return
...
...
python/dgl/graphbolt/subgraph_sampler.py
View file @
ec495239
...
@@ -3,12 +3,14 @@
...
@@ -3,12 +3,14 @@
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Dict
from
typing
import
Dict
from
torch.utils.data
import
functional_datapipe
from
torchdata.datapipes.iter
import
Mapper
from
torchdata.datapipes.iter
import
Mapper
from
.base
import
etype_str_to_tuple
from
.base
import
etype_str_to_tuple
from
.utils
import
unique_and_compact
from
.utils
import
unique_and_compact
@
functional_datapipe
(
"sample_subgraph"
)
class
SubgraphSampler
(
Mapper
):
class
SubgraphSampler
(
Mapper
):
"""A subgraph sampler used to sample a subgraph from a given set of nodes
"""A subgraph sampler used to sample a subgraph from a given set of nodes
from a larger graph."""
from a larger graph."""
...
...
tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
View file @
ec495239
...
@@ -5,6 +5,75 @@ import torch
...
@@ -5,6 +5,75 @@ import torch
from
torchdata.datapipes.iter
import
Mapper
from
torchdata.datapipes.iter
import
Mapper
def
test_NegativeSampler_invoke
():
# Instantiate graph and required datapipes.
num_seeds
=
30
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
2
*
num_seeds
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
batch_size
=
10
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
)
negative_ratio
=
2
# Invoke NegativeSampler via class constructor.
negative_sampler
=
gb
.
NegativeSampler
(
item_sampler
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
)
with
pytest
.
raises
(
NotImplementedError
):
next
(
iter
(
negative_sampler
))
# Invoke NegativeSampler via functional form.
negative_sampler
=
item_sampler
.
sample_negative
(
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
)
with
pytest
.
raises
(
NotImplementedError
):
next
(
iter
(
negative_sampler
))
def
test_UniformNegativeSampler_invoke
():
# Instantiate graph and required datapipes.
graph
=
gb_test_utils
.
rand_csc_graph
(
100
,
0.05
)
num_seeds
=
30
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
2
*
num_seeds
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
batch_size
=
10
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
)
negative_ratio
=
2
# Verify iteration over UniformNegativeSampler.
def
_verify
(
negative_sampler
):
for
data
in
negative_sampler
:
src
,
dst
=
data
.
node_pairs
labels
=
data
.
labels
# Assertation
assert
len
(
src
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
dst
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
len
(
labels
)
==
batch_size
*
(
negative_ratio
+
1
)
assert
torch
.
all
(
torch
.
eq
(
labels
[:
batch_size
],
1
))
assert
torch
.
all
(
torch
.
eq
(
labels
[
batch_size
:],
0
))
# Invoke UniformNegativeSampler via class constructor.
negative_sampler
=
gb
.
UniformNegativeSampler
(
item_sampler
,
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
graph
,
)
_verify
(
negative_sampler
)
# Invoke UniformNegativeSampler via functional form.
negative_sampler
=
item_sampler
.
sample_uniform_negative
(
negative_ratio
,
gb
.
LinkPredictionEdgeFormat
.
INDEPENDENT
,
graph
,
)
_verify
(
negative_sampler
)
@
pytest
.
mark
.
parametrize
(
"negative_ratio"
,
[
1
,
5
,
10
,
20
])
@
pytest
.
mark
.
parametrize
(
"negative_ratio"
,
[
1
,
5
,
10
,
20
])
def
test_NegativeSampler_Independent_Format
(
negative_ratio
):
def
test_NegativeSampler_Independent_Format
(
negative_ratio
):
# Construct CSCSamplingGraph.
# Construct CSCSamplingGraph.
...
...
tests/python/pytorch/graphbolt/test_base.py
View file @
ec495239
...
@@ -11,8 +11,14 @@ import torch
...
@@ -11,8 +11,14 @@ import torch
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"cpu"
,
"CopyTo needs GPU to test"
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"cpu"
,
"CopyTo needs GPU to test"
)
def
test_CopyTo
():
def
test_CopyTo
():
dp
=
gb
.
ItemSampler
(
gb
.
ItemSet
(
torch
.
randn
(
20
)),
4
)
dp
=
gb
.
ItemSampler
(
gb
.
ItemSet
(
torch
.
randn
(
20
)),
4
)
# Invoke CopyTo via class constructor.
dp
=
gb
.
CopyTo
(
dp
,
"cuda"
)
dp
=
gb
.
CopyTo
(
dp
,
"cuda"
)
for
data
in
dp
:
assert
data
.
device
.
type
==
"cuda"
# Invoke CopyTo via functional form.
dp
=
dp
.
copy_to
(
"cuda"
)
for
data
in
dp
:
for
data
in
dp
:
assert
data
.
device
.
type
==
"cuda"
assert
data
.
device
.
type
==
"cuda"
...
...
tests/python/pytorch/graphbolt/test_feature_fetcher.py
View file @
ec495239
...
@@ -4,6 +4,35 @@ import torch
...
@@ -4,6 +4,35 @@ import torch
from
torchdata.datapipes.iter
import
Mapper
from
torchdata.datapipes.iter
import
Mapper
def
test_FeatureFetcher_invoke
():
# Prepare graph and required datapipes.
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
a
=
torch
.
randint
(
0
,
10
,
(
graph
.
num_nodes
,))
b
=
torch
.
randint
(
0
,
10
,
(
graph
.
num_edges
,))
features
=
{}
keys
=
[(
"node"
,
None
,
"a"
),
(
"edge"
,
None
,
"b"
)]
features
[
keys
[
0
]]
=
gb
.
TorchBasedFeature
(
a
)
features
[
keys
[
1
]]
=
gb
.
TorchBasedFeature
(
b
)
feature_store
=
gb
.
BasicFeatureStore
(
features
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
# Invoke FeatureFetcher via class constructor.
datapipe
=
gb
.
NeighborSampler
(
datapipe
,
graph
,
fanouts
)
datapipe
=
gb
.
FeatureFetcher
(
datapipe
,
feature_store
,
[
"a"
],
[
"b"
])
assert
len
(
list
(
datapipe
))
==
5
# Invoke FeatureFetcher via functional form.
datapipe
=
datapipe
.
sample_neighbor
(
graph
,
fanouts
).
fetch_feature
(
feature_store
,
[
"a"
],
[
"b"
]
)
assert
len
(
list
(
datapipe
))
==
5
def
test_FeatureFetcher_homo
():
def
test_FeatureFetcher_homo
():
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
a
=
torch
.
randint
(
0
,
10
,
(
graph
.
num_nodes
,))
a
=
torch
.
randint
(
0
,
10
,
(
graph
.
num_nodes
,))
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
ec495239
...
@@ -6,6 +6,42 @@ import torchdata.datapipes as dp
...
@@ -6,6 +6,42 @@ import torchdata.datapipes as dp
from
torchdata.datapipes.iter
import
Mapper
from
torchdata.datapipes.iter
import
Mapper
def
test_SubgraphSampler_invoke
():
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
# Invoke via class constructor.
datapipe
=
gb
.
SubgraphSampler
(
datapipe
)
with
pytest
.
raises
(
NotImplementedError
):
next
(
iter
(
datapipe
))
# Invokde via functional form.
datapipe
=
datapipe
.
sample_subgraph
()
with
pytest
.
raises
(
NotImplementedError
):
next
(
iter
(
datapipe
))
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_NeighborSampler_invoke
(
labor
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
# Invoke via class constructor.
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
datapipe
=
Sampler
(
datapipe
,
graph
,
fanouts
)
assert
len
(
list
(
datapipe
))
==
5
# Invokde via functional form.
if
labor
:
datapipe
=
datapipe
.
sample_layer_neighbor
(
graph
,
fanouts
)
else
:
datapipe
=
datapipe
.
sample_neighbor
(
graph
,
fanouts
)
assert
len
(
list
(
datapipe
))
==
5
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Node
(
labor
):
def
test_SubgraphSampler_Node
(
labor
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment