Unverified Commit 9fee20b9 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Sampler] [Example] SAINTSampler and Simplify GraphSAINT Example (#3879)



* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Fix
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 0dd500e3
...@@ -46,6 +46,14 @@ Samplers ...@@ -46,6 +46,14 @@ Samplers
MultiLayerFullNeighborSampler MultiLayerFullNeighborSampler
ClusterGCNSampler ClusterGCNSampler
ShaDowKHopSampler ShaDowKHopSampler
SAINTSampler
Sampler Transformations
-----------------------
.. autosummary::
:toctree: ../../generated/
as_edge_prediction_sampler as_edge_prediction_sampler
BlockSampler BlockSampler
......
...@@ -8,6 +8,8 @@ Author's code: https://github.com/GraphSAINT/GraphSAINT ...@@ -8,6 +8,8 @@ Author's code: https://github.com/GraphSAINT/GraphSAINT
Contributor: Jiahang Li ([@ljh1064126026](https://github.com/ljh1064126026)) Tang Liu ([@lt610](https://github.com/lt610)) Contributor: Jiahang Li ([@ljh1064126026](https://github.com/ljh1064126026)) Tang Liu ([@lt610](https://github.com/lt610))
For built-in GraphSAINT subgraph samplers with online sampling, use `dgl.dataloading.SAINTSampler`.
## Dependencies ## Dependencies
- Python 3.7.10 - Python 3.7.10
...@@ -69,7 +71,7 @@ python train_sampling.py --task $task $online ...@@ -69,7 +71,7 @@ python train_sampling.py --task $task $online
* Paper: results from the paper * Paper: results from the paper
* Running: results from experiments with the authors' code * Running: results from experiments with the authors' code
* DGL: results from experiments with the DGL example. The experiment config comes from `config.py`. You can modify parameters in the `config.py` to see different performance of different setup. * DGL: results from experiments with the DGL example. The experiment config comes from `config.py`. You can modify parameters in the `config.py` to see different performance of different setup.
> Note that we implement offline sampling and online sampling in training phase. Offline sampling means all subgraphs utilized in training phase come from pre-sampled subgraphs. Online sampling means we discard all pre-sampled subgraphs and re-sample new subgraphs in training phase. > Note that we implement offline sampling and online sampling in training phase. Offline sampling means all subgraphs utilized in training phase come from pre-sampled subgraphs. Online sampling means we discard all pre-sampled subgraphs and re-sample new subgraphs in training phase.
...@@ -132,7 +134,7 @@ python train_sampling.py --task $task $online ...@@ -132,7 +134,7 @@ python train_sampling.py --task $task $online
- We've run experiments 10 times repeatedly to test average and standard deviation of sampling and normalization time. Here we just test time without training model to the end. Moreover, for efficient testing, the hardware and config employed here are not the same as the experiments above, so the sampling time might be a bit different from that above. But we keep the environment consistent in all experiments below. - We've run experiments 10 times repeatedly to test average and standard deviation of sampling and normalization time. Here we just test time without training model to the end. Moreover, for efficient testing, the hardware and config employed here are not the same as the experiments above, so the sampling time might be a bit different from that above. But we keep the environment consistent in all experiments below.
> The config here which is different with that in the section above is only `num_workers_sampler`, `batch_size_sampler` and `num_workers`, which are only correlated to the sampling speed. Other parameters are kept consistent across two sections thus the model's performance is not affected. > The config here which is different with that in the section above is only `num_workers_sampler`, `batch_size_sampler` and `num_workers`, which are only correlated to the sampling speed. Other parameters are kept consistent across two sections thus the model's performance is not affected.
> The value is (average, std). > The value is (average, std).
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from .. import backend as F from .. import backend as F
from .neighbor_sampler import * from .neighbor_sampler import *
from .cluster_gcn import * from .cluster_gcn import *
from .graphsaint import *
from .shadow import * from .shadow import *
from .base import * from .base import *
from . import negative_sampler from . import negative_sampler
......
"""GraphSAINT samplers."""
from ..base import DGLError
from ..random import choice
from ..sampling import random_walk, pack_traces
from .base import set_node_lazy_features, set_edge_lazy_features, Sampler
try:
import torch
except ImportError:
pass
class SAINTSampler(Sampler):
"""Random node/edge/walk sampler from
`GraphSAINT: Graph Sampling Based Inductive Learning Method
<https://arxiv.org/abs/1907.04931>`__
For each call, the sampler samples a node subset and then returns a node induced subgraph.
There are three options for sampling node subsets:
- For :attr:`'node'` sampler, the probability to sample a node is in proportion
to its out-degree.
- The :attr:`'edge'` sampler first samples an edge subset and then use the
end nodes of the edges.
- The :attr:`'walk'` sampler uses the nodes visited by random walks. It uniformly selects
a number of root nodes and then performs a fixed-length random walk from each root node.
Parameters
----------
mode : str
The sampler to use, which can be :attr:`'node'`, :attr:`'edge'`, or :attr:`'walk'`.
budget : int or tuple[int]
Sampler configuration.
- For :attr:`'node'` sampler, budget specifies the number of nodes
in each sampled subgraph.
- For :attr:`'edge'` sampler, budget specifies the number of edges
to sample for inducing a subgraph.
- For :attr:`'walk'` sampler, budget is a tuple. budget[0] specifies
the number of root nodes to generate random walks. budget[1] specifies
the length of a random walk.
cache : bool, optional
If False, it will not cache the probability arrays for sampling. Setting
it to False is required if you want to use the sampler across different graphs.
prefetch_ndata : list[str], optional
The node data to prefetch for the subgraph.
See :ref:`guide-minibatch-prefetching` for a detailed explanation of prefetching.
prefetch_edata : list[str], optional
The edge data to prefetch for the subgraph.
See :ref:`guide-minibatch-prefetching` for a detailed explanation of prefetching.
output_device : device, optional
The device of the output subgraphs.
Examples
--------
>>> import torch
>>> from dgl.dataloading import SAINTSampler, DataLoader
>>> num_iters = 1000
>>> sampler = SAINTSampler(mode='node', budget=6000)
>>> # Assume g.ndata['feat'] and g.ndata['label'] hold node features and labels
>>> dataloader = DataLoader(g, torch.arange(num_iters), sampler, num_workers=4,
... prefetch_ndata=['feat', 'label'])
>>> for subg in dataloader:
... train_on(subg)
"""
def __init__(self, mode, budget, cache=True, prefetch_ndata=None,
prefetch_edata=None, output_device='cpu'):
super().__init__()
self.budget = budget
if mode == 'node':
self.sampler = self.node_sampler
elif mode == 'edge':
self.sampler = self.edge_sampler
elif mode == 'walk':
self.sampler = self.walk_sampler
else:
raise DGLError(f"Expect mode to be 'node', 'edge' or 'walk', got {mode}.")
self.cache = cache
self.prob = None
self.prefetch_ndata = prefetch_ndata or []
self.prefetch_edata = prefetch_edata or []
self.output_device = output_device
def node_sampler(self, g):
"""Node ID sampler for random node sampler"""
# Alternatively, this can be realized by uniformly sampling an edge subset,
# and then take the src node of the sampled edges. However, the number of edges
# is typically much larger than the number of nodes.
if self.cache and self.prob is not None:
prob = self.prob
else:
prob = g.out_degrees().float().clamp(min=1)
if self.cache:
self.prob = prob
return torch.multinomial(prob, num_samples=self.budget,
replacement=True).unique().type(g.idtype)
def edge_sampler(self, g):
"""Node ID sampler for random edge sampler"""
src, dst = g.edges()
if self.cache and self.prob is not None:
prob = self.prob
else:
in_deg = g.in_degrees().float().clamp(min=1)
out_deg = g.out_degrees().float().clamp(min=1)
# We can reduce the sample space by half if graphs are always symmetric.
prob = 1. / in_deg[dst.long()] + 1. / out_deg[src.long()]
prob /= prob.sum()
if self.cache:
self.prob = prob
sampled_edges = torch.unique(choice(len(prob), size=self.budget, prob=prob))
sampled_nodes = torch.cat([src[sampled_edges], dst[sampled_edges]])
return sampled_nodes.unique().type(g.idtype)
def walk_sampler(self, g):
"""Node ID sampler for random walk sampler"""
num_roots, walk_length = self.budget
sampled_roots = torch.randint(0, g.num_nodes(), (num_roots,))
traces, types = random_walk(g, nodes=sampled_roots, length=walk_length)
sampled_nodes, _, _, _ = pack_traces(traces, types)
return sampled_nodes.unique().type(g.idtype)
def sample(self, g, indices):
"""Sampling function
Parameters
----------
g : DGLGraph
The graph to sample from.
indices : Tensor
Placeholder not used.
Returns
-------
DGLGraph
The sampled subgraph.
"""
node_ids = self.sampler(g)
sg = g.subgraph(node_ids, relabel_nodes=True, output_device=self.output_device)
set_node_lazy_features(sg, self.prefetch_ndata)
set_edge_lazy_features(sg, self.prefetch_edata)
return sg
...@@ -51,6 +51,24 @@ def test_shadow(num_workers): ...@@ -51,6 +51,24 @@ def test_shadow(num_workers):
if i == 5: if i == 5:
break break
@pytest.mark.parametrize('num_workers', [0, 4])
@pytest.mark.parametrize('mode', ['node', 'edge', 'walk'])
def test_saint(num_workers, mode):
g = dgl.data.CoraFullDataset()[0]
if mode == 'node':
budget = 100
elif mode == 'edge':
budget = 200
elif mode == 'walk':
budget = (3, 2)
sampler = dgl.dataloading.SAINTSampler(mode, budget)
dataloader = dgl.dataloading.DataLoader(
g, torch.arange(100), sampler, num_workers=num_workers)
assert len(dataloader) == 100
for sg in dataloader:
pass
@pytest.mark.parametrize('num_workers', [0, 4]) @pytest.mark.parametrize('num_workers', [0, 4])
def test_neighbor_nonuniform(num_workers): def test_neighbor_nonuniform(num_workers):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment