Unverified Commit 59a7d0d1 authored by Kay Liu's avatar Kay Liu Committed by GitHub
Browse files

[Model] add model example GCN-based Anti-Spam (#3145)



* add model example GCN-based Anti-Spam

* update example index

* add usage info

* improvements as per comments

* fix image invisiable problem

* add image file
Co-authored-by: default avatarzhjwy9343 <6593865@qq.com>
parent c0719ec5
...@@ -159,6 +159,9 @@ To quickly locate the examples of your interest, search for the tagged keywords ...@@ -159,6 +159,9 @@ To quickly locate the examples of your interest, search for the tagged keywords
- <a name='gtn'></a> Yun S, Jeong M, et al. Graph transformer networks. [Paper link](https://arxiv.org/abs/1911.06455). - <a name='gtn'></a> Yun S, Jeong M, et al. Graph transformer networks. [Paper link](https://arxiv.org/abs/1911.06455).
- Example code: [OpenHGNN](https://github.com/BUPT-GAMMA/OpenHGNN/tree/main/openhgnn/output/GTN) - Example code: [OpenHGNN](https://github.com/BUPT-GAMMA/OpenHGNN/tree/main/openhgnn/output/GTN)
- Tags: Heterogeneous graph, Graph neural network, Graph structure - Tags: Heterogeneous graph, Graph neural network, Graph structure
- <a name='gas'></a> Li A, Qin Z, et al. Spam Review Detection with Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1908.10679).
- Example code: [PyTorch](../examples/pytorch/gas)
- Tags: Fraud detection, Heterogeneous graph, Edge classification, Graph attention
## 2018 ## 2018
......
# DGL Implementation of the GAS Paper
This DGL example implements the Heterogeneous GCN part of the model proposed in the paper [Spam Review Detection with Graph Convolutional Networks](https://arxiv.org/abs/1908.10679).
Example implementor
----------------------
This example was implemented by [Kay Liu](https://github.com/kayzliu) during his SDE intern work at the AWS Shanghai AI Lab.
Dependencies
----------------------
- Python 3.7.10
- PyTorch 1.8.1
- dgl 0.7.0
- scikit-learn 0.23.2
Dataset
---------------------------------------
The datasets used for edge classification are variants of DGL's built-in [fake news datasets](https://github.com/dmlc/dgl/blob/master/python/dgl/data/fakenews.py). The converting process from tree-structured graph to bipartite graph is shown in the figure.
![variant](variant.png)
**NOTE**: Same as the original fake news dataset, this variant is for academic use only as well, and commercial use is prohibited. The statistics are summarized as followings:
**Politifact**
- Nodes:
- user (u): 276,277
- news (v): 581
- Edges:
- forward: 399,016
- backward: 399,016
- Number of Classes: 2
- Node feature size: 300
- Edge feature size: 300
**Gossicop**
- Nodes:
- user (u): 565,660
- news (v): 10,333
- Edges:
- forward: 1,254,469
- backward: 1,254,469
- Number of Classes: 2
- Node feature size: 300
- Edge feature size: 300
How to run
--------------------------------
In the gas folder, run
```
python main.py
```
If want to use a GPU, run
```
python main.py --gpu 0
```
If the mini-batch training is required to run on a GPU, run
```
python main_sampling.py --gpu 0
```
Performance
-------------------------
|Dataset | Xianyu Graph (paper reported) | Fake News Politifact | Fake News Gossipcop |
| -------------------- | ----------------- | -------------------- | ------------------- |
| F1 | 0.8143 | 0.9994 | 0.9942 |
| AUC | 0.9860 | 1.0000 | 0.9991 |
| Recall@90% precision | 0.6702 | 0.9999 | 0.9976 |
\ No newline at end of file
import os
import dgl
import torch as th
import numpy as np
import scipy.io as sio
from dgl.data import DGLBuiltinDataset
from dgl.data.utils import save_graphs, load_graphs, _get_dgl_url
class GASDataset(DGLBuiltinDataset):
file_urls = {
'pol': 'dataset/GASPOL.zip',
'gos': 'dataset/GASGOS.zip'
}
def __init__(self, name, raw_dir=None, random_seed=717, train_size=0.7, val_size=0.1):
assert name in ['gos', 'pol'], "Only supports 'gos' or 'pol'."
self.seed = random_seed
self.train_size = train_size
self.val_size = val_size
url = _get_dgl_url(self.file_urls[name])
super(GASDataset, self).__init__(name=name,
url=url,
raw_dir=raw_dir)
def process(self):
"""process raw data to graph, labels and masks"""
data = sio.loadmat(os.path.join(self.raw_path, f'{self.name}_retweet_graph.mat'))
adj = data['graph'].tocoo()
num_edges = len(adj.row)
row, col = adj.row[:int(num_edges/2)], adj.col[:int(num_edges/2)]
graph = dgl.graph((np.concatenate((row, col)), np.concatenate((col, row))))
news_labels = data['label'].squeeze()
num_news = len(news_labels)
node_feature = np.load(os.path.join(self.raw_path, f'{self.name}_node_feature.npy'))
edge_feature = np.load(os.path.join(self.raw_path, f'{self.name}_edge_feature.npy'))[:int(num_edges/2)]
graph.ndata['feat'] = th.tensor(node_feature)
graph.edata['feat'] = th.tensor(np.tile(edge_feature, (2, 1)))
pos_news = news_labels.nonzero()[0]
edge_labels = th.zeros(num_edges)
edge_labels[graph.in_edges(pos_news, form='eid')] = 1
edge_labels[graph.out_edges(pos_news, form='eid')] = 1
graph.edata['label'] = edge_labels
ntypes = th.ones(graph.num_nodes(), dtype=int)
etypes = th.ones(graph.num_edges(), dtype=int)
ntypes[graph.nodes() < num_news] = 0
etypes[:int(num_edges/2)] = 0
graph.ndata['_TYPE'] = ntypes
graph.edata['_TYPE'] = etypes
hg = dgl.to_heterogeneous(graph, ['v', 'u'], ['forward', 'backward'])
self._random_split(hg, self.seed, self.train_size, self.val_size)
self.graph = hg
def save(self):
"""save the graph list and the labels"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin')
save_graphs(str(graph_path), self.graph)
def has_cache(self):
""" check whether there are processed data in `self.save_path` """
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin')
return os.path.exists(graph_path)
def load(self):
"""load processed data from directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin')
graph, _ = load_graphs(str(graph_path))
self.graph = graph[0]
@property
def num_classes(self):
"""Number of classes for each graph, i.e. number of prediction tasks."""
return 2
def __getitem__(self, idx):
r""" Get graph object
Parameters
----------
idx : int
Item index
Returns
-------
:class:`dgl.DGLGraph`
"""
assert idx == 0, "This dataset has only one graph"
return self.graph
def __len__(self):
r"""Number of data examples
Return
-------
int
"""
return len(self.graph)
def _random_split(self, graph, seed=717, train_size=0.7, val_size=0.1):
"""split the dataset into training set, validation set and testing set"""
assert 0 <= train_size + val_size <= 1, \
"The sum of valid training set size and validation set size " \
"must between 0 and 1 (inclusive)."
num_edges = graph.num_edges(etype='forward')
index = np.arange(num_edges)
index = np.random.RandomState(seed).permutation(index)
train_idx = index[:int(train_size * num_edges)]
val_idx = index[num_edges - int(val_size * num_edges):]
test_idx = index[int(train_size * num_edges):num_edges - int(val_size * num_edges)]
train_mask = np.zeros(num_edges, dtype=np.bool)
val_mask = np.zeros(num_edges, dtype=np.bool)
test_mask = np.zeros(num_edges, dtype=np.bool)
train_mask[train_idx] = True
val_mask[val_idx] = True
test_mask[test_idx] = True
graph.edges['forward'].data['train_mask'] = th.tensor(train_mask)
graph.edges['forward'].data['val_mask'] = th.tensor(val_mask)
graph.edges['forward'].data['test_mask'] = th.tensor(test_mask)
graph.edges['backward'].data['train_mask'] = th.tensor(train_mask)
graph.edges['backward'].data['val_mask'] = th.tensor(val_mask)
graph.edges['backward'].data['test_mask'] = th.tensor(test_mask)
import argparse
import torch as th
import torch.optim as optim
import torch.nn.functional as F
from dataloader import GASDataset
from model import GAS
from sklearn.metrics import f1_score, roc_auc_score, precision_recall_curve
def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# Load dataset
dataset = GASDataset(args.dataset)
graph = dataset[0]
# check cuda
if args.gpu >= 0 and th.cuda.is_available():
device = 'cuda:{}'.format(args.gpu)
else:
device = 'cpu'
# binary classification
num_classes = dataset.num_classes
# retrieve labels of ground truth
labels = graph.edges['forward'].data['label'].to(device).long()
# Extract node features
e_feat = graph.edges['forward'].data['feat'].to(device)
u_feat = graph.nodes['u'].data['feat'].to(device)
v_feat = graph.nodes['v'].data['feat'].to(device)
# retrieve masks for train/validation/test
train_mask = graph.edges['forward'].data['train_mask']
val_mask = graph.edges['forward'].data['val_mask']
test_mask = graph.edges['forward'].data['test_mask']
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze(1).to(device)
val_idx = th.nonzero(val_mask, as_tuple=False).squeeze(1).to(device)
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze(1).to(device)
graph = graph.to(device)
# Step 2: Create model =================================================================== #
model = GAS(e_in_dim=e_feat.shape[-1],
u_in_dim=u_feat.shape[-1],
v_in_dim=v_feat.shape[-1],
e_hid_dim=args.e_hid_dim,
u_hid_dim=args.u_hid_dim,
v_hid_dim=args.v_hid_dim,
out_dim=num_classes,
num_layers=args.num_layers,
dropout=args.dropout,
activation=F.relu)
model = model.to(device)
# Step 3: Create training components ===================================================== #
loss_fn = th.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# Step 4: training epochs =============================================================== #
for epoch in range(args.max_epoch):
# Training and validation using a full graph
model.train()
logits = model(graph, e_feat, u_feat, v_feat)
# compute loss
tr_loss = loss_fn(logits[train_idx], labels[train_idx])
tr_f1 = f1_score(labels[train_idx].cpu(), logits[train_idx].argmax(dim=1).cpu())
tr_auc = roc_auc_score(labels[train_idx].cpu(), logits[train_idx][:, 1].detach().cpu())
tr_pre, tr_re, _ = precision_recall_curve(labels[train_idx].cpu(), logits[train_idx][:, 1].detach().cpu())
tr_rap = tr_re[tr_pre > args.precision].max()
# validation
valid_loss = loss_fn(logits[val_idx], labels[val_idx])
valid_f1 = f1_score(labels[val_idx].cpu(), logits[val_idx].argmax(dim=1).cpu())
valid_auc = roc_auc_score(labels[val_idx].cpu(), logits[val_idx][:, 1].detach().cpu())
valid_pre, valid_re, _ = precision_recall_curve(labels[val_idx].cpu(), logits[val_idx][:, 1].detach().cpu())
valid_rap = valid_re[valid_pre > args.precision].max()
# backward
optimizer.zero_grad()
tr_loss.backward()
optimizer.step()
# Print out performance
print("In epoch {}, Train R@P: {:.4f} | Train F1: {:.4f} | Train AUC: {:.4f} | Train Loss: {:.4f}; "
"Valid R@P: {:.4f} | Valid F1: {:.4f} | Valid AUC: {:.4f} | Valid loss: {:.4f}".
format(epoch, tr_rap, tr_f1, tr_auc, tr_loss.item(), valid_rap, valid_f1, valid_auc, valid_loss.item()))
# Test after all epoch
model.eval()
# forward
logits = model(graph, e_feat, u_feat, v_feat)
# compute loss
test_loss = loss_fn(logits[test_idx], labels[test_idx])
test_f1 = f1_score(labels[test_idx].cpu(), logits[test_idx].argmax(dim=1).cpu())
test_auc = roc_auc_score(labels[test_idx].cpu(), logits[test_idx][:, 1].detach().cpu())
test_pre, test_re, _ = precision_recall_curve(labels[test_idx].cpu(), logits[test_idx][:, 1].detach().cpu())
test_rap = test_re[test_pre > args.precision].max()
print("Test R@P: {:.4f} | Test F1: {:.4f} | Test AUC: {:.4f} | Test loss: {:.4f}".
format(test_rap, test_f1, test_auc, test_loss.item()))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN-based Anti-Spam Model')
parser.add_argument("--dataset", type=str, default="pol", help="'pol', or 'gos'")
parser.add_argument("--gpu", type=int, default=-1, help="GPU Index. Default: -1, using CPU.")
parser.add_argument("--e_hid_dim", type=int, default=128, help="Hidden layer dimension for edges")
parser.add_argument("--u_hid_dim", type=int, default=128, help="Hidden layer dimension for source nodes")
parser.add_argument("--v_hid_dim", type=int, default=128, help="Hidden layer dimension for destination nodes")
parser.add_argument("--num_layers", type=int, default=2, help="Number of GCN layers")
parser.add_argument("--max_epoch", type=int, default=100, help="The max number of epochs. Default: 100")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate. Default: 1e-3")
parser.add_argument("--dropout", type=float, default=0.0, help="Dropout rate. Default: 0.0")
parser.add_argument("--weight_decay", type=float, default=5e-4, help="Weight Decay. Default: 0.0005")
parser.add_argument("--precision", type=float, default=0.9, help="The value p in recall@p precision. Default: 0.9")
args = parser.parse_args()
print(args)
main(args)
import dgl
import argparse
import torch as th
import torch.optim as optim
import torch.nn.functional as F
from dataloader import GASDataset
from model_sampling import GAS
from sklearn.metrics import f1_score, precision_recall_curve, roc_auc_score
def evaluate(model, loss_fn, dataloader, device='cpu'):
loss = 0
f1 = 0
auc = 0
rap = 0
num_blocks = 0
for input_nodes, edge_subgraph, blocks in dataloader:
blocks = [b.to(device) for b in blocks]
edge_subgraph = edge_subgraph.to(device)
u_feat = blocks[0].srcdata['feat']['u']
v_feat = blocks[0].srcdata['feat']['v']
f_feat = blocks[0].edges['forward'].data['feat']
b_feat = blocks[0].edges['backward'].data['feat']
labels = edge_subgraph.edges['forward'].data['label'].long()
logits = model(edge_subgraph, blocks, f_feat, b_feat, u_feat, v_feat)
loss += loss_fn(logits, labels).item()
f1 += f1_score(labels.cpu(), logits.argmax(dim=1).cpu())
auc += roc_auc_score(labels.cpu(), logits[:, 1].detach().cpu())
pre, re, _ = precision_recall_curve(labels.cpu(), logits[:, 1].detach().cpu())
rap += re[pre > args.precision].max()
num_blocks += 1
return rap / num_blocks, f1 / num_blocks, auc / num_blocks, loss / num_blocks
def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# Load dataset
dataset = GASDataset(args.dataset)
graph = dataset[0]
# generate mini-batch only for forward edges
sampler = dgl.dataloading.MultiLayerNeighborSampler([10, 10])
tr_eid_dict = {}
val_eid_dict = {}
test_eid_dict = {}
tr_eid_dict['forward'] = graph.edges['forward'].data["train_mask"].nonzero().squeeze()
val_eid_dict['forward'] = graph.edges['forward'].data["val_mask"].nonzero().squeeze()
test_eid_dict['forward'] = graph.edges['forward'].data["test_mask"].nonzero().squeeze()
tr_loader = dgl.dataloading.EdgeDataLoader(graph,
tr_eid_dict,
sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers)
val_loader = dgl.dataloading.EdgeDataLoader(graph,
val_eid_dict,
sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers)
test_loader = dgl.dataloading.EdgeDataLoader(graph,
test_eid_dict,
sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers)
# check cuda
if args.gpu >= 0 and th.cuda.is_available():
device = 'cuda:{}'.format(args.gpu)
else:
device = 'cpu'
# binary classification
num_classes = dataset.num_classes
# Extract node features
e_feats = graph.edges['forward'].data['feat'].shape[-1]
u_feats = graph.nodes['u'].data['feat'].shape[-1]
v_feats = graph.nodes['v'].data['feat'].shape[-1]
# Step 2: Create model =================================================================== #
model = GAS(e_in_dim=e_feats,
u_in_dim=u_feats,
v_in_dim=v_feats,
e_hid_dim=args.e_hid_dim,
u_hid_dim=args.u_hid_dim,
v_hid_dim=args.v_hid_dim,
out_dim=num_classes,
num_layers=args.num_layers,
dropout=args.dropout,
activation=F.relu)
model = model.to(device)
# Step 3: Create training components ===================================================== #
loss_fn = th.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# Step 4: training epochs =============================================================== #
for epoch in range(args.max_epoch):
model.train()
tr_loss = 0
tr_f1 = 0
tr_auc = 0
tr_rap = 0
tr_blocks = 0
for input_nodes, edge_subgraph, blocks in tr_loader:
blocks = [b.to(device) for b in blocks]
edge_subgraph = edge_subgraph.to(device)
u_feat = blocks[0].srcdata['feat']['u']
v_feat = blocks[0].srcdata['feat']['v']
f_feat = blocks[0].edges['forward'].data['feat']
b_feat = blocks[0].edges['backward'].data['feat']
labels = edge_subgraph.edges['forward'].data['label'].long()
logits = model(edge_subgraph, blocks, f_feat, b_feat, u_feat, v_feat)
# compute loss
batch_loss = loss_fn(logits, labels)
tr_loss += batch_loss.item()
tr_f1 += f1_score(labels.cpu(), logits.argmax(dim=1).cpu())
tr_auc += roc_auc_score(labels.cpu(), logits[:, 1].detach().cpu())
tr_pre, tr_re, _ = precision_recall_curve(labels.cpu(), logits[:, 1].detach().cpu())
tr_rap += tr_re[tr_pre > args.precision].max()
tr_blocks += 1
# backward
optimizer.zero_grad()
batch_loss.backward()
optimizer.step()
# validation
model.eval()
val_rap, val_f1, val_auc, val_loss = evaluate(model, loss_fn, val_loader, device)
# Print out performance
print("In epoch {}, Train R@P: {:.4f} | Train F1: {:.4f} | Train AUC: {:.4f} | Train Loss: {:.4f}; "
"Valid R@P: {:.4f} | Valid F1: {:.4f} | Valid AUC: {:.4f} | Valid loss: {:.4f}".
format(epoch, tr_rap / tr_blocks, tr_f1 / tr_blocks, tr_auc / tr_blocks , tr_loss / tr_blocks,
val_rap, val_f1, val_auc, val_loss))
# Test with mini batch after all epoch
model.eval()
test_rap, test_f1, test_auc, test_loss = evaluate(model, loss_fn, test_loader, device)
print("Test R@P: {:.4f} | Test F1: {:.4f} | Test AUC: {:.4f} | Test loss: {:.4f}".
format(test_rap, test_f1, test_auc, test_loss))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN-based Anti-Spam Model')
parser.add_argument("--dataset", type=str, default="pol", help="'pol', or 'gos'")
parser.add_argument("--gpu", type=int, default=-1, help="GPU Index. Default: -1, using CPU.")
parser.add_argument("--e_hid_dim", type=int, default=128, help="Hidden layer dimension for edges")
parser.add_argument("--u_hid_dim", type=int, default=128, help="Hidden layer dimension for source nodes")
parser.add_argument("--v_hid_dim", type=int, default=128, help="Hidden layer dimension for destination nodes")
parser.add_argument("--num_layers", type=int, default=2, help="Number of GCN layers")
parser.add_argument("--max_epoch", type=int, default=100, help="The max number of epochs. Default: 100")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate. Default: 1e-3")
parser.add_argument("--dropout", type=float, default=0.0, help="Dropout rate. Default: 0.0")
parser.add_argument("--batch_size", type=int, default=64, help="Size of mini-batches. Default: 64")
parser.add_argument("--num_workers", type=int, default=4, help="Number of node dataloader")
parser.add_argument("--weight_decay", type=float, default=5e-4, help="Weight Decay. Default: 0.0005")
parser.add_argument("--precision", type=float, default=0.9, help="The value p in recall@p precision. Default: 0.9")
args = parser.parse_args()
print(args)
main(args)
import torch.nn as nn
import dgl.function as fn
import torch as th
from dgl.nn.functional import edge_softmax
class MLP(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.W = nn.Linear(in_dim, out_dim)
def apply_edges(self, edges):
h_e = edges.data['h']
h_u = edges.src['h']
h_v = edges.dst['h']
score = self.W(th.cat([h_e, h_u, h_v], -1))
return {'score': score}
def forward(self, g, e_feat, u_feat, v_feat):
with g.local_scope():
g.edges['forward'].data['h'] = e_feat
g.nodes['u'].data['h'] = u_feat
g.nodes['v'].data['h'] = v_feat
g.apply_edges(self.apply_edges, etype="forward")
return g.edges['forward'].data['score']
class GASConv(nn.Module):
"""One layer of GAS."""
def __init__(self,
e_in_dim,
u_in_dim,
v_in_dim,
e_out_dim,
u_out_dim,
v_out_dim,
activation=None,
dropout=0):
super(GASConv, self).__init__()
self.activation = activation
self.dropout = nn.Dropout(dropout)
self.e_linear = nn.Linear(e_in_dim, e_out_dim)
self.u_linear = nn.Linear(u_in_dim, e_out_dim)
self.v_linear = nn.Linear(v_in_dim, e_out_dim)
self.W_ATTN_u = nn.Linear(u_in_dim, v_in_dim + e_in_dim)
self.W_ATTN_v = nn.Linear(v_in_dim, u_in_dim + e_in_dim)
# the proportion of h_u and h_Nu are specified as 1/2 in formula 8
nu_dim = int(u_out_dim / 2)
nv_dim = int(v_out_dim / 2)
self.W_u = nn.Linear(v_in_dim + e_in_dim, nu_dim)
self.W_v = nn.Linear(u_in_dim + e_in_dim, nv_dim)
self.Vu = nn.Linear(u_in_dim, u_out_dim - nu_dim)
self.Vv = nn.Linear(v_in_dim, v_out_dim - nv_dim)
def forward(self, g, e_feat, u_feat, v_feat):
with g.local_scope():
g.nodes['u'].data['h'] = u_feat
g.nodes['v'].data['h'] = v_feat
g.edges['forward'].data['h'] = e_feat
g.edges['backward'].data['h'] = e_feat
# formula 3 and 4 (optimized implementation to save memory)
g.nodes["u"].data.update({'he_u': self.u_linear(u_feat)})
g.nodes["v"].data.update({'he_v': self.v_linear(v_feat)})
g.edges["forward"].data.update({'he_e': self.e_linear(e_feat)})
g.apply_edges(lambda edges: {'he': edges.data['he_e'] + edges.src['he_u'] + edges.dst['he_v']}, etype='forward')
he = g.edges["forward"].data['he']
if self.activation is not None:
he = self.activation(he)
# formula 6
g.apply_edges(lambda edges: {'h_ve': th.cat([edges.src['h'], edges.data['h']], -1)}, etype='backward')
g.apply_edges(lambda edges: {'h_ue': th.cat([edges.src['h'], edges.data['h']], -1)}, etype='forward')
# formula 7, self-attention
g.nodes['u'].data['h_att_u'] = self.W_ATTN_u(u_feat)
g.nodes['v'].data['h_att_v'] = self.W_ATTN_v(v_feat)
# Step 1: dot product
g.apply_edges(fn.e_dot_v('h_ve', 'h_att_u', 'edotv'), etype='backward')
g.apply_edges(fn.e_dot_v('h_ue', 'h_att_v', 'edotv'), etype='forward')
# Step 2. softmax
g.edges['backward'].data['sfm'] = edge_softmax(g['backward'], g.edges['backward'].data['edotv'])
g.edges['forward'].data['sfm'] = edge_softmax(g['forward'], g.edges['forward'].data['edotv'])
# Step 3. Broadcast softmax value to each edge, and then attention is done
g.apply_edges(lambda edges: {'attn': edges.data['h_ve'] * edges.data['sfm']}, etype='backward')
g.apply_edges(lambda edges: {'attn': edges.data['h_ue'] * edges.data['sfm']}, etype='forward')
# Step 4. Aggregate attention to dst,user nodes, so formula 7 is done
g.update_all(fn.copy_e('attn', 'm'), fn.sum('m', 'agg_u'), etype='backward')
g.update_all(fn.copy_e('attn', 'm'), fn.sum('m', 'agg_v'), etype='forward')
# formula 5
h_nu = self.W_u(g.nodes['u'].data['agg_u'])
h_nv = self.W_v(g.nodes['v'].data['agg_v'])
if self.activation is not None:
h_nu = self.activation(h_nu)
h_nv = self.activation(h_nv)
# Dropout
he = self.dropout(he)
h_nu = self.dropout(h_nu)
h_nv = self.dropout(h_nv)
# formula 8
hu = th.cat([self.Vu(u_feat), h_nu], -1)
hv = th.cat([self.Vv(v_feat), h_nv], -1)
return he, hu, hv
class GAS(nn.Module):
def __init__(self,
e_in_dim,
u_in_dim,
v_in_dim,
e_hid_dim,
u_hid_dim,
v_hid_dim,
out_dim,
num_layers=2,
dropout=0.0,
activation=None):
super(GAS, self).__init__()
self.e_in_dim = e_in_dim
self.u_in_dim = u_in_dim
self.v_in_dim = v_in_dim
self.e_hid_dim = e_hid_dim
self.u_hid_dim = u_hid_dim
self.v_hid_dim = v_hid_dim
self.out_dim = out_dim
self.num_layer = num_layers
self.dropout = dropout
self.activation = activation
self.predictor = MLP(e_hid_dim + u_hid_dim + v_hid_dim, out_dim)
self.layers = nn.ModuleList()
# Input layer
self.layers.append(GASConv(self.e_in_dim,
self.u_in_dim,
self.v_in_dim,
self.e_hid_dim,
self.u_hid_dim,
self.v_hid_dim,
activation=self.activation,
dropout=self.dropout))
# Hidden layers with n - 1 CompGraphConv layers
for i in range(self.num_layer - 1):
self.layers.append(GASConv(self.e_hid_dim,
self.u_hid_dim,
self.v_hid_dim,
self.e_hid_dim,
self.u_hid_dim,
self.v_hid_dim,
activation=self.activation,
dropout=self.dropout))
def forward(self, graph, e_feat, u_feat, v_feat):
# For full graph training, directly use the graph
# Forward of n layers of GAS
for layer in self.layers:
e_feat, u_feat, v_feat = layer(graph, e_feat, u_feat, v_feat)
# return the result of final prediction layer
return self.predictor(graph, e_feat, u_feat, v_feat)
import torch.nn as nn
import dgl.function as fn
import torch as th
from dgl.nn.functional import edge_softmax
class MLP(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.W = nn.Linear(in_dim, out_dim)
def apply_edges(self, edges):
h_e = edges.data['h']
h_u = edges.src['h']
h_v = edges.dst['h']
score = self.W(th.cat([h_e, h_u, h_v], -1))
return {'score': score}
def forward(self, g, e_feat, u_feat, v_feat):
with g.local_scope():
g.edges['forward'].data['h'] = e_feat
g.nodes['u'].data['h'] = u_feat
g.nodes['v'].data['h'] = v_feat
g.apply_edges(self.apply_edges, etype="forward")
return g.edges['forward'].data['score']
class GASConv(nn.Module):
"""One layer of GAS."""
def __init__(self,
e_in_dim,
u_in_dim,
v_in_dim,
e_out_dim,
u_out_dim,
v_out_dim,
activation=None,
dropout=0):
super(GASConv, self).__init__()
self.activation = activation
self.dropout = nn.Dropout(dropout)
self.e_linear = nn.Linear(e_in_dim, e_out_dim)
self.u_linear = nn.Linear(u_in_dim, e_out_dim)
self.v_linear = nn.Linear(v_in_dim, e_out_dim)
self.W_ATTN_u = nn.Linear(u_in_dim, v_in_dim + e_in_dim)
self.W_ATTN_v = nn.Linear(v_in_dim, u_in_dim + e_in_dim)
# the proportion of h_u and h_Nu are specified as 1/2 in formula 8
nu_dim = int(u_out_dim / 2)
nv_dim = int(v_out_dim / 2)
self.W_u = nn.Linear(v_in_dim + e_in_dim, nu_dim)
self.W_v = nn.Linear(u_in_dim + e_in_dim, nv_dim)
self.Vu = nn.Linear(u_in_dim, u_out_dim - nu_dim)
self.Vv = nn.Linear(v_in_dim, v_out_dim - nv_dim)
def forward(self, g, f_feat, b_feat, u_feat, v_feat):
g.srcnodes['u'].data['h'] = u_feat
g.srcnodes['v'].data['h'] = v_feat
g.dstnodes['u'].data['h'] = u_feat[:g.number_of_dst_nodes(ntype='u')]
g.dstnodes['v'].data['h'] = v_feat[:g.number_of_dst_nodes(ntype='v')]
g.edges['forward'].data['h'] = f_feat
g.edges['backward'].data['h'] = b_feat
# formula 3 and 4 (optimized implementation to save memory)
g.srcnodes["u"].data.update({'he_u': self.u_linear(g.srcnodes['u'].data['h'])})
g.srcnodes["v"].data.update({'he_v': self.v_linear(g.srcnodes['v'].data['h'])})
g.dstnodes["u"].data.update({'he_u': self.u_linear(g.dstnodes['u'].data['h'])})
g.dstnodes["v"].data.update({'he_v': self.v_linear(g.dstnodes['v'].data['h'])})
g.edges["forward"].data.update({'he_e': self.e_linear(f_feat)})
g.edges["backward"].data.update({'he_e': self.e_linear(b_feat)})
g.apply_edges(lambda edges: {'he': edges.data['he_e'] + edges.dst['he_u'] + edges.src['he_v']}, etype='backward')
g.apply_edges(lambda edges: {'he': edges.data['he_e'] + edges.src['he_u'] + edges.dst['he_v']}, etype='forward')
hf = g.edges["forward"].data['he']
hb = g.edges["backward"].data['he']
if self.activation is not None:
hf = self.activation(hf)
hb = self.activation(hb)
# formula 6
g.apply_edges(lambda edges: {'h_ve': th.cat([edges.src['h'], edges.data['h']], -1)}, etype='backward')
g.apply_edges(lambda edges: {'h_ue': th.cat([edges.src['h'], edges.data['h']], -1)}, etype='forward')
# formula 7, self-attention
g.srcnodes['u'].data['h_att_u'] = self.W_ATTN_u(g.srcnodes['u'].data['h'])
g.srcnodes['v'].data['h_att_v'] = self.W_ATTN_v(g.srcnodes['v'].data['h'])
g.dstnodes['u'].data['h_att_u'] = self.W_ATTN_u(g.dstnodes['u'].data['h'])
g.dstnodes['v'].data['h_att_v'] = self.W_ATTN_v(g.dstnodes['v'].data['h'])
# Step 1: dot product
g.apply_edges(fn.e_dot_v('h_ve', 'h_att_u', 'edotv'), etype='backward')
g.apply_edges(fn.e_dot_v('h_ue', 'h_att_v', 'edotv'), etype='forward')
# Step 2. softmax
g.edges['backward'].data['sfm'] = edge_softmax(g['backward'], g.edges['backward'].data['edotv'])
g.edges['forward'].data['sfm'] = edge_softmax(g['forward'], g.edges['forward'].data['edotv'])
# Step 3. Broadcast softmax value to each edge, and then attention is done
g.apply_edges(lambda edges: {'attn': edges.data['h_ve'] * edges.data['sfm']}, etype='backward')
g.apply_edges(lambda edges: {'attn': edges.data['h_ue'] * edges.data['sfm']}, etype='forward')
# Step 4. Aggregate attention to dst,user nodes, so formula 7 is done
g.update_all(fn.copy_e('attn', 'm'), fn.sum('m', 'agg_u'), etype='backward')
g.update_all(fn.copy_e('attn', 'm'), fn.sum('m', 'agg_v'), etype='forward')
# formula 5
h_nu = self.W_u(g.dstnodes['u'].data['agg_u'])
h_nv = self.W_v(g.dstnodes['v'].data['agg_v'])
if self.activation is not None:
h_nu = self.activation(h_nu)
h_nv = self.activation(h_nv)
# Dropout
hf = self.dropout(hf)
hb = self.dropout(hb)
h_nu = self.dropout(h_nu)
h_nv = self.dropout(h_nv)
# formula 8
hu = th.cat([self.Vu(g.dstnodes['u'].data['h']), h_nu], -1)
hv = th.cat([self.Vv(g.dstnodes['v'].data['h']), h_nv], -1)
return hf, hb, hu, hv
class GAS(nn.Module):
def __init__(self,
e_in_dim,
u_in_dim,
v_in_dim,
e_hid_dim,
u_hid_dim,
v_hid_dim,
out_dim,
num_layers=2,
dropout=0.0,
activation=None):
super(GAS, self).__init__()
self.e_in_dim = e_in_dim
self.u_in_dim = u_in_dim
self.v_in_dim = v_in_dim
self.e_hid_dim = e_hid_dim
self.u_hid_dim = u_hid_dim
self.v_hid_dim = v_hid_dim
self.out_dim = out_dim
self.num_layer = num_layers
self.dropout = dropout
self.activation = activation
self.predictor = MLP(e_hid_dim + u_hid_dim + v_hid_dim, out_dim)
self.layers = nn.ModuleList()
# Input layer
self.layers.append(GASConv(self.e_in_dim,
self.u_in_dim,
self.v_in_dim,
self.e_hid_dim,
self.u_hid_dim,
self.v_hid_dim,
activation=self.activation,
dropout=self.dropout))
# Hidden layers with n - 1 CompGraphConv layers
for i in range(self.num_layer - 1):
self.layers.append(GASConv(self.e_hid_dim,
self.u_hid_dim,
self.v_hid_dim,
self.e_hid_dim,
self.u_hid_dim,
self.v_hid_dim,
activation=self.activation,
dropout=self.dropout))
def forward(self, subgraph, blocks, f_feat, b_feat, u_feat, v_feat):
# Forward of n layers of GAS
for layer, block in zip(self.layers, blocks):
f_feat, b_feat, u_feat, v_feat = layer(block,
f_feat[:block.num_edges(etype='forward')],
b_feat[:block.num_edges(etype='backward')],
u_feat,
v_feat)
# return the result of final prediction layer
return self.predictor(subgraph, f_feat[:subgraph.num_edges(etype='forward')], u_feat, v_feat)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment