Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
4ea2bd45
Unverified
Commit
4ea2bd45
authored
Sep 14, 2023
by
Hongzhi (Steve), Chen
Committed by
GitHub
Sep 14, 2023
Browse files
[Graphbolt] Add MiniBatchTransformer to support exclude edges. (#6330)
parent
47c6fb1f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
93 additions
and
8 deletions
+93
-8
python/dgl/graphbolt/__init__.py
python/dgl/graphbolt/__init__.py
+1
-0
python/dgl/graphbolt/feature_fetcher.py
python/dgl/graphbolt/feature_fetcher.py
+2
-2
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+0
-2
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+48
-0
python/dgl/graphbolt/minibatch_transformer.py
python/dgl/graphbolt/minibatch_transformer.py
+37
-0
python/dgl/graphbolt/negative_sampler.py
python/dgl/graphbolt/negative_sampler.py
+3
-2
python/dgl/graphbolt/subgraph_sampler.py
python/dgl/graphbolt/subgraph_sampler.py
+2
-2
No files found.
python/dgl/graphbolt/__init__.py
View file @
4ea2bd45
...
@@ -14,6 +14,7 @@ from .feature_store import *
...
@@ -14,6 +14,7 @@ from .feature_store import *
from
.impl
import
*
from
.impl
import
*
from
.itemset
import
*
from
.itemset
import
*
from
.item_sampler
import
*
from
.item_sampler
import
*
from
.minibatch_transformer
import
*
from
.negative_sampler
import
*
from
.negative_sampler
import
*
from
.sampled_subgraph
import
*
from
.sampled_subgraph
import
*
from
.subgraph_sampler
import
*
from
.subgraph_sampler
import
*
...
...
python/dgl/graphbolt/feature_fetcher.py
View file @
4ea2bd45
...
@@ -4,11 +4,11 @@ from typing import Dict
...
@@ -4,11 +4,11 @@ from typing import Dict
from
torch.utils.data
import
functional_datapipe
from
torch.utils.data
import
functional_datapipe
from
torchdata.datapipes.iter
import
Mapp
er
from
.minibatch_transformer
import
MiniBatchTransform
er
@
functional_datapipe
(
"fetch_feature"
)
@
functional_datapipe
(
"fetch_feature"
)
class
FeatureFetcher
(
M
app
er
):
class
FeatureFetcher
(
M
iniBatchTransform
er
):
"""A feature fetcher used to fetch features for node/edge in graphbolt."""
"""A feature fetcher used to fetch features for node/edge in graphbolt."""
def
__init__
(
def
__init__
(
...
...
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
4ea2bd45
...
@@ -56,7 +56,6 @@ class NeighborSampler(SubgraphSampler):
...
@@ -56,7 +56,6 @@ class NeighborSampler(SubgraphSampler):
Examples
Examples
-------
-------
>>> import dgl.graphbolt as gb
>>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> from dgl import graphbolt as gb
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
...
@@ -165,7 +164,6 @@ class LayerNeighborSampler(NeighborSampler):
...
@@ -165,7 +164,6 @@ class LayerNeighborSampler(NeighborSampler):
Examples
Examples
-------
-------
>>> import dgl.graphbolt as gb
>>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> from dgl import graphbolt as gb
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
...
...
python/dgl/graphbolt/minibatch.py
View file @
4ea2bd45
...
@@ -9,6 +9,7 @@ import dgl
...
@@ -9,6 +9,7 @@ import dgl
from
.base
import
etype_str_to_tuple
from
.base
import
etype_str_to_tuple
from
.sampled_subgraph
import
SampledSubgraph
from
.sampled_subgraph
import
SampledSubgraph
from
.utils
import
add_reverse_edges
__all__
=
[
"MiniBatch"
]
__all__
=
[
"MiniBatch"
]
...
@@ -225,3 +226,50 @@ class MiniBatch:
...
@@ -225,3 +226,50 @@ class MiniBatch:
block
.
edata
[
dgl
.
EID
]
=
subgraph
.
reverse_edge_ids
block
.
edata
[
dgl
.
EID
]
=
subgraph
.
reverse_edge_ids
return
blocks
return
blocks
def
exclude_edges
(
minibatch
:
MiniBatch
,
edges
:
Union
[
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
],
):
"""
Exclude edges from the sampled subgraphs in the minibatch.
Parameters
----------
minibatch : MiniBatch
The minibatch.
edges : Dict[str, Tuple[torch.Tensor, torch.Tensor]] or Tuple[torch.Tensor, torch.Tensor]
The edges to be excluded.
"""
minibatch
.
sampled_subgraphs
=
[
subgraph
.
exclude_edges
(
edges
)
for
subgraph
in
minibatch
.
sampled_subgraphs
]
return
minibatch
def
exclude_seed_edges
(
minibatch
:
MiniBatch
):
"""Exclude seed edges from the sampled subgraphs in the minibatch."""
return
exclude_edges
(
minibatch
,
minibatch
.
node_pairs
)
def
exclude_seed_edges_and_reverse
(
minibatch
:
MiniBatch
,
reverse_etypes
:
Dict
[
str
,
str
]
=
None
):
"""
Exclude seed edges and their reverse edges from the sampled subgraphs in
the minibatch.
Parameters
----------
minibatch : MiniBatch
The minibatch.
reverse_etypes : Dict[str, str] = None
The mapping from the original edge types to their reverse edge types.
"""
edges_to_exclude
=
add_reverse_edges
(
minibatch
.
node_pairs
,
reverse_etypes
)
return
exclude_edges
(
minibatch
,
edges_to_exclude
)
python/dgl/graphbolt/minibatch_transformer.py
0 → 100644
View file @
4ea2bd45
"""Mini-batch transformer"""
from
torch.utils.data
import
functional_datapipe
from
torchdata.datapipes.iter
import
Mapper
from
.minibatch
import
MiniBatch
@
functional_datapipe
(
"transform"
)
class
MiniBatchTransformer
(
Mapper
):
"""A mini-batch transformer used to manipulate mini-batch"""
def
__init__
(
self
,
datapipe
,
transformer
,
):
"""
Initlization for a subgraph transformer.
Parameters
----------
datapipe : DataPipe
The datapipe.
transformer:
The function applied to each minibatch which is responsible for
transforming the minibatch.
"""
super
().
__init__
(
datapipe
,
self
.
_transformer
)
self
.
transformer
=
transformer
def
_transformer
(
self
,
minibatch
):
minibatch
=
self
.
transformer
(
minibatch
)
assert
isinstance
(
minibatch
,
MiniBatch
),
"The transformer output should be a instance of MiniBatch"
return
minibatch
python/dgl/graphbolt/negative_sampler.py
View file @
4ea2bd45
...
@@ -3,11 +3,12 @@
...
@@ -3,11 +3,12 @@
from
_collections_abc
import
Mapping
from
_collections_abc
import
Mapping
from
torch.utils.data
import
functional_datapipe
from
torch.utils.data
import
functional_datapipe
from
torchdata.datapipes.iter
import
Mapper
from
.minibatch_transformer
import
MiniBatchTransformer
@
functional_datapipe
(
"sample_negative"
)
@
functional_datapipe
(
"sample_negative"
)
class
NegativeSampler
(
M
app
er
):
class
NegativeSampler
(
M
iniBatchTransform
er
):
"""
"""
A negative sampler used to generate negative samples and return
A negative sampler used to generate negative samples and return
a mix of positive and negative samples.
a mix of positive and negative samples.
...
...
python/dgl/graphbolt/subgraph_sampler.py
View file @
4ea2bd45
...
@@ -4,14 +4,14 @@ from collections import defaultdict
...
@@ -4,14 +4,14 @@ from collections import defaultdict
from
typing
import
Dict
from
typing
import
Dict
from
torch.utils.data
import
functional_datapipe
from
torch.utils.data
import
functional_datapipe
from
torchdata.datapipes.iter
import
Mapper
from
.base
import
etype_str_to_tuple
from
.base
import
etype_str_to_tuple
from
.minibatch_transformer
import
MiniBatchTransformer
from
.utils
import
unique_and_compact
from
.utils
import
unique_and_compact
@
functional_datapipe
(
"sample_subgraph"
)
@
functional_datapipe
(
"sample_subgraph"
)
class
SubgraphSampler
(
M
app
er
):
class
SubgraphSampler
(
M
iniBatchTransform
er
):
"""A subgraph sampler used to sample a subgraph from a given set of nodes
"""A subgraph sampler used to sample a subgraph from a given set of nodes
from a larger graph."""
from a larger graph."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment