Unverified Commit 70fdb69f authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[doc] add example for SubgraphSampler (#6752)

parent 9430bec6
......@@ -113,7 +113,7 @@ class SparseNeighborSampler(SubgraphSampler):
fanout = torch.LongTensor([int(fanout)])
self.fanouts.insert(0, fanout)
def _sample_subgraphs(self, seeds):
def sample_subgraphs(self, seeds):
sampled_matrices = []
src = seeds
......
......@@ -67,7 +67,7 @@ class InSubgraphSampler(SubgraphSampler):
self.output_cscformat = output_cscformat
self.sampler = graph.in_subgraph
def _sample_subgraphs(self, seeds):
def sample_subgraphs(self, seeds):
subgraph = self.sampler(seeds, self.output_cscformat)
if not self.output_cscformat:
(
......
......@@ -116,7 +116,7 @@ class NeighborSampler(SubgraphSampler):
self.output_cscformat = output_cscformat
self.sampler = graph.sample_neighbors
def _sample_subgraphs(self, seeds):
def sample_subgraphs(self, seeds):
subgraphs = []
num_layers = len(self.fanouts)
# Enrich seeds with all node types.
......
......@@ -21,6 +21,9 @@ class SubgraphSampler(MiniBatchTransformer):
Functional name: :obj:`sample_subgraph`.
This class is the base class of all subgraph samplers. Any subclass of
SubgraphSampler should implement the :meth:`sample_subgraphs` method.
Parameters
----------
datapipe : DataPipe
......@@ -51,7 +54,7 @@ class SubgraphSampler(MiniBatchTransformer):
(
minibatch.input_nodes,
minibatch.sampled_subgraphs,
) = self._sample_subgraphs(seeds)
) = self.sample_subgraphs(seeds)
return minibatch
def _node_pairs_preprocess(self, minibatch):
......@@ -134,7 +137,7 @@ class SubgraphSampler(MiniBatchTransformer):
compacted_negative_dsts if has_neg_dst else None,
)
def _sample_subgraphs(self, seeds):
def sample_subgraphs(self, seeds):
"""Sample subgraphs from the given seeds.
Any subclass of SubgraphSampler should implement this method.
......@@ -148,7 +151,27 @@ class SubgraphSampler(MiniBatchTransformer):
-------
Union[torch.Tensor, Dict[str, torch.Tensor]]
The input nodes.
SampledSubgraph
List[SampledSubgraph]
The sampled subgraphs.
Examples
--------
>>> @functional_datapipe("my_sample_subgraph")
>>> class MySubgraphSampler(SubgraphSampler):
>>> def __init__(self, datapipe, graph, fanouts):
>>> super().__init__(datapipe)
>>> self.graph = graph
>>> self.fanouts = fanouts
>>> def sample_subgraphs(self, seeds):
>>> # Sample subgraphs from the given seeds.
>>> subgraphs = []
>>> subgraphs_nodes = []
>>> for fanout in reversed(self.fanouts):
>>> subgraph = self.graph.sample_neighbors(seeds, fanout)
>>> subgraphs.insert(0, subgraph)
>>> subgraphs_nodes.append(subgraph.nodes)
>>> seeds = subgraph.nodes
>>> subgraphs_nodes = torch.unique(torch.cat(subgraphs_nodes))
>>> return subgraphs_nodes, subgraphs
"""
raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment