"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a16957159e24e53df0a40bfadbcd0a61b4b3ea8c"
Unverified Commit 70fdb69f authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[doc] add example for SubgraphSampler (#6752)

parent 9430bec6
...@@ -113,7 +113,7 @@ class SparseNeighborSampler(SubgraphSampler): ...@@ -113,7 +113,7 @@ class SparseNeighborSampler(SubgraphSampler):
fanout = torch.LongTensor([int(fanout)]) fanout = torch.LongTensor([int(fanout)])
self.fanouts.insert(0, fanout) self.fanouts.insert(0, fanout)
def _sample_subgraphs(self, seeds): def sample_subgraphs(self, seeds):
sampled_matrices = [] sampled_matrices = []
src = seeds src = seeds
......
...@@ -67,7 +67,7 @@ class InSubgraphSampler(SubgraphSampler): ...@@ -67,7 +67,7 @@ class InSubgraphSampler(SubgraphSampler):
self.output_cscformat = output_cscformat self.output_cscformat = output_cscformat
self.sampler = graph.in_subgraph self.sampler = graph.in_subgraph
def _sample_subgraphs(self, seeds): def sample_subgraphs(self, seeds):
subgraph = self.sampler(seeds, self.output_cscformat) subgraph = self.sampler(seeds, self.output_cscformat)
if not self.output_cscformat: if not self.output_cscformat:
( (
......
...@@ -116,7 +116,7 @@ class NeighborSampler(SubgraphSampler): ...@@ -116,7 +116,7 @@ class NeighborSampler(SubgraphSampler):
self.output_cscformat = output_cscformat self.output_cscformat = output_cscformat
self.sampler = graph.sample_neighbors self.sampler = graph.sample_neighbors
def _sample_subgraphs(self, seeds): def sample_subgraphs(self, seeds):
subgraphs = [] subgraphs = []
num_layers = len(self.fanouts) num_layers = len(self.fanouts)
# Enrich seeds with all node types. # Enrich seeds with all node types.
......
...@@ -21,6 +21,9 @@ class SubgraphSampler(MiniBatchTransformer): ...@@ -21,6 +21,9 @@ class SubgraphSampler(MiniBatchTransformer):
Functional name: :obj:`sample_subgraph`. Functional name: :obj:`sample_subgraph`.
This class is the base class of all subgraph samplers. Any subclass of
SubgraphSampler should implement the :meth:`sample_subgraphs` method.
Parameters Parameters
---------- ----------
datapipe : DataPipe datapipe : DataPipe
...@@ -51,7 +54,7 @@ class SubgraphSampler(MiniBatchTransformer): ...@@ -51,7 +54,7 @@ class SubgraphSampler(MiniBatchTransformer):
( (
minibatch.input_nodes, minibatch.input_nodes,
minibatch.sampled_subgraphs, minibatch.sampled_subgraphs,
) = self._sample_subgraphs(seeds) ) = self.sample_subgraphs(seeds)
return minibatch return minibatch
def _node_pairs_preprocess(self, minibatch): def _node_pairs_preprocess(self, minibatch):
...@@ -134,7 +137,7 @@ class SubgraphSampler(MiniBatchTransformer): ...@@ -134,7 +137,7 @@ class SubgraphSampler(MiniBatchTransformer):
compacted_negative_dsts if has_neg_dst else None, compacted_negative_dsts if has_neg_dst else None,
) )
def _sample_subgraphs(self, seeds): def sample_subgraphs(self, seeds):
"""Sample subgraphs from the given seeds. """Sample subgraphs from the given seeds.
Any subclass of SubgraphSampler should implement this method. Any subclass of SubgraphSampler should implement this method.
...@@ -148,7 +151,27 @@ class SubgraphSampler(MiniBatchTransformer): ...@@ -148,7 +151,27 @@ class SubgraphSampler(MiniBatchTransformer):
------- -------
Union[torch.Tensor, Dict[str, torch.Tensor]] Union[torch.Tensor, Dict[str, torch.Tensor]]
The input nodes. The input nodes.
SampledSubgraph List[SampledSubgraph]
The sampled subgraphs. The sampled subgraphs.
Examples
--------
>>> @functional_datapipe("my_sample_subgraph")
>>> class MySubgraphSampler(SubgraphSampler):
>>> def __init__(self, datapipe, graph, fanouts):
>>> super().__init__(datapipe)
>>> self.graph = graph
>>> self.fanouts = fanouts
>>> def sample_subgraphs(self, seeds):
>>> # Sample subgraphs from the given seeds.
>>> subgraphs = []
>>> subgraphs_nodes = []
>>> for fanout in reversed(self.fanouts):
>>> subgraph = self.graph.sample_neighbors(seeds, fanout)
>>> subgraphs.insert(0, subgraph)
>>> subgraphs_nodes.append(subgraph.nodes)
>>> seeds = subgraph.nodes
>>> subgraphs_nodes = torch.unique(torch.cat(subgraphs_nodes))
>>> return subgraphs_nodes, subgraphs
""" """
raise NotImplementedError raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment