Unverified Commit bdbc0178 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[Example] Labor sampling (#4718)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent d7410cf4
......@@ -5,7 +5,10 @@ The folder contains example implementations of selected research papers related
* For examples working with a certain release, check out `https://github.com/dmlc/dgl/tree/<release_version>/examples` (E.g., https://github.com/dmlc/dgl/tree/0.5.x/examples)
To quickly locate the examples of your interest, search for the tagged keywords or use the search tool on [dgl.ai](https://www.dgl.ai/).
## 2022
- <a name="labor"></a> Balin et al. Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs. [Paper link](https://arxiv.org/abs/2210.13339)
- Example code: [PyTorch](../examples/labor/train_lightning.py)
- Tags: node classification, weighted graphs, sampling
## 2021
- <a name="rnaglib"></a> Mallet et al. Learning Protein and Small Molecule binding sites in RNA molecules with 2.5D graphs. [Paper link](https://academic.oup.com/bioinformatics/article/38/5/1458/6462185?login=true)
- Example code: [PyTorch](https://jwgitlab.cs.mcgill.ca/cgoliver/rnaglib)
......
# referenced the following implementation: https://github.com/BarclayII/dgl/blob/ladies/examples/pytorch/ladies/ladies2.py
import dgl
import dgl.function as fn
import torch
def find_indices_in(a, b):
b_sorted, indices = torch.sort(b)
sorted_indices = torch.searchsorted(b_sorted, a)
sorted_indices[sorted_indices >= indices.shape[0]] = 0
return indices[sorted_indices]
def union(*arrays):
return torch.unique(torch.cat(arrays))
def normalized_edata(g, weight=None):
with g.local_scope():
if weight is None:
weight = "W"
g.edata[weight] = torch.ones(g.number_of_edges(), device=g.device)
g.update_all(fn.copy_e(weight, weight), fn.sum(weight, "v"))
g.apply_edges(lambda edges: {"w": 1 / edges.dst["v"]})
return g.edata["w"]
class LadiesSampler(dgl.dataloading.BlockSampler):
def __init__(
self,
nodes_per_layer,
importance_sampling=True,
weight="w",
out_weight="edge_weights",
replace=False,
):
super().__init__()
self.nodes_per_layer = nodes_per_layer
self.importance_sampling = importance_sampling
self.edge_weight = weight
self.output_weight = out_weight
self.replace = replace
def compute_prob(self, g, seed_nodes, weight, num):
"""
g : the whole graph
seed_nodes : the output nodes for the current layer
weight : the weight of the edges
return : the unnormalized probability of the candidate nodes, as well as the subgraph
containing all the edges from the candidate nodes to the output nodes.
"""
insg = dgl.in_subgraph(g, seed_nodes)
insg = dgl.compact_graphs(insg, seed_nodes)
if self.importance_sampling:
out_frontier = dgl.reverse(insg, copy_edata=True)
weight = weight[out_frontier.edata[dgl.EID].long()]
prob = dgl.ops.copy_e_sum(out_frontier, weight**2)
# prob = torch.sqrt(prob)
else:
prob = torch.ones(insg.num_nodes())
prob[insg.out_degrees() == 0] = 0
return prob, insg
def select_neighbors(self, prob, num):
"""
seed_nodes : output nodes
cand_nodes : candidate nodes. Must contain all output nodes in @seed_nodes
prob : unnormalized probability of each candidate node
num : number of neighbors to sample
return : the set of input nodes in terms of their indices in @cand_nodes, and also the indices of
seed nodes in the selected nodes.
"""
# The returned nodes should be a union of seed_nodes plus @num nodes from cand_nodes.
# Because compute_prob returns a compacted subgraph and a list of probabilities,
# we need to find the corresponding local IDs of the resulting union in the subgraph
# so that we can compute the edge weights of the block.
# This is why we need a find_indices_in() function.
neighbor_nodes_idx = torch.multinomial(
prob, min(num, prob.shape[0]), replacement=self.replace
)
return neighbor_nodes_idx
def generate_block(self, insg, neighbor_nodes_idx, seed_nodes, P_sg, W_sg):
"""
insg : the subgraph yielded by compute_prob()
neighbor_nodes_idx : the sampled nodes from the subgraph @insg, yielded by select_neighbors()
seed_nodes_local_idx : the indices of seed nodes in the selected neighbor nodes, also yielded
by select_neighbors()
P_sg : unnormalized probability of each node being sampled, yielded by compute_prob()
W_sg : edge weights of @insg
return : the block.
"""
seed_nodes_idx = find_indices_in(seed_nodes, insg.ndata[dgl.NID])
u_nodes = union(neighbor_nodes_idx, seed_nodes_idx)
sg = insg.subgraph(u_nodes.type(insg.idtype))
u, v = sg.edges()
lu = sg.ndata[dgl.NID][u.long()]
s = find_indices_in(lu, neighbor_nodes_idx)
eg = dgl.edge_subgraph(
sg, lu == neighbor_nodes_idx[s], relabel_nodes=False
)
eg.ndata[dgl.NID] = sg.ndata[dgl.NID][: eg.num_nodes()]
eg.edata[dgl.EID] = sg.edata[dgl.EID][eg.edata[dgl.EID].long()]
sg = eg
nids = insg.ndata[dgl.NID][sg.ndata[dgl.NID].long()]
P = P_sg[u_nodes.long()]
W = W_sg[sg.edata[dgl.EID].long()]
W_tilde = dgl.ops.e_div_u(sg, W, P)
W_tilde_sum = dgl.ops.copy_e_sum(sg, W_tilde)
d = sg.in_degrees()
W_tilde = dgl.ops.e_mul_v(sg, W_tilde, d / W_tilde_sum)
block = dgl.to_block(sg, seed_nodes_idx.type(sg.idtype))
block.edata[self.output_weight] = W_tilde
# correct node ID mapping
block.srcdata[dgl.NID] = nids[block.srcdata[dgl.NID].long()]
block.dstdata[dgl.NID] = nids[block.dstdata[dgl.NID].long()]
sg_eids = insg.edata[dgl.EID][sg.edata[dgl.EID].long()]
block.edata[dgl.EID] = sg_eids[block.edata[dgl.EID].long()]
return block
def sample_blocks(self, g, seed_nodes, exclude_eids=None):
output_nodes = seed_nodes
blocks = []
for block_id in reversed(range(len(self.nodes_per_layer))):
num_nodes_to_sample = self.nodes_per_layer[block_id]
W = g.edata[self.edge_weight]
prob, insg = self.compute_prob(
g, seed_nodes, W, num_nodes_to_sample
)
neighbor_nodes_idx = self.select_neighbors(
prob, num_nodes_to_sample
)
block = self.generate_block(
insg,
neighbor_nodes_idx.type(g.idtype),
seed_nodes.type(g.idtype),
prob,
W[insg.edata[dgl.EID].long()],
)
seed_nodes = block.srcdata[dgl.NID]
blocks.insert(0, block)
return seed_nodes, output_nodes, blocks
class PoissonLadiesSampler(LadiesSampler):
def __init__(
self,
nodes_per_layer,
importance_sampling=True,
weight="w",
out_weight="edge_weights",
skip=False,
):
super().__init__(
nodes_per_layer, importance_sampling, weight, out_weight
)
self.eps = 0.9999
self.skip = skip
def compute_prob(self, g, seed_nodes, weight, num):
"""
g : the whole graph
seed_nodes : the output nodes for the current layer
weight : the weight of the edges
return : the unnormalized probability of the candidate nodes, as well as the subgraph
containing all the edges from the candidate nodes to the output nodes.
"""
prob, insg = super().compute_prob(g, seed_nodes, weight, num)
one = torch.ones_like(prob)
if prob.shape[0] <= num:
return one, insg
c = 1.0
for i in range(50):
S = torch.sum(torch.minimum(prob * c, one).to(torch.float64)).item()
if min(S, num) / max(S, num) >= self.eps:
break
else:
c *= num / S
if self.skip:
skip_nodes = find_indices_in(seed_nodes, insg.ndata[dgl.NID])
prob[skip_nodes] = float("inf")
return torch.minimum(prob * c, one), insg
def select_neighbors(self, prob, num):
"""
seed_nodes : output nodes
cand_nodes : candidate nodes. Must contain all output nodes in @seed_nodes
prob : unnormalized probability of each candidate node
num : number of neighbors to sample
return : the set of input nodes in terms of their indices in @cand_nodes, and also the indices of
seed nodes in the selected nodes.
"""
# The returned nodes should be a union of seed_nodes plus @num nodes from cand_nodes.
# Because compute_prob returns a compacted subgraph and a list of probabilities,
# we need to find the corresponding local IDs of the resulting union in the subgraph
# so that we can compute the edge weights of the block.
# This is why we need a find_indices_in() function.
neighbor_nodes_idx = torch.arange(prob.shape[0], device=prob.device)[
torch.bernoulli(prob) == 1
]
return neighbor_nodes_idx
import dgl
import torch as th
def load_data(data):
g = data[0]
g.ndata["features"] = g.ndata.pop("feat")
g.ndata["labels"] = g.ndata.pop("label")
return g, data.num_classes
def load_dgl(name):
from dgl.data import (
CiteseerGraphDataset,
CoraGraphDataset,
FlickrDataset,
PubmedGraphDataset,
RedditDataset,
YelpDataset,
)
d = {
"cora": CoraGraphDataset,
"citeseer": CiteseerGraphDataset,
"pubmed": PubmedGraphDataset,
"reddit": RedditDataset,
"yelp": YelpDataset,
"flickr": FlickrDataset,
}
return load_data(d[name]())
def load_reddit(self_loop=True):
from dgl.data import RedditDataset
# load reddit data
data = RedditDataset(self_loop=self_loop)
return load_data(data)
def load_mag240m(root="dataset"):
from os.path import join
import numpy as np
from ogb.lsc import MAG240MDataset
dataset = MAG240MDataset(root=root)
print("Loading graph")
(g,), _ = dgl.load_graphs(join(root, "mag240m_kddcup2021/graph.dgl"))
print("Loading features")
paper_offset = dataset.num_authors + dataset.num_institutions
num_nodes = paper_offset + dataset.num_papers
num_features = dataset.num_paper_features
feats = th.from_numpy(
np.memmap(
join(root, "mag240m_kddcup2021/full.npy"),
mode="r",
dtype="float16",
shape=(num_nodes, num_features),
)
).float()
g.ndata["features"] = feats
train_nid = th.LongTensor(dataset.get_idx_split("train")) + paper_offset
val_nid = th.LongTensor(dataset.get_idx_split("valid")) + paper_offset
test_nid = th.LongTensor(dataset.get_idx_split("test-dev")) + paper_offset
train_mask = th.zeros((g.number_of_nodes(),), dtype=th.bool)
train_mask[train_nid] = True
val_mask = th.zeros((g.number_of_nodes(),), dtype=th.bool)
val_mask[val_nid] = True
test_mask = th.zeros((g.number_of_nodes(),), dtype=th.bool)
test_mask[test_nid] = True
g.ndata["train_mask"] = train_mask
g.ndata["val_mask"] = val_mask
g.ndata["test_mask"] = test_mask
labels = th.tensor(dataset.paper_label)
num_labels = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
g.ndata["labels"] = -th.ones(g.number_of_nodes(), dtype=th.int64)
g.ndata["labels"][train_nid] = labels[train_nid - paper_offset].long()
g.ndata["labels"][val_nid] = labels[val_nid - paper_offset].long()
return g, num_labels
def load_ogb(name, root="dataset"):
if name == "ogbn-mag240M":
return load_mag240m(root)
from ogb.nodeproppred import DglNodePropPredDataset
print("load", name)
data = DglNodePropPredDataset(name=name, root=root)
print("finish loading", name)
splitted_idx = data.get_idx_split()
graph, labels = data[0]
labels = labels[:, 0]
graph.ndata["features"] = graph.ndata.pop("feat")
num_labels = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
graph.ndata["labels"] = labels.type(th.LongTensor)
in_feats = graph.ndata["features"].shape[1]
# Find the node IDs in the training, validation, and test set.
train_nid, val_nid, test_nid = (
splitted_idx["train"],
splitted_idx["valid"],
splitted_idx["test"],
)
train_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)
train_mask[train_nid] = True
val_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)
val_mask[val_nid] = True
test_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)
test_mask[test_nid] = True
graph.ndata["train_mask"] = train_mask
graph.ndata["val_mask"] = val_mask
graph.ndata["test_mask"] = test_mask
print("finish constructing", name)
return graph, num_labels
def load_dataset(dataset_name):
multilabel = False
if dataset_name in [
"reddit",
"cora",
"citeseer",
"pubmed",
"yelp",
"flickr",
]:
g, n_classes = load_dgl(dataset_name)
multilabel = dataset_name in ["yelp"]
if multilabel:
g.ndata["labels"] = g.ndata["labels"].to(dtype=th.float32)
elif dataset_name in [
"ogbn-products",
"ogbn-arxiv",
"ogbn-papers100M",
"ogbn-mag240M",
]:
g, n_classes = load_ogb(dataset_name)
else:
raise ValueError("unknown dataset")
return g, n_classes, multilabel
import dgl
import dgl.nn as dglnn
import sklearn.linear_model as lm
import sklearn.metrics as skm
import torch as th
import torch.functional as F
import torch.nn as nn
from dgl.nn import GATv2Conv
class GATv2(nn.Module):
def __init__(
self,
num_layers,
in_dim,
num_hidden,
num_classes,
heads,
activation,
feat_drop,
attn_drop,
negative_slope,
residual,
):
super(GATv2, self).__init__()
self.num_layers = num_layers
self.gatv2_layers = nn.ModuleList()
self.activation = activation
# input projection (no residual)
self.gatv2_layers.append(
GATv2Conv(
in_dim,
num_hidden,
heads[0],
feat_drop,
attn_drop,
negative_slope,
False,
self.activation,
True,
bias=False,
share_weights=True,
)
)
# hidden layers
for l in range(1, num_layers - 1):
# due to multi-head, the in_dim = num_hidden * num_heads
self.gatv2_layers.append(
GATv2Conv(
num_hidden * heads[l - 1],
num_hidden,
heads[l],
feat_drop,
attn_drop,
negative_slope,
residual,
self.activation,
True,
bias=False,
share_weights=True,
)
)
# output projection
self.gatv2_layers.append(
GATv2Conv(
num_hidden * heads[-2],
num_classes,
heads[-1],
feat_drop,
attn_drop,
negative_slope,
residual,
None,
True,
bias=False,
share_weights=True,
)
)
def forward(self, mfgs, h):
for l, mfg in enumerate(mfgs):
h = self.gatv2_layers[l](mfg, h)
h = h.flatten(1) if l < self.num_layers - 1 else h.mean(1)
return h
class SAGE(nn.Module):
def __init__(
self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
):
super().__init__()
self.init(in_feats, n_hidden, n_classes, n_layers, activation, dropout)
def init(
self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
):
self.n_layers = n_layers
self.n_hidden = n_hidden
self.n_classes = n_classes
self.layers = nn.ModuleList()
if n_layers > 1:
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
for i in range(1, n_layers - 1):
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
else:
self.layers.append(dglnn.SAGEConv(in_feats, n_classes, "mean"))
self.dropout = nn.Dropout(dropout)
self.activation = activation
def forward(self, blocks, x):
h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)):
h = layer(
block,
h,
edge_weight=block.edata["edge_weights"]
if "edge_weights" in block.edata
else None,
)
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
return h
class RGAT(nn.Module):
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
num_etypes,
num_layers,
num_heads,
dropout,
pred_ntype,
):
super().__init__()
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.skips = nn.ModuleList()
self.convs.append(
nn.ModuleList(
[
dglnn.GATConv(
in_channels,
hidden_channels // num_heads,
num_heads,
allow_zero_in_degree=True,
)
for _ in range(num_etypes)
]
)
)
self.norms.append(nn.BatchNorm1d(hidden_channels))
self.skips.append(nn.Linear(in_channels, hidden_channels))
for _ in range(num_layers - 1):
self.convs.append(
nn.ModuleList(
[
dglnn.GATConv(
hidden_channels,
hidden_channels // num_heads,
num_heads,
allow_zero_in_degree=True,
)
for _ in range(num_etypes)
]
)
)
self.norms.append(nn.BatchNorm1d(hidden_channels))
self.skips.append(nn.Linear(hidden_channels, hidden_channels))
self.mlp = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels),
nn.BatchNorm1d(hidden_channels),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_channels, out_channels),
)
self.dropout = nn.Dropout(dropout)
self.hidden_channels = hidden_channels
self.pred_ntype = pred_ntype
self.num_etypes = num_etypes
def forward(self, mfgs, x):
for i in range(len(mfgs)):
mfg = mfgs[i]
x_dst = x[mfg.dst_in_src]
for data in [mfg.srcdata, mfg.dstdata]:
for k in list(data.keys()):
if k not in ["features", "labels"]:
data.pop(k)
mfg = dgl.block_to_graph(mfg)
x_skip = self.skips[i](x_dst)
for j in range(self.num_etypes):
subg = mfg.edge_subgraph(
mfg.edata["etype"] == j, relabel_nodes=False
)
x_skip += self.convs[i][j](subg, (x, x_dst)).view(
-1, self.hidden_channels
)
x = self.norms[i](x_skip)
x = th.nn.functional.elu(x)
x = self.dropout(x)
return self.mlp(x)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment