"vscode:/vscode.git/clone" did not exist on "f090e397df7527415d1c1ea2cb7aa665ba8fc945"
Unverified Commit 436de3d1 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[API Deprecation] Remove _dataloading and tgcn example (#5118)

parent c8bc5588
# Temporal Graph Neural Network (TGN)
## DGL Implementation of tgn paper.
This DGL examples implements the GNN mode proposed in the paper [TemporalGraphNeuralNetwork](https://arxiv.org/abs/2006.10637.pdf)
## TGN implementor
This example was implemented by [Ericcsr](https://github.com/Ericcsr) during his SDE internship at the AWS Shanghai AI Lab.
## Graph Dataset
Jodie Wikipedia Temporal dataset. Dataset summary:
- Num Nodes: 9227
- Num Edges: 157, 474
- Num Edge Features: 172
- Edge Feature type: LIWC
- Time Span: 30 days
- Chronological Split: Train: 70% Valid: 15% Test: 15%
Jodie Reddit Temporal dataset. Dataset summary:
- Num Nodes: 11,000
- Num Edges: 672, 447
- Num Edge Features: 172
- Edge Feature type: LIWC
- Time Span: 30 days
- Chronological Split: Train: 70% Valid: 15% Test: 15%
## How to run example files
In tgn folder, run
**please use `train.py`**
```python
python train.py --dataset wikipedia
```
If you want to run in fast mode:
```python
python train.py --dataset wikipedia --fast_mode
```
If you want to run in simple mode:
```python
python train.py --dataset wikipedia --simple_mode
```
If you want to change memory updating module:
```python
python train.py --dataset wikipedia --memory_updater [rnn/gru]
```
If you want to use TGAT:
```python
python train.py --dataset wikipedia --not_use_memory --k_hop 2
```
## Performance
#### Without New Node in test set
| Models/Datasets | Wikipedia | Reddit |
| --------------- | ------------------ | ---------------- |
| TGN simple mode | AP: 98.5 AUC: 98.9 | AP: N/A AUC: N/A |
| TGN fast mode | AP: 98.2 AUC: 98.6 | AP: N/A AUC: N/A |
| TGN | AP: 98.9 AUC: 98.5 | AP: N/A AUC: N/A |
#### With New Node in test set
| Models/Datasets | Wikipedia | Reddit |
| --------------- | ------------------- | ---------------- |
| TGN simple mode | AP: 98.2 AUC: 98.6 | AP: N/A AUC: N/A |
| TGN fast mode | AP: 98.0 AUC: 98.4 | AP: N/A AUC: N/A |
| TGN | AP: 98.2 AUC: 98.1 | AP: N/A AUC: N/A |
## Training Speed / Batch
Intel E5 2cores, Tesla K80, Wikipedia Dataset
| Models/Datasets | Wikipedia | Reddit |
| --------------- | --------- | -------- |
| TGN simple mode | 0.3s | N/A |
| TGN fast mode | 0.28s | N/A |
| TGN | 1.3s | N/A |
### Details explained
**What is Simple Mode**
Simple Temporal Sampler just choose the edges that happen before the current timestamp and build the subgraph of the corresponding nodes.
And then the simple sampler uses the static graph neighborhood sampling methods.
**What is Fast Mode**
Normally temporal encoding needs each node to use incoming time frame as current time which might lead to two nodes have multiple interactions within the same batch need to maintain multiple embedding features which slow down the batching process to avoid feature duplication, fast mode enables fast batching since it uses last memory update time in the last batch as temporal encoding benchmark for each node. Also within each batch, all interaction between two nodes are predicted using the same set of embedding feature
**What is New Node test**
To test the model has the ability to predict link between unseen nodes based on neighboring information of seen nodes. This model deliberately select 10 % of node in test graph and mask them out during the training.
**Why the attention module is not exactly same as TGN original paper**
Attention module used in this model is adapted from DGL GATConv, considering edge feature and time encoding. It is more memory efficient and faster to compute then the attention module proposed in the paper, meanwhile, according to our test, the accuracy of our module compared with the one in paper is the same.
import os
import ssl
import numpy as np
import pandas as pd
import torch
from six.moves import urllib
import dgl
# === Below data preprocessing code are based on
# https://github.com/twitter-research/tgn
# Preprocess the raw data split each features
def preprocess(data_name):
u_list, i_list, ts_list, label_list = [], [], [], []
feat_l = []
idx_list = []
with open(data_name) as f:
s = next(f)
for idx, line in enumerate(f):
e = line.strip().split(",")
u = int(e[0])
i = int(e[1])
ts = float(e[2])
label = float(e[3]) # int(e[3])
feat = np.array([float(x) for x in e[4:]])
u_list.append(u)
i_list.append(i)
ts_list.append(ts)
label_list.append(label)
idx_list.append(idx)
feat_l.append(feat)
return pd.DataFrame(
{
"u": u_list,
"i": i_list,
"ts": ts_list,
"label": label_list,
"idx": idx_list,
}
), np.array(feat_l)
# Re index nodes for DGL convience
def reindex(df, bipartite=True):
new_df = df.copy()
if bipartite:
assert df.u.max() - df.u.min() + 1 == len(df.u.unique())
assert df.i.max() - df.i.min() + 1 == len(df.i.unique())
upper_u = df.u.max() + 1
new_i = df.i + upper_u
new_df.i = new_i
new_df.u += 1
new_df.i += 1
new_df.idx += 1
else:
new_df.u += 1
new_df.i += 1
new_df.idx += 1
return new_df
# Save edge list, features in different file for data easy process data
def run(data_name, bipartite=True):
PATH = "./data/{}.csv".format(data_name)
OUT_DF = "./data/ml_{}.csv".format(data_name)
OUT_FEAT = "./data/ml_{}.npy".format(data_name)
OUT_NODE_FEAT = "./data/ml_{}_node.npy".format(data_name)
df, feat = preprocess(PATH)
new_df = reindex(df, bipartite)
empty = np.zeros(feat.shape[1])[np.newaxis, :]
feat = np.vstack([empty, feat])
max_idx = max(new_df.u.max(), new_df.i.max())
rand_feat = np.zeros((max_idx + 1, 172))
new_df.to_csv(OUT_DF)
np.save(OUT_FEAT, feat)
np.save(OUT_NODE_FEAT, rand_feat)
# === code from twitter-research-tgn end ===
# If you have new dataset follow by same format in Jodie,
# you can directly use name to retrieve dataset
def TemporalDataset(dataset):
if not os.path.exists("./data/{}.bin".format(dataset)):
if not os.path.exists("./data/{}.csv".format(dataset)):
if not os.path.exists("./data"):
os.mkdir("./data")
url = "https://snap.stanford.edu/jodie/{}.csv".format(dataset)
print("Start Downloading File....")
context = ssl._create_unverified_context()
data = urllib.request.urlopen(url, context=context)
with open("./data/{}.csv".format(dataset), "wb") as handle:
handle.write(data.read())
print("Start Process Data ...")
run(dataset)
raw_connection = pd.read_csv("./data/ml_{}.csv".format(dataset))
raw_feature = np.load("./data/ml_{}.npy".format(dataset))
# -1 for re-index the node
src = raw_connection["u"].to_numpy() - 1
dst = raw_connection["i"].to_numpy() - 1
# Create directed graph
g = dgl.graph((src, dst))
g.edata["timestamp"] = torch.from_numpy(raw_connection["ts"].to_numpy())
g.edata["label"] = torch.from_numpy(raw_connection["label"].to_numpy())
g.edata["feats"] = torch.from_numpy(raw_feature[1:, :]).float()
dgl.save_graphs("./data/{}.bin".format(dataset), [g])
else:
print("Data is exist directly loaded.")
gs, _ = dgl.load_graphs("./data/{}.bin".format(dataset))
g = gs[0]
return g
def TemporalWikipediaDataset():
# Download the dataset
return TemporalDataset("wikipedia")
def TemporalRedditDataset():
return TemporalDataset("reddit")
This diff is collapsed.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.base import DGLError
from dgl.ops import edge_softmax
class Identity(nn.Module):
"""A placeholder identity operator that is argument-insensitive.
(Identity has already been supported by PyTorch 1.2, we will directly
import torch.nn.Identity in the future)
"""
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
"""Return input"""
return x
class MsgLinkPredictor(nn.Module):
"""Predict Pair wise link from pos subg and neg subg
use message passing.
Use Two layer MLP on edge to predict the link probability
Parameters
----------
embed_dim : int
dimension of each each feature's embedding
Example
----------
>>> linkpred = MsgLinkPredictor(10)
>>> pos_g = dgl.graph(([0,1,2,3,4],[1,2,3,4,0]))
>>> neg_g = dgl.graph(([0,1,2,3,4],[2,1,4,3,0]))
>>> x = torch.ones(5,10)
>>> linkpred(x,pos_g,neg_g)
(tensor([[0.0902],
[0.0902],
[0.0902],
[0.0902],
[0.0902]], grad_fn=<AddmmBackward>),
tensor([[0.0902],
[0.0902],
[0.0902],
[0.0902],
[0.0902]], grad_fn=<AddmmBackward>))
"""
def __init__(self, emb_dim):
super(MsgLinkPredictor, self).__init__()
self.src_fc = nn.Linear(emb_dim, emb_dim)
self.dst_fc = nn.Linear(emb_dim, emb_dim)
self.out_fc = nn.Linear(emb_dim, 1)
def link_pred(self, edges):
src_hid = self.src_fc(edges.src["embedding"])
dst_hid = self.dst_fc(edges.dst["embedding"])
score = F.relu(src_hid + dst_hid)
score = self.out_fc(score)
return {"score": score}
def forward(self, x, pos_g, neg_g):
# Local Scope?
pos_g.ndata["embedding"] = x
neg_g.ndata["embedding"] = x
pos_g.apply_edges(self.link_pred)
neg_g.apply_edges(self.link_pred)
pos_escore = pos_g.edata["score"]
neg_escore = neg_g.edata["score"]
return pos_escore, neg_escore
class TimeEncode(nn.Module):
"""Use finite fourier series with different phase and frequency to encode
time different between two event
..math::
\Phi(t) = [\cos(\omega_0t+\psi_0),\cos(\omega_1t+\psi_1),...,\cos(\omega_nt+\psi_n)]
Parameter
----------
dimension : int
Length of the fourier series. The longer it is ,
the more timescale information it can capture
Example
----------
>>> tecd = TimeEncode(10)
>>> t = torch.tensor([[1]])
>>> tecd(t)
tensor([[[0.5403, 0.9950, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
1.0000, 1.0000]]], dtype=torch.float64, grad_fn=<CosBackward>)
"""
def __init__(self, dimension):
super(TimeEncode, self).__init__()
self.dimension = dimension
self.w = torch.nn.Linear(1, dimension)
self.w.weight = torch.nn.Parameter(
(torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension)))
.double()
.reshape(dimension, -1)
)
self.w.bias = torch.nn.Parameter(torch.zeros(dimension).double())
def forward(self, t):
t = t.unsqueeze(dim=2)
output = torch.cos(self.w(t))
return output
class MemoryModule(nn.Module):
"""Memory module as well as update interface
The memory module stores both historical representation in last_update_t
Parameters
----------
n_node : int
number of node of the entire graph
hidden_dim : int
dimension of memory of each node
Example
----------
Please refers to examples/pytorch/tgn/tgn.py;
examples/pytorch/tgn/train.py
"""
def __init__(self, n_node, hidden_dim):
super(MemoryModule, self).__init__()
self.n_node = n_node
self.hidden_dim = hidden_dim
self.reset_memory()
def reset_memory(self):
self.last_update_t = nn.Parameter(
torch.zeros(self.n_node).float(), requires_grad=False
)
self.memory = nn.Parameter(
torch.zeros((self.n_node, self.hidden_dim)).float(),
requires_grad=False,
)
def backup_memory(self):
"""
Return a deep copy of memory state and last_update_t
For test new node, since new node need to use memory upto validation set
After validation, memory need to be backed up before run test set without new node
so finally, we can use backup memory to update the new node test set
"""
return self.memory.clone(), self.last_update_t.clone()
def restore_memory(self, memory_backup):
"""Restore the memory from validation set
Parameters
----------
memory_backup : (memory,last_update_t)
restore memory based on input tuple
"""
self.memory = memory_backup[0].clone()
self.last_update_t = memory_backup[1].clone()
# Which is used for attach to subgraph
def get_memory(self, node_idxs):
return self.memory[node_idxs, :]
# When the memory need to be updated
def set_memory(self, node_idxs, values):
self.memory[node_idxs, :] = values
def set_last_update_t(self, node_idxs, values):
self.last_update_t[node_idxs] = values
# For safety check
def get_last_update(self, node_idxs):
return self.last_update_t[node_idxs]
def detach_memory(self):
"""
Disconnect the memory from computation graph to prevent gradient be propagated multiple
times
"""
self.memory.detach_()
class MemoryOperation(nn.Module):
"""Memory update using message passing manner, update memory based on positive
pair graph of each batch with recurrent module GRU or RNN
Message function
..math::
m_i(t) = concat(memory_i(t^-),TimeEncode(t),v_i(t))
v_i is node feature at current time stamp
Aggregation function
..math::
\bar{m}_i(t) = last(m_i(t_1),...,m_i(t_b))
Update function
..math::
memory_i(t) = GRU(\bar{m}_i(t),memory_i(t-1))
Parameters
----------
updater_type : str
indicator string to specify updater
'rnn' : use Vanilla RNN as updater
'gru' : use GRU as updater
memory : MemoryModule
memory content for update
e_feat_dim : int
dimension of edge feature
temporal_dim : int
length of fourier series for time encoding
Example
----------
Please refers to examples/pytorch/tgn/tgn.py
"""
def __init__(self, updater_type, memory, e_feat_dim, temporal_encoder):
super(MemoryOperation, self).__init__()
updater_dict = {"gru": nn.GRUCell, "rnn": nn.RNNCell}
self.memory = memory
memory_dim = self.memory.hidden_dim
self.temporal_encoder = temporal_encoder
self.message_dim = (
memory_dim
+ memory_dim
+ e_feat_dim
+ self.temporal_encoder.dimension
)
self.updater = updater_dict[updater_type](
input_size=self.message_dim, hidden_size=memory_dim
)
self.memory = memory
# Here assume g is a subgraph from each iteration
def stick_feat_to_graph(self, g):
# How can I ensure order of the node ID
g.ndata["timestamp"] = self.memory.last_update_t[g.ndata[dgl.NID]]
g.ndata["memory"] = self.memory.memory[g.ndata[dgl.NID]]
def msg_fn_cat(self, edges):
src_delta_time = edges.data["timestamp"] - edges.src["timestamp"]
time_encode = self.temporal_encoder(
src_delta_time.unsqueeze(dim=1)
).view(len(edges.data["timestamp"]), -1)
ret = torch.cat(
[
edges.src["memory"],
edges.dst["memory"],
edges.data["feats"],
time_encode,
],
dim=1,
)
return {"message": ret, "timestamp": edges.data["timestamp"]}
def agg_last(self, nodes):
timestamp, latest_idx = torch.max(nodes.mailbox["timestamp"], dim=1)
ret = (
nodes.mailbox["message"]
.gather(
1,
latest_idx.repeat(self.message_dim).view(
-1, 1, self.message_dim
),
)
.view(-1, self.message_dim)
)
return {
"message_bar": ret.reshape(-1, self.message_dim),
"timestamp": timestamp,
}
def update_memory(self, nodes):
# It should pass the feature through RNN
ret = self.updater(
nodes.data["message_bar"].float(), nodes.data["memory"].float()
)
return {"memory": ret}
def forward(self, g):
self.stick_feat_to_graph(g)
g.update_all(self.msg_fn_cat, self.agg_last, self.update_memory)
return g
class EdgeGATConv(nn.Module):
"""Edge Graph attention compute the graph attention from node and edge feature then aggregate both node and
edge feature.
Parameter
==========
node_feats : int
number of node features
edge_feats : int
number of edge features
out_feats : int
number of output features
num_heads : int
number of heads in multihead attention
feat_drop : float, optional
drop out rate on the feature
attn_drop : float, optional
drop out rate on the attention weight
negative_slope : float, optional
LeakyReLU angle of negative slope.
residual : bool, optional
whether use residual connection
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Defaults: ``False``.
"""
def __init__(
self,
node_feats,
edge_feats,
out_feats,
num_heads,
feat_drop=0.0,
attn_drop=0.0,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False,
):
super(EdgeGATConv, self).__init__()
self._num_heads = num_heads
self._node_feats = node_feats
self._edge_feats = edge_feats
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
self.fc_node = nn.Linear(
self._node_feats, self._out_feats * self._num_heads
)
self.fc_edge = nn.Linear(
self._edge_feats, self._out_feats * self._num_heads
)
self.attn_l = nn.Parameter(
torch.FloatTensor(size=(1, self._num_heads, self._out_feats))
)
self.attn_r = nn.Parameter(
torch.FloatTensor(size=(1, self._num_heads, self._out_feats))
)
self.attn_e = nn.Parameter(
torch.FloatTensor(size=(1, self._num_heads, self._out_feats))
)
self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope)
self.residual = residual
if residual:
if self._node_feats != self._out_feats:
self.res_fc = nn.Linear(
self._node_feats,
self._out_feats * self._num_heads,
bias=False,
)
else:
self.res_fc = Identity()
self.reset_parameters()
self.activation = activation
def reset_parameters(self):
gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.fc_node.weight, gain=gain)
nn.init.xavier_normal_(self.fc_edge.weight, gain=gain)
nn.init.xavier_normal_(self.attn_l, gain=gain)
nn.init.xavier_normal_(self.attn_r, gain=gain)
nn.init.xavier_normal_(self.attn_e, gain=gain)
if self.residual and isinstance(self.res_fc, nn.Linear):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
def msg_fn(self, edges):
ret = (
edges.data["a"].view(-1, self._num_heads, 1)
* edges.data["el_prime"]
)
return {"m": ret}
def forward(self, graph, nfeat, efeat, get_attention=False):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
nfeat = self.feat_drop(nfeat)
efeat = self.feat_drop(efeat)
node_feat = self.fc_node(nfeat).view(
-1, self._num_heads, self._out_feats
)
edge_feat = self.fc_edge(efeat).view(
-1, self._num_heads, self._out_feats
)
el = (node_feat * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (node_feat * self.attn_r).sum(dim=-1).unsqueeze(-1)
ee = (edge_feat * self.attn_e).sum(dim=-1).unsqueeze(-1)
graph.ndata["ft"] = node_feat
graph.ndata["el"] = el
graph.ndata["er"] = er
graph.edata["ee"] = ee
graph.apply_edges(fn.u_add_e("el", "ee", "el_prime"))
graph.apply_edges(fn.e_add_v("el_prime", "er", "e"))
e = self.leaky_relu(graph.edata["e"])
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
graph.edata["efeat"] = edge_feat
graph.update_all(self.msg_fn, fn.sum("m", "ft"))
rst = graph.ndata["ft"]
if self.residual:
resval = self.res_fc(nfeat).view(
nfeat.shape[0], -1, self._out_feats
)
rst = rst + resval
if self.activation:
rst = self.activation(rst)
if get_attention:
return rst, graph.edata["a"]
else:
return rst
class TemporalEdgePreprocess(nn.Module):
"""Preprocess layer, which finish time encoding and concatenate
the time encoding to edge feature.
Parameter
==========
edge_feats : int
number of orginal edge feature
temporal_encoder : torch.nn.Module
time encoder model
"""
def __init__(self, edge_feats, temporal_encoder):
super(TemporalEdgePreprocess, self).__init__()
self.edge_feats = edge_feats
self.temporal_encoder = temporal_encoder
def edge_fn(self, edges):
t0 = torch.zeros_like(edges.dst["timestamp"])
time_diff = edges.data["timestamp"] - edges.src["timestamp"]
time_encode = self.temporal_encoder(time_diff.unsqueeze(dim=1)).view(
t0.shape[0], -1
)
edge_feat = torch.cat([edges.data["feats"], time_encode], dim=1)
return {"efeat": edge_feat}
def forward(self, graph):
graph.apply_edges(self.edge_fn)
efeat = graph.edata["efeat"]
return efeat
class TemporalTransformerConv(nn.Module):
def __init__(
self,
edge_feats,
memory_feats,
temporal_encoder,
out_feats,
num_heads,
allow_zero_in_degree=False,
layers=1,
):
"""Temporal Transformer model for TGN and TGAT
Parameter
==========
edge_feats : int
number of edge features
memory_feats : int
dimension of memory vector
temporal_encoder : torch.nn.Module
compute fourier time encoding
out_feats : int
number of out features
num_heads : int
number of attention head
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Defaults: ``False``.
"""
super(TemporalTransformerConv, self).__init__()
self._edge_feats = edge_feats
self._memory_feats = memory_feats
self.temporal_encoder = temporal_encoder
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
self._num_heads = num_heads
self.layers = layers
self.preprocessor = TemporalEdgePreprocess(
self._edge_feats, self.temporal_encoder
)
self.layer_list = nn.ModuleList()
self.layer_list.append(
EdgeGATConv(
node_feats=self._memory_feats,
edge_feats=self._edge_feats + self.temporal_encoder.dimension,
out_feats=self._out_feats,
num_heads=self._num_heads,
feat_drop=0.6,
attn_drop=0.6,
residual=True,
allow_zero_in_degree=allow_zero_in_degree,
)
)
for i in range(self.layers - 1):
self.layer_list.append(
EdgeGATConv(
node_feats=self._out_feats * self._num_heads,
edge_feats=self._edge_feats
+ self.temporal_encoder.dimension,
out_feats=self._out_feats,
num_heads=self._num_heads,
feat_drop=0.6,
attn_drop=0.6,
residual=True,
allow_zero_in_degree=allow_zero_in_degree,
)
)
def forward(self, graph, memory, ts):
graph = graph.local_var()
graph.ndata["timestamp"] = ts
efeat = self.preprocessor(graph).float()
rst = memory
for i in range(self.layers - 1):
rst = self.layer_list[i](graph, rst, efeat).flatten(1)
rst = self.layer_list[-1](graph, rst, efeat).mean(1)
return rst
import copy
import torch.nn as nn
from modules import (
MemoryModule,
MemoryOperation,
MsgLinkPredictor,
TemporalTransformerConv,
TimeEncode,
)
import dgl
class TGN(nn.Module):
def __init__(
self,
edge_feat_dim,
memory_dim,
temporal_dim,
embedding_dim,
num_heads,
num_nodes,
n_neighbors=10,
memory_updater_type="gru",
layers=1,
):
super(TGN, self).__init__()
self.memory_dim = memory_dim
self.edge_feat_dim = edge_feat_dim
self.temporal_dim = temporal_dim
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.n_neighbors = n_neighbors
self.memory_updater_type = memory_updater_type
self.num_nodes = num_nodes
self.layers = layers
self.temporal_encoder = TimeEncode(self.temporal_dim)
self.memory = MemoryModule(self.num_nodes, self.memory_dim)
self.memory_ops = MemoryOperation(
self.memory_updater_type,
self.memory,
self.edge_feat_dim,
self.temporal_encoder,
)
self.embedding_attn = TemporalTransformerConv(
self.edge_feat_dim,
self.memory_dim,
self.temporal_encoder,
self.embedding_dim,
self.num_heads,
layers=self.layers,
allow_zero_in_degree=True,
)
self.msg_linkpredictor = MsgLinkPredictor(embedding_dim)
def embed(self, postive_graph, negative_graph, blocks):
emb_graph = blocks[0]
emb_memory = self.memory.memory[emb_graph.ndata[dgl.NID], :]
emb_t = emb_graph.ndata["timestamp"]
embedding = self.embedding_attn(emb_graph, emb_memory, emb_t)
emb2pred = dict(
zip(emb_graph.ndata[dgl.NID].tolist(), emb_graph.nodes().tolist())
)
# Since postive graph and negative graph has same is mapping
feat_id = [emb2pred[int(n)] for n in postive_graph.ndata[dgl.NID]]
feat = embedding[feat_id]
pred_pos, pred_neg = self.msg_linkpredictor(
feat, postive_graph, negative_graph
)
return pred_pos, pred_neg
def update_memory(self, subg):
new_g = self.memory_ops(subg)
self.memory.set_memory(new_g.ndata[dgl.NID], new_g.ndata["memory"])
self.memory.set_last_update_t(
new_g.ndata[dgl.NID], new_g.ndata["timestamp"]
)
# Some memory operation wrappers
def detach_memory(self):
self.memory.detach_memory()
def reset_memory(self):
self.memory.reset_memory()
def store_memory(self):
memory_checkpoint = {}
memory_checkpoint["memory"] = copy.deepcopy(self.memory.memory)
memory_checkpoint["last_t"] = copy.deepcopy(self.memory.last_update_t)
return memory_checkpoint
def restore_memory(self, memory_checkpoint):
self.memory.memory = memory_checkpoint["memory"]
self.memory.last_update_time = memory_checkpoint["last_t"]
import argparse
import copy
import time
import traceback
import numpy as np
import torch
from data_preprocess import (
TemporalDataset,
TemporalRedditDataset,
TemporalWikipediaDataset,
)
from dataloading import (
FastTemporalEdgeCollator,
FastTemporalSampler,
SimpleTemporalEdgeCollator,
SimpleTemporalSampler,
TemporalEdgeCollator,
TemporalEdgeDataLoader,
TemporalSampler,
)
from sklearn.metrics import average_precision_score, roc_auc_score
from tgn import TGN
import dgl
TRAIN_SPLIT = 0.7
VALID_SPLIT = 0.85
# set random Seed
np.random.seed(2021)
torch.manual_seed(2021)
def train(model, dataloader, sampler, criterion, optimizer, args):
model.train()
total_loss = 0
batch_cnt = 0
last_t = time.time()
for _, positive_pair_g, negative_pair_g, blocks in dataloader:
optimizer.zero_grad()
pred_pos, pred_neg = model.embed(
positive_pair_g, negative_pair_g, blocks
)
loss = criterion(pred_pos, torch.ones_like(pred_pos))
loss += criterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss) * args.batch_size
retain_graph = True if batch_cnt == 0 and not args.fast_mode else False
loss.backward(retain_graph=retain_graph)
optimizer.step()
model.detach_memory()
if not args.not_use_memory:
model.update_memory(positive_pair_g)
if args.fast_mode:
sampler.attach_last_update(model.memory.last_update_t)
print("Batch: ", batch_cnt, "Time: ", time.time() - last_t)
last_t = time.time()
batch_cnt += 1
return total_loss
def test_val(model, dataloader, sampler, criterion, args):
model.eval()
batch_size = args.batch_size
total_loss = 0
aps, aucs = [], []
batch_cnt = 0
with torch.no_grad():
for _, postive_pair_g, negative_pair_g, blocks in dataloader:
pred_pos, pred_neg = model.embed(
postive_pair_g, negative_pair_g, blocks
)
loss = criterion(pred_pos, torch.ones_like(pred_pos))
loss += criterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss) * batch_size
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat(
[torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))],
dim=0,
)
if not args.not_use_memory:
model.update_memory(postive_pair_g)
if args.fast_mode:
sampler.attach_last_update(model.memory.last_update_t)
aps.append(average_precision_score(y_true, y_pred))
aucs.append(roc_auc_score(y_true, y_pred))
batch_cnt += 1
return float(torch.tensor(aps).mean()), float(torch.tensor(aucs).mean())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--epochs",
type=int,
default=50,
help="epochs for training on entire dataset",
)
parser.add_argument(
"--batch_size", type=int, default=200, help="Size of each batch"
)
parser.add_argument(
"--embedding_dim",
type=int,
default=100,
help="Embedding dim for link prediction",
)
parser.add_argument(
"--memory_dim", type=int, default=100, help="dimension of memory"
)
parser.add_argument(
"--temporal_dim",
type=int,
default=100,
help="Temporal dimension for time encoding",
)
parser.add_argument(
"--memory_updater",
type=str,
default="gru",
help="Recurrent unit for memory update",
)
parser.add_argument(
"--aggregator",
type=str,
default="last",
help="Aggregation method for memory update",
)
parser.add_argument(
"--n_neighbors",
type=int,
default=10,
help="number of neighbors while doing embedding",
)
parser.add_argument(
"--sampling_method",
type=str,
default="topk",
help="In embedding how node aggregate from its neighor",
)
parser.add_argument(
"--num_heads",
type=int,
default=8,
help="Number of heads for multihead attention mechanism",
)
parser.add_argument(
"--fast_mode",
action="store_true",
default=False,
help="Fast Mode uses batch temporal sampling, history within same batch cannot be obtained",
)
parser.add_argument(
"--simple_mode",
action="store_true",
default=False,
help="Simple Mode directly delete the temporal edges from the original static graph",
)
parser.add_argument(
"--num_negative_samples",
type=int,
default=1,
help="number of negative samplers per positive samples",
)
parser.add_argument(
"--dataset",
type=str,
default="wikipedia",
help="dataset selection wikipedia/reddit",
)
parser.add_argument(
"--k_hop", type=int, default=1, help="sampling k-hop neighborhood"
)
parser.add_argument(
"--not_use_memory",
action="store_true",
default=False,
help="Enable memory for TGN Model disable memory for TGN Model",
)
args = parser.parse_args()
assert not (
args.fast_mode and args.simple_mode
), "you can only choose one sampling mode"
if args.k_hop != 1:
assert args.simple_mode, "this k-hop parameter only support simple mode"
if args.dataset == "wikipedia":
data = TemporalWikipediaDataset()
elif args.dataset == "reddit":
data = TemporalRedditDataset()
else:
print("Warning Using Untested Dataset: " + args.dataset)
data = TemporalDataset(args.dataset)
# Pre-process data, mask new node in test set from original graph
num_nodes = data.num_nodes()
num_edges = data.num_edges()
num_edges = data.num_edges()
trainval_div = int(VALID_SPLIT * num_edges)
# Select new node from test set and remove them from entire graph
test_split_ts = data.edata["timestamp"][trainval_div]
test_nodes = (
torch.cat(
[data.edges()[0][trainval_div:], data.edges()[1][trainval_div:]]
)
.unique()
.numpy()
)
test_new_nodes = np.random.choice(
test_nodes, int(0.1 * len(test_nodes)), replace=False
)
in_subg = dgl.in_subgraph(data, test_new_nodes)
out_subg = dgl.out_subgraph(data, test_new_nodes)
# Remove edge who happen before the test set to prevent from learning the connection info
new_node_in_eid_delete = in_subg.edata[dgl.EID][
in_subg.edata["timestamp"] < test_split_ts
]
new_node_out_eid_delete = out_subg.edata[dgl.EID][
out_subg.edata["timestamp"] < test_split_ts
]
new_node_eid_delete = torch.cat(
[new_node_in_eid_delete, new_node_out_eid_delete]
).unique()
graph_new_node = copy.deepcopy(data)
# relative order preseved
graph_new_node.remove_edges(new_node_eid_delete)
# Now for no new node graph, all edge id need to be removed
in_eid_delete = in_subg.edata[dgl.EID]
out_eid_delete = out_subg.edata[dgl.EID]
eid_delete = torch.cat([in_eid_delete, out_eid_delete]).unique()
graph_no_new_node = copy.deepcopy(data)
graph_no_new_node.remove_edges(eid_delete)
# graph_no_new_node and graph_new_node should have same set of nid
# Sampler Initialization
if args.simple_mode:
fan_out = [args.n_neighbors for _ in range(args.k_hop)]
sampler = SimpleTemporalSampler(graph_no_new_node, fan_out)
new_node_sampler = SimpleTemporalSampler(data, fan_out)
edge_collator = SimpleTemporalEdgeCollator
elif args.fast_mode:
sampler = FastTemporalSampler(graph_no_new_node, k=args.n_neighbors)
new_node_sampler = FastTemporalSampler(data, k=args.n_neighbors)
edge_collator = FastTemporalEdgeCollator
else:
sampler = TemporalSampler(k=args.n_neighbors)
edge_collator = TemporalEdgeCollator
neg_sampler = dgl.dataloading.negative_sampler.Uniform(
k=args.num_negative_samples
)
# Set Train, validation, test and new node test id
train_seed = torch.arange(int(TRAIN_SPLIT * graph_no_new_node.num_edges()))
valid_seed = torch.arange(
int(TRAIN_SPLIT * graph_no_new_node.num_edges()),
trainval_div - new_node_eid_delete.size(0),
)
test_seed = torch.arange(
trainval_div - new_node_eid_delete.size(0),
graph_no_new_node.num_edges(),
)
test_new_node_seed = torch.arange(
trainval_div - new_node_eid_delete.size(0), graph_new_node.num_edges()
)
g_sampling = (
None
if args.fast_mode
else dgl.add_reverse_edges(graph_no_new_node, copy_edata=True)
)
new_node_g_sampling = (
None
if args.fast_mode
else dgl.add_reverse_edges(graph_new_node, copy_edata=True)
)
if not args.fast_mode:
new_node_g_sampling.ndata[dgl.NID] = new_node_g_sampling.nodes()
g_sampling.ndata[dgl.NID] = new_node_g_sampling.nodes()
# we highly recommend that you always set the num_workers=0, otherwise the sampled subgraph may not be correct.
train_dataloader = TemporalEdgeDataLoader(
graph_no_new_node,
train_seed,
sampler,
batch_size=args.batch_size,
negative_sampler=neg_sampler,
shuffle=False,
drop_last=False,
num_workers=0,
collator=edge_collator,
g_sampling=g_sampling,
)
valid_dataloader = TemporalEdgeDataLoader(
graph_no_new_node,
valid_seed,
sampler,
batch_size=args.batch_size,
negative_sampler=neg_sampler,
shuffle=False,
drop_last=False,
num_workers=0,
collator=edge_collator,
g_sampling=g_sampling,
)
test_dataloader = TemporalEdgeDataLoader(
graph_no_new_node,
test_seed,
sampler,
batch_size=args.batch_size,
negative_sampler=neg_sampler,
shuffle=False,
drop_last=False,
num_workers=0,
collator=edge_collator,
g_sampling=g_sampling,
)
test_new_node_dataloader = TemporalEdgeDataLoader(
graph_new_node,
test_new_node_seed,
new_node_sampler if args.fast_mode else sampler,
batch_size=args.batch_size,
negative_sampler=neg_sampler,
shuffle=False,
drop_last=False,
num_workers=0,
collator=edge_collator,
g_sampling=new_node_g_sampling,
)
edge_dim = data.edata["feats"].shape[1]
num_node = data.num_nodes()
model = TGN(
edge_feat_dim=edge_dim,
memory_dim=args.memory_dim,
temporal_dim=args.temporal_dim,
embedding_dim=args.embedding_dim,
num_heads=args.num_heads,
num_nodes=num_node,
n_neighbors=args.n_neighbors,
memory_updater_type=args.memory_updater,
layers=args.k_hop,
)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
# Implement Logging mechanism
f = open("logging.txt", "w")
if args.fast_mode:
sampler.reset()
try:
for i in range(args.epochs):
train_loss = train(
model, train_dataloader, sampler, criterion, optimizer, args
)
val_ap, val_auc = test_val(
model, valid_dataloader, sampler, criterion, args
)
memory_checkpoint = model.store_memory()
if args.fast_mode:
new_node_sampler.sync(sampler)
test_ap, test_auc = test_val(
model, test_dataloader, sampler, criterion, args
)
model.restore_memory(memory_checkpoint)
if args.fast_mode:
sample_nn = new_node_sampler
else:
sample_nn = sampler
nn_test_ap, nn_test_auc = test_val(
model, test_new_node_dataloader, sample_nn, criterion, args
)
log_content = []
log_content.append(
"Epoch: {}; Training Loss: {} | Validation AP: {:.3f} AUC: {:.3f}\n".format(
i, train_loss, val_ap, val_auc
)
)
log_content.append(
"Epoch: {}; Test AP: {:.3f} AUC: {:.3f}\n".format(
i, test_ap, test_auc
)
)
log_content.append(
"Epoch: {}; Test New Node AP: {:.3f} AUC: {:.3f}\n".format(
i, nn_test_ap, nn_test_auc
)
)
f.writelines(log_content)
model.reset_memory()
if i < args.epochs - 1 and args.fast_mode:
sampler.reset()
print(log_content[0], log_content[1], log_content[2])
except KeyboardInterrupt:
traceback.print_exc()
error_content = "Training Interreputed!"
f.writelines(error_content)
f.close()
print("========Training is Done========")
...@@ -9,6 +9,8 @@ and transforming graphs. ...@@ -9,6 +9,8 @@ and transforming graphs.
# This initializes Winsock and performs cleanup at termination as required # This initializes Winsock and performs cleanup at termination as required
import socket import socket
from distutils.version import LooseVersion
# setup logging before everything # setup logging before everything
from .logging import enable_verbose_logging from .logging import enable_verbose_logging
...@@ -25,7 +27,6 @@ from . import storages ...@@ -25,7 +27,6 @@ from . import storages
from . import dataloading from . import dataloading
from . import ops from . import ops
from . import cuda from . import cuda
from . import _dataloading # legacy dataloading modules
from ._ffi.runtime_ctypes import TypeCode from ._ffi.runtime_ctypes import TypeCode
from ._ffi.function import register_func, get_global_func, list_global_func_names, extract_ext_funcs from ._ffi.function import register_func, get_global_func, list_global_func_names, extract_ext_funcs
......
"""The ``dgl.dataloading`` package contains:
* Data loader classes for iterating over a set of nodes or edges in a graph and generates
computation dependency via neighborhood sampling methods.
* Various sampler classes that perform neighborhood sampling for multi-layer GNNs.
* Negative samplers for link prediction.
For a holistic explanation on how different components work together.
Read the user guide :ref:`guide-minibatch`.
.. note::
This package is experimental and the interfaces may be subject
to changes in future releases. It currently only has implementations in PyTorch.
"""
from .neighbor import *
from .dataloader import *
from .cluster_gcn import *
from .shadow import *
from . import negative_sampler
from .. import backend as F
if F.get_preferred_backend() == 'pytorch':
from .pytorch import *
"""Cluster-GCN subgraph iterators."""
import os
import pickle
import numpy as np
from ..transforms import metis_partition_assignment
from .. import backend as F
from .dataloader import SubgraphIterator
class ClusterGCNSubgraphIterator(SubgraphIterator):
"""Subgraph sampler following that of ClusterGCN.
This sampler first partitions the graph with METIS partitioning, then it caches the nodes of
each partition to a file within the given cache directory.
This is used in conjunction with :class:`dgl.dataloading.pytorch.GraphDataLoader`.
Notes
-----
The graph must be homogeneous and on CPU.
Parameters
----------
g : DGLGraph
The original graph.
num_partitions : int
The number of partitions.
cache_directory : str
The path to the cache directory for storing the partition result.
refresh : bool
If True, recompute the partition.
Examples
--------
Assuming that you have a graph ``g``:
>>> sgiter = dgl.dataloading.ClusterGCNSubgraphIterator(
... g, num_partitions=100, cache_directory='.', refresh=True)
>>> dataloader = dgl.dataloading.GraphDataLoader(sgiter, batch_size=4, num_workers=0)
>>> for subgraph_batch in dataloader:
... train_on(subgraph_batch)
"""
def __init__(self, g, num_partitions, cache_directory, refresh=False):
if os.name == 'nt':
raise NotImplementedError("METIS partitioning is not supported on Windows yet.")
super().__init__(g)
# First see if the cache is already there. If so, directly read from cache.
if not refresh and self._load_parts(cache_directory):
return
# Otherwise, build the cache.
assignment = F.asnumpy(metis_partition_assignment(g, num_partitions))
self._save_parts(assignment, cache_directory)
def _cache_file_path(self, cache_directory):
return os.path.join(cache_directory, 'cluster_gcn_cache')
def _load_parts(self, cache_directory):
path = self._cache_file_path(cache_directory)
if not os.path.exists(path):
return False
with open(path, 'rb') as file_:
self.part_indptr, self.part_indices = pickle.load(file_)
return True
def _save_parts(self, assignment, cache_directory):
os.makedirs(cache_directory, exist_ok=True)
self.part_indices = np.argsort(assignment)
num_nodes_per_part = np.bincount(assignment)
self.part_indptr = np.insert(np.cumsum(num_nodes_per_part), 0, 0)
with open(self._cache_file_path(cache_directory), 'wb') as file_:
pickle.dump((self.part_indptr, self.part_indices), file_)
def __len__(self):
return self.part_indptr.shape[0] - 1
def __getitem__(self, i):
nodes = self.part_indices[self.part_indptr[i]:self.part_indptr[i+1]]
return self.g.subgraph(nodes)
This diff is collapsed.
"""Negative samplers"""
from collections.abc import Mapping
from .. import backend as F
from ..sampling import global_uniform_negative_sampling
class _BaseNegativeSampler(object):
def _generate(self, g, eids, canonical_etype):
raise NotImplementedError
def __call__(self, g, eids):
"""Returns negative samples.
Parameters
----------
g : DGLGraph
The graph.
eids : Tensor or dict[etype, Tensor]
The sampled edges in the minibatch.
Returns
-------
tuple[Tensor, Tensor] or dict[etype, tuple[Tensor, Tensor]]
The returned source-destination pairs as negative samples.
"""
if isinstance(eids, Mapping):
eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
neg_pair = {k: self._generate(g, v, k) for k, v in eids.items()}
else:
assert len(g.etypes) == 1, \
'please specify a dict of etypes and ids for graphs with multiple edge types'
neg_pair = self._generate(g, eids, g.canonical_etypes[0])
return neg_pair
class PerSourceUniform(_BaseNegativeSampler):
"""Negative sampler that randomly chooses negative destination nodes
for each source node according to a uniform distribution.
For each edge ``(u, v)`` of type ``(srctype, etype, dsttype)``, DGL generates
:attr:`k` pairs of negative edges ``(u, v')``, where ``v'`` is chosen
uniformly from all the nodes of type ``dsttype``. The resulting edges will
also have type ``(srctype, etype, dsttype)``.
Parameters
----------
k : int
The number of negative samples per edge.
Examples
--------
>>> g = dgl.graph(([0, 1, 2], [1, 2, 3]))
>>> neg_sampler = dgl.dataloading.negative_sampler.PerSourceUniform(2)
>>> neg_sampler(g, torch.tensor([0, 1]))
(tensor([0, 0, 1, 1]), tensor([1, 0, 2, 3]))
"""
def __init__(self, k):
self.k = k
def _generate(self, g, eids, canonical_etype):
_, _, vtype = canonical_etype
shape = F.shape(eids)
dtype = F.dtype(eids)
ctx = F.context(eids)
shape = (shape[0] * self.k,)
src, _ = g.find_edges(eids, etype=canonical_etype)
src = F.repeat(src, self.k, 0)
dst = F.randint(shape, dtype, ctx, 0, g.number_of_nodes(vtype))
return src, dst
# Alias
Uniform = PerSourceUniform
class GlobalUniform(_BaseNegativeSampler):
"""Negative sampler that randomly chooses negative source-destination pairs according
to a uniform distribution.
For each edge ``(u, v)`` of type ``(srctype, etype, dsttype)``, DGL generates at most
:attr:`k` pairs of negative edges ``(u', v')``, where ``u'`` is chosen uniformly from
all the nodes of type ``srctype`` and ``v'`` is chosen uniformly from all the nodes
of type ``dsttype``. The resulting edges will also have type
``(srctype, etype, dsttype)``. DGL guarantees that the sampled pairs will not have
edges in between.
Parameters
----------
k : int
The desired number of negative samples to generate per edge.
exclude_self_loops : bool, optional
Whether to exclude self-loops from negative samples. (Default: True)
replace : bool, optional
Whether to sample with replacement. Setting it to True will make things
faster. (Default: True)
redundancy : float, optional
Indicates how much more negative samples to actually generate during rejection sampling
before finding the unique pairs.
Increasing it will increase the likelihood of getting :attr:`k` negative samples
per edge, but will also take more time and memory.
(Default: automatically determined by the density of graph)
Notes
-----
This negative sampler will try to generate as many negative samples as possible, but
it may rarely return less than :attr:`k` negative samples per edge.
This is more likely to happen if a graph is so small or dense that not many unique
negative samples exist.
Examples
--------
>>> g = dgl.graph(([0, 1, 2], [1, 2, 3]))
>>> neg_sampler = dgl.dataloading.negative_sampler.GlobalUniform(2, True)
>>> neg_sampler(g, torch.LongTensor([0, 1]))
(tensor([0, 1, 3, 2]), tensor([2, 0, 2, 1]))
"""
def __init__(self, k, exclude_self_loops=True, replace=False, redundancy=None):
self.k = k
self.exclude_self_loops = exclude_self_loops
self.replace = replace
self.redundancy = redundancy
def _generate(self, g, eids, canonical_etype):
return global_uniform_negative_sampling(
g, len(eids) * self.k, self.exclude_self_loops, self.replace,
canonical_etype, self.redundancy)
"""Data loading components for neighbor sampling"""
from .dataloader import BlockSampler
from .. import sampling, distributed
from .. import ndarray as nd
from .. import backend as F
from ..base import ETYPE
class NeighborSamplingMixin(object):
"""Mixin object containing common optimizing routines that caches fanout and probability
arrays.
The mixin requires the object to have the following attributes:
- :attr:`prob`: The edge feature name that stores the (unnormalized) probability.
- :attr:`fanouts`: The list of fanouts (either an integer or a dictionary of edge
types and integers).
The mixin will generate the following attributes:
- :attr:`prob_arrays`: List of DGL NDArrays containing the unnormalized probabilities
for every edge type.
- :attr:`fanout_arrays`: List of DGL NDArrays containing the fanouts for every edge
type at every layer.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) # forward to base classes
self.fanout_arrays = []
self.prob_arrays = None
def _build_prob_arrays(self, g):
if self.prob is not None:
self.prob_arrays = [F.to_dgl_nd(g.edges[etype].data[self.prob]) for etype in g.etypes]
elif self.prob_arrays is None:
# build prob_arrays only once
self.prob_arrays = [nd.array([], ctx=nd.cpu())] * len(g.etypes)
def _build_fanout(self, block_id, g):
assert not self.fanouts is None, \
"_build_fanout() should only be called when fanouts is not None"
# build fanout_arrays only once for each layer
while block_id >= len(self.fanout_arrays):
for i in range(len(self.fanouts)):
fanout = self.fanouts[i]
if not isinstance(fanout, dict):
fanout_array = [int(fanout)] * len(g.etypes)
else:
if len(fanout) != len(g.etypes):
raise DGLError('Fan-out must be specified for each edge type '
'if a dict is provided.')
fanout_array = [None] * len(g.etypes)
for etype, value in fanout.items():
fanout_array[g.get_etype_id(etype)] = value
self.fanout_arrays.append(
F.to_dgl_nd(F.tensor(fanout_array, dtype=F.int64)))
class MultiLayerNeighborSampler(NeighborSamplingMixin, BlockSampler):
"""Sampler that builds computational dependency of node representations via
neighbor sampling for multilayer GNN.
This sampler will make every node gather messages from a fixed number of neighbors
per edge type. The neighbors are picked uniformly.
Parameters
----------
fanouts : list[int] or list[dict[etype, int]]
List of neighbors to sample per edge type for each GNN layer, with the i-th
element being the fanout for the i-th GNN layer.
If only a single integer is provided, DGL assumes that every edge type
will have the same fanout.
If -1 is provided for one edge type on one layer, then all inbound edges
of that edge type will be included.
replace : bool, default False
Whether to sample with replacement
return_eids : bool, default False
Whether to return the edge IDs involved in message passing in the MFG.
If True, the edge IDs will be stored as an edge feature named ``dgl.EID``.
prob : str, optional
If given, the probability of each neighbor being sampled is proportional
to the edge feature value with the given name in ``g.edata``. The feature must be
a scalar on each edge.
Examples
--------
To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on
a homogeneous graph where each node takes messages from 5, 10, 15 neighbors for
the first, second, and third layer respectively (assuming the backend is PyTorch):
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10, 15])
>>> dataloader = dgl.dataloading.NodeDataLoader(
... g, train_nid, sampler,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, output_nodes, blocks in dataloader:
... train_on(blocks)
If training on a heterogeneous graph and you want different number of neighbors for each
edge type, one should instead provide a list of dicts. Each dict would specify the
number of neighbors to pick per edge type.
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([
... {('user', 'follows', 'user'): 5,
... ('user', 'plays', 'game'): 4,
... ('game', 'played-by', 'user'): 3}] * 3)
If you would like non-uniform neighbor sampling:
>>> g.edata['p'] = torch.rand(g.num_edges()) # any non-negative 1D vector works
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10, 15], prob='p')
Notes
-----
For the concept of MFGs, please refer to
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
def __init__(self, fanouts, replace=False, return_eids=False, prob=None):
super().__init__(len(fanouts), return_eids)
self.fanouts = fanouts
self.replace = replace
# used to cache computations and memory allocations
# list[dgl.nd.NDArray]; each array stores the fan-outs of all edge types
self.prob = prob
@classmethod
def exclude_edges_in_frontier(cls, g):
return not isinstance(g, distributed.DistGraph) and g.device == F.cpu() \
and not g.is_pinned()
def sample_frontier(self, block_id, g, seed_nodes, exclude_eids=None):
fanout = self.fanouts[block_id]
if isinstance(g, distributed.DistGraph):
if len(g.etypes) > 1: # heterogeneous distributed graph
frontier = distributed.sample_etype_neighbors(
g, seed_nodes, ETYPE, fanout, replace=self.replace)
else:
frontier = distributed.sample_neighbors(
g, seed_nodes, fanout, replace=self.replace)
else:
self._build_fanout(block_id, g)
self._build_prob_arrays(g)
frontier = sampling.sample_neighbors(
g, seed_nodes, self.fanout_arrays[block_id],
replace=self.replace, prob=self.prob_arrays, exclude_edges=exclude_eids)
return frontier
class MultiLayerFullNeighborSampler(MultiLayerNeighborSampler):
"""Sampler that builds computational dependency of node representations by taking messages
from all neighbors for multilayer GNN.
This sampler will make every node gather messages from every single neighbor per edge type.
Parameters
----------
n_layers : int
The number of GNN layers to sample.
return_eids : bool, default False
Whether to return the edge IDs involved in message passing in the MFG.
If True, the edge IDs will be stored as an edge feature named ``dgl.EID``.
Examples
--------
To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on
a homogeneous graph where each node takes messages from all neighbors for the first,
second, and third layer respectively (assuming the backend is PyTorch):
>>> sampler = dgl.dataloading.MultiLayerFullNeighborSampler(3)
>>> dataloader = dgl.dataloading.NodeDataLoader(
... g, train_nid, sampler,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, output_nodes, blocks in dataloader:
... train_on(blocks)
Notes
-----
For the concept of MFGs, please refer to
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
def __init__(self, n_layers, return_eids=False):
super().__init__([-1] * n_layers, return_eids=return_eids)
@classmethod
def exclude_edges_in_frontier(cls, g):
return False
"""DGL PyTorch DataLoader module."""
from .dataloader import *
This diff is collapsed.
"""ShaDow-GNN subgraph samplers."""
from ..utils import prepare_tensor_or_dict
from ..base import NID
from .. import transforms
from ..sampling import sample_neighbors
from .neighbor import NeighborSamplingMixin
from .dataloader import exclude_edges, Sampler
class ShaDowKHopSampler(NeighborSamplingMixin, Sampler):
"""K-hop subgraph sampler used by
`ShaDow-GNN <https://arxiv.org/abs/2012.01380>`__.
It performs node-wise neighbor sampling but instead of returning a list of
MFGs, it returns a single subgraph induced by all the sampled nodes. The
seed nodes from which the neighbors are sampled will appear the first in the
induced nodes of the subgraph.
This is used in conjunction with :class:`dgl.dataloading.pytorch.NodeDataLoader`
and :class:`dgl.dataloading.pytorch.EdgeDataLoader`.
Parameters
----------
fanouts : list[int] or list[dict[etype, int]]
List of neighbors to sample per edge type for each GNN layer, with the i-th
element being the fanout for the i-th GNN layer.
If only a single integer is provided, DGL assumes that every edge type
will have the same fanout.
If -1 is provided for one edge type on one layer, then all inbound edges
of that edge type will be included.
replace : bool, default True
Whether to sample with replacement
prob : str, optional
If given, the probability of each neighbor being sampled is proportional
to the edge feature value with the given name in ``g.edata``. The feature must be
a scalar on each edge.
Examples
--------
To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on
a homogeneous graph where each node takes messages from 5, 10, 15 neighbors for
the first, second, and third layer respectively (assuming the backend is PyTorch):
>>> g = dgl.data.CoraFullDataset()[0]
>>> sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15])
>>> dataloader = dgl.dataloading.NodeDataLoader(
... g, torch.arange(g.num_nodes()), sampler,
... batch_size=5, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, output_nodes, (subgraph,) in dataloader:
... print(subgraph)
... assert torch.equal(input_nodes, subgraph.ndata[dgl.NID])
... assert torch.equal(input_nodes[:output_nodes.shape[0]], output_nodes)
... break
Graph(num_nodes=529, num_edges=3796,
ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64),
'feat': Scheme(shape=(8710,), dtype=torch.float32),
'_ID': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})
If training on a heterogeneous graph and you want different number of neighbors for each
edge type, one should instead provide a list of dicts. Each dict would specify the
number of neighbors to pick per edge type.
>>> sampler = dgl.dataloading.ShaDowKHopSampler([
... {('user', 'follows', 'user'): 5,
... ('user', 'plays', 'game'): 4,
... ('game', 'played-by', 'user'): 3}] * 3)
If you would like non-uniform neighbor sampling:
>>> g.edata['p'] = torch.rand(g.num_edges()) # any non-negative 1D vector works
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10, 15], prob='p')
"""
def __init__(self, fanouts, replace=False, prob=None, output_ctx=None):
super().__init__(output_ctx)
self.fanouts = fanouts
self.replace = replace
self.prob = prob
self.set_output_context(output_ctx)
def sample(self, g, seed_nodes, exclude_eids=None):
self._build_fanout(len(self.fanouts), g)
self._build_prob_arrays(g)
seed_nodes = prepare_tensor_or_dict(g, seed_nodes, 'seed nodes')
output_nodes = seed_nodes
for i in range(len(self.fanouts)):
fanout = self.fanouts[i]
frontier = sample_neighbors(
g, seed_nodes, fanout, replace=self.replace, prob=self.prob_arrays)
block = transforms.to_block(frontier, seed_nodes)
seed_nodes = block.srcdata[NID]
subg = g.subgraph(seed_nodes, relabel_nodes=True)
subg = exclude_edges(subg, exclude_eids, self.output_device)
return seed_nodes, output_nodes, [subg]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment