Unverified Commit 4ea2bd45 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Graphbolt] Add MiniBatchTransformer to support exclude edges. (#6330)

parent 47c6fb1f
...@@ -14,6 +14,7 @@ from .feature_store import * ...@@ -14,6 +14,7 @@ from .feature_store import *
from .impl import * from .impl import *
from .itemset import * from .itemset import *
from .item_sampler import * from .item_sampler import *
from .minibatch_transformer import *
from .negative_sampler import * from .negative_sampler import *
from .sampled_subgraph import * from .sampled_subgraph import *
from .subgraph_sampler import * from .subgraph_sampler import *
......
...@@ -4,11 +4,11 @@ from typing import Dict ...@@ -4,11 +4,11 @@ from typing import Dict
from torch.utils.data import functional_datapipe from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import Mapper from .minibatch_transformer import MiniBatchTransformer
@functional_datapipe("fetch_feature") @functional_datapipe("fetch_feature")
class FeatureFetcher(Mapper): class FeatureFetcher(MiniBatchTransformer):
"""A feature fetcher used to fetch features for node/edge in graphbolt.""" """A feature fetcher used to fetch features for node/edge in graphbolt."""
def __init__( def __init__(
......
...@@ -56,7 +56,6 @@ class NeighborSampler(SubgraphSampler): ...@@ -56,7 +56,6 @@ class NeighborSampler(SubgraphSampler):
Examples Examples
------- -------
>>> import dgl.graphbolt as gb >>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8]) >>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5]) >>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
...@@ -165,7 +164,6 @@ class LayerNeighborSampler(NeighborSampler): ...@@ -165,7 +164,6 @@ class LayerNeighborSampler(NeighborSampler):
Examples Examples
------- -------
>>> import dgl.graphbolt as gb >>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8]) >>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5]) >>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
......
...@@ -9,6 +9,7 @@ import dgl ...@@ -9,6 +9,7 @@ import dgl
from .base import etype_str_to_tuple from .base import etype_str_to_tuple
from .sampled_subgraph import SampledSubgraph from .sampled_subgraph import SampledSubgraph
from .utils import add_reverse_edges
__all__ = ["MiniBatch"] __all__ = ["MiniBatch"]
...@@ -225,3 +226,50 @@ class MiniBatch: ...@@ -225,3 +226,50 @@ class MiniBatch:
block.edata[dgl.EID] = subgraph.reverse_edge_ids block.edata[dgl.EID] = subgraph.reverse_edge_ids
return blocks return blocks
def exclude_edges(
minibatch: MiniBatch,
edges: Union[
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor],
],
):
"""
Exclude edges from the sampled subgraphs in the minibatch.
Parameters
----------
minibatch : MiniBatch
The minibatch.
edges : Dict[str, Tuple[torch.Tensor, torch.Tensor]] or Tuple[torch.Tensor, torch.Tensor]
The edges to be excluded.
"""
minibatch.sampled_subgraphs = [
subgraph.exclude_edges(edges)
for subgraph in minibatch.sampled_subgraphs
]
return minibatch
def exclude_seed_edges(minibatch: MiniBatch):
"""Exclude seed edges from the sampled subgraphs in the minibatch."""
return exclude_edges(minibatch, minibatch.node_pairs)
def exclude_seed_edges_and_reverse(
minibatch: MiniBatch, reverse_etypes: Dict[str, str] = None
):
"""
Exclude seed edges and their reverse edges from the sampled subgraphs in
the minibatch.
Parameters
----------
minibatch : MiniBatch
The minibatch.
reverse_etypes : Dict[str, str] = None
The mapping from the original edge types to their reverse edge types.
"""
edges_to_exclude = add_reverse_edges(minibatch.node_pairs, reverse_etypes)
return exclude_edges(minibatch, edges_to_exclude)
"""Mini-batch transformer"""
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import Mapper
from .minibatch import MiniBatch
@functional_datapipe("transform")
class MiniBatchTransformer(Mapper):
"""A mini-batch transformer used to manipulate mini-batch"""
def __init__(
self,
datapipe,
transformer,
):
"""
Initlization for a subgraph transformer.
Parameters
----------
datapipe : DataPipe
The datapipe.
transformer:
The function applied to each minibatch which is responsible for
transforming the minibatch.
"""
super().__init__(datapipe, self._transformer)
self.transformer = transformer
def _transformer(self, minibatch):
minibatch = self.transformer(minibatch)
assert isinstance(
minibatch, MiniBatch
), "The transformer output should be a instance of MiniBatch"
return minibatch
...@@ -3,11 +3,12 @@ ...@@ -3,11 +3,12 @@
from _collections_abc import Mapping from _collections_abc import Mapping
from torch.utils.data import functional_datapipe from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import Mapper
from .minibatch_transformer import MiniBatchTransformer
@functional_datapipe("sample_negative") @functional_datapipe("sample_negative")
class NegativeSampler(Mapper): class NegativeSampler(MiniBatchTransformer):
""" """
A negative sampler used to generate negative samples and return A negative sampler used to generate negative samples and return
a mix of positive and negative samples. a mix of positive and negative samples.
......
...@@ -4,14 +4,14 @@ from collections import defaultdict ...@@ -4,14 +4,14 @@ from collections import defaultdict
from typing import Dict from typing import Dict
from torch.utils.data import functional_datapipe from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import Mapper
from .base import etype_str_to_tuple from .base import etype_str_to_tuple
from .minibatch_transformer import MiniBatchTransformer
from .utils import unique_and_compact from .utils import unique_and_compact
@functional_datapipe("sample_subgraph") @functional_datapipe("sample_subgraph")
class SubgraphSampler(Mapper): class SubgraphSampler(MiniBatchTransformer):
"""A subgraph sampler used to sample a subgraph from a given set of nodes """A subgraph sampler used to sample a subgraph from a given set of nodes
from a larger graph.""" from a larger graph."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment