Unverified Commit a107993f authored by Kay Liu's avatar Kay Liu Committed by GitHub
Browse files

[Model] add model example CARE-GNN (#3187)



* [Model] add model example CARE-GNN

* update README

* improvements based on the review feedback

* fix missing item()
Co-authored-by: default avatarzhjwy9343 <6593865@qq.com>
parent c4791fd4
......@@ -75,6 +75,9 @@ To quickly locate the examples of your interest, search for the tagged keywords
- <a name="nshe"></a> Zhao J, Wang X, et al. Network Schema Preserving Heterogeneous Information Network Embedding. [Paper link](https://www.ijcai.org/Proceedings/2020/0190.pdf).
- Example code: [OpenHGNN](https://github.com/BUPT-GAMMA/OpenHGNN/tree/main/openhgnn/output/NSHE)
- Tags: Heterogeneous graph, Graph neural network, Graph embedding, Network Schema
- <a name="caregnn"></a> Dou Y, Liu Z, et al. Enhancing Graph Neural Network-based Fraud Detectors against Camouflaged Fraudsters. [Paper link](https://arxiv.org/abs/2008.08692).
- Example code: [PyTorch](../examples/pytorch/caregnn)
- Tags: Multi-relational graph, Graph neural network, Fraud detection, Reinforcement learning, Node classification
## 2019
......
# DGL Implementation of the CARE-GNN Paper
This DGL example implements the CAmouflage-REsistant GNN (CARE-GNN) model proposed in the paper [Enhancing Graph Neural Network-based Fraud Detectors against Camouflaged Fraudsters](https://arxiv.org/abs/2008.08692). The author's codes of implementation is [here](https://github.com/YingtongDou/CARE-GNN).
**NOTE**: The sampling version of this model has been modified according to the feature of the DGL's NodeDataLoader. For the formula 2 in the paper, rather than using the embedding of the last layer, this version uses the embedding of the current layer in the previous epoch to measure the similarity between center nodes and their neighbors.
Example implementor
----------------------
This example was implemented by [Kay Liu](https://github.com/kayzliu) during his SDE intern work at the AWS Shanghai AI Lab.
Dependencies
----------------------
- Python 3.7.10
- PyTorch 1.8.1
- dgl 0.7.0
- scikit-learn 0.23.2
Dataset
---------------------------------------
The datasets used for node classification are DGL's built-in FraudDataset. The statistics are summarized as followings:
**Amazon**
- Nodes: 11,944
- Edges:
- U-P-U: 351,216
- U-S-U: 7,132,958
- U-V-U: 2,073,474
- Classes:
- Positive (fraudulent): 821
- Negative (benign): 7,818
- Unlabeled: 3,305
- Positive-Negative ratio: 1 : 10.5
- Node feature size: 25
**YelpChi**
- Nodes: 45,954
- Edges:
- R-U-R: 98,630
- R-T-R: 1,147,232
- R-S-R: 6,805,486
- Classes:
- Positive (spam): 6,677
- Negative (legitimate): 39,277
- Positive-Negative ratio: 1 : 5.9
- Node feature size: 32
How to run
--------------------------------
To run the full graph version, in the care-gnn folder, run
```
python main.py
```
If want to use a GPU, run
```
python main.py --gpu 0
```
To train on Yelp dataset instead of Amazon, run
```
python main.py --dataset yelp
```
To run the sampling version, run
```
python main_sampling.py
```
Performance
-------------------------
The result reported by the paper is the best validation results within 30 epochs, while ours are testing results after the max epoch specified in the table. Early stopping with patience value of 100 is applied.
<table>
<tr>
<th colspan="2">Dataset</th>
<th>Amazon</th>
<th>Yelp</th>
</tr >
<tr>
<td>Metric</td>
<td>Max Epoch</td>
<td>30 / 1000</td>
<td>30 / 1000</td>
</tr>
<tr >
<td rowspan="3">AUC</td>
<td>paper reported</td>
<td>89.73 / -</td>
<td>75.70 / -</td>
</tr>
<tr>
<td>DGL full graph</td>
<td>89.50 / 92.35</td>
<td>69.16 / 79.91</td>
</tr>
<tr>
<td>DGL sampling</td>
<td>93.27 / 92.94</td>
<td>79.38 / 80.53</td>
</tr>
<tr >
<td rowspan="3">Recall</td>
<td>paper reported</td>
<td>88.48 / -</td>
<td>71.92 / -</td>
</tr>
<tr>
<td>DGL full graph</td>
<td>85.54 / 84.47</td>
<td>69.91 / 73.47</td>
</tr>
<tr>
<td>DGL sampling</td>
<td>85.83 / 87.46</td>
<td>77.26 / 64.34</td>
</tr>
</table>
import dgl
import argparse
import torch as th
from model import CAREGNN
import torch.optim as optim
from utils import EarlyStopping
from sklearn.metrics import recall_score, roc_auc_score
def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# Load dataset
dataset = dgl.data.FraudDataset(args.dataset, train_size=0.4)
graph = dataset[0]
num_classes = dataset.num_classes
# check cuda
if args.gpu >= 0 and th.cuda.is_available():
device = 'cuda:{}'.format(args.gpu)
else:
device = 'cpu'
# retrieve labels of ground truth
labels = graph.ndata['label'].to(device).squeeze().long()
# Extract node features
feat = graph.ndata['feature'].to(device).float()
# retrieve masks for train/validation/test
train_mask = graph.ndata['train_mask']
val_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze(1).to(device)
val_idx = th.nonzero(val_mask, as_tuple=False).squeeze(1).to(device)
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze(1).to(device)
# Reinforcement learning module only for positive training nodes
rl_idx = th.nonzero(train_mask.to(device) & labels.bool(), as_tuple=False).squeeze(1)
graph = graph.to(device)
# Step 2: Create model =================================================================== #
model = CAREGNN(in_dim=feat.shape[-1],
num_classes=num_classes,
hid_dim=args.hid_dim,
num_layers=args.num_layers,
activation=th.tanh,
step_size=args.step_size,
edges=graph.canonical_etypes)
model = model.to(device)
# Step 3: Create training components ===================================================== #
_, cnt = th.unique(labels, return_counts=True)
loss_fn = th.nn.CrossEntropyLoss(weight=1 / cnt)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
if args.early_stop:
stopper = EarlyStopping(patience=100)
# Step 4: training epochs =============================================================== #
for epoch in range(args.max_epoch):
# Training and validation using a full graph
model.train()
logits_gnn, logits_sim = model(graph, feat)
# compute loss
tr_loss = loss_fn(logits_gnn[train_idx], labels[train_idx]) + \
args.sim_weight * loss_fn(logits_sim[train_idx], labels[train_idx])
tr_recall = recall_score(labels[train_idx].cpu(), logits_gnn.data[train_idx].argmax(dim=1).cpu())
tr_auc = roc_auc_score(labels[train_idx].cpu(), logits_gnn.data[train_idx][:, 1].cpu())
# validation
val_loss = loss_fn(logits_gnn[val_idx], labels[val_idx]) + \
args.sim_weight * loss_fn(logits_sim[val_idx], labels[val_idx])
val_recall = recall_score(labels[val_idx].cpu(), logits_gnn.data[val_idx].argmax(dim=1).cpu())
val_auc = roc_auc_score(labels[val_idx].cpu(), logits_gnn.data[val_idx][:, 1].cpu())
# backward
optimizer.zero_grad()
tr_loss.backward()
optimizer.step()
# Print out performance
print("Epoch {}, Train: Recall: {:.4f} AUC: {:.4f} Loss: {:.4f} | Val: Recall: {:.4f} AUC: {:.4f} Loss: {:.4f}"
.format(epoch, tr_recall, tr_auc, tr_loss.item(), val_recall, val_auc, val_loss.item()))
# Adjust p value with reinforcement learning module
model.RLModule(graph, epoch, rl_idx)
if args.early_stop:
if stopper.step(val_auc, model):
break
# Test after all epoch
model.eval()
if args.early_stop:
model.load_state_dict(th.load('es_checkpoint.pt'))
# forward
logits_gnn, logits_sim = model.forward(graph, feat)
# compute loss
test_loss = loss_fn(logits_gnn[test_idx], labels[test_idx]) + \
args.sim_weight * loss_fn(logits_sim[test_idx], labels[test_idx])
test_recall = recall_score(labels[test_idx].cpu(), logits_gnn[test_idx].argmax(dim=1).cpu())
test_auc = roc_auc_score(labels[test_idx].cpu(), logits_gnn.data[test_idx][:, 1].cpu())
print("Test Recall: {:.4f} AUC: {:.4f} Loss: {:.4f}".format(test_recall, test_auc, test_loss.item()))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN-based Anti-Spam Model')
parser.add_argument("--dataset", type=str, default="amazon", help="DGL dataset for this model (yelp, or amazon)")
parser.add_argument("--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU.")
parser.add_argument("--hid_dim", type=int, default=64, help="Hidden layer dimension")
parser.add_argument("--num_layers", type=int, default=1, help="Number of layers")
parser.add_argument("--max_epoch", type=int, default=30, help="The max number of epochs. Default: 30")
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate. Default: 0.01")
parser.add_argument("--weight_decay", type=float, default=0.001, help="Weight decay. Default: 0.001")
parser.add_argument("--step_size", type=float, default=0.02, help="RL action step size (lambda 2). Default: 0.02")
parser.add_argument("--sim_weight", type=float, default=2, help="Similarity loss weight (lambda 1). Default: 2")
parser.add_argument('--early-stop', action='store_true', default=True, help="indicates whether to use early stop")
args = parser.parse_args()
print(args)
th.manual_seed(717)
main(args)
import dgl
import argparse
import torch as th
import torch.optim as optim
from utils import EarlyStopping
from model_sampling import CAREGNN, CARESampler, _l1_dist
from sklearn.metrics import roc_auc_score, recall_score
def evaluate(model, loss_fn, dataloader, device='cpu'):
loss = 0
auc = 0
recall = 0
num_blocks = 0
for input_nodes, output_nodes, blocks in dataloader:
blocks = [b.to(device) for b in blocks]
feature = blocks[0].srcdata['feature'].float()
label = blocks[-1].dstdata['label'].squeeze().long()
logits_gnn, logits_sim = model(blocks, feature)
# compute loss
loss += loss_fn(logits_gnn, label).item() + args.sim_weight * loss_fn(logits_sim, label).item()
recall += recall_score(label.cpu(), logits_gnn.argmax(dim=1).detach().cpu())
auc += roc_auc_score(label.cpu(), logits_gnn[:, 1].detach().cpu())
num_blocks += 1
return recall / num_blocks, auc / num_blocks, loss / num_blocks
def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# Load dataset
dataset = dgl.data.FraudDataset(args.dataset, train_size=0.4)
graph = dataset[0]
num_classes = dataset.num_classes
# check cuda
if args.gpu >= 0 and th.cuda.is_available():
device = 'cuda:{}'.format(args.gpu)
args.num_workers = 0
else:
device = 'cpu'
# retrieve labels of ground truth
labels = graph.ndata['label'].to(device).bool()
# Extract node features
feat = graph.ndata['feature'].to(device).float()
layers_feat = feat.expand(args.num_layers, -1, -1)
# retrieve masks for train/validation/test
train_mask = graph.ndata['train_mask']
val_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze(1).to(device)
val_idx = th.nonzero(val_mask, as_tuple=False).squeeze(1).to(device)
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze(1).to(device)
# Reinforcement learning module only for positive training nodes
rl_idx = th.nonzero(train_mask.to(device) & labels, as_tuple=False).squeeze(1)
graph = graph.to(device)
# Step 2: Create model =================================================================== #
model = CAREGNN(in_dim=feat.shape[-1],
num_classes=num_classes,
hid_dim=args.hid_dim,
num_layers=args.num_layers,
activation=th.tanh,
step_size=args.step_size,
edges=graph.canonical_etypes)
model = model.to(device)
# Step 3: Create training components ===================================================== #
_, cnt = th.unique(labels, return_counts=True)
loss_fn = th.nn.CrossEntropyLoss(weight=1 / cnt)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
if args.early_stop:
stopper = EarlyStopping(patience=100)
# Step 4: training epochs =============================================================== #
for epoch in range(args.max_epoch):
# calculate the distance of each edges and sample based on the distance
dists = []
p = []
for i in range(args.num_layers):
dist = {}
graph.ndata['nd'] = th.tanh(model.layers[i].MLP(layers_feat[i]))
for etype in graph.canonical_etypes:
graph.apply_edges(_l1_dist, etype=etype)
dist[etype] = graph.edges[etype].data['ed']
dists.append(dist)
p.append(model.layers[i].p)
sampler = CARESampler(p, dists, args.num_layers)
# train
model.train()
tr_loss = 0
tr_recall = 0
tr_auc = 0
tr_blk = 0
train_dataloader = dgl.dataloading.NodeDataLoader(graph,
train_idx,
sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers
)
for input_nodes, output_nodes, blocks in train_dataloader:
blocks = [b.to(device) for b in blocks]
train_feature = blocks[0].srcdata['feature'].float()
train_label = blocks[-1].dstdata['label'].squeeze().long()
logits_gnn, logits_sim = model(blocks, train_feature)
# compute loss
blk_loss = loss_fn(logits_gnn, train_label) + args.sim_weight * loss_fn(logits_sim, train_label)
tr_loss += blk_loss.item()
tr_recall += recall_score(train_label.cpu(), logits_gnn.argmax(dim=1).detach().cpu())
tr_auc += roc_auc_score(train_label.cpu(), logits_gnn[:, 1].detach().cpu())
tr_blk += 1
# backward
optimizer.zero_grad()
blk_loss.backward()
optimizer.step()
# Reinforcement learning module
model.RLModule(graph, epoch, rl_idx, dists)
# validation
model.eval()
val_dataloader = dgl.dataloading.NodeDataLoader(graph,
val_idx,
sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers
)
val_recall, val_auc, val_loss = evaluate(model, loss_fn, val_dataloader, device)
# Print out performance
print("In epoch {}, Train Recall: {:.4f} | Train AUC: {:.4f} | Train Loss: {:.4f}; "
"Valid Recall: {:.4f} | Valid AUC: {:.4f} | Valid loss: {:.4f}".
format(epoch, tr_recall / tr_blk, tr_auc / tr_blk, tr_loss / tr_blk, val_recall, val_auc, val_loss))
if args.early_stop:
if stopper.step(val_auc, model):
break
# Test with mini batch after all epoch
model.eval()
if args.early_stop:
model.load_state_dict(th.load('es_checkpoint.pt'))
test_dataloader = dgl.dataloading.NodeDataLoader(graph,
test_idx,
sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers
)
test_recall, test_auc, test_loss = evaluate(model, loss_fn, test_dataloader, device)
print("Test Recall: {:.4f} | Test AUC: {:.4f} | Test loss: {:.4f}".format(test_recall, test_auc, test_loss))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN-based Anti-Spam Model')
parser.add_argument("--dataset", type=str, default="amazon", help="DGL dataset for this model (yelp, or amazon)")
parser.add_argument("--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU.")
parser.add_argument("--hid_dim", type=int, default=64, help="Hidden layer dimension")
parser.add_argument("--num_layers", type=int, default=1, help="Number of layers")
parser.add_argument("--batch_size", type=int, default=256, help="Size of mini-batch")
parser.add_argument("--max_epoch", type=int, default=30, help="The max number of epochs. Default: 30")
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate. Default: 0.01")
parser.add_argument("--weight_decay", type=float, default=0.001, help="Weight decay. Default: 0.001")
parser.add_argument("--step_size", type=float, default=0.02, help="RL action step size (lambda 2). Default: 0.02")
parser.add_argument("--sim_weight", type=float, default=2, help="Similarity loss weight (lambda 1). Default: 0.001")
parser.add_argument("--num_workers", type=int, default=4, help="Number of node dataloader")
parser.add_argument('--early-stop', action='store_true', default=True, help="indicates whether to use early stop")
args = parser.parse_args()
print(args)
main(args)
import torch as th
import numpy as np
import torch.nn as nn
import dgl.function as fn
class CAREConv(nn.Module):
"""One layer of CARE-GNN."""
def __init__(self, in_dim, out_dim, num_classes, edges, activation=None, step_size=0.02):
super(CAREConv, self).__init__()
self.activation = activation
self.step_size = step_size
self.in_dim = in_dim
self.out_dim = out_dim
self.num_classes = num_classes
self.edges = edges
self.dist = {}
self.linear = nn.Linear(self.in_dim, self.out_dim)
self.MLP = nn.Linear(self.in_dim, self.num_classes)
self.p = {}
self.last_avg_dist = {}
self.f = {}
self.cvg = {}
for etype in edges:
self.p[etype] = 0.5
self.last_avg_dist[etype] = 0
self.f[etype] = []
self.cvg[etype] = False
def _calc_distance(self, edges):
# formula 2
d = th.norm(th.tanh(self.MLP(edges.src['h'])) - th.tanh(self.MLP(edges.dst['h'])), 1, 1)
return {'d': d}
def _top_p_sampling(self, g, p):
# this implementation is low efficient
# optimization requires dgl.sampling.select_top_p requested in issue #3100
dist = g.edata['d']
neigh_list = []
for node in g.nodes():
edges = g.in_edges(node, form='eid')
num_neigh = int(g.in_degrees(node) * p)
neigh_dist = dist[edges]
neigh_index = np.argpartition(neigh_dist.cpu().detach(), num_neigh)[:num_neigh]
neigh_list.append(edges[neigh_index])
return th.cat(neigh_list)
def forward(self, g, feat):
with g.local_scope():
g.ndata['h'] = feat
hr = {}
for i, etype in enumerate(g.canonical_etypes):
g.apply_edges(self._calc_distance, etype=etype)
self.dist[etype] = g.edges[etype].data['d']
sampled_edges = self._top_p_sampling(g[etype], self.p[etype])
# formula 8
g.send_and_recv(sampled_edges, fn.copy_u('h', 'm'), fn.mean('m', 'h_%s' % etype[1]), etype=etype)
hr[etype] = g.ndata['h_%s' % etype[1]]
if self.activation is not None:
hr[etype] = self.activation(hr[etype])
# formula 9 using mean as inter-relation aggregator
p_tensor = th.Tensor(list(self.p.values())).view(-1, 1, 1).to(g.device)
h_homo = th.sum(th.stack(list(hr.values())) * p_tensor, dim=0)
h_homo += feat
if self.activation is not None:
h_homo = self.activation(h_homo)
return self.linear(h_homo)
class CAREGNN(nn.Module):
def __init__(self,
in_dim,
num_classes,
hid_dim=64,
edges=None,
num_layers=2,
activation=None,
step_size=0.02):
super(CAREGNN, self).__init__()
self.in_dim = in_dim
self.hid_dim = hid_dim
self.num_classes = num_classes
self.edges = edges
self.activation = activation
self.step_size = step_size
self.num_layers = num_layers
self.layers = nn.ModuleList()
if self.num_layers == 1:
# Single layer
self.layers.append(CAREConv(self.in_dim,
self.num_classes,
self.num_classes,
self.edges,
activation=self.activation,
step_size=self.step_size))
else:
# Input layer
self.layers.append(CAREConv(self.in_dim,
self.hid_dim,
self.num_classes,
self.edges,
activation=self.activation,
step_size=self.step_size))
# Hidden layers with n - 2 layers
for i in range(self.num_layers - 2):
self.layers.append(CAREConv(self.hid_dim,
self.hid_dim,
self.num_classes,
self.edges,
activation=self.activation,
step_size=self.step_size))
# Output layer
self.layers.append(CAREConv(self.hid_dim,
self.num_classes,
self.num_classes,
self.edges,
activation=self.activation,
step_size=self.step_size))
def forward(self, graph, feat):
# For full graph training, directly use the graph
# formula 4
sim = th.tanh(self.layers[0].MLP(feat))
# Forward of n layers of CARE-GNN
for layer in self.layers:
feat = layer(graph, feat)
return feat, sim
def RLModule(self, graph, epoch, idx):
for layer in self.layers:
for etype in self.edges:
if not layer.cvg[etype]:
# formula 5
eid = graph.in_edges(idx, form='eid', etype=etype)
avg_dist = th.mean(layer.dist[etype][eid])
# formula 6
if layer.last_avg_dist[etype] < avg_dist:
if layer.p[etype] - self.step_size > 0:
layer.p[etype] -= self.step_size
layer.f[etype].append(-1)
else:
if layer.p[etype] + self.step_size <= 1:
layer.p[etype] += self.step_size
layer.f[etype].append(+1)
layer.last_avg_dist[etype] = avg_dist
# formula 7
if epoch >= 9 and abs(sum(layer.f[etype][-10:])) <= 2:
layer.cvg[etype] = True
import dgl
import torch as th
import numpy as np
import torch.nn as nn
import dgl.function as fn
def _l1_dist(edges):
# formula 2
ed = th.norm(edges.src['nd'] - edges.dst['nd'], 1, 1)
return {'ed': ed}
class CARESampler(dgl.dataloading.BlockSampler):
def __init__(self, p, dists, num_layers):
super().__init__(num_layers)
self.p = p
self.dists = dists
def sample_frontier(self, block_id, g, seed_nodes, *args, **kwargs):
with g.local_scope():
new_edges_masks = {}
for etype in g.canonical_etypes:
edge_mask = th.zeros(g.number_of_edges(etype))
# extract each node from dict because of single node type
for node in seed_nodes:
edges = g.in_edges(node, form='eid', etype=etype)
num_neigh = int(g.in_degrees(node, etype=etype) * self.p[block_id][etype])
neigh_dist = self.dists[block_id][etype][edges]
neigh_index = np.argpartition(neigh_dist.cpu().detach(), num_neigh)[:num_neigh]
edge_mask[edges[neigh_index]] = 1
new_edges_masks[etype] = edge_mask.bool()
return dgl.edge_subgraph(g, new_edges_masks, relabel_nodes=False)
def __len__(self):
return self.num_layers
class CAREConv(nn.Module):
"""One layer of CARE-GNN."""
def __init__(self, in_dim, out_dim, num_classes, edges, activation=None, step_size=0.02):
super(CAREConv, self).__init__()
self.activation = activation
self.step_size = step_size
self.in_dim = in_dim
self.out_dim = out_dim
self.num_classes = num_classes
self.edges = edges
self.linear = nn.Linear(self.in_dim, self.out_dim)
self.MLP = nn.Linear(self.in_dim, self.num_classes)
self.p = {}
self.last_avg_dist = {}
self.f = {}
self.cvg = {}
for etype in edges:
self.p[etype] = 0.5
self.last_avg_dist[etype] = 0
self.f[etype] = []
self.cvg[etype] = False
def forward(self, g, feat):
g.srcdata['h'] = feat
# formula 8
hr = {}
for etype in g.canonical_etypes:
g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'hr'), etype=etype)
hr[etype] = g.dstdata['hr']
if self.activation is not None:
hr[etype] = self.activation(hr[etype])
# formula 9 using mean as inter-relation aggregator
p_tensor = th.Tensor(list(self.p.values())).view(-1, 1, 1).to(feat.device)
h_homo = th.sum(th.stack(list(hr.values())) * p_tensor, dim=0)
h_homo += feat[:g.number_of_dst_nodes()]
if self.activation is not None:
h_homo = self.activation(h_homo)
return self.linear(h_homo)
class CAREGNN(nn.Module):
def __init__(self,
in_dim,
num_classes,
hid_dim=64,
edges=None,
num_layers=2,
activation=None,
step_size=0.02):
super(CAREGNN, self).__init__()
self.in_dim = in_dim
self.hid_dim = hid_dim
self.num_classes = num_classes
self.edges = edges
self.num_layers = num_layers
self.activation = activation
self.step_size = step_size
self.layers = nn.ModuleList()
if self.num_layers == 1:
# Single layer
self.layers.append(CAREConv(self.in_dim,
self.num_classes,
self.num_classes,
self.edges,
activation=self.activation,
step_size=self.step_size))
else:
# Input layer
self.layers.append(CAREConv(self.in_dim,
self.hid_dim,
self.num_classes,
self.edges,
activation=self.activation,
step_size=self.step_size))
# Hidden layers with n - 2 layers
for i in range(self.num_layers - 2):
self.layers.append(CAREConv(self.hid_dim,
self.hid_dim,
self.num_classes,
self.edges,
activation=self.activation,
step_size=self.step_size))
# Output layer
self.layers.append(CAREConv(self.hid_dim,
self.num_classes,
self.num_classes,
self.edges,
activation=self.activation,
step_size=self.step_size))
def forward(self, blocks, feat):
# formula 4
sim = th.tanh(self.layers[0].MLP(blocks[-1].dstdata['feature'].float()))
# Forward of n layers of CARE-GNN
for block, layer in zip(blocks, self.layers):
feat = layer(block, feat)
return feat, sim
def RLModule(self, graph, epoch, idx, dists):
for i, layer in enumerate(self.layers):
for etype in self.edges:
if not layer.cvg:
# formula 5
eid = graph.in_edges(idx, form='eid', etype=etype)
avg_dist = th.mean(dists[i][etype][eid])
# formula 6
if layer.last_avg_dist[etype] < avg_dist:
layer.p[etype] -= self.step_size
layer.f.append(-1)
else:
layer.p[etype] += self.step_size
layer.f.append(+1)
# formula 7
if epoch >= 10 and sum(layer.f[-10:]) <= 2:
layer.cvg = True
"""
From GAT utils
"""
import torch
class EarlyStopping:
def __init__(self, patience=10):
self.patience = patience
self.counter = 0
self.best_score = None
self.early_stop = False
def step(self, acc, model):
score = acc
if self.best_score is None:
self.best_score = score
self.save_checkpoint(model)
elif score < self.best_score:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(model)
self.counter = 0
return self.early_stop
def save_checkpoint(self, model):
'''Saves model when validation loss decrease.'''
torch.save(model.state_dict(), 'es_checkpoint.pt')
......@@ -25,6 +25,8 @@ class FakeNewsDataset(DGLBuiltinDataset):
- profile: the 10-dimensional node feature composed of ten Twitter user profile attributes.
- spacy: the 300-dimensional node feature composed of Twitter user historical tweets encoded by the spaCy word2vec encoder.
Reference: <https://github.com/safe-graph/GNN-FakeNews>
Note: this dataset is for academic use only, and commercial use is prohibited.
Statistics:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment