"git@developer.sourcefind.cn:OpenDAS/fastmoe.git" did not exist on "5ead59db0f750ce755b98c517de6bef3898b34d5"
Unverified Commit 436de3d1 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[API Deprecation] Remove _dataloading and tgcn example (#5118)

parent c8bc5588
# Temporal Graph Neural Network (TGN)
## DGL Implementation of tgn paper.
This DGL examples implements the GNN mode proposed in the paper [TemporalGraphNeuralNetwork](https://arxiv.org/abs/2006.10637.pdf)
## TGN implementor
This example was implemented by [Ericcsr](https://github.com/Ericcsr) during his SDE internship at the AWS Shanghai AI Lab.
## Graph Dataset
Jodie Wikipedia Temporal dataset. Dataset summary:
- Num Nodes: 9227
- Num Edges: 157, 474
- Num Edge Features: 172
- Edge Feature type: LIWC
- Time Span: 30 days
- Chronological Split: Train: 70% Valid: 15% Test: 15%
Jodie Reddit Temporal dataset. Dataset summary:
- Num Nodes: 11,000
- Num Edges: 672, 447
- Num Edge Features: 172
- Edge Feature type: LIWC
- Time Span: 30 days
- Chronological Split: Train: 70% Valid: 15% Test: 15%
## How to run example files
In tgn folder, run
**please use `train.py`**
```python
python train.py --dataset wikipedia
```
If you want to run in fast mode:
```python
python train.py --dataset wikipedia --fast_mode
```
If you want to run in simple mode:
```python
python train.py --dataset wikipedia --simple_mode
```
If you want to change memory updating module:
```python
python train.py --dataset wikipedia --memory_updater [rnn/gru]
```
If you want to use TGAT:
```python
python train.py --dataset wikipedia --not_use_memory --k_hop 2
```
## Performance
#### Without New Node in test set
| Models/Datasets | Wikipedia | Reddit |
| --------------- | ------------------ | ---------------- |
| TGN simple mode | AP: 98.5 AUC: 98.9 | AP: N/A AUC: N/A |
| TGN fast mode | AP: 98.2 AUC: 98.6 | AP: N/A AUC: N/A |
| TGN | AP: 98.9 AUC: 98.5 | AP: N/A AUC: N/A |
#### With New Node in test set
| Models/Datasets | Wikipedia | Reddit |
| --------------- | ------------------- | ---------------- |
| TGN simple mode | AP: 98.2 AUC: 98.6 | AP: N/A AUC: N/A |
| TGN fast mode | AP: 98.0 AUC: 98.4 | AP: N/A AUC: N/A |
| TGN | AP: 98.2 AUC: 98.1 | AP: N/A AUC: N/A |
## Training Speed / Batch
Intel E5 2cores, Tesla K80, Wikipedia Dataset
| Models/Datasets | Wikipedia | Reddit |
| --------------- | --------- | -------- |
| TGN simple mode | 0.3s | N/A |
| TGN fast mode | 0.28s | N/A |
| TGN | 1.3s | N/A |
### Details explained
**What is Simple Mode**
Simple Temporal Sampler just choose the edges that happen before the current timestamp and build the subgraph of the corresponding nodes.
And then the simple sampler uses the static graph neighborhood sampling methods.
**What is Fast Mode**
Normally temporal encoding needs each node to use incoming time frame as current time which might lead to two nodes have multiple interactions within the same batch need to maintain multiple embedding features which slow down the batching process to avoid feature duplication, fast mode enables fast batching since it uses last memory update time in the last batch as temporal encoding benchmark for each node. Also within each batch, all interaction between two nodes are predicted using the same set of embedding feature
**What is New Node test**
To test the model has the ability to predict link between unseen nodes based on neighboring information of seen nodes. This model deliberately select 10 % of node in test graph and mask them out during the training.
**Why the attention module is not exactly same as TGN original paper**
Attention module used in this model is adapted from DGL GATConv, considering edge feature and time encoding. It is more memory efficient and faster to compute then the attention module proposed in the paper, meanwhile, according to our test, the accuracy of our module compared with the one in paper is the same.
import os
import ssl
import numpy as np
import pandas as pd
import torch
from six.moves import urllib
import dgl
# === Below data preprocessing code are based on
# https://github.com/twitter-research/tgn
# Preprocess the raw data split each features
def preprocess(data_name):
u_list, i_list, ts_list, label_list = [], [], [], []
feat_l = []
idx_list = []
with open(data_name) as f:
s = next(f)
for idx, line in enumerate(f):
e = line.strip().split(",")
u = int(e[0])
i = int(e[1])
ts = float(e[2])
label = float(e[3]) # int(e[3])
feat = np.array([float(x) for x in e[4:]])
u_list.append(u)
i_list.append(i)
ts_list.append(ts)
label_list.append(label)
idx_list.append(idx)
feat_l.append(feat)
return pd.DataFrame(
{
"u": u_list,
"i": i_list,
"ts": ts_list,
"label": label_list,
"idx": idx_list,
}
), np.array(feat_l)
# Re index nodes for DGL convience
def reindex(df, bipartite=True):
new_df = df.copy()
if bipartite:
assert df.u.max() - df.u.min() + 1 == len(df.u.unique())
assert df.i.max() - df.i.min() + 1 == len(df.i.unique())
upper_u = df.u.max() + 1
new_i = df.i + upper_u
new_df.i = new_i
new_df.u += 1
new_df.i += 1
new_df.idx += 1
else:
new_df.u += 1
new_df.i += 1
new_df.idx += 1
return new_df
# Save edge list, features in different file for data easy process data
def run(data_name, bipartite=True):
PATH = "./data/{}.csv".format(data_name)
OUT_DF = "./data/ml_{}.csv".format(data_name)
OUT_FEAT = "./data/ml_{}.npy".format(data_name)
OUT_NODE_FEAT = "./data/ml_{}_node.npy".format(data_name)
df, feat = preprocess(PATH)
new_df = reindex(df, bipartite)
empty = np.zeros(feat.shape[1])[np.newaxis, :]
feat = np.vstack([empty, feat])
max_idx = max(new_df.u.max(), new_df.i.max())
rand_feat = np.zeros((max_idx + 1, 172))
new_df.to_csv(OUT_DF)
np.save(OUT_FEAT, feat)
np.save(OUT_NODE_FEAT, rand_feat)
# === code from twitter-research-tgn end ===
# If you have new dataset follow by same format in Jodie,
# you can directly use name to retrieve dataset
def TemporalDataset(dataset):
if not os.path.exists("./data/{}.bin".format(dataset)):
if not os.path.exists("./data/{}.csv".format(dataset)):
if not os.path.exists("./data"):
os.mkdir("./data")
url = "https://snap.stanford.edu/jodie/{}.csv".format(dataset)
print("Start Downloading File....")
context = ssl._create_unverified_context()
data = urllib.request.urlopen(url, context=context)
with open("./data/{}.csv".format(dataset), "wb") as handle:
handle.write(data.read())
print("Start Process Data ...")
run(dataset)
raw_connection = pd.read_csv("./data/ml_{}.csv".format(dataset))
raw_feature = np.load("./data/ml_{}.npy".format(dataset))
# -1 for re-index the node
src = raw_connection["u"].to_numpy() - 1
dst = raw_connection["i"].to_numpy() - 1
# Create directed graph
g = dgl.graph((src, dst))
g.edata["timestamp"] = torch.from_numpy(raw_connection["ts"].to_numpy())
g.edata["label"] = torch.from_numpy(raw_connection["label"].to_numpy())
g.edata["feats"] = torch.from_numpy(raw_feature[1:, :]).float()
dgl.save_graphs("./data/{}.bin".format(dataset), [g])
else:
print("Data is exist directly loaded.")
gs, _ = dgl.load_graphs("./data/{}.bin".format(dataset))
g = gs[0]
return g
def TemporalWikipediaDataset():
# Download the dataset
return TemporalDataset("wikipedia")
def TemporalRedditDataset():
return TemporalDataset("reddit")
import torch
import dgl
from dgl._dataloading.dataloader import EdgeCollator
from dgl._dataloading import BlockSampler
from dgl._dataloading.pytorch import _pop_subgraph_storage, _pop_storages, EdgeDataLoader
from dgl.base import DGLError
from functools import partial
import copy
import dgl.function as fn
def _prepare_tensor(g, data, name, is_distributed):
return torch.tensor(data) if is_distributed else dgl.utils.prepare_tensor(g, data, name)
class TemporalSampler(BlockSampler):
""" Temporal Sampler builds computational and temporal dependency of node representations via
temporal neighbors selection and screening.
The sampler expects input node to have same time stamps, in the case of TGN, it should be
either positive [src,dst] pair or negative samples. It will first take in-subgraph of seed
nodes and then screening out edges which happen after that timestamp. Finally it will sample
a fixed number of neighbor edges using random or topk sampling.
Parameters
----------
sampler_type : str
sampler indication string of the final sampler.
If 'topk' then sample topk most recent nodes
If 'uniform' then uniform randomly sample k nodes
k : int
maximum number of neighors to sampler
default 10 neighbors as paper stated
Examples
----------
Please refers to examples/pytorch/tgn/train.py
"""
def __init__(self, sampler_type='topk', k=10):
super(TemporalSampler, self).__init__(1, False)
if sampler_type == 'topk':
self.sampler = partial(
dgl.sampling.select_topk, k=k, weight='timestamp')
elif sampler_type == 'uniform':
self.sampler = partial(dgl.sampling.sample_neighbors, fanout=k)
else:
raise DGLError(
"Sampler string invalid please use \'topk\' or \'uniform\'")
def sampler_frontier(self,
block_id,
g,
seed_nodes,
timestamp):
full_neighbor_subgraph = dgl.in_subgraph(g, seed_nodes)
full_neighbor_subgraph = dgl.add_edges(full_neighbor_subgraph,
seed_nodes, seed_nodes)
temporal_edge_mask = (full_neighbor_subgraph.edata['timestamp'] < timestamp) + (
full_neighbor_subgraph.edata['timestamp'] <= 0)
temporal_subgraph = dgl.edge_subgraph(
full_neighbor_subgraph, temporal_edge_mask)
# Map preserve ID
temp2origin = temporal_subgraph.ndata[dgl.NID]
# The added new edgge will be preserved hence
root2sub_dict = dict(
zip(temp2origin.tolist(), temporal_subgraph.nodes().tolist()))
temporal_subgraph.ndata[dgl.NID] = g.ndata[dgl.NID][temp2origin]
seed_nodes = [root2sub_dict[int(n)] for n in seed_nodes]
final_subgraph = self.sampler(g=temporal_subgraph, nodes=seed_nodes)
final_subgraph.remove_self_loop()
return final_subgraph
# Temporal Subgraph
def sample_blocks(self,
g,
seed_nodes,
timestamp):
blocks = []
frontier = self.sampler_frontier(0, g, seed_nodes, timestamp)
#block = transform.to_block(frontier,seed_nodes)
block = frontier
if self.return_eids:
self.assign_block_eids(block, frontier)
blocks.append(block)
return blocks
class TemporalEdgeCollator(EdgeCollator):
""" Temporal Edge collator merge the edges specified by eid: items
Since we cannot keep duplicated nodes on a graph we need to iterate though
the incoming edges and expand the duplicated node and form a batched block
graph capture the temporal and computational dependency.
Parameters
----------
g : DGLGraph
The graph from which the edges are iterated in minibatches and the subgraphs
are generated.
eids : Tensor or dict[etype, Tensor]
The edge set in graph :attr:`g` to compute outputs.
graph_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler.
g_sampling : DGLGraph, optional
The graph where neighborhood sampling and message passing is performed.
Note that this is not necessarily the same as :attr:`g`.
If None, assume to be the same as :attr:`g`.
exclude : str, optional
Whether and how to exclude dependencies related to the sampled edges in the
minibatch. Possible values are
* None, which excludes nothing.
* ``'reverse_id'``, which excludes the reverse edges of the sampled edges. The said
reverse edges have the same edge type as the sampled edges. Only works
on edge types whose source node type is the same as its destination node type.
* ``'reverse_types'``, which excludes the reverse edges of the sampled edges. The
said reverse edges have different edge types from the sampled edges.
If ``g_sampling`` is given, ``exclude`` is ignored and will be always ``None``.
reverse_eids : Tensor or dict[etype, Tensor], optional
The mapping from original edge ID to its reverse edge ID.
Required and only used when ``exclude`` is set to ``reverse_id``.
For heterogeneous graph this will be a dict of edge type and edge IDs. Note that
only the edge types whose source node type is the same as destination node type
are needed.
reverse_etypes : dict[etype, etype], optional
The mapping from the edge type to its reverse edge type.
Required and only used when ``exclude`` is set to ``reverse_types``.
negative_sampler : callable, optional
The negative sampler. Can be omitted if no negative sampling is needed.
The negative sampler must be a callable that takes in the following arguments:
* The original (heterogeneous) graph.
* The ID array of sampled edges in the minibatch, or the dictionary of edge
types and ID array of sampled edges in the minibatch if the graph is
heterogeneous.
It should return
* A pair of source and destination node ID arrays as negative samples,
or a dictionary of edge types and such pairs if the graph is heterogenenous.
A set of builtin negative samplers are provided in
:ref:`the negative sampling module <api-dataloading-negative-sampling>`.
example
----------
Please refers to examples/pytorch/tgn/train.py
"""
def _collate_with_negative_sampling(self, items):
items = _prepare_tensor(self.g_sampling, items, 'items', False)
# Here node id will not change
pair_graph = self.g.edge_subgraph(items, relabel_nodes=False)
induced_edges = pair_graph.edata[dgl.EID]
neg_srcdst_raw = self.negative_sampler(self.g, items)
neg_srcdst = {self.g.canonical_etypes[0]: neg_srcdst_raw}
dtype = list(neg_srcdst.values())[0][0].dtype
neg_edges = {
etype: neg_srcdst.get(etype, (torch.tensor(
[], dtype=dtype), torch.tensor([], dtype=dtype)))
for etype in self.g.canonical_etypes}
neg_pair_graph = dgl.heterograph(
neg_edges, {ntype: self.g.number_of_nodes(ntype) for ntype in self.g.ntypes})
pair_graph, neg_pair_graph = dgl.transforms.compact_graphs(
[pair_graph, neg_pair_graph])
# Need to remap id
pair_graph.ndata[dgl.NID] = self.g.nodes()[pair_graph.ndata[dgl.NID]]
neg_pair_graph.ndata[dgl.NID] = self.g.nodes()[
neg_pair_graph.ndata[dgl.NID]]
pair_graph.edata[dgl.EID] = induced_edges
batch_graphs = []
nodes_id = []
timestamps = []
for i, edge in enumerate(zip(self.g.edges()[0][items], self.g.edges()[1][items])):
ts = pair_graph.edata['timestamp'][i]
timestamps.append(ts)
subg = self.graph_sampler.sample_blocks(self.g_sampling,
list(edge),
timestamp=ts)[0]
subg.ndata['timestamp'] = ts.repeat(subg.num_nodes())
nodes_id.append(subg.srcdata[dgl.NID])
batch_graphs.append(subg)
timestamps = torch.tensor(timestamps).repeat_interleave(
self.negative_sampler.k)
for i, neg_edge in enumerate(zip(neg_srcdst_raw[0].tolist(), neg_srcdst_raw[1].tolist())):
ts = timestamps[i]
subg = self.graph_sampler.sample_blocks(self.g_sampling,
[neg_edge[1]],
timestamp=ts)[0]
subg.ndata['timestamp'] = ts.repeat(subg.num_nodes())
batch_graphs.append(subg)
blocks = [dgl.batch(batch_graphs)]
input_nodes = torch.cat(nodes_id)
return input_nodes, pair_graph, neg_pair_graph, blocks
def collator(self, items):
"""
The interface of collator, input items is edge id of the attached graph
"""
result = super().collate(items)
# Copy the feature from parent graph
_pop_subgraph_storage(result[1], self.g)
_pop_subgraph_storage(result[2], self.g)
_pop_storages(result[-1], self.g_sampling)
return result
class TemporalEdgeDataLoader(EdgeDataLoader):
""" TemporalEdgeDataLoader is an iteratable object to generate blocks for temporal embedding
as well as pos and neg pair graph for memory update.
The batch generated will follow temporal order
Parameters
----------
g : dgl.Heterograph
graph for batching the temporal edge id as well as generate negative subgraph
eids : torch.tensor() or numpy array
eids range which to be batched, it is useful to split training validation test dataset
graph_sampler : dgl.dataloading.BlockSampler
temporal neighbor sampler which sample temporal and computationally depend blocks for computation
device : str
'cpu' means load dataset on cpu
'cuda' means load dataset on gpu
collator : dgl.dataloading.EdgeCollator
Merge input eid from pytorch dataloader to graph
Example
----------
Please refers to examples/pytorch/tgn/train.py
"""
def __init__(self, g, eids, graph_sampler, device='cpu', collator=TemporalEdgeCollator, **kwargs):
super().__init__(g, eids, graph_sampler, device, **kwargs)
collator_kwargs = {}
dataloader_kwargs = {}
for k, v in kwargs.items():
if k in self.collator_arglist:
collator_kwargs[k] = v
else:
dataloader_kwargs[k] = v
self.collator = collator(g, eids, graph_sampler, **collator_kwargs)
assert not isinstance(g, dgl.distributed.DistGraph), \
'EdgeDataLoader does not support DistGraph for now. ' \
+ 'Please use DistDataLoader directly.'
self.dataloader = torch.utils.data.DataLoader(
self.collator.dataset, collate_fn=self.collator.collate, **dataloader_kwargs)
self.device = device
# Precompute the CSR and CSC representations so each subprocess does not
# duplicate.
if dataloader_kwargs.get('num_workers', 0) > 0:
g.create_formats_()
def __iter__(self):
return iter(self.dataloader)
# ====== Fast Mode ======
# Part of code in reservoir sampling comes from PyG library
# https://github.com/rusty1s/pytorch_geometric/nn/models/tgn.py
class FastTemporalSampler(BlockSampler):
"""Temporal Sampler which implemented with a fast query lookup table. Sample
temporal and computationally depending subgraph.
The sampler maintains a lookup table of most current k neighbors of each node
each time, the sampler need to be updated with new edges from incoming batch to
update the lookup table.
Parameters
----------
g : dgl.Heterograph
graph to be sampled here it which only exist to provide feature and data reference
k : int
number of neighbors the lookup table is maintaining
device : str
indication str which represent where the data will be stored
'cpu' store the intermediate data on cpu memory
'cuda' store the intermediate data on gpu memory
Example
----------
Please refers to examples/pytorch/tgn/train.py
"""
def __init__(self, g, k, device='cpu'):
self.k = k
self.g = g
num_nodes = g.num_nodes()
self.neighbors = torch.empty(
(num_nodes, k), dtype=torch.long, device=device)
self.e_id = torch.empty(
(num_nodes, k), dtype=torch.long, device=device)
self.__assoc__ = torch.empty(
num_nodes, dtype=torch.long, device=device)
self.last_update = torch.zeros(num_nodes, dtype=torch.double)
self.reset()
def sample_frontier(self,
block_id,
g,
seed_nodes):
n_id = seed_nodes
# Here Assume n_id is the bg nid
neighbors = self.neighbors[n_id]
nodes = n_id.view(-1, 1).repeat(1, self.k)
e_id = self.e_id[n_id]
mask = e_id >= 0
neighbors[~mask] = nodes[~mask]
# Screen out orphan node
orphans = nodes[~mask].unique()
nodes = nodes[mask]
neighbors = neighbors[mask]
e_id = e_id[mask]
neighbors = neighbors.flatten()
nodes = nodes.flatten()
n_id = torch.cat([nodes, neighbors]).unique()
self.__assoc__[n_id] = torch.arange(n_id.size(0), device=n_id.device)
neighbors, nodes = self.__assoc__[neighbors], self.__assoc__[nodes]
subg = dgl.graph((nodes, neighbors))
# New node to complement orphans which haven't created
subg.add_nodes(len(orphans))
# Copy the seed node feature to subgraph
subg.edata['timestamp'] = torch.zeros(subg.num_edges()).double()
subg.edata['timestamp'] = self.g.edata['timestamp'][e_id]
n_id = torch.cat([n_id, orphans])
subg.ndata['timestamp'] = self.last_update[n_id]
subg.edata['feats'] = torch.zeros(
(subg.num_edges(), self.g.edata['feats'].shape[1])).float()
subg.edata['feats'] = self.g.edata['feats'][e_id]
subg = dgl.add_self_loop(subg)
subg.ndata[dgl.NID] = n_id
return subg
def sample_blocks(self,
g,
seed_nodes):
blocks = []
frontier = self.sample_frontier(0, g, seed_nodes)
block = frontier
blocks.append(block)
return blocks
def add_edges(self, src, dst):
"""
Add incoming batch edge info to the lookup table
Parameters
----------
src : torch.Tensor
src node of incoming batch of it should be consistent with self.g
dst : torch.Tensor
src node of incoming batch of it should be consistent with self.g
"""
neighbors = torch.cat([src, dst], dim=0)
nodes = torch.cat([dst, src], dim=0)
e_id = torch.arange(self.cur_e_id, self.cur_e_id + src.size(0),
device=src.device).repeat(2)
self.cur_e_id += src.numel()
# Convert newly encountered interaction ids so that they point to
# locations of a "dense" format of shape [num_nodes, size].
nodes, perm = nodes.sort()
neighbors, e_id = neighbors[perm], e_id[perm]
n_id = nodes.unique()
self.__assoc__[n_id] = torch.arange(n_id.numel(), device=n_id.device)
dense_id = torch.arange(nodes.size(0), device=nodes.device) % self.k
dense_id += self.__assoc__[nodes].mul_(self.k)
dense_e_id = e_id.new_full((n_id.numel() * self.k, ), -1)
dense_e_id[dense_id] = e_id
dense_e_id = dense_e_id.view(-1, self.k)
dense_neighbors = e_id.new_empty(n_id.numel() * self.k)
dense_neighbors[dense_id] = neighbors
dense_neighbors = dense_neighbors.view(-1, self.k)
# Collect new and old interactions...
e_id = torch.cat([self.e_id[n_id, :self.k], dense_e_id], dim=-1)
neighbors = torch.cat(
[self.neighbors[n_id, :self.k], dense_neighbors], dim=-1)
# And sort them based on `e_id`.
e_id, perm = e_id.topk(self.k, dim=-1)
self.e_id[n_id] = e_id
self.neighbors[n_id] = torch.gather(neighbors, 1, perm)
def reset(self):
"""
Clean up the lookup table
"""
self.cur_e_id = 0
self.e_id.fill_(-1)
def attach_last_update(self, last_t):
"""
Attach current last timestamp a node has been updated
Parameters:
----------
last_t : torch.Tensor
last timestamp a node has been updated its size need to be consistent with self.g
"""
self.last_update = last_t
def sync(self, sampler):
"""
Copy the lookup table information from another sampler
This method is useful run the test dataset with new node,
when test new node dataset the lookup table's state should
be restored from the sampler just after validation
Parameters
----------
sampler : FastTemporalSampler
The sampler from which current sampler get the lookup table info
"""
self.cur_e_id = sampler.cur_e_id
self.neighbors = copy.deepcopy(sampler.neighbors)
self.e_id = copy.deepcopy(sampler.e_id)
self.__assoc__ = copy.deepcopy(sampler.__assoc__)
class FastTemporalEdgeCollator(EdgeCollator):
""" Temporal Edge collator merge the edges specified by eid: items
Since we cannot keep duplicated nodes on a graph we need to iterate though
the incoming edges and expand the duplicated node and form a batched block
graph capture the temporal and computational dependency.
Parameters
----------
g : DGLGraph
The graph from which the edges are iterated in minibatches and the subgraphs
are generated.
eids : Tensor or dict[etype, Tensor]
The edge set in graph :attr:`g` to compute outputs.
graph_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler.
g_sampling : DGLGraph, optional
The graph where neighborhood sampling and message passing is performed.
Note that this is not necessarily the same as :attr:`g`.
If None, assume to be the same as :attr:`g`.
exclude : str, optional
Whether and how to exclude dependencies related to the sampled edges in the
minibatch. Possible values are
* None, which excludes nothing.
* ``'reverse_id'``, which excludes the reverse edges of the sampled edges. The said
reverse edges have the same edge type as the sampled edges. Only works
on edge types whose source node type is the same as its destination node type.
* ``'reverse_types'``, which excludes the reverse edges of the sampled edges. The
said reverse edges have different edge types from the sampled edges.
If ``g_sampling`` is given, ``exclude`` is ignored and will be always ``None``.
reverse_eids : Tensor or dict[etype, Tensor], optional
The mapping from original edge ID to its reverse edge ID.
Required and only used when ``exclude`` is set to ``reverse_id``.
For heterogeneous graph this will be a dict of edge type and edge IDs. Note that
only the edge types whose source node type is the same as destination node type
are needed.
reverse_etypes : dict[etype, etype], optional
The mapping from the edge type to its reverse edge type.
Required and only used when ``exclude`` is set to ``reverse_types``.
negative_sampler : callable, optional
The negative sampler. Can be omitted if no negative sampling is needed.
The negative sampler must be a callable that takes in the following arguments:
* The original (heterogeneous) graph.
* The ID array of sampled edges in the minibatch, or the dictionary of edge
types and ID array of sampled edges in the minibatch if the graph is
heterogeneous.
It should return
* A pair of source and destination node ID arrays as negative samples,
or a dictionary of edge types and such pairs if the graph is heterogenenous.
A set of builtin negative samplers are provided in
:ref:`the negative sampling module <api-dataloading-negative-sampling>`.
example
----------
Please refers to examples/pytorch/tgn/train.py
"""
def _collate_with_negative_sampling(self, items):
items = _prepare_tensor(self.g_sampling, items, 'items', False)
# Here node id will not change
pair_graph = self.g.edge_subgraph(items, relabel_nodes=False)
induced_edges = pair_graph.edata[dgl.EID]
neg_srcdst_raw = self.negative_sampler(self.g, items)
neg_srcdst = {self.g.canonical_etypes[0]: neg_srcdst_raw}
dtype = list(neg_srcdst.values())[0][0].dtype
neg_edges = {
etype: neg_srcdst.get(etype, (torch.tensor(
[], dtype=dtype), torch.tensor([], dtype=dtype)))
for etype in self.g.canonical_etypes}
neg_pair_graph = dgl.heterograph(
neg_edges, {ntype: self.g.number_of_nodes(ntype) for ntype in self.g.ntypes})
pair_graph, neg_pair_graph = dgl.transforms.compact_graphs(
[pair_graph, neg_pair_graph])
# Need to remap id
pair_graph.ndata[dgl.NID] = self.g.nodes()[pair_graph.ndata[dgl.NID]]
neg_pair_graph.ndata[dgl.NID] = self.g.nodes()[
neg_pair_graph.ndata[dgl.NID]]
pair_graph.edata[dgl.EID] = induced_edges
seed_nodes = pair_graph.ndata[dgl.NID]
blocks = self.graph_sampler.sample_blocks(self.g_sampling, seed_nodes)
blocks[0].ndata['timestamp'] = torch.zeros(
blocks[0].num_nodes()).double()
input_nodes = blocks[0].edges()[1]
# update sampler
_src = self.g.nodes()[self.g.edges()[0][items]]
_dst = self.g.nodes()[self.g.edges()[1][items]]
self.graph_sampler.add_edges(_src, _dst)
return input_nodes, pair_graph, neg_pair_graph, blocks
def collator(self, items):
result = super().collate(items)
# Copy the feature from parent graph
_pop_subgraph_storage(result[1], self.g)
_pop_subgraph_storage(result[2], self.g)
_pop_storages(result[-1], self.g_sampling)
return result
# ====== Simple Mode ======
# Part of code comes from paper
# "APAN: Asynchronous Propagation Attention Network for Real-time Temporal Graph Embedding"
# that will be appeared in SIGMOD 21, code repo https://github.com/WangXuhongCN/APAN
class SimpleTemporalSampler(BlockSampler):
'''
Simple Temporal Sampler just choose the edges that happen before the current timestamp, to build the subgraph of the corresponding nodes.
And then the sampler uses the simplest static graph neighborhood sampling methods.
Parameters
----------
fanouts : [int, ..., int] int list
The neighbors sampling strategy
'''
def __init__(self, g, fanouts, return_eids=False):
super().__init__(len(fanouts), return_eids)
self.fanouts = fanouts
self.ts = 0
self.frontiers = [None for _ in range(len(fanouts))]
def sample_frontier(self, block_id, g, seed_nodes):
'''
Deleting the the edges that happen after the current timestamp, then use a simple topk edge sampling by timestamp.
'''
fanout = self.fanouts[block_id]
# List of neighbors to sample per edge type for each GNN layer, starting from the first layer.
g = dgl.in_subgraph(g, seed_nodes)
g.remove_edges(torch.where(g.edata['timestamp'] > self.ts)[0]) # Deleting the the edges that happen after the current timestamp
if fanout is None: # full neighborhood sampling
frontier = g
else:
frontier = dgl.sampling.select_topk(g, fanout, 'timestamp', seed_nodes) # most recent timestamp edge sampling
self.frontiers[block_id] = frontier # save frontier
return frontier
class SimpleTemporalEdgeCollator(EdgeCollator):
'''
Temporal Edge collator merge the edges specified by eid: items
Parameters
----------
g : DGLGraph
The graph from which the edges are iterated in minibatches and the subgraphs
are generated.
eids : Tensor or dict[etype, Tensor]
The edge set in graph :attr:`g` to compute outputs.
graph_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler.
g_sampling : DGLGraph, optional
The graph where neighborhood sampling and message passing is performed.
Note that this is not necessarily the same as :attr:`g`.
If None, assume to be the same as :attr:`g`.
exclude : str, optional
Whether and how to exclude dependencies related to the sampled edges in the
minibatch. Possible values are
* None, which excludes nothing.
* ``'reverse_id'``, which excludes the reverse edges of the sampled edges. The said
reverse edges have the same edge type as the sampled edges. Only works
on edge types whose source node type is the same as its destination node type.
* ``'reverse_types'``, which excludes the reverse edges of the sampled edges. The
said reverse edges have different edge types from the sampled edges.
If ``g_sampling`` is given, ``exclude`` is ignored and will be always ``None``.
reverse_eids : Tensor or dict[etype, Tensor], optional
The mapping from original edge ID to its reverse edge ID.
Required and only used when ``exclude`` is set to ``reverse_id``.
For heterogeneous graph this will be a dict of edge type and edge IDs. Note that
only the edge types whose source node type is the same as destination node type
are needed.
reverse_etypes : dict[etype, etype], optional
The mapping from the edge type to its reverse edge type.
Required and only used when ``exclude`` is set to ``reverse_types``.
negative_sampler : callable, optional
The negative sampler. Can be omitted if no negative sampling is needed.
The negative sampler must be a callable that takes in the following arguments:
* The original (heterogeneous) graph.
* The ID array of sampled edges in the minibatch, or the dictionary of edge
types and ID array of sampled edges in the minibatch if the graph is
heterogeneous.
It should return
* A pair of source and destination node ID arrays as negative samples,
or a dictionary of edge types and such pairs if the graph is heterogenenous.
A set of builtin negative samplers are provided in
:ref:`the negative sampling module <api-dataloading-negative-sampling>`.
'''
def __init__(self, g, eids, graph_sampler, g_sampling=None, exclude=None,
reverse_eids=None, reverse_etypes=None, negative_sampler=None):
super(SimpleTemporalEdgeCollator, self).__init__(g, eids, graph_sampler,
g_sampling, exclude, reverse_eids, reverse_etypes, negative_sampler)
self.n_layer = len(self.graph_sampler.fanouts)
def collate(self,items):
'''
items: edge id in graph g.
We sample iteratively k-times and batch them into one single subgraph.
'''
current_ts = self.g.edata['timestamp'][items[0]] #only sample edges before current timestamp
self.graph_sampler.ts = current_ts # restore the current timestamp to the graph sampler.
# if link prefiction, we use a negative_sampler to generate neg-graph for loss computing.
if self.negative_sampler is None:
neg_pair_graph = None
input_nodes, pair_graph, blocks = self._collate(items)
else:
input_nodes, pair_graph, neg_pair_graph, blocks = self._collate_with_negative_sampling(items)
# we sampling k-hop subgraph and batch them into one graph
for i in range(self.n_layer-1):
self.graph_sampler.frontiers[0].add_edges(*self.graph_sampler.frontiers[i+1].edges())
frontier = self.graph_sampler.frontiers[0]
# computing node last-update timestamp
frontier.update_all(fn.copy_e('timestamp','ts'), fn.max('ts','timestamp'))
return input_nodes, pair_graph, neg_pair_graph, [frontier]
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.base import DGLError
from dgl.ops import edge_softmax
class Identity(nn.Module):
"""A placeholder identity operator that is argument-insensitive.
(Identity has already been supported by PyTorch 1.2, we will directly
import torch.nn.Identity in the future)
"""
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
"""Return input"""
return x
class MsgLinkPredictor(nn.Module):
"""Predict Pair wise link from pos subg and neg subg
use message passing.
Use Two layer MLP on edge to predict the link probability
Parameters
----------
embed_dim : int
dimension of each each feature's embedding
Example
----------
>>> linkpred = MsgLinkPredictor(10)
>>> pos_g = dgl.graph(([0,1,2,3,4],[1,2,3,4,0]))
>>> neg_g = dgl.graph(([0,1,2,3,4],[2,1,4,3,0]))
>>> x = torch.ones(5,10)
>>> linkpred(x,pos_g,neg_g)
(tensor([[0.0902],
[0.0902],
[0.0902],
[0.0902],
[0.0902]], grad_fn=<AddmmBackward>),
tensor([[0.0902],
[0.0902],
[0.0902],
[0.0902],
[0.0902]], grad_fn=<AddmmBackward>))
"""
def __init__(self, emb_dim):
super(MsgLinkPredictor, self).__init__()
self.src_fc = nn.Linear(emb_dim, emb_dim)
self.dst_fc = nn.Linear(emb_dim, emb_dim)
self.out_fc = nn.Linear(emb_dim, 1)
def link_pred(self, edges):
src_hid = self.src_fc(edges.src["embedding"])
dst_hid = self.dst_fc(edges.dst["embedding"])
score = F.relu(src_hid + dst_hid)
score = self.out_fc(score)
return {"score": score}
def forward(self, x, pos_g, neg_g):
# Local Scope?
pos_g.ndata["embedding"] = x
neg_g.ndata["embedding"] = x
pos_g.apply_edges(self.link_pred)
neg_g.apply_edges(self.link_pred)
pos_escore = pos_g.edata["score"]
neg_escore = neg_g.edata["score"]
return pos_escore, neg_escore
class TimeEncode(nn.Module):
"""Use finite fourier series with different phase and frequency to encode
time different between two event
..math::
\Phi(t) = [\cos(\omega_0t+\psi_0),\cos(\omega_1t+\psi_1),...,\cos(\omega_nt+\psi_n)]
Parameter
----------
dimension : int
Length of the fourier series. The longer it is ,
the more timescale information it can capture
Example
----------
>>> tecd = TimeEncode(10)
>>> t = torch.tensor([[1]])
>>> tecd(t)
tensor([[[0.5403, 0.9950, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
1.0000, 1.0000]]], dtype=torch.float64, grad_fn=<CosBackward>)
"""
def __init__(self, dimension):
super(TimeEncode, self).__init__()
self.dimension = dimension
self.w = torch.nn.Linear(1, dimension)
self.w.weight = torch.nn.Parameter(
(torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension)))
.double()
.reshape(dimension, -1)
)
self.w.bias = torch.nn.Parameter(torch.zeros(dimension).double())
def forward(self, t):
t = t.unsqueeze(dim=2)
output = torch.cos(self.w(t))
return output
class MemoryModule(nn.Module):
"""Memory module as well as update interface
The memory module stores both historical representation in last_update_t
Parameters
----------
n_node : int
number of node of the entire graph
hidden_dim : int
dimension of memory of each node
Example
----------
Please refers to examples/pytorch/tgn/tgn.py;
examples/pytorch/tgn/train.py
"""
def __init__(self, n_node, hidden_dim):
super(MemoryModule, self).__init__()
self.n_node = n_node
self.hidden_dim = hidden_dim
self.reset_memory()
def reset_memory(self):
self.last_update_t = nn.Parameter(
torch.zeros(self.n_node).float(), requires_grad=False
)
self.memory = nn.Parameter(
torch.zeros((self.n_node, self.hidden_dim)).float(),
requires_grad=False,
)
def backup_memory(self):
"""
Return a deep copy of memory state and last_update_t
For test new node, since new node need to use memory upto validation set
After validation, memory need to be backed up before run test set without new node
so finally, we can use backup memory to update the new node test set
"""
return self.memory.clone(), self.last_update_t.clone()
def restore_memory(self, memory_backup):
"""Restore the memory from validation set
Parameters
----------
memory_backup : (memory,last_update_t)
restore memory based on input tuple
"""
self.memory = memory_backup[0].clone()
self.last_update_t = memory_backup[1].clone()
# Which is used for attach to subgraph
def get_memory(self, node_idxs):
return self.memory[node_idxs, :]
# When the memory need to be updated
def set_memory(self, node_idxs, values):
self.memory[node_idxs, :] = values
def set_last_update_t(self, node_idxs, values):
self.last_update_t[node_idxs] = values
# For safety check
def get_last_update(self, node_idxs):
return self.last_update_t[node_idxs]
def detach_memory(self):
"""
Disconnect the memory from computation graph to prevent gradient be propagated multiple
times
"""
self.memory.detach_()
class MemoryOperation(nn.Module):
"""Memory update using message passing manner, update memory based on positive
pair graph of each batch with recurrent module GRU or RNN
Message function
..math::
m_i(t) = concat(memory_i(t^-),TimeEncode(t),v_i(t))
v_i is node feature at current time stamp
Aggregation function
..math::
\bar{m}_i(t) = last(m_i(t_1),...,m_i(t_b))
Update function
..math::
memory_i(t) = GRU(\bar{m}_i(t),memory_i(t-1))
Parameters
----------
updater_type : str
indicator string to specify updater
'rnn' : use Vanilla RNN as updater
'gru' : use GRU as updater
memory : MemoryModule
memory content for update
e_feat_dim : int
dimension of edge feature
temporal_dim : int
length of fourier series for time encoding
Example
----------
Please refers to examples/pytorch/tgn/tgn.py
"""
def __init__(self, updater_type, memory, e_feat_dim, temporal_encoder):
super(MemoryOperation, self).__init__()
updater_dict = {"gru": nn.GRUCell, "rnn": nn.RNNCell}
self.memory = memory
memory_dim = self.memory.hidden_dim
self.temporal_encoder = temporal_encoder
self.message_dim = (
memory_dim
+ memory_dim
+ e_feat_dim
+ self.temporal_encoder.dimension
)
self.updater = updater_dict[updater_type](
input_size=self.message_dim, hidden_size=memory_dim
)
self.memory = memory
# Here assume g is a subgraph from each iteration
def stick_feat_to_graph(self, g):
# How can I ensure order of the node ID
g.ndata["timestamp"] = self.memory.last_update_t[g.ndata[dgl.NID]]
g.ndata["memory"] = self.memory.memory[g.ndata[dgl.NID]]
def msg_fn_cat(self, edges):
src_delta_time = edges.data["timestamp"] - edges.src["timestamp"]
time_encode = self.temporal_encoder(
src_delta_time.unsqueeze(dim=1)
).view(len(edges.data["timestamp"]), -1)
ret = torch.cat(
[
edges.src["memory"],
edges.dst["memory"],
edges.data["feats"],
time_encode,
],
dim=1,
)
return {"message": ret, "timestamp": edges.data["timestamp"]}
def agg_last(self, nodes):
timestamp, latest_idx = torch.max(nodes.mailbox["timestamp"], dim=1)
ret = (
nodes.mailbox["message"]
.gather(
1,
latest_idx.repeat(self.message_dim).view(
-1, 1, self.message_dim
),
)
.view(-1, self.message_dim)
)
return {
"message_bar": ret.reshape(-1, self.message_dim),
"timestamp": timestamp,
}
def update_memory(self, nodes):
# It should pass the feature through RNN
ret = self.updater(
nodes.data["message_bar"].float(), nodes.data["memory"].float()
)
return {"memory": ret}
def forward(self, g):
self.stick_feat_to_graph(g)
g.update_all(self.msg_fn_cat, self.agg_last, self.update_memory)
return g
class EdgeGATConv(nn.Module):
"""Edge Graph attention compute the graph attention from node and edge feature then aggregate both node and
edge feature.
Parameter
==========
node_feats : int
number of node features
edge_feats : int
number of edge features
out_feats : int
number of output features
num_heads : int
number of heads in multihead attention
feat_drop : float, optional
drop out rate on the feature
attn_drop : float, optional
drop out rate on the attention weight
negative_slope : float, optional
LeakyReLU angle of negative slope.
residual : bool, optional
whether use residual connection
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Defaults: ``False``.
"""
def __init__(
self,
node_feats,
edge_feats,
out_feats,
num_heads,
feat_drop=0.0,
attn_drop=0.0,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False,
):
super(EdgeGATConv, self).__init__()
self._num_heads = num_heads
self._node_feats = node_feats
self._edge_feats = edge_feats
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
self.fc_node = nn.Linear(
self._node_feats, self._out_feats * self._num_heads
)
self.fc_edge = nn.Linear(
self._edge_feats, self._out_feats * self._num_heads
)
self.attn_l = nn.Parameter(
torch.FloatTensor(size=(1, self._num_heads, self._out_feats))
)
self.attn_r = nn.Parameter(
torch.FloatTensor(size=(1, self._num_heads, self._out_feats))
)
self.attn_e = nn.Parameter(
torch.FloatTensor(size=(1, self._num_heads, self._out_feats))
)
self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope)
self.residual = residual
if residual:
if self._node_feats != self._out_feats:
self.res_fc = nn.Linear(
self._node_feats,
self._out_feats * self._num_heads,
bias=False,
)
else:
self.res_fc = Identity()
self.reset_parameters()
self.activation = activation
def reset_parameters(self):
gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.fc_node.weight, gain=gain)
nn.init.xavier_normal_(self.fc_edge.weight, gain=gain)
nn.init.xavier_normal_(self.attn_l, gain=gain)
nn.init.xavier_normal_(self.attn_r, gain=gain)
nn.init.xavier_normal_(self.attn_e, gain=gain)
if self.residual and isinstance(self.res_fc, nn.Linear):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
def msg_fn(self, edges):
ret = (
edges.data["a"].view(-1, self._num_heads, 1)
* edges.data["el_prime"]
)
return {"m": ret}
def forward(self, graph, nfeat, efeat, get_attention=False):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
nfeat = self.feat_drop(nfeat)
efeat = self.feat_drop(efeat)
node_feat = self.fc_node(nfeat).view(
-1, self._num_heads, self._out_feats
)
edge_feat = self.fc_edge(efeat).view(
-1, self._num_heads, self._out_feats
)
el = (node_feat * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (node_feat * self.attn_r).sum(dim=-1).unsqueeze(-1)
ee = (edge_feat * self.attn_e).sum(dim=-1).unsqueeze(-1)
graph.ndata["ft"] = node_feat
graph.ndata["el"] = el
graph.ndata["er"] = er
graph.edata["ee"] = ee
graph.apply_edges(fn.u_add_e("el", "ee", "el_prime"))
graph.apply_edges(fn.e_add_v("el_prime", "er", "e"))
e = self.leaky_relu(graph.edata["e"])
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
graph.edata["efeat"] = edge_feat
graph.update_all(self.msg_fn, fn.sum("m", "ft"))
rst = graph.ndata["ft"]
if self.residual:
resval = self.res_fc(nfeat).view(
nfeat.shape[0], -1, self._out_feats
)
rst = rst + resval
if self.activation:
rst = self.activation(rst)
if get_attention:
return rst, graph.edata["a"]
else:
return rst
class TemporalEdgePreprocess(nn.Module):
"""Preprocess layer, which finish time encoding and concatenate
the time encoding to edge feature.
Parameter
==========
edge_feats : int
number of orginal edge feature
temporal_encoder : torch.nn.Module
time encoder model
"""
def __init__(self, edge_feats, temporal_encoder):
super(TemporalEdgePreprocess, self).__init__()
self.edge_feats = edge_feats
self.temporal_encoder = temporal_encoder
def edge_fn(self, edges):
t0 = torch.zeros_like(edges.dst["timestamp"])
time_diff = edges.data["timestamp"] - edges.src["timestamp"]
time_encode = self.temporal_encoder(time_diff.unsqueeze(dim=1)).view(
t0.shape[0], -1
)
edge_feat = torch.cat([edges.data["feats"], time_encode], dim=1)
return {"efeat": edge_feat}
def forward(self, graph):
graph.apply_edges(self.edge_fn)
efeat = graph.edata["efeat"]
return efeat
class TemporalTransformerConv(nn.Module):
def __init__(
self,
edge_feats,
memory_feats,
temporal_encoder,
out_feats,
num_heads,
allow_zero_in_degree=False,
layers=1,
):
"""Temporal Transformer model for TGN and TGAT
Parameter
==========
edge_feats : int
number of edge features
memory_feats : int
dimension of memory vector
temporal_encoder : torch.nn.Module
compute fourier time encoding
out_feats : int
number of out features
num_heads : int
number of attention head
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Defaults: ``False``.
"""
super(TemporalTransformerConv, self).__init__()
self._edge_feats = edge_feats
self._memory_feats = memory_feats
self.temporal_encoder = temporal_encoder
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
self._num_heads = num_heads
self.layers = layers
self.preprocessor = TemporalEdgePreprocess(
self._edge_feats, self.temporal_encoder
)
self.layer_list = nn.ModuleList()
self.layer_list.append(
EdgeGATConv(
node_feats=self._memory_feats,
edge_feats=self._edge_feats + self.temporal_encoder.dimension,
out_feats=self._out_feats,
num_heads=self._num_heads,
feat_drop=0.6,
attn_drop=0.6,
residual=True,
allow_zero_in_degree=allow_zero_in_degree,
)
)
for i in range(self.layers - 1):
self.layer_list.append(
EdgeGATConv(
node_feats=self._out_feats * self._num_heads,
edge_feats=self._edge_feats
+ self.temporal_encoder.dimension,
out_feats=self._out_feats,
num_heads=self._num_heads,
feat_drop=0.6,
attn_drop=0.6,
residual=True,
allow_zero_in_degree=allow_zero_in_degree,
)
)
def forward(self, graph, memory, ts):
graph = graph.local_var()
graph.ndata["timestamp"] = ts
efeat = self.preprocessor(graph).float()
rst = memory
for i in range(self.layers - 1):
rst = self.layer_list[i](graph, rst, efeat).flatten(1)
rst = self.layer_list[-1](graph, rst, efeat).mean(1)
return rst
import copy
import torch.nn as nn
from modules import (
MemoryModule,
MemoryOperation,
MsgLinkPredictor,
TemporalTransformerConv,
TimeEncode,
)
import dgl
class TGN(nn.Module):
def __init__(
self,
edge_feat_dim,
memory_dim,
temporal_dim,
embedding_dim,
num_heads,
num_nodes,
n_neighbors=10,
memory_updater_type="gru",
layers=1,
):
super(TGN, self).__init__()
self.memory_dim = memory_dim
self.edge_feat_dim = edge_feat_dim
self.temporal_dim = temporal_dim
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.n_neighbors = n_neighbors
self.memory_updater_type = memory_updater_type
self.num_nodes = num_nodes
self.layers = layers
self.temporal_encoder = TimeEncode(self.temporal_dim)
self.memory = MemoryModule(self.num_nodes, self.memory_dim)
self.memory_ops = MemoryOperation(
self.memory_updater_type,
self.memory,
self.edge_feat_dim,
self.temporal_encoder,
)
self.embedding_attn = TemporalTransformerConv(
self.edge_feat_dim,
self.memory_dim,
self.temporal_encoder,
self.embedding_dim,
self.num_heads,
layers=self.layers,
allow_zero_in_degree=True,
)
self.msg_linkpredictor = MsgLinkPredictor(embedding_dim)
def embed(self, postive_graph, negative_graph, blocks):
emb_graph = blocks[0]
emb_memory = self.memory.memory[emb_graph.ndata[dgl.NID], :]
emb_t = emb_graph.ndata["timestamp"]
embedding = self.embedding_attn(emb_graph, emb_memory, emb_t)
emb2pred = dict(
zip(emb_graph.ndata[dgl.NID].tolist(), emb_graph.nodes().tolist())
)
# Since postive graph and negative graph has same is mapping
feat_id = [emb2pred[int(n)] for n in postive_graph.ndata[dgl.NID]]
feat = embedding[feat_id]
pred_pos, pred_neg = self.msg_linkpredictor(
feat, postive_graph, negative_graph
)
return pred_pos, pred_neg
def update_memory(self, subg):
new_g = self.memory_ops(subg)
self.memory.set_memory(new_g.ndata[dgl.NID], new_g.ndata["memory"])
self.memory.set_last_update_t(
new_g.ndata[dgl.NID], new_g.ndata["timestamp"]
)
# Some memory operation wrappers
def detach_memory(self):
self.memory.detach_memory()
def reset_memory(self):
self.memory.reset_memory()
def store_memory(self):
memory_checkpoint = {}
memory_checkpoint["memory"] = copy.deepcopy(self.memory.memory)
memory_checkpoint["last_t"] = copy.deepcopy(self.memory.last_update_t)
return memory_checkpoint
def restore_memory(self, memory_checkpoint):
self.memory.memory = memory_checkpoint["memory"]
self.memory.last_update_time = memory_checkpoint["last_t"]
import argparse
import copy
import time
import traceback
import numpy as np
import torch
from data_preprocess import (
TemporalDataset,
TemporalRedditDataset,
TemporalWikipediaDataset,
)
from dataloading import (
FastTemporalEdgeCollator,
FastTemporalSampler,
SimpleTemporalEdgeCollator,
SimpleTemporalSampler,
TemporalEdgeCollator,
TemporalEdgeDataLoader,
TemporalSampler,
)
from sklearn.metrics import average_precision_score, roc_auc_score
from tgn import TGN
import dgl
TRAIN_SPLIT = 0.7
VALID_SPLIT = 0.85
# set random Seed
np.random.seed(2021)
torch.manual_seed(2021)
def train(model, dataloader, sampler, criterion, optimizer, args):
model.train()
total_loss = 0
batch_cnt = 0
last_t = time.time()
for _, positive_pair_g, negative_pair_g, blocks in dataloader:
optimizer.zero_grad()
pred_pos, pred_neg = model.embed(
positive_pair_g, negative_pair_g, blocks
)
loss = criterion(pred_pos, torch.ones_like(pred_pos))
loss += criterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss) * args.batch_size
retain_graph = True if batch_cnt == 0 and not args.fast_mode else False
loss.backward(retain_graph=retain_graph)
optimizer.step()
model.detach_memory()
if not args.not_use_memory:
model.update_memory(positive_pair_g)
if args.fast_mode:
sampler.attach_last_update(model.memory.last_update_t)
print("Batch: ", batch_cnt, "Time: ", time.time() - last_t)
last_t = time.time()
batch_cnt += 1
return total_loss
def test_val(model, dataloader, sampler, criterion, args):
model.eval()
batch_size = args.batch_size
total_loss = 0
aps, aucs = [], []
batch_cnt = 0
with torch.no_grad():
for _, postive_pair_g, negative_pair_g, blocks in dataloader:
pred_pos, pred_neg = model.embed(
postive_pair_g, negative_pair_g, blocks
)
loss = criterion(pred_pos, torch.ones_like(pred_pos))
loss += criterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss) * batch_size
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat(
[torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))],
dim=0,
)
if not args.not_use_memory:
model.update_memory(postive_pair_g)
if args.fast_mode:
sampler.attach_last_update(model.memory.last_update_t)
aps.append(average_precision_score(y_true, y_pred))
aucs.append(roc_auc_score(y_true, y_pred))
batch_cnt += 1
return float(torch.tensor(aps).mean()), float(torch.tensor(aucs).mean())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--epochs",
type=int,
default=50,
help="epochs for training on entire dataset",
)
parser.add_argument(
"--batch_size", type=int, default=200, help="Size of each batch"
)
parser.add_argument(
"--embedding_dim",
type=int,
default=100,
help="Embedding dim for link prediction",
)
parser.add_argument(
"--memory_dim", type=int, default=100, help="dimension of memory"
)
parser.add_argument(
"--temporal_dim",
type=int,
default=100,
help="Temporal dimension for time encoding",
)
parser.add_argument(
"--memory_updater",
type=str,
default="gru",
help="Recurrent unit for memory update",
)
parser.add_argument(
"--aggregator",
type=str,
default="last",
help="Aggregation method for memory update",
)
parser.add_argument(
"--n_neighbors",
type=int,
default=10,
help="number of neighbors while doing embedding",
)
parser.add_argument(
"--sampling_method",
type=str,
default="topk",
help="In embedding how node aggregate from its neighor",
)
parser.add_argument(
"--num_heads",
type=int,
default=8,
help="Number of heads for multihead attention mechanism",
)
parser.add_argument(
"--fast_mode",
action="store_true",
default=False,
help="Fast Mode uses batch temporal sampling, history within same batch cannot be obtained",
)
parser.add_argument(
"--simple_mode",
action="store_true",
default=False,
help="Simple Mode directly delete the temporal edges from the original static graph",
)
parser.add_argument(
"--num_negative_samples",
type=int,
default=1,
help="number of negative samplers per positive samples",
)
parser.add_argument(
"--dataset",
type=str,
default="wikipedia",
help="dataset selection wikipedia/reddit",
)
parser.add_argument(
"--k_hop", type=int, default=1, help="sampling k-hop neighborhood"
)
parser.add_argument(
"--not_use_memory",
action="store_true",
default=False,
help="Enable memory for TGN Model disable memory for TGN Model",
)
args = parser.parse_args()
assert not (
args.fast_mode and args.simple_mode
), "you can only choose one sampling mode"
if args.k_hop != 1:
assert args.simple_mode, "this k-hop parameter only support simple mode"
if args.dataset == "wikipedia":
data = TemporalWikipediaDataset()
elif args.dataset == "reddit":
data = TemporalRedditDataset()
else:
print("Warning Using Untested Dataset: " + args.dataset)
data = TemporalDataset(args.dataset)
# Pre-process data, mask new node in test set from original graph
num_nodes = data.num_nodes()
num_edges = data.num_edges()
num_edges = data.num_edges()
trainval_div = int(VALID_SPLIT * num_edges)
# Select new node from test set and remove them from entire graph
test_split_ts = data.edata["timestamp"][trainval_div]
test_nodes = (
torch.cat(
[data.edges()[0][trainval_div:], data.edges()[1][trainval_div:]]
)
.unique()
.numpy()
)
test_new_nodes = np.random.choice(
test_nodes, int(0.1 * len(test_nodes)), replace=False
)
in_subg = dgl.in_subgraph(data, test_new_nodes)
out_subg = dgl.out_subgraph(data, test_new_nodes)
# Remove edge who happen before the test set to prevent from learning the connection info
new_node_in_eid_delete = in_subg.edata[dgl.EID][
in_subg.edata["timestamp"] < test_split_ts
]
new_node_out_eid_delete = out_subg.edata[dgl.EID][
out_subg.edata["timestamp"] < test_split_ts
]
new_node_eid_delete = torch.cat(
[new_node_in_eid_delete, new_node_out_eid_delete]
).unique()
graph_new_node = copy.deepcopy(data)
# relative order preseved
graph_new_node.remove_edges(new_node_eid_delete)
# Now for no new node graph, all edge id need to be removed
in_eid_delete = in_subg.edata[dgl.EID]
out_eid_delete = out_subg.edata[dgl.EID]
eid_delete = torch.cat([in_eid_delete, out_eid_delete]).unique()
graph_no_new_node = copy.deepcopy(data)
graph_no_new_node.remove_edges(eid_delete)
# graph_no_new_node and graph_new_node should have same set of nid
# Sampler Initialization
if args.simple_mode:
fan_out = [args.n_neighbors for _ in range(args.k_hop)]
sampler = SimpleTemporalSampler(graph_no_new_node, fan_out)
new_node_sampler = SimpleTemporalSampler(data, fan_out)
edge_collator = SimpleTemporalEdgeCollator
elif args.fast_mode:
sampler = FastTemporalSampler(graph_no_new_node, k=args.n_neighbors)
new_node_sampler = FastTemporalSampler(data, k=args.n_neighbors)
edge_collator = FastTemporalEdgeCollator
else:
sampler = TemporalSampler(k=args.n_neighbors)
edge_collator = TemporalEdgeCollator
neg_sampler = dgl.dataloading.negative_sampler.Uniform(
k=args.num_negative_samples
)
# Set Train, validation, test and new node test id
train_seed = torch.arange(int(TRAIN_SPLIT * graph_no_new_node.num_edges()))
valid_seed = torch.arange(
int(TRAIN_SPLIT * graph_no_new_node.num_edges()),
trainval_div - new_node_eid_delete.size(0),
)
test_seed = torch.arange(
trainval_div - new_node_eid_delete.size(0),
graph_no_new_node.num_edges(),
)
test_new_node_seed = torch.arange(
trainval_div - new_node_eid_delete.size(0), graph_new_node.num_edges()
)
g_sampling = (
None
if args.fast_mode
else dgl.add_reverse_edges(graph_no_new_node, copy_edata=True)
)
new_node_g_sampling = (
None
if args.fast_mode
else dgl.add_reverse_edges(graph_new_node, copy_edata=True)
)
if not args.fast_mode:
new_node_g_sampling.ndata[dgl.NID] = new_node_g_sampling.nodes()
g_sampling.ndata[dgl.NID] = new_node_g_sampling.nodes()
# we highly recommend that you always set the num_workers=0, otherwise the sampled subgraph may not be correct.
train_dataloader = TemporalEdgeDataLoader(
graph_no_new_node,
train_seed,
sampler,
batch_size=args.batch_size,
negative_sampler=neg_sampler,
shuffle=False,
drop_last=False,
num_workers=0,
collator=edge_collator,
g_sampling=g_sampling,
)
valid_dataloader = TemporalEdgeDataLoader(
graph_no_new_node,
valid_seed,
sampler,
batch_size=args.batch_size,
negative_sampler=neg_sampler,
shuffle=False,
drop_last=False,
num_workers=0,
collator=edge_collator,
g_sampling=g_sampling,
)
test_dataloader = TemporalEdgeDataLoader(
graph_no_new_node,
test_seed,
sampler,
batch_size=args.batch_size,
negative_sampler=neg_sampler,
shuffle=False,
drop_last=False,
num_workers=0,
collator=edge_collator,
g_sampling=g_sampling,
)
test_new_node_dataloader = TemporalEdgeDataLoader(
graph_new_node,
test_new_node_seed,
new_node_sampler if args.fast_mode else sampler,
batch_size=args.batch_size,
negative_sampler=neg_sampler,
shuffle=False,
drop_last=False,
num_workers=0,
collator=edge_collator,
g_sampling=new_node_g_sampling,
)
edge_dim = data.edata["feats"].shape[1]
num_node = data.num_nodes()
model = TGN(
edge_feat_dim=edge_dim,
memory_dim=args.memory_dim,
temporal_dim=args.temporal_dim,
embedding_dim=args.embedding_dim,
num_heads=args.num_heads,
num_nodes=num_node,
n_neighbors=args.n_neighbors,
memory_updater_type=args.memory_updater,
layers=args.k_hop,
)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
# Implement Logging mechanism
f = open("logging.txt", "w")
if args.fast_mode:
sampler.reset()
try:
for i in range(args.epochs):
train_loss = train(
model, train_dataloader, sampler, criterion, optimizer, args
)
val_ap, val_auc = test_val(
model, valid_dataloader, sampler, criterion, args
)
memory_checkpoint = model.store_memory()
if args.fast_mode:
new_node_sampler.sync(sampler)
test_ap, test_auc = test_val(
model, test_dataloader, sampler, criterion, args
)
model.restore_memory(memory_checkpoint)
if args.fast_mode:
sample_nn = new_node_sampler
else:
sample_nn = sampler
nn_test_ap, nn_test_auc = test_val(
model, test_new_node_dataloader, sample_nn, criterion, args
)
log_content = []
log_content.append(
"Epoch: {}; Training Loss: {} | Validation AP: {:.3f} AUC: {:.3f}\n".format(
i, train_loss, val_ap, val_auc
)
)
log_content.append(
"Epoch: {}; Test AP: {:.3f} AUC: {:.3f}\n".format(
i, test_ap, test_auc
)
)
log_content.append(
"Epoch: {}; Test New Node AP: {:.3f} AUC: {:.3f}\n".format(
i, nn_test_ap, nn_test_auc
)
)
f.writelines(log_content)
model.reset_memory()
if i < args.epochs - 1 and args.fast_mode:
sampler.reset()
print(log_content[0], log_content[1], log_content[2])
except KeyboardInterrupt:
traceback.print_exc()
error_content = "Training Interreputed!"
f.writelines(error_content)
f.close()
print("========Training is Done========")
......@@ -9,6 +9,8 @@ and transforming graphs.
# This initializes Winsock and performs cleanup at termination as required
import socket
from distutils.version import LooseVersion
# setup logging before everything
from .logging import enable_verbose_logging
......@@ -25,7 +27,6 @@ from . import storages
from . import dataloading
from . import ops
from . import cuda
from . import _dataloading # legacy dataloading modules
from ._ffi.runtime_ctypes import TypeCode
from ._ffi.function import register_func, get_global_func, list_global_func_names, extract_ext_funcs
......
"""The ``dgl.dataloading`` package contains:
* Data loader classes for iterating over a set of nodes or edges in a graph and generates
computation dependency via neighborhood sampling methods.
* Various sampler classes that perform neighborhood sampling for multi-layer GNNs.
* Negative samplers for link prediction.
For a holistic explanation on how different components work together.
Read the user guide :ref:`guide-minibatch`.
.. note::
This package is experimental and the interfaces may be subject
to changes in future releases. It currently only has implementations in PyTorch.
"""
from .neighbor import *
from .dataloader import *
from .cluster_gcn import *
from .shadow import *
from . import negative_sampler
from .. import backend as F
if F.get_preferred_backend() == 'pytorch':
from .pytorch import *
"""Cluster-GCN subgraph iterators."""
import os
import pickle
import numpy as np
from ..transforms import metis_partition_assignment
from .. import backend as F
from .dataloader import SubgraphIterator
class ClusterGCNSubgraphIterator(SubgraphIterator):
"""Subgraph sampler following that of ClusterGCN.
This sampler first partitions the graph with METIS partitioning, then it caches the nodes of
each partition to a file within the given cache directory.
This is used in conjunction with :class:`dgl.dataloading.pytorch.GraphDataLoader`.
Notes
-----
The graph must be homogeneous and on CPU.
Parameters
----------
g : DGLGraph
The original graph.
num_partitions : int
The number of partitions.
cache_directory : str
The path to the cache directory for storing the partition result.
refresh : bool
If True, recompute the partition.
Examples
--------
Assuming that you have a graph ``g``:
>>> sgiter = dgl.dataloading.ClusterGCNSubgraphIterator(
... g, num_partitions=100, cache_directory='.', refresh=True)
>>> dataloader = dgl.dataloading.GraphDataLoader(sgiter, batch_size=4, num_workers=0)
>>> for subgraph_batch in dataloader:
... train_on(subgraph_batch)
"""
def __init__(self, g, num_partitions, cache_directory, refresh=False):
if os.name == 'nt':
raise NotImplementedError("METIS partitioning is not supported on Windows yet.")
super().__init__(g)
# First see if the cache is already there. If so, directly read from cache.
if not refresh and self._load_parts(cache_directory):
return
# Otherwise, build the cache.
assignment = F.asnumpy(metis_partition_assignment(g, num_partitions))
self._save_parts(assignment, cache_directory)
def _cache_file_path(self, cache_directory):
return os.path.join(cache_directory, 'cluster_gcn_cache')
def _load_parts(self, cache_directory):
path = self._cache_file_path(cache_directory)
if not os.path.exists(path):
return False
with open(path, 'rb') as file_:
self.part_indptr, self.part_indices = pickle.load(file_)
return True
def _save_parts(self, assignment, cache_directory):
os.makedirs(cache_directory, exist_ok=True)
self.part_indices = np.argsort(assignment)
num_nodes_per_part = np.bincount(assignment)
self.part_indptr = np.insert(np.cumsum(num_nodes_per_part), 0, 0)
with open(self._cache_file_path(cache_directory), 'wb') as file_:
pickle.dump((self.part_indptr, self.part_indices), file_)
def __len__(self):
return self.part_indptr.shape[0] - 1
def __getitem__(self, i):
nodes = self.part_indices[self.part_indptr[i]:self.part_indptr[i+1]]
return self.g.subgraph(nodes)
"""Data loaders"""
from collections.abc import Mapping, Sequence
from abc import ABC, abstractproperty, abstractmethod
import re
import numpy as np
from .. import transforms
from ..base import NID, EID
from .. import backend as F
from .. import utils
from ..batch import batch
from ..convert import heterograph
from ..heterograph import DGLGraph
from ..distributed.dist_graph import DistGraph
from ..utils import to_device
def _tensor_or_dict_to_numpy(ids):
if isinstance(ids, Mapping):
return {k: F.zerocopy_to_numpy(v) for k, v in ids.items()}
else:
return F.zerocopy_to_numpy(ids)
def _locate_eids_to_exclude(frontier_parent_eids, exclude_eids):
"""Find the edges whose IDs in parent graph appeared in exclude_eids.
Note that both arguments are numpy arrays or numpy dicts.
"""
if isinstance(frontier_parent_eids, Mapping):
result = {
k: np.isin(frontier_parent_eids[k], exclude_eids[k]).nonzero()[0]
for k in frontier_parent_eids.keys() if k in exclude_eids.keys()}
return {k: F.zerocopy_from_numpy(v) for k, v in result.items()}
else:
result = np.isin(frontier_parent_eids, exclude_eids).nonzero()[0]
return F.zerocopy_from_numpy(result)
class _EidExcluder():
def __init__(self, exclude_eids):
device = None
if isinstance(exclude_eids, Mapping):
for _, v in exclude_eids.items():
if device is None:
device = F.context(v)
break
else:
device = F.context(exclude_eids)
self._exclude_eids = None
self._filter = None
if device == F.cpu():
# TODO(nv-dlasalle): Once Filter is implemented for the CPU, we
# should just use that irregardless of the device.
self._exclude_eids = (
_tensor_or_dict_to_numpy(exclude_eids) if exclude_eids is not None else None)
else:
if isinstance(exclude_eids, Mapping):
self._filter = {k: utils.Filter(v) for k, v in exclude_eids.items()}
else:
self._filter = utils.Filter(exclude_eids)
def _find_indices(self, parent_eids):
""" Find the set of edge indices to remove.
"""
if self._exclude_eids is not None:
parent_eids_np = _tensor_or_dict_to_numpy(parent_eids)
return _locate_eids_to_exclude(parent_eids_np, self._exclude_eids)
else:
assert self._filter is not None
if isinstance(parent_eids, Mapping):
located_eids = {k: self._filter[k].find_included_indices(parent_eids[k])
for k, v in parent_eids.items() if k in self._filter}
else:
located_eids = self._filter.find_included_indices(parent_eids)
return located_eids
def __call__(self, frontier):
parent_eids = frontier.edata[EID]
located_eids = self._find_indices(parent_eids)
if not isinstance(located_eids, Mapping):
# (BarclayII) If frontier already has a EID field and located_eids is empty,
# the returned graph will keep EID intact. Otherwise, EID will change
# to the mapping from the new graph to the old frontier.
# So we need to test if located_eids is empty, and do the remapping ourselves.
if len(located_eids) > 0:
frontier = transforms.remove_edges(
frontier, located_eids, store_ids=True)
frontier.edata[EID] = F.gather_row(parent_eids, frontier.edata[EID])
else:
# (BarclayII) remove_edges only accepts removing one type of edges,
# so I need to keep track of the edge IDs left one by one.
new_eids = parent_eids.copy()
for k, v in located_eids.items():
if len(v) > 0:
frontier = transforms.remove_edges(
frontier, v, etype=k, store_ids=True)
new_eids[k] = F.gather_row(parent_eids[k], frontier.edges[k].data[EID])
frontier.edata[EID] = new_eids
return frontier
def exclude_edges(subg, exclude_eids, device):
"""Find and remove from the subgraph the edges whose IDs in the parent
graph are given.
Parameters
----------
subg : DGLGraph
The subgraph. Must have ``dgl.EID`` field containing the original
edge IDs in the parent graph.
exclude_eids : Tensor or dict
The edge IDs to exclude.
device : device
The output device of the graph.
Returns
-------
DGLGraph
The new subgraph with edges removed. The ``dgl.EID`` field contains
the original edge IDs in the same parent graph.
"""
if exclude_eids is None:
return subg
if device is not None:
if isinstance(exclude_eids, Mapping):
exclude_eids = {k: F.copy_to(v, device) \
for k, v in exclude_eids.items()}
else:
exclude_eids = F.copy_to(exclude_eids, device)
excluder = _EidExcluder(exclude_eids)
return subg if excluder is None else excluder(subg)
def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map):
if isinstance(eids, Mapping):
eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
exclude_eids = {
k: F.cat([v, F.gather_row(reverse_eid_map[k], v)], 0)
for k, v in eids.items()}
else:
exclude_eids = F.cat([eids, F.gather_row(reverse_eid_map, eids)], 0)
return exclude_eids
def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map):
exclude_eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
reverse_etype_map = {
g.to_canonical_etype(k): g.to_canonical_etype(v)
for k, v in reverse_etype_map.items()}
exclude_eids.update({reverse_etype_map[k]: v for k, v in exclude_eids.items()})
return exclude_eids
def _find_exclude_eids(g, exclude_mode, eids, **kwargs):
"""Find all edge IDs to exclude according to :attr:`exclude_mode`.
Parameters
----------
g : DGLGraph
The graph.
exclude_mode : str, optional
Can be either of the following,
None (default)
Does not exclude any edge.
'self'
Exclude the given edges themselves but nothing else.
'reverse_id'
Exclude all edges specified in ``eids``, as well as their reverse edges
of the same edge type.
The mapping from each edge ID to its reverse edge ID is specified in
the keyword argument ``reverse_eid_map``.
This mode assumes that the reverse of an edge with ID ``e`` and type
``etype`` will have ID ``reverse_eid_map[e]`` and type ``etype``.
'reverse_types'
Exclude all edges specified in ``eids``, as well as their reverse
edges of the corresponding edge types.
The mapping from each edge type to its reverse edge type is specified
in the keyword argument ``reverse_etype_map``.
This mode assumes that the reverse of an edge with ID ``e`` and type ``etype``
will have ID ``e`` and type ``reverse_etype_map[etype]``.
eids : Tensor or dict[etype, Tensor]
The edge IDs.
reverse_eid_map : Tensor or dict[etype, Tensor]
The mapping from edge ID to its reverse edge ID.
reverse_etype_map : dict[etype, etype]
The mapping from edge etype to its reverse edge type.
"""
if exclude_mode is None:
return None
elif exclude_mode == 'self':
return eids
elif exclude_mode == 'reverse_id':
return _find_exclude_eids_with_reverse_id(g, eids, kwargs['reverse_eid_map'])
elif exclude_mode == 'reverse_types':
return _find_exclude_eids_with_reverse_types(g, eids, kwargs['reverse_etype_map'])
else:
raise ValueError('unsupported mode {}'.format(exclude_mode))
class Sampler(object):
"""An abstract class that takes in a graph and a set of seed nodes and returns a
structure representing a smaller portion of the graph for computation. It can
be either a list of bipartite graphs (i.e. :class:`BlockSampler`), or a single
subgraph.
"""
def __init__(self, output_ctx=None):
self.set_output_context(output_ctx)
def sample(self, g, seed_nodes, exclude_eids=None):
"""Sample a structure from the graph.
Parameters
----------
g : DGLGraph
The original graph.
seed_nodes : Tensor or dict[ntype, Tensor]
The destination nodes by type.
If the graph only has one node type, one can just specify a single tensor
of node IDs.
exclude_eids : Tensor or dict[etype, Tensor]
The edges to exclude from computation dependency.
Returns
-------
Tensor or dict[ntype, Tensor]
The nodes whose input features are required for computing the output
representation of :attr:`seed_nodes`.
any
Any data representing the structure.
"""
raise NotImplementedError
def set_output_context(self, ctx):
"""Set the device the generated block or subgraph will be output to.
This should only be set to a cuda device, when multi-processing is not
used in the dataloader (e.g., num_workers is 0).
Parameters
----------
ctx : DGLContext, default None
The device context the sampled blocks will be stored on. This
should only be a CUDA context if multiprocessing is not used in
the dataloader (e.g., num_workers is 0). If this is None, the
sampled blocks will be stored on the same device as the input
graph.
"""
if ctx is not None:
self.output_device = F.to_backend_ctx(ctx)
else:
self.output_device = None
class BlockSampler(Sampler):
"""Abstract class specifying the neighborhood sampling strategy for DGL data loaders.
The main method for BlockSampler is :meth:`sample`,
which generates a list of message flow graphs (MFGs) for a multi-layer GNN given a set of
seed nodes to have their outputs computed.
The default implementation of :meth:`sample` is
to repeat :attr:`num_layers` times the following procedure from the last layer to the first
layer:
* Obtain a frontier. The frontier is defined as a graph with the same nodes as the
original graph but only the edges involved in message passing on the current layer.
Customizable via :meth:`sample_frontier`.
* Optionally, if the task is link prediction or edge classfication, remove edges
connecting training node pairs. If the graph is undirected, also remove the
reverse edges. This is controlled by the argument :attr:`exclude_eids` in
:meth:`sample` method.
* Convert the frontier into a MFG.
* Optionally assign the IDs of the edges in the original graph selected in the first step
to the MFG, controlled by the argument ``return_eids`` in
:meth:`sample` method.
* Prepend the MFG to the MFG list to be returned.
All subclasses should override :meth:`sample_frontier`
method while specifying the number of layers to sample in :attr:`num_layers` argument.
Parameters
----------
num_layers : int
The number of layers to sample.
return_eids : bool, default False
Whether to return the edge IDs involved in message passing in the MFG.
If True, the edge IDs will be stored as an edge feature named ``dgl.EID``.
output_ctx : DGLContext, default None
The context the sampled blocks will be stored on. This should only be
a CUDA context if multiprocessing is not used in the dataloader (e.g.,
num_workers is 0). If this is None, the sampled blocks will be stored
on the same device as the input graph.
exclude_edges_in_frontier : bool, default False
If True, the :func:`sample_frontier` method will receive an argument
:attr:`exclude_eids` containing the edge IDs from the original graph to exclude.
The :func:`sample_frontier` method must return a graph that does not contain
the edges corresponding to the excluded edges. No additional postprocessing
will be done.
Otherwise, the edges will be removed *after* :func:`sample_frontier` returns.
Notes
-----
For the concept of frontiers and MFGs, please refer to
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
def __init__(self, num_layers, return_eids=False, output_ctx=None):
super().__init__(output_ctx)
self.num_layers = num_layers
self.return_eids = return_eids
# pylint: disable=unused-argument
@staticmethod
def assign_block_eids(block, frontier):
"""Assigns edge IDs from the original graph to the message flow graph (MFG).
See also
--------
BlockSampler
"""
for etype in block.canonical_etypes:
block.edges[etype].data[EID] = frontier.edges[etype].data[EID][
block.edges[etype].data[EID]]
return block
# This is really a hack working around the lack of GPU-based neighbor sampling
# with edge exclusion.
@classmethod
def exclude_edges_in_frontier(cls, g):
"""Returns whether the sampler will exclude edges in :func:`sample_frontier`.
If this method returns True, the method :func:`sample_frontier` will receive an
argument :attr:`exclude_eids` from :func:`sample`. :func:`sample_frontier`
is then responsible for removing those edges.
If this method returns False, :func:`sample` will be responsible for
removing the edges.
When subclassing :class:`BlockSampler`, this method should return True when you
would like to remove the excluded edges in your :func:`sample_frontier` method.
By default this method returns False.
Parameters
----------
g : DGLGraph
The original graph
Returns
-------
bool
Whether :func:`sample_frontier` will receive an argument :attr:`exclude_eids`.
"""
return False
def sample_frontier(self, block_id, g, seed_nodes, exclude_eids=None):
"""Generate the frontier given the destination nodes.
The subclasses should override this function.
Parameters
----------
block_id : int
Represents which GNN layer the frontier is generated for.
g : DGLGraph
The original graph.
seed_nodes : Tensor or dict[ntype, Tensor]
The destination nodes by node type.
If the graph only has one node type, one can just specify a single tensor
of node IDs.
exclude_eids: Tensor or dict
Edge IDs to exclude during sampling neighbors for the seed nodes.
This argument can take a single ID tensor or a dictionary of edge types and ID tensors.
If a single tensor is given, the graph must only have one type of nodes.
Returns
-------
DGLGraph
The frontier generated for the current layer.
Notes
-----
For the concept of frontiers and MFGs, please refer to
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
raise NotImplementedError
def sample(self, g, seed_nodes, exclude_eids=None):
"""Generate the a list of MFGs given the destination nodes.
Parameters
----------
g : DGLGraph
The original graph.
seed_nodes : Tensor or dict[ntype, Tensor]
The destination nodes by node type.
If the graph only has one node type, one can just specify a single tensor
of node IDs.
exclude_eids : Tensor or dict[etype, Tensor]
The edges to exclude from computation dependency.
Returns
-------
list[DGLGraph]
The MFGs generated for computing the multi-layer GNN output.
Notes
-----
For the concept of frontiers and MFGs, please refer to
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
blocks = []
if isinstance(g, DistGraph):
# TODO:(nv-dlasalle) dist graphs may not have an associated graph,
# causing an error when trying to fetch the device, so for now,
# always assume the distributed graph's device is CPU.
graph_device = F.cpu()
else:
graph_device = g.device
for block_id in reversed(range(self.num_layers)):
seed_nodes_in = to_device(seed_nodes, graph_device)
if self.exclude_edges_in_frontier(g):
frontier = self.sample_frontier(
block_id, g, seed_nodes_in, exclude_eids=exclude_eids)
else:
frontier = self.sample_frontier(block_id, g, seed_nodes_in)
if self.output_device is not None:
frontier = frontier.to(self.output_device)
seed_nodes_out = to_device(seed_nodes, self.output_device)
else:
seed_nodes_out = seed_nodes
# Removing edges from the frontier for link prediction training falls
# into the category of frontier postprocessing
if not self.exclude_edges_in_frontier(g):
frontier = exclude_edges(frontier, exclude_eids, self.output_device)
block = transforms.to_block(frontier, seed_nodes_out)
if self.return_eids:
self.assign_block_eids(block, frontier)
seed_nodes = {ntype: block.srcnodes[ntype].data[NID] for ntype in block.srctypes}
blocks.insert(0, block)
return blocks[0].srcdata[NID], blocks[-1].dstdata[NID], blocks
def sample_blocks(self, g, seed_nodes, exclude_eids=None):
"""Deprecated and identical to :meth:`sample`.
"""
return self.sample(g, seed_nodes, exclude_eids)
class Collator(ABC):
"""Abstract DGL collator for training GNNs on downstream tasks stochastically.
Provides a :attr:`dataset` object containing the collection of all nodes or edges,
as well as a :attr:`collate` method that combines a set of items from
:attr:`dataset` and obtains the message flow graphs (MFGs).
Notes
-----
For the concept of MFGs, please refer to
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
@abstractproperty
def dataset(self):
"""Returns the dataset object of the collator."""
raise NotImplementedError
@abstractmethod
def collate(self, items):
"""Combines the items from the dataset object and obtains the list of MFGs.
Parameters
----------
items : list[str, int]
The list of node or edge IDs or type-ID pairs.
Notes
-----
For the concept of MFGs, please refer to
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
raise NotImplementedError
class NodeCollator(Collator):
"""DGL collator to combine nodes and their computation dependencies within a minibatch for
training node classification or regression on a single graph with neighborhood sampling.
Parameters
----------
g : DGLGraph
The graph.
nids : Tensor or dict[ntype, Tensor]
The node set to compute outputs.
graph_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler.
Examples
--------
To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on
a homogeneous graph where each node takes messages from all neighbors (assume
the backend is PyTorch):
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> collator = dgl.dataloading.NodeCollator(g, train_nid, sampler)
>>> dataloader = torch.utils.data.DataLoader(
... collator.dataset, collate_fn=collator.collate,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, output_nodes, blocks in dataloader:
... train_on(input_nodes, output_nodes, blocks)
Notes
-----
For the concept of MFGs, please refer to
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
def __init__(self, g, nids, graph_sampler):
self.g = g
if not isinstance(nids, Mapping):
assert len(g.ntypes) == 1, \
"nids should be a dict of node type and ids for graph with multiple node types"
self.graph_sampler = graph_sampler
self.nids = utils.prepare_tensor_or_dict(g, nids, 'nids')
self._dataset = utils.maybe_flatten_dict(self.nids)
@property
def dataset(self):
return self._dataset
def collate(self, items):
"""Find the list of MFGs necessary for computing the representation of given
nodes for a node classification/regression task.
Parameters
----------
items : list[int] or list[tuple[str, int]]
Either a list of node IDs (for homogeneous graphs), or a list of node type-ID
pairs (for heterogeneous graphs).
Returns
-------
input_nodes : Tensor or dict[ntype, Tensor]
The input nodes necessary for computation in this minibatch.
If the original graph has multiple node types, return a dictionary of
node type names and node ID tensors. Otherwise, return a single tensor.
output_nodes : Tensor or dict[ntype, Tensor]
The nodes whose representations are to be computed in this minibatch.
If the original graph has multiple node types, return a dictionary of
node type names and node ID tensors. Otherwise, return a single tensor.
MFGs : list[DGLGraph]
The list of MFGs necessary for computing the representation.
"""
if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items)
items = utils.prepare_tensor_or_dict(self.g, items, 'items')
input_nodes, output_nodes, blocks = self.graph_sampler.sample_blocks(self.g, items)
return input_nodes, output_nodes, blocks
class EdgeCollator(Collator):
"""DGL collator to combine edges and their computation dependencies within a minibatch for
training edge classification, edge regression, or link prediction on a single graph
with neighborhood sampling.
Given a set of edges, the collate function will yield
* A tensor of input nodes necessary for computing the representation on edges, or
a dictionary of node type names and such tensors.
* A subgraph that contains only the edges in the minibatch and their incident nodes.
Note that the graph has an identical metagraph with the original graph.
* If a negative sampler is given, another graph that contains the "negative edges",
connecting the source and destination nodes yielded from the given negative sampler.
* A list of MFGs necessary for computing the representation of the incident nodes
of the edges in the minibatch.
Parameters
----------
g : DGLGraph
The graph from which the edges are iterated in minibatches and the subgraphs
are generated.
eids : Tensor or dict[etype, Tensor]
The edge set in graph :attr:`g` to compute outputs.
graph_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler.
g_sampling : DGLGraph, optional
The graph where neighborhood sampling and message passing is performed.
Note that this is not necessarily the same as :attr:`g`.
If None, assume to be the same as :attr:`g`.
exclude : str, optional
Whether and how to exclude dependencies related to the sampled edges in the
minibatch. Possible values are
* None, which excludes nothing.
* ``'self'``, which excludes the sampled edges themselves but nothing else.
* ``'reverse_id'``, which excludes the reverse edges of the sampled edges. The said
reverse edges have the same edge type as the sampled edges. Only works
on edge types whose source node type is the same as its destination node type.
* ``'reverse_types'``, which excludes the reverse edges of the sampled edges. The
said reverse edges have different edge types from the sampled edges.
If ``g_sampling`` is given, ``exclude`` is ignored and will be always ``None``.
reverse_eids : Tensor or dict[etype, Tensor], optional
A tensor of reverse edge ID mapping. The i-th element indicates the ID of
the i-th edge's reverse edge.
If the graph is heterogeneous, this argument requires a dictionary of edge
types and the reverse edge ID mapping tensors.
Required and only used when ``exclude`` is set to ``reverse_id``.
For heterogeneous graph this will be a dict of edge type and edge IDs. Note that
only the edge types whose source node type is the same as destination node type
are needed.
reverse_etypes : dict[etype, etype], optional
The mapping from the edge type to its reverse edge type.
Required and only used when ``exclude`` is set to ``reverse_types``.
negative_sampler : callable, optional
The negative sampler. Can be omitted if no negative sampling is needed.
The negative sampler must be a callable that takes in the following arguments:
* The original (heterogeneous) graph.
* The ID array of sampled edges in the minibatch, or the dictionary of edge
types and ID array of sampled edges in the minibatch if the graph is
heterogeneous.
It should return
* A pair of source and destination node ID arrays as negative samples,
or a dictionary of edge types and such pairs if the graph is heterogenenous.
A set of builtin negative samplers are provided in
:ref:`the negative sampling module <api-dataloading-negative-sampling>`.
Examples
--------
The following example shows how to train a 3-layer GNN for edge classification on a
set of edges ``train_eid`` on a homogeneous undirected graph. Each node takes
messages from all neighbors.
Say that you have an array of source node IDs ``src`` and another array of destination
node IDs ``dst``. One can make it bidirectional by adding another set of edges
that connects from ``dst`` to ``src``:
>>> g = dgl.graph((torch.cat([src, dst]), torch.cat([dst, src])))
One can then know that the ID difference of an edge and its reverse edge is ``|E|``,
where ``|E|`` is the length of your source/destination array. The reverse edge
mapping can be obtained by
>>> E = len(src)
>>> reverse_eids = torch.cat([torch.arange(E, 2 * E), torch.arange(0, E)])
Note that the sampled edges as well as their reverse edges are removed from
computation dependencies of the incident nodes. This is a common trick to avoid
information leakage.
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> collator = dgl.dataloading.EdgeCollator(
... g, train_eid, sampler, exclude='reverse_id',
... reverse_eids=reverse_eids)
>>> dataloader = torch.utils.data.DataLoader(
... collator.dataset, collate_fn=collator.collate,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pair_graph, blocks in dataloader:
... train_on(input_nodes, pair_graph, blocks)
To train a 3-layer GNN for link prediction on a set of edges ``train_eid`` on a
homogeneous graph where each node takes messages from all neighbors (assume the
backend is PyTorch), with 5 uniformly chosen negative samples per edge:
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
>>> collator = dgl.dataloading.EdgeCollator(
... g, train_eid, sampler, exclude='reverse_id',
... reverse_eids=reverse_eids, negative_sampler=neg_sampler)
>>> dataloader = torch.utils.data.DataLoader(
... collator.dataset, collate_fn=collator.collate,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
... train_on(input_nodse, pair_graph, neg_pair_graph, blocks)
For heterogeneous graphs, the reverse of an edge may have a different edge type
from the original edge. For instance, consider that you have an array of
user-item clicks, representated by a user array ``user`` and an item array ``item``.
You may want to build a heterogeneous graph with a user-click-item relation and an
item-clicked-by-user relation.
>>> g = dgl.heterograph({
... ('user', 'click', 'item'): (user, item),
... ('item', 'clicked-by', 'user'): (item, user)})
To train a 3-layer GNN for edge classification on a set of edges ``train_eid`` with
type ``click``, you can write
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> collator = dgl.dataloading.EdgeCollator(
... g, {'click': train_eid}, sampler, exclude='reverse_types',
... reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'})
>>> dataloader = torch.utils.data.DataLoader(
... collator.dataset, collate_fn=collator.collate,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pair_graph, blocks in dataloader:
... train_on(input_nodes, pair_graph, blocks)
To train a 3-layer GNN for link prediction on a set of edges ``train_eid`` with type
``click``, you can write
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
>>> collator = dgl.dataloading.EdgeCollator(
... g, train_eid, sampler, exclude='reverse_types',
... reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'},
... negative_sampler=neg_sampler)
>>> dataloader = torch.utils.data.DataLoader(
... collator.dataset, collate_fn=collator.collate,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
... train_on(input_nodes, pair_graph, neg_pair_graph, blocks)
Notes
-----
For the concept of MFGs, please refer to
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
def __init__(self, g, eids, graph_sampler, g_sampling=None, exclude=None,
reverse_eids=None, reverse_etypes=None, negative_sampler=None):
self.g = g
if not isinstance(eids, Mapping):
assert len(g.etypes) == 1, \
"eids should be a dict of etype and ids for graph with multiple etypes"
self.graph_sampler = graph_sampler
# One may wish to iterate over the edges in one graph while perform sampling in
# another graph. This may be the case for iterating over validation and test
# edge set while perform neighborhood sampling on the graph formed by only
# the training edge set.
# See GCMC for an example usage.
if g_sampling is not None:
self.g_sampling = g_sampling
self.exclude = None
else:
self.g_sampling = self.g
self.exclude = exclude
self.reverse_eids = reverse_eids
self.reverse_etypes = reverse_etypes
self.negative_sampler = negative_sampler
self.eids = utils.prepare_tensor_or_dict(g, eids, 'eids')
self._dataset = utils.maybe_flatten_dict(self.eids)
@property
def dataset(self):
return self._dataset
def _collate(self, items):
if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items)
items = utils.prepare_tensor_or_dict(self.g_sampling, items, 'items')
pair_graph = self.g.edge_subgraph(items)
seed_nodes = pair_graph.ndata[NID]
exclude_eids = _find_exclude_eids(
self.g_sampling,
self.exclude,
items,
reverse_eid_map=self.reverse_eids,
reverse_etype_map=self.reverse_etypes)
input_nodes, _, blocks = self.graph_sampler.sample_blocks(
self.g_sampling, seed_nodes, exclude_eids=exclude_eids)
return input_nodes, pair_graph, blocks
def _collate_with_negative_sampling(self, items):
if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items)
items = utils.prepare_tensor_or_dict(self.g_sampling, items, 'items')
pair_graph = self.g.edge_subgraph(items, relabel_nodes=False)
induced_edges = pair_graph.edata[EID]
neg_srcdst = self.negative_sampler(self.g, items)
if not isinstance(neg_srcdst, Mapping):
assert len(self.g.etypes) == 1, \
'graph has multiple or no edge types; '\
'please return a dict in negative sampler.'
neg_srcdst = {self.g.canonical_etypes[0]: neg_srcdst}
# Get dtype from a tuple of tensors
dtype = F.dtype(list(neg_srcdst.values())[0][0])
ctx = F.context(pair_graph)
neg_edges = {
etype: neg_srcdst.get(etype, (F.copy_to(F.tensor([], dtype), ctx),
F.copy_to(F.tensor([], dtype), ctx)))
for etype in self.g.canonical_etypes}
neg_pair_graph = heterograph(
neg_edges, {ntype: self.g.number_of_nodes(ntype) for ntype in self.g.ntypes})
pair_graph, neg_pair_graph = transforms.compact_graphs([pair_graph, neg_pair_graph])
pair_graph.edata[EID] = induced_edges
seed_nodes = pair_graph.ndata[NID]
exclude_eids = _find_exclude_eids(
self.g_sampling,
self.exclude,
items,
reverse_eid_map=self.reverse_eids,
reverse_etype_map=self.reverse_etypes)
input_nodes, _, blocks = self.graph_sampler.sample_blocks(
self.g_sampling, seed_nodes, exclude_eids=exclude_eids)
return input_nodes, pair_graph, neg_pair_graph, blocks
def collate(self, items):
"""Combines the sampled edges into a minibatch for edge classification, edge
regression, and link prediction tasks.
Parameters
----------
items : list[int] or list[tuple[str, int]]
Either a list of edge IDs (for homogeneous graphs), or a list of edge type-ID
pairs (for heterogeneous graphs).
Returns
-------
Either ``(input_nodes, pair_graph, blocks)``, or
``(input_nodes, pair_graph, negative_pair_graph, blocks)`` if negative sampling is
enabled.
input_nodes : Tensor or dict[ntype, Tensor]
The input nodes necessary for computation in this minibatch.
If the original graph has multiple node types, return a dictionary of
node type names and node ID tensors. Otherwise, return a single tensor.
pair_graph : DGLGraph
The graph that contains only the edges in the minibatch as well as their incident
nodes.
Note that the metagraph of this graph will be identical to that of the original
graph.
negative_pair_graph : DGLGraph
The graph that contains only the edges connecting the source and destination nodes
yielded from the given negative sampler, if negative sampling is enabled.
Note that the metagraph of this graph will be identical to that of the original
graph.
blocks : list[DGLGraph]
The list of MFGs necessary for computing the representation of the edges.
"""
if self.negative_sampler is None:
return self._collate(items)
else:
return self._collate_with_negative_sampling(items)
class GraphCollator(object):
"""Given a set of graphs as well as their graph-level data, the collate function will batch the
graphs into a batched graph, and stack the tensors into a single bigger tensor. If the
example is a container (such as sequences or mapping), the collate function preserves
the structure and collates each of the elements recursively.
If the set of graphs has no graph-level data, the collate function will yield a batched graph.
Examples
--------
To train a GNN for graph classification on a set of graphs in ``dataset`` (assume
the backend is PyTorch):
>>> dataloader = dgl.dataloading.GraphDataLoader(
... dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for batched_graph, labels in dataloader:
... train_on(batched_graph, labels)
"""
def __init__(self):
self.graph_collate_err_msg_format = (
"graph_collate: batch must contain DGLGraph, tensors, numpy arrays, "
"numbers, dicts or lists; found {}")
self.np_str_obj_array_pattern = re.compile(r'[SaUO]')
#This implementation is based on torch.utils.data._utils.collate.default_collate
def collate(self, items):
"""This function is similar to ``torch.utils.data._utils.collate.default_collate``.
It combines the sampled graphs and corresponding graph-level data
into a batched graph and tensors.
Parameters
----------
items : list of data points or tuples
Elements in the list are expected to have the same length.
Each sub-element will be batched as a batched graph, or a
batched tensor correspondingly.
Returns
-------
A tuple of the batching results.
"""
elem = items[0]
elem_type = type(elem)
if isinstance(elem, DGLGraph):
batched_graphs = batch(items)
return batched_graphs
elif F.is_tensor(elem):
return F.stack(items, 0)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(self.graph_collate_err_msg_format.format(elem.dtype))
return self.collate([F.tensor(b) for b in items])
elif elem.shape == (): # scalars
return F.tensor(items)
elif isinstance(elem, float):
return F.tensor(items, dtype=F.float64)
elif isinstance(elem, int):
return F.tensor(items)
elif isinstance(elem, (str, bytes)):
return items
elif isinstance(elem, Mapping):
return {key: self.collate([d[key] for d in items]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(self.collate(samples) for samples in zip(*items)))
elif isinstance(elem, Sequence):
# check to make sure that the elements in batch have consistent size
item_iter = iter(items)
elem_size = len(next(item_iter))
if not all(len(elem) == elem_size for elem in item_iter):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*items)
return [self.collate(samples) for samples in transposed]
raise TypeError(self.graph_collate_err_msg_format.format(elem_type))
class SubgraphIterator(object):
"""Abstract class representing an iterator that yields a subgraph given a graph.
"""
def __init__(self, g):
self.g = g
"""Negative samplers"""
from collections.abc import Mapping
from .. import backend as F
from ..sampling import global_uniform_negative_sampling
class _BaseNegativeSampler(object):
def _generate(self, g, eids, canonical_etype):
raise NotImplementedError
def __call__(self, g, eids):
"""Returns negative samples.
Parameters
----------
g : DGLGraph
The graph.
eids : Tensor or dict[etype, Tensor]
The sampled edges in the minibatch.
Returns
-------
tuple[Tensor, Tensor] or dict[etype, tuple[Tensor, Tensor]]
The returned source-destination pairs as negative samples.
"""
if isinstance(eids, Mapping):
eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
neg_pair = {k: self._generate(g, v, k) for k, v in eids.items()}
else:
assert len(g.etypes) == 1, \
'please specify a dict of etypes and ids for graphs with multiple edge types'
neg_pair = self._generate(g, eids, g.canonical_etypes[0])
return neg_pair
class PerSourceUniform(_BaseNegativeSampler):
"""Negative sampler that randomly chooses negative destination nodes
for each source node according to a uniform distribution.
For each edge ``(u, v)`` of type ``(srctype, etype, dsttype)``, DGL generates
:attr:`k` pairs of negative edges ``(u, v')``, where ``v'`` is chosen
uniformly from all the nodes of type ``dsttype``. The resulting edges will
also have type ``(srctype, etype, dsttype)``.
Parameters
----------
k : int
The number of negative samples per edge.
Examples
--------
>>> g = dgl.graph(([0, 1, 2], [1, 2, 3]))
>>> neg_sampler = dgl.dataloading.negative_sampler.PerSourceUniform(2)
>>> neg_sampler(g, torch.tensor([0, 1]))
(tensor([0, 0, 1, 1]), tensor([1, 0, 2, 3]))
"""
def __init__(self, k):
self.k = k
def _generate(self, g, eids, canonical_etype):
_, _, vtype = canonical_etype
shape = F.shape(eids)
dtype = F.dtype(eids)
ctx = F.context(eids)
shape = (shape[0] * self.k,)
src, _ = g.find_edges(eids, etype=canonical_etype)
src = F.repeat(src, self.k, 0)
dst = F.randint(shape, dtype, ctx, 0, g.number_of_nodes(vtype))
return src, dst
# Alias
Uniform = PerSourceUniform
class GlobalUniform(_BaseNegativeSampler):
"""Negative sampler that randomly chooses negative source-destination pairs according
to a uniform distribution.
For each edge ``(u, v)`` of type ``(srctype, etype, dsttype)``, DGL generates at most
:attr:`k` pairs of negative edges ``(u', v')``, where ``u'`` is chosen uniformly from
all the nodes of type ``srctype`` and ``v'`` is chosen uniformly from all the nodes
of type ``dsttype``. The resulting edges will also have type
``(srctype, etype, dsttype)``. DGL guarantees that the sampled pairs will not have
edges in between.
Parameters
----------
k : int
The desired number of negative samples to generate per edge.
exclude_self_loops : bool, optional
Whether to exclude self-loops from negative samples. (Default: True)
replace : bool, optional
Whether to sample with replacement. Setting it to True will make things
faster. (Default: True)
redundancy : float, optional
Indicates how much more negative samples to actually generate during rejection sampling
before finding the unique pairs.
Increasing it will increase the likelihood of getting :attr:`k` negative samples
per edge, but will also take more time and memory.
(Default: automatically determined by the density of graph)
Notes
-----
This negative sampler will try to generate as many negative samples as possible, but
it may rarely return less than :attr:`k` negative samples per edge.
This is more likely to happen if a graph is so small or dense that not many unique
negative samples exist.
Examples
--------
>>> g = dgl.graph(([0, 1, 2], [1, 2, 3]))
>>> neg_sampler = dgl.dataloading.negative_sampler.GlobalUniform(2, True)
>>> neg_sampler(g, torch.LongTensor([0, 1]))
(tensor([0, 1, 3, 2]), tensor([2, 0, 2, 1]))
"""
def __init__(self, k, exclude_self_loops=True, replace=False, redundancy=None):
self.k = k
self.exclude_self_loops = exclude_self_loops
self.replace = replace
self.redundancy = redundancy
def _generate(self, g, eids, canonical_etype):
return global_uniform_negative_sampling(
g, len(eids) * self.k, self.exclude_self_loops, self.replace,
canonical_etype, self.redundancy)
"""Data loading components for neighbor sampling"""
from .dataloader import BlockSampler
from .. import sampling, distributed
from .. import ndarray as nd
from .. import backend as F
from ..base import ETYPE
class NeighborSamplingMixin(object):
"""Mixin object containing common optimizing routines that caches fanout and probability
arrays.
The mixin requires the object to have the following attributes:
- :attr:`prob`: The edge feature name that stores the (unnormalized) probability.
- :attr:`fanouts`: The list of fanouts (either an integer or a dictionary of edge
types and integers).
The mixin will generate the following attributes:
- :attr:`prob_arrays`: List of DGL NDArrays containing the unnormalized probabilities
for every edge type.
- :attr:`fanout_arrays`: List of DGL NDArrays containing the fanouts for every edge
type at every layer.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) # forward to base classes
self.fanout_arrays = []
self.prob_arrays = None
def _build_prob_arrays(self, g):
if self.prob is not None:
self.prob_arrays = [F.to_dgl_nd(g.edges[etype].data[self.prob]) for etype in g.etypes]
elif self.prob_arrays is None:
# build prob_arrays only once
self.prob_arrays = [nd.array([], ctx=nd.cpu())] * len(g.etypes)
def _build_fanout(self, block_id, g):
assert not self.fanouts is None, \
"_build_fanout() should only be called when fanouts is not None"
# build fanout_arrays only once for each layer
while block_id >= len(self.fanout_arrays):
for i in range(len(self.fanouts)):
fanout = self.fanouts[i]
if not isinstance(fanout, dict):
fanout_array = [int(fanout)] * len(g.etypes)
else:
if len(fanout) != len(g.etypes):
raise DGLError('Fan-out must be specified for each edge type '
'if a dict is provided.')
fanout_array = [None] * len(g.etypes)
for etype, value in fanout.items():
fanout_array[g.get_etype_id(etype)] = value
self.fanout_arrays.append(
F.to_dgl_nd(F.tensor(fanout_array, dtype=F.int64)))
class MultiLayerNeighborSampler(NeighborSamplingMixin, BlockSampler):
"""Sampler that builds computational dependency of node representations via
neighbor sampling for multilayer GNN.
This sampler will make every node gather messages from a fixed number of neighbors
per edge type. The neighbors are picked uniformly.
Parameters
----------
fanouts : list[int] or list[dict[etype, int]]
List of neighbors to sample per edge type for each GNN layer, with the i-th
element being the fanout for the i-th GNN layer.
If only a single integer is provided, DGL assumes that every edge type
will have the same fanout.
If -1 is provided for one edge type on one layer, then all inbound edges
of that edge type will be included.
replace : bool, default False
Whether to sample with replacement
return_eids : bool, default False
Whether to return the edge IDs involved in message passing in the MFG.
If True, the edge IDs will be stored as an edge feature named ``dgl.EID``.
prob : str, optional
If given, the probability of each neighbor being sampled is proportional
to the edge feature value with the given name in ``g.edata``. The feature must be
a scalar on each edge.
Examples
--------
To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on
a homogeneous graph where each node takes messages from 5, 10, 15 neighbors for
the first, second, and third layer respectively (assuming the backend is PyTorch):
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10, 15])
>>> dataloader = dgl.dataloading.NodeDataLoader(
... g, train_nid, sampler,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, output_nodes, blocks in dataloader:
... train_on(blocks)
If training on a heterogeneous graph and you want different number of neighbors for each
edge type, one should instead provide a list of dicts. Each dict would specify the
number of neighbors to pick per edge type.
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([
... {('user', 'follows', 'user'): 5,
... ('user', 'plays', 'game'): 4,
... ('game', 'played-by', 'user'): 3}] * 3)
If you would like non-uniform neighbor sampling:
>>> g.edata['p'] = torch.rand(g.num_edges()) # any non-negative 1D vector works
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10, 15], prob='p')
Notes
-----
For the concept of MFGs, please refer to
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
def __init__(self, fanouts, replace=False, return_eids=False, prob=None):
super().__init__(len(fanouts), return_eids)
self.fanouts = fanouts
self.replace = replace
# used to cache computations and memory allocations
# list[dgl.nd.NDArray]; each array stores the fan-outs of all edge types
self.prob = prob
@classmethod
def exclude_edges_in_frontier(cls, g):
return not isinstance(g, distributed.DistGraph) and g.device == F.cpu() \
and not g.is_pinned()
def sample_frontier(self, block_id, g, seed_nodes, exclude_eids=None):
fanout = self.fanouts[block_id]
if isinstance(g, distributed.DistGraph):
if len(g.etypes) > 1: # heterogeneous distributed graph
frontier = distributed.sample_etype_neighbors(
g, seed_nodes, ETYPE, fanout, replace=self.replace)
else:
frontier = distributed.sample_neighbors(
g, seed_nodes, fanout, replace=self.replace)
else:
self._build_fanout(block_id, g)
self._build_prob_arrays(g)
frontier = sampling.sample_neighbors(
g, seed_nodes, self.fanout_arrays[block_id],
replace=self.replace, prob=self.prob_arrays, exclude_edges=exclude_eids)
return frontier
class MultiLayerFullNeighborSampler(MultiLayerNeighborSampler):
"""Sampler that builds computational dependency of node representations by taking messages
from all neighbors for multilayer GNN.
This sampler will make every node gather messages from every single neighbor per edge type.
Parameters
----------
n_layers : int
The number of GNN layers to sample.
return_eids : bool, default False
Whether to return the edge IDs involved in message passing in the MFG.
If True, the edge IDs will be stored as an edge feature named ``dgl.EID``.
Examples
--------
To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on
a homogeneous graph where each node takes messages from all neighbors for the first,
second, and third layer respectively (assuming the backend is PyTorch):
>>> sampler = dgl.dataloading.MultiLayerFullNeighborSampler(3)
>>> dataloader = dgl.dataloading.NodeDataLoader(
... g, train_nid, sampler,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, output_nodes, blocks in dataloader:
... train_on(blocks)
Notes
-----
For the concept of MFGs, please refer to
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
def __init__(self, n_layers, return_eids=False):
super().__init__([-1] * n_layers, return_eids=return_eids)
@classmethod
def exclude_edges_in_frontier(cls, g):
return False
"""DGL PyTorch DataLoader module."""
from .dataloader import *
"""DGL PyTorch DataLoaders"""
import inspect
import math
import threading
import queue
from distutils.version import LooseVersion
import torch as th
from torch.utils.data import DataLoader, IterableDataset
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
from ..dataloader import NodeCollator, EdgeCollator, GraphCollator, SubgraphIterator
from ...distributed import DistGraph
from ...ndarray import NDArray as DGLNDArray
from ... import backend as F
from ...base import DGLError, dgl_warning
from ...utils import to_dgl_context, check_device
__all__ = ['NodeDataLoader', 'EdgeDataLoader', 'GraphDataLoader',
# Temporary exposure.
'_pop_subgraph_storage', '_pop_storages',
'_restore_subgraph_storage', '_restore_storages']
PYTORCH_VER = LooseVersion(th.__version__)
PYTORCH_16 = PYTORCH_VER >= LooseVersion("1.6.0")
PYTORCH_17 = PYTORCH_VER >= LooseVersion("1.7.0")
def _check_graph_type(g):
if isinstance(g, DistGraph):
raise TypeError("Please use DistNodeDataLoader or DistEdgeDataLoader for DistGraph")
def _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed):
# Note: will change the content of dataloader_kwargs
dist_sampler_kwargs = {'shuffle': dataloader_kwargs['shuffle']}
dataloader_kwargs['shuffle'] = False
if PYTORCH_16:
dist_sampler_kwargs['seed'] = ddp_seed
if PYTORCH_17:
dist_sampler_kwargs['drop_last'] = dataloader_kwargs['drop_last']
dataloader_kwargs['drop_last'] = False
return DistributedSampler(dataset, **dist_sampler_kwargs)
class _ScalarDataBatcherIter:
def __init__(self, dataset, batch_size, drop_last):
self.dataset = dataset
self.batch_size = batch_size
self.index = 0
self.drop_last = drop_last
# Make this an iterator for PyTorch Lightning compatibility
def __iter__(self):
return self
def __next__(self):
num_items = self.dataset.shape[0]
if self.index >= num_items:
raise StopIteration
end_idx = self.index + self.batch_size
if end_idx > num_items:
if self.drop_last:
raise StopIteration
end_idx = num_items
batch = self.dataset[self.index:end_idx]
self.index += self.batch_size
return batch
class _ScalarDataBatcher(th.utils.data.IterableDataset):
"""Custom Dataset wrapper to return mini-batches as tensors, rather than as
lists. When the dataset is on the GPU, this significantly reduces
the overhead. For the case of a batch size of 1024, instead of giving a
list of 1024 tensors to the collator, a single tensor of 1024 dimensions
is passed in.
"""
def __init__(self, dataset, shuffle=False, batch_size=1,
drop_last=False, use_ddp=False, ddp_seed=0):
super(_ScalarDataBatcher).__init__()
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.use_ddp = use_ddp
if use_ddp:
self.rank = dist.get_rank()
self.num_replicas = dist.get_world_size()
self.seed = ddp_seed
self.epoch = 0
# The following code (and the idea of cross-process shuffling with the same seed)
# comes from PyTorch. See torch/utils/data/distributed.py for details.
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any sample, since the dataset will be split evenly.
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
# `type:ignore` is required because Dataset cannot provide a default __len__
# see NOTE in pytorch/torch/utils/data/sampler.py
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
if self.use_ddp:
return self._iter_ddp()
else:
return self._iter_non_ddp()
def _divide_by_worker(self, dataset):
worker_info = th.utils.data.get_worker_info()
if worker_info:
# worker gets only a fraction of the dataset
chunk_size = dataset.shape[0] // worker_info.num_workers
left_over = dataset.shape[0] % worker_info.num_workers
start = (chunk_size*worker_info.id) + min(left_over, worker_info.id)
end = start + chunk_size + (worker_info.id < left_over)
assert worker_info.id < worker_info.num_workers-1 or \
end == dataset.shape[0]
dataset = dataset[start:end]
return dataset
def _iter_non_ddp(self):
dataset = self._divide_by_worker(self.dataset)
if self.shuffle:
# permute the dataset
perm = th.randperm(dataset.shape[0], device=dataset.device)
dataset = dataset[perm]
return _ScalarDataBatcherIter(dataset, self.batch_size, self.drop_last)
def _iter_ddp(self):
# The following code (and the idea of cross-process shuffling with the same seed)
# comes from PyTorch. See torch/utils/data/distributed.py for details.
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = th.Generator()
g.manual_seed(self.seed + self.epoch)
indices = th.randperm(len(self.dataset), generator=g)
else:
indices = th.arange(len(self.dataset))
if not self.drop_last:
# add extra samples to make it evenly divisible
indices = th.cat([indices, indices[:(self.total_size - indices.shape[0])]])
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert indices.shape[0] == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert indices.shape[0] == self.num_samples
# Dividing by worker is our own stuff.
dataset = self._divide_by_worker(self.dataset[indices])
return _ScalarDataBatcherIter(dataset, self.batch_size, self.drop_last)
def __len__(self):
num_samples = self.num_samples if self.use_ddp else self.dataset.shape[0]
return (num_samples + (0 if self.drop_last else self.batch_size - 1)) // self.batch_size
def set_epoch(self, epoch):
"""Set epoch number for distributed training."""
self.epoch = epoch
# The following code is a fix to the PyTorch-specific issue in
# https://github.com/dmlc/dgl/issues/2137
#
# Basically the sampled MFGs/subgraphs contain the features extracted from the
# parent graph. In DGL, the MFGs/subgraphs will hold a reference to the parent
# graph feature tensor and an index tensor, so that the features could be extracted upon
# request. However, in the context of multiprocessed sampling, we do not need to
# transmit the parent graph feature tensor from the subprocess to the main process,
# since they are exactly the same tensor, and transmitting a tensor from a subprocess
# to the main process is costly in PyTorch as it uses shared memory. We work around
# it with the following trick:
#
# In the collator running in the sampler processes:
# For each frame in the MFG, we check each column and the column with the same name
# in the corresponding parent frame. If the storage of the former column is the
# same object as the latter column, we are sure that the former column is a
# subcolumn of the latter, and set the storage of the former column as None.
#
# In the iterator of the main process:
# For each frame in the MFG, we check each column and the column with the same name
# in the corresponding parent frame. If the storage of the former column is None,
# we replace it with the storage of the latter column.
def _pop_subframe_storage(subframe, frame):
for key, col in subframe._columns.items():
if key in frame._columns and col.storage is frame._columns[key].storage:
col.storage = None
def _pop_subgraph_storage(subg, g):
for ntype in subg.ntypes:
if ntype not in g.ntypes:
continue
subframe = subg._node_frames[subg.get_ntype_id(ntype)]
frame = g._node_frames[g.get_ntype_id(ntype)]
_pop_subframe_storage(subframe, frame)
for etype in subg.canonical_etypes:
if etype not in g.canonical_etypes:
continue
subframe = subg._edge_frames[subg.get_etype_id(etype)]
frame = g._edge_frames[g.get_etype_id(etype)]
_pop_subframe_storage(subframe, frame)
def _pop_block_storage(block, g):
for ntype in block.srctypes:
if ntype not in g.ntypes:
continue
subframe = block._node_frames[block.get_ntype_id_from_src(ntype)]
frame = g._node_frames[g.get_ntype_id(ntype)]
_pop_subframe_storage(subframe, frame)
for ntype in block.dsttypes:
if ntype not in g.ntypes:
continue
subframe = block._node_frames[block.get_ntype_id_from_dst(ntype)]
frame = g._node_frames[g.get_ntype_id(ntype)]
_pop_subframe_storage(subframe, frame)
for etype in block.canonical_etypes:
if etype not in g.canonical_etypes:
continue
subframe = block._edge_frames[block.get_etype_id(etype)]
frame = g._edge_frames[g.get_etype_id(etype)]
_pop_subframe_storage(subframe, frame)
def _pop_storages(subgs, g):
for subg in subgs:
if subg.is_block:
_pop_block_storage(subg, g)
else:
_pop_subgraph_storage(subg, g)
def _restore_subframe_storage(subframe, frame):
for key, col in subframe._columns.items():
if col.storage is None:
col.storage = frame._columns[key].storage
def _restore_subgraph_storage(subg, g):
for ntype in subg.ntypes:
if ntype not in g.ntypes:
continue
subframe = subg._node_frames[subg.get_ntype_id(ntype)]
frame = g._node_frames[g.get_ntype_id(ntype)]
_restore_subframe_storage(subframe, frame)
for etype in subg.canonical_etypes:
if etype not in g.canonical_etypes:
continue
subframe = subg._edge_frames[subg.get_etype_id(etype)]
frame = g._edge_frames[g.get_etype_id(etype)]
_restore_subframe_storage(subframe, frame)
def _restore_block_storage(block, g):
for ntype in block.srctypes:
if ntype not in g.ntypes:
continue
subframe = block._node_frames[block.get_ntype_id_from_src(ntype)]
frame = g._node_frames[g.get_ntype_id(ntype)]
_restore_subframe_storage(subframe, frame)
for ntype in block.dsttypes:
if ntype not in g.ntypes:
continue
subframe = block._node_frames[block.get_ntype_id_from_dst(ntype)]
frame = g._node_frames[g.get_ntype_id(ntype)]
_restore_subframe_storage(subframe, frame)
for etype in block.canonical_etypes:
if etype not in g.canonical_etypes:
continue
subframe = block._edge_frames[block.get_etype_id(etype)]
frame = g._edge_frames[g.get_etype_id(etype)]
_restore_subframe_storage(subframe, frame)
def _restore_storages(subgs, g):
for subg in subgs:
if subg.is_block:
_restore_block_storage(subg, g)
else:
_restore_subgraph_storage(subg, g)
class _NodeCollator(NodeCollator):
def collate(self, items): # pylint: disable=missing-docstring
# input_nodes, output_nodes, blocks
result = super().collate(items)
_pop_storages(result[-1], self.g)
return result
class _EdgeCollator(EdgeCollator):
def collate(self, items): # pylint: disable=missing-docstring
if self.negative_sampler is None:
# input_nodes, pair_graph, blocks
result = super().collate(items)
_pop_subgraph_storage(result[1], self.g)
_pop_storages(result[-1], self.g_sampling)
return result
else:
# input_nodes, pair_graph, neg_pair_graph, blocks
result = super().collate(items)
_pop_subgraph_storage(result[1], self.g)
_pop_subgraph_storage(result[2], self.g)
_pop_storages(result[-1], self.g_sampling)
return result
class _GraphCollator(GraphCollator):
def __init__(self, subgraph_iterator, **kwargs):
super().__init__(**kwargs)
self.subgraph_iterator = subgraph_iterator
def collate(self, items):
result = super().collate(items)
if self.subgraph_iterator is not None:
_pop_storages([result], self.subgraph_iterator.g)
return result
def _to_device(data, device):
if isinstance(data, dict):
for k, v in data.items():
data[k] = v.to(device)
elif isinstance(data, list):
data = [item.to(device) for item in data]
else:
data = data.to(device)
return data
def _index_select(in_tensor, idx, pin_memory):
idx = idx.to(in_tensor.device)
shape = list(in_tensor.shape)
shape[0] = len(idx)
out_tensor = th.empty(*shape, dtype=in_tensor.dtype, pin_memory=pin_memory)
th.index_select(in_tensor, 0, idx, out=out_tensor)
return out_tensor
def _next(dl_iter, graph, device, load_input, load_output, stream=None):
# input_nodes, ouput_nodes, blocks
input_nodes, output_nodes, blocks = next(dl_iter)
_restore_storages(blocks, graph)
input_data = {}
for tag, data in load_input.items():
sliced = _index_select(data, input_nodes, data.device != device)
input_data[tag] = sliced
output_data = {}
for tag, data in load_output.items():
sliced = _index_select(data, output_nodes, data.device != device)
output_data[tag] = sliced
result_ = (input_nodes, output_nodes, blocks, input_data, output_data)
if stream is not None:
with th.cuda.stream(stream):
result = [_to_device(data, device)
for data in result_], result_, stream.record_event()
else:
result = [_to_device(data, device) for data in result_]
return result
def _background_node_dataloader(dl_iter, g, device, results, load_input, load_output):
dev = None
if device.type == 'cuda':
dev = device
elif g.device.type == 'cuda':
dev = g.device
stream = th.cuda.Stream(device=dev)
try:
while True:
results.put(_next(dl_iter, g, device, load_input, load_output, stream))
except StopIteration:
results.put((None, None, None))
class _NodeDataLoaderIter:
def __init__(self, node_dataloader, iter_):
self.device = node_dataloader.device
self.node_dataloader = node_dataloader
self.iter_ = iter_
self.async_load = node_dataloader.async_load and (
F.device_type(self.device) == 'cuda')
if self.async_load:
self.results = queue.Queue(1)
threading.Thread(target=_background_node_dataloader, args=(
self.iter_, self.node_dataloader.collator.g, self.device,
self.results, node_dataloader.load_input, node_dataloader.load_output
), daemon=True).start()
# Make this an iterator for PyTorch Lightning compatibility
def __iter__(self):
return self
def __next__(self):
res = ()
if self.async_load:
res, _, event = self.results.get()
if res is None:
raise StopIteration
event.wait(th.cuda.default_stream())
else:
res = _next(self.iter_, self.node_dataloader.collator.g, self.device,
self.node_dataloader.load_input, self.node_dataloader.load_output)
input_nodes, output_nodes, blocks, input_data, output_data = res
if input_data:
for tag, data in input_data.items():
blocks[0].srcdata[tag] = data
if output_data:
for tag, data in output_data.items():
blocks[-1].dstdata[tag] = data
return input_nodes, output_nodes, blocks
class _EdgeDataLoaderIter:
def __init__(self, edge_dataloader, iter_):
self.device = edge_dataloader.device
self.edge_dataloader = edge_dataloader
self.iter_ = iter_
# Make this an iterator for PyTorch Lightning compatibility
def __iter__(self):
return self
def __next__(self):
result_ = next(self.iter_)
if self.edge_dataloader.collator.negative_sampler is not None:
# input_nodes, pair_graph, neg_pair_graph, blocks if None.
# Otherwise, input_nodes, pair_graph, blocks
_restore_subgraph_storage(result_[2], self.edge_dataloader.collator.g)
_restore_subgraph_storage(result_[1], self.edge_dataloader.collator.g)
_restore_storages(result_[-1], self.edge_dataloader.collator.g_sampling)
result = [_to_device(data, self.device) for data in result_]
return result
class _GraphDataLoaderIter:
def __init__(self, graph_dataloader, iter_):
self.dataloader = graph_dataloader
self.iter_ = iter_
def __iter__(self):
return self
def __next__(self):
result = next(self.iter_)
if self.dataloader.is_subgraph_loader:
_restore_storages([result], g)
return result
def _init_dataloader(collator, device, dataloader_kwargs, use_ddp, ddp_seed):
dataset = collator.dataset
use_scalar_batcher = False
scalar_batcher = None
if device.type == 'cuda' and dataloader_kwargs.get('num_workers', 0) == 0:
batch_size = dataloader_kwargs.get('batch_size', 1)
if batch_size > 1:
if isinstance(dataset, DGLNDArray):
# the dataset needs to be a torch tensor for the
# _ScalarDataBatcher
dataset = F.zerocopy_from_dgl_ndarray(dataset)
if isinstance(dataset, th.Tensor):
shuffle = dataloader_kwargs.get('shuffle', False)
drop_last = dataloader_kwargs.get('drop_last', False)
# manually batch into tensors
dataset = _ScalarDataBatcher(dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
use_ddp=use_ddp,
ddp_seed=ddp_seed)
# need to overwrite things that will be handled by the batcher
dataloader_kwargs['batch_size'] = None
dataloader_kwargs['shuffle'] = False
dataloader_kwargs['drop_last'] = False
use_scalar_batcher = True
scalar_batcher = dataset
if use_ddp and not use_scalar_batcher:
dist_sampler = _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed)
dataloader_kwargs['sampler'] = dist_sampler
else:
dist_sampler = None
return use_scalar_batcher, scalar_batcher, dataset, collator, dist_sampler
class NodeDataLoader(DataLoader):
"""PyTorch dataloader for batch-iterating over a set of nodes, generating the list
of message flow graphs (MFGs) as computation dependency of the said minibatch.
Parameters
----------
g : DGLGraph
The graph.
nids : Tensor or dict[ntype, Tensor]
The node set to compute outputs.
graph_sampler : dgl.dataloading.Sampler
The neighborhood sampler.
device : device context, optional
The device of the generated MFGs in each iteration, which should be a
PyTorch device object (e.g., ``torch.device``).
By default this value is the same as the device of :attr:`g`.
use_ddp : boolean, optional
If True, tells the DataLoader to split the training set for each
participating process appropriately using
:class:`torch.utils.data.distributed.DistributedSampler`.
Note that :func:`~dgl.dataloading.NodeDataLoader.set_epoch` must be called
at the beginning of every epoch if :attr:`use_ddp` is True.
Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`.
ddp_seed : int, optional
The seed for shuffling the dataset in
:class:`torch.utils.data.distributed.DistributedSampler`.
Only effective when :attr:`use_ddp` is True.
load_input : dict[tag, Tensor], optional
The tensors will be sliced according to ``blocks[0].srcdata[dgl.NID]``
and will be attached to ``blocks[0].srcdata``.
load_output : dict[tag, Tensor], optional
The tensors will be sliced according to ``blocks[-1].dstdata[dgl.NID]``
and will be attached to ``blocks[-1].dstdata``.
async_load : boolean, optional
If True, data including graph, sliced tensors will be transferred
between devices asynchronously.This is transparent to end users. This
feature could speed up model train, especially when large data need
to be transferred. As a disadvantage, underlying `to_block` on GPU
becomes disabled and could lead to decreased performance. This is a
trade-off which needs profiling to decide whether to enable it.
kwargs : dict
Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
Examples
--------
To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on
a homogeneous graph where each node takes messages from all neighbors (assume
the backend is PyTorch):
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.NodeDataLoader(
... g, train_nid, sampler,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, output_nodes, blocks in dataloader:
... train_on(input_nodes, output_nodes, blocks)
**Using with Distributed Data Parallel**
If you are using PyTorch's distributed training (e.g. when using
:mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by turning
on the `use_ddp` option:
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.NodeDataLoader(
... g, train_nid, sampler, use_ddp=True,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for epoch in range(start_epoch, n_epochs):
... dataloader.set_epoch(epoch)
... for input_nodes, output_nodes, blocks in dataloader:
... train_on(input_nodes, output_nodes, blocks)
Notes
-----
Please refer to
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`
and :ref:`User Guide Section 6 <guide-minibatch>` for usage.
**Tips for selecting the proper device**
* If the input graph :attr:`g` is on GPU, the output device :attr:`device` must be the same GPU
and :attr:`num_workers` must be zero. In this case, the sampling and subgraph construction
will take place on the GPU. This is the recommended setting when using a single-GPU and
the whole graph fits in GPU memory.
* If the input graph :attr:`g` is on CPU while the output device :attr:`device` is GPU, then
depending on the value of :attr:`num_workers`:
- If :attr:`num_workers` is set to 0, the sampling will happen on the CPU, and then the
subgraphs will be constructed directly on the GPU. This hybrid mode is deprecated and
will be removed in the next release. Use UVA sampling instead, especially in
multi-GPU configurations.
- Otherwise, if :attr:`num_workers` is greater than 0, both the sampling and subgraph
construction will take place on the CPU. This is the recommended setting when using a
single-GPU and the whole graph does not fit in GPU memory.
"""
collator_arglist = inspect.getfullargspec(NodeCollator).args
def __init__(self, g, nids, graph_sampler, device=None, use_ddp=False, ddp_seed=0,
load_input=None, load_output=None, async_load=False, **kwargs):
_check_graph_type(g)
collator_kwargs = {}
dataloader_kwargs = {}
for k, v in kwargs.items():
if k in self.collator_arglist:
collator_kwargs[k] = v
else:
dataloader_kwargs[k] = v
# default to the same device the graph is on
device = th.device(g.device if device is None else device)
num_workers = dataloader_kwargs.get('num_workers', 0)
if g.device.type == 'cuda' or g.is_pinned():
sampling_type = 'UVA sampling' if g.is_pinned() else 'GPU sampling'
assert device.type == 'cuda', \
f"'device' must be a cuda device to enable {sampling_type}, got {device}."
assert check_device(nids, device), \
f"'nids' must be on {device} to use {sampling_type}."
assert num_workers == 0, \
f"'num_workers' must be 0 to use {sampling_type}."
# g is on CPU
elif device.type == 'cuda' and num_workers == 0:
dgl_warning('CPU-GPU hybrid sampling is deprecated and will be removed '
'in the next release. Use pure GPU sampling if your graph can '
'fit onto the GPU memory, or UVA sampling in other cases.')
if not g.is_homogeneous:
if load_input or load_output:
raise DGLError('load_input/load_output not supported for heterograph yet.')
self.load_input = {} if load_input is None else load_input
self.load_output = {} if load_output is None else load_output
self.async_load = async_load
# if the sampler supports it, tell it to output to the specified device.
# But if async_load is enabled, set_output_context should be skipped as
# we'd like to avoid any graph/data transfer graphs across devices in
# sampler. Such transfer will be handled in dataloader.
if ((not async_load) and
callable(getattr(graph_sampler, "set_output_context", None)) and
num_workers == 0):
graph_sampler.set_output_context(to_dgl_context(device))
self.collator = _NodeCollator(g, nids, graph_sampler, **collator_kwargs)
self.use_scalar_batcher, self.scalar_batcher, self.dataloader, self.dist_sampler = \
_init_dataloader(self.collator, device, dataloader_kwargs, use_ddp, ddp_seed)
self.use_ddp = use_ddp
self.is_distributed = False
# Precompute the CSR and CSC representations so each subprocess does not
# duplicate.
if num_workers > 0:
g.create_formats_()
self.device = device
def __iter__(self):
return _NodeDataLoaderIter(self, super().__iter__())
def set_epoch(self, epoch):
"""Sets the epoch number for the underlying sampler which ensures all replicas
to use a different ordering for each epoch.
Only available when :attr:`use_ddp` is True.
Calls :meth:`torch.utils.data.distributed.DistributedSampler.set_epoch`.
Parameters
----------
epoch : int
The epoch number.
"""
if self.use_ddp:
if self.use_scalar_batcher:
self.scalar_batcher.set_epoch(epoch)
else:
self.dist_sampler.set_epoch(epoch)
else:
raise DGLError('set_epoch is only available when use_ddp is True.')
class EdgeDataLoader(DataLoader):
"""PyTorch dataloader for batch-iterating over a set of edges, generating the list
of message flow graphs (MFGs) as computation dependency of the said minibatch for
edge classification, edge regression, and link prediction.
For each iteration, the object will yield
* A tensor of input nodes necessary for computing the representation on edges, or
a dictionary of node type names and such tensors.
* A subgraph that contains only the edges in the minibatch and their incident nodes.
Note that the graph has an identical metagraph with the original graph.
* If a negative sampler is given, another graph that contains the "negative edges",
connecting the source and destination nodes yielded from the given negative sampler.
* A list of MFGs necessary for computing the representation of the incident nodes
of the edges in the minibatch.
For more details, please refer to :ref:`guide-minibatch-edge-classification-sampler`
and :ref:`guide-minibatch-link-classification-sampler`.
Parameters
----------
g : DGLGraph
The graph. Currently must be on CPU; GPU is not supported.
eids : Tensor or dict[etype, Tensor]
The edge set in graph :attr:`g` to compute outputs.
graph_sampler : dgl.dataloading.Sampler
The neighborhood sampler.
device : device context, optional
The device of the generated MFGs and graphs in each iteration, which should be a
PyTorch device object (e.g., ``torch.device``).
By default this value is the same as the device of :attr:`g`.
g_sampling : DGLGraph, optional
The graph where neighborhood sampling is performed.
One may wish to iterate over the edges in one graph while perform sampling in
another graph. This may be the case for iterating over validation and test
edge set while perform neighborhood sampling on the graph formed by only
the training edge set.
If None, assume to be the same as ``g``.
exclude : str, optional
Whether and how to exclude dependencies related to the sampled edges in the
minibatch. Possible values are
* None,
* ``self``,
* ``reverse_id``,
* ``reverse_types``
See the description of the argument with the same name in the docstring of
:class:`~dgl.dataloading.EdgeCollator` for more details.
reverse_eids : Tensor or dict[etype, Tensor], optional
A tensor of reverse edge ID mapping. The i-th element indicates the ID of
the i-th edge's reverse edge.
If the graph is heterogeneous, this argument requires a dictionary of edge
types and the reverse edge ID mapping tensors.
See the description of the argument with the same name in the docstring of
:class:`~dgl.dataloading.EdgeCollator` for more details.
reverse_etypes : dict[etype, etype], optional
The mapping from the original edge types to their reverse edge types.
See the description of the argument with the same name in the docstring of
:class:`~dgl.dataloading.EdgeCollator` for more details.
negative_sampler : callable, optional
The negative sampler.
See the description of the argument with the same name in the docstring of
:class:`~dgl.dataloading.EdgeCollator` for more details.
use_ddp : boolean, optional
If True, tells the DataLoader to split the training set for each
participating process appropriately using
:mod:`torch.utils.data.distributed.DistributedSampler`.
Note that :func:`~dgl.dataloading.NodeDataLoader.set_epoch` must be called
at the beginning of every epoch if :attr:`use_ddp` is True.
The dataloader will have a :attr:`dist_sampler` attribute to set the
epoch number, as recommended by PyTorch.
Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`.
ddp_seed : int, optional
The seed for shuffling the dataset in
:class:`torch.utils.data.distributed.DistributedSampler`.
Only effective when :attr:`use_ddp` is True.
kwargs : dict
Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
Examples
--------
The following example shows how to train a 3-layer GNN for edge classification on a
set of edges ``train_eid`` on a homogeneous undirected graph. Each node takes
messages from all neighbors.
Say that you have an array of source node IDs ``src`` and another array of destination
node IDs ``dst``. One can make it bidirectional by adding another set of edges
that connects from ``dst`` to ``src``:
>>> g = dgl.graph((torch.cat([src, dst]), torch.cat([dst, src])))
One can then know that the ID difference of an edge and its reverse edge is ``|E|``,
where ``|E|`` is the length of your source/destination array. The reverse edge
mapping can be obtained by
>>> E = len(src)
>>> reverse_eids = torch.cat([torch.arange(E, 2 * E), torch.arange(0, E)])
Note that the sampled edges as well as their reverse edges are removed from
computation dependencies of the incident nodes. That is, the edge will not
involve in neighbor sampling and message aggregation. This is a common trick
to avoid information leakage.
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.EdgeDataLoader(
... g, train_eid, sampler, exclude='reverse_id',
... reverse_eids=reverse_eids,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pair_graph, blocks in dataloader:
... train_on(input_nodes, pair_graph, blocks)
To train a 3-layer GNN for link prediction on a set of edges ``train_eid`` on a
homogeneous graph where each node takes messages from all neighbors (assume the
backend is PyTorch), with 5 uniformly chosen negative samples per edge:
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
>>> dataloader = dgl.dataloading.EdgeDataLoader(
... g, train_eid, sampler, exclude='reverse_id',
... reverse_eids=reverse_eids, negative_sampler=neg_sampler,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
... train_on(input_nodse, pair_graph, neg_pair_graph, blocks)
For heterogeneous graphs, the reverse of an edge may have a different edge type
from the original edge. For instance, consider that you have an array of
user-item clicks, representated by a user array ``user`` and an item array ``item``.
You may want to build a heterogeneous graph with a user-click-item relation and an
item-clicked-by-user relation.
>>> g = dgl.heterograph({
... ('user', 'click', 'item'): (user, item),
... ('item', 'clicked-by', 'user'): (item, user)})
To train a 3-layer GNN for edge classification on a set of edges ``train_eid`` with
type ``click``, you can write
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.EdgeDataLoader(
... g, {'click': train_eid}, sampler, exclude='reverse_types',
... reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'},
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pair_graph, blocks in dataloader:
... train_on(input_nodes, pair_graph, blocks)
To train a 3-layer GNN for link prediction on a set of edges ``train_eid`` with type
``click``, you can write
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
>>> dataloader = dgl.dataloading.EdgeDataLoader(
... g, train_eid, sampler, exclude='reverse_types',
... reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'},
... negative_sampler=neg_sampler,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
... train_on(input_nodes, pair_graph, neg_pair_graph, blocks)
**Using with Distributed Data Parallel**
If you are using PyTorch's distributed training (e.g. when using
:mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by
turning on the :attr:`use_ddp` option:
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.EdgeDataLoader(
... g, train_eid, sampler, use_ddp=True, exclude='reverse_id',
... reverse_eids=reverse_eids,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for epoch in range(start_epoch, n_epochs):
... dataloader.set_epoch(epoch)
... for input_nodes, pair_graph, blocks in dataloader:
... train_on(input_nodes, pair_graph, blocks)
See also
--------
dgl.dataloading.dataloader.EdgeCollator
Notes
-----
Please refer to
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`
and :ref:`User Guide Section 6 <guide-minibatch>` for usage.
For end-to-end usages, please refer to the following tutorial/examples:
* Edge classification on heterogeneous graph: GCMC
* Link prediction on homogeneous graph: GraphSAGE for unsupervised learning
* Link prediction on heterogeneous graph: RGCN for link prediction.
"""
collator_arglist = inspect.getfullargspec(EdgeCollator).args
def __init__(self, g, eids, graph_sampler, device=None, use_ddp=False, ddp_seed=0, **kwargs):
_check_graph_type(g)
collator_kwargs = {}
dataloader_kwargs = {}
for k, v in kwargs.items():
if k in self.collator_arglist:
collator_kwargs[k] = v
else:
dataloader_kwargs[k] = v
# default to the same device the graph is on
device = th.device(g.device if device is None else device)
num_workers = dataloader_kwargs.get('num_workers', 0)
if g.device.type == 'cuda' or g.is_pinned():
sampling_type = 'UVA sampling' if g.is_pinned() else 'GPU sampling'
assert device.type == 'cuda', \
f"'device' must be a cuda device to enable {sampling_type}, got {device}."
assert check_device(eids, device), \
f"'eids' must be on {device} to use {sampling_type}."
assert num_workers == 0, \
f"'num_workers' must be 0 to use {sampling_type}."
# g is on CPU
elif device.type == 'cuda' and num_workers == 0:
dgl_warning('CPU-GPU hybrid sampling is deprecated and will be removed '
'in the next release. Use pure GPU sampling if your graph can '
'fit onto the GPU memory, or UVA sampling in other cases.')
# if the sampler supports it, tell it to output to the
# specified device
if callable(getattr(graph_sampler, "set_output_context", None)) and num_workers == 0:
graph_sampler.set_output_context(to_dgl_context(device))
self.collator = EdgeCollator(g, eids, graph_sampler, **collator_kwargs)
self.use_scalar_batcher, self.scalar_batcher, dataset, collator, self.dist_sampler = \
_init_dataloader(self.collator, device, dataloader_kwargs, use_ddp, ddp_seed)
self.use_ddp = use_ddp
super().__init__(dataset, collate_fn=collator.collate, **dataloader_kwargs)
# Precompute the CSR and CSC representations so each subprocess does not duplicate.
if num_workers > 0:
g.create_formats_()
self.device = device
def __iter__(self):
return _EdgeDataLoaderIter(self, super().__iter__())
def set_epoch(self, epoch):
"""Sets the epoch number for the underlying sampler which ensures all replicas
to use a different ordering for each epoch.
Only available when :attr:`use_ddp` is True.
Calls :meth:`torch.utils.data.distributed.DistributedSampler.set_epoch`.
Parameters
----------
epoch : int
The epoch number.
"""
if self.use_ddp:
if self.use_scalar_batcher:
self.scalar_batcher.set_epoch(epoch)
else:
self.dist_sampler.set_epoch(epoch)
else:
raise DGLError('set_epoch is only available when use_ddp is True.')
class GraphDataLoader(DataLoader):
"""PyTorch dataloader for batch-iterating over a set of graphs, generating the batched
graph and corresponding label tensor (if provided) of the said minibatch.
Parameters
----------
collate_fn : Function, default is None
The customized collate function. Will use the default collate
function if not given.
use_ddp : boolean, optional
If True, tells the DataLoader to split the training set for each
participating process appropriately using
:class:`torch.utils.data.distributed.DistributedSampler`.
Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`.
ddp_seed : int, optional
The seed for shuffling the dataset in
:class:`torch.utils.data.distributed.DistributedSampler`.
Only effective when :attr:`use_ddp` is True.
kwargs : dict
Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
Examples
--------
To train a GNN for graph classification on a set of graphs in ``dataset`` (assume
the backend is PyTorch):
>>> dataloader = dgl.dataloading.GraphDataLoader(
... dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for batched_graph, labels in dataloader:
... train_on(batched_graph, labels)
**Using with Distributed Data Parallel**
If you are using PyTorch's distributed training (e.g. when using
:mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by
turning on the :attr:`use_ddp` option:
>>> dataloader = dgl.dataloading.GraphDataLoader(
... dataset, use_ddp=True, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for epoch in range(start_epoch, n_epochs):
... dataloader.set_epoch(epoch)
... for batched_graph, labels in dataloader:
... train_on(batched_graph, labels)
"""
collator_arglist = inspect.getfullargspec(GraphCollator).args
def __init__(self, dataset, collate_fn=None, use_ddp=False, ddp_seed=0, **kwargs):
collator_kwargs = {}
dataloader_kwargs = {}
for k, v in kwargs.items():
if k in self.collator_arglist:
collator_kwargs[k] = v
else:
dataloader_kwargs[k] = v
# If the dataset is an infinite SubgraphIterator (i.e. without __len__) over a
# larger graph, convert it to an IterableDataset.
if isinstance(dataset, SubgraphIterator) and not hasattr(dataset, '__len__'):
class _Dataset(IterableDataset):
def __init__(self, iter_):
self._it = iter_
def __iter__(self):
return iter(self._it)
self.subgraph_iterator = dataset
dataset = _Dataset(dataset)
self.is_subgraph_loader = True
else:
self.is_subgraph_loader = False
self.subgraph_iterator = None
if collate_fn is None:
self.collate = _GraphCollator(self.subgraph_iterator, **collator_kwargs).collate
else:
self.collate = collate_fn
self.use_ddp = use_ddp
if use_ddp:
self.dist_sampler = _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed)
dataloader_kwargs['sampler'] = self.dist_sampler
super().__init__(dataset, collate_fn=self.collate, **dataloader_kwargs)
def __iter__(self):
"""Return the iterator of the data loader."""
return _GraphDataLoaderIter(self, super().__iter__())
def __len__(self):
"""Return the number of batches of the data loader."""
return len(self.dataloader)
def set_epoch(self, epoch):
"""Sets the epoch number for the underlying sampler which ensures all replicas
to use a different ordering for each epoch.
Only available when :attr:`use_ddp` is True.
Calls :meth:`torch.utils.data.distributed.DistributedSampler.set_epoch`.
Parameters
----------
epoch : int
The epoch number.
"""
if self.use_ddp:
self.dist_sampler.set_epoch(epoch)
else:
raise DGLError('set_epoch is only available when use_ddp is True.')
"""ShaDow-GNN subgraph samplers."""
from ..utils import prepare_tensor_or_dict
from ..base import NID
from .. import transforms
from ..sampling import sample_neighbors
from .neighbor import NeighborSamplingMixin
from .dataloader import exclude_edges, Sampler
class ShaDowKHopSampler(NeighborSamplingMixin, Sampler):
"""K-hop subgraph sampler used by
`ShaDow-GNN <https://arxiv.org/abs/2012.01380>`__.
It performs node-wise neighbor sampling but instead of returning a list of
MFGs, it returns a single subgraph induced by all the sampled nodes. The
seed nodes from which the neighbors are sampled will appear the first in the
induced nodes of the subgraph.
This is used in conjunction with :class:`dgl.dataloading.pytorch.NodeDataLoader`
and :class:`dgl.dataloading.pytorch.EdgeDataLoader`.
Parameters
----------
fanouts : list[int] or list[dict[etype, int]]
List of neighbors to sample per edge type for each GNN layer, with the i-th
element being the fanout for the i-th GNN layer.
If only a single integer is provided, DGL assumes that every edge type
will have the same fanout.
If -1 is provided for one edge type on one layer, then all inbound edges
of that edge type will be included.
replace : bool, default True
Whether to sample with replacement
prob : str, optional
If given, the probability of each neighbor being sampled is proportional
to the edge feature value with the given name in ``g.edata``. The feature must be
a scalar on each edge.
Examples
--------
To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on
a homogeneous graph where each node takes messages from 5, 10, 15 neighbors for
the first, second, and third layer respectively (assuming the backend is PyTorch):
>>> g = dgl.data.CoraFullDataset()[0]
>>> sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15])
>>> dataloader = dgl.dataloading.NodeDataLoader(
... g, torch.arange(g.num_nodes()), sampler,
... batch_size=5, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, output_nodes, (subgraph,) in dataloader:
... print(subgraph)
... assert torch.equal(input_nodes, subgraph.ndata[dgl.NID])
... assert torch.equal(input_nodes[:output_nodes.shape[0]], output_nodes)
... break
Graph(num_nodes=529, num_edges=3796,
ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64),
'feat': Scheme(shape=(8710,), dtype=torch.float32),
'_ID': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})
If training on a heterogeneous graph and you want different number of neighbors for each
edge type, one should instead provide a list of dicts. Each dict would specify the
number of neighbors to pick per edge type.
>>> sampler = dgl.dataloading.ShaDowKHopSampler([
... {('user', 'follows', 'user'): 5,
... ('user', 'plays', 'game'): 4,
... ('game', 'played-by', 'user'): 3}] * 3)
If you would like non-uniform neighbor sampling:
>>> g.edata['p'] = torch.rand(g.num_edges()) # any non-negative 1D vector works
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10, 15], prob='p')
"""
def __init__(self, fanouts, replace=False, prob=None, output_ctx=None):
super().__init__(output_ctx)
self.fanouts = fanouts
self.replace = replace
self.prob = prob
self.set_output_context(output_ctx)
def sample(self, g, seed_nodes, exclude_eids=None):
self._build_fanout(len(self.fanouts), g)
self._build_prob_arrays(g)
seed_nodes = prepare_tensor_or_dict(g, seed_nodes, 'seed nodes')
output_nodes = seed_nodes
for i in range(len(self.fanouts)):
fanout = self.fanouts[i]
frontier = sample_neighbors(
g, seed_nodes, fanout, replace=self.replace, prob=self.prob_arrays)
block = transforms.to_block(frontier, seed_nodes)
seed_nodes = block.srcdata[NID]
subg = g.subgraph(seed_nodes, relabel_nodes=True)
subg = exclude_edges(subg, exclude_eids, self.output_device)
return seed_nodes, output_nodes, [subg]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment