Unverified Commit 7c771d0d authored by Yuchen's avatar Yuchen Committed by GitHub
Browse files

[BugFix] fix #3429 and update results of caregnn (#3441)



* squeeze node labels in FraudDataset

* fix RLModule

* update results in README.md

* fix KeyError in full graph training
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent ea8b93f9
...@@ -12,7 +12,7 @@ Dependencies ...@@ -12,7 +12,7 @@ Dependencies
---------------------- ----------------------
- Python 3.7.10 - Python 3.7.10
- PyTorch 1.8.1 - PyTorch 1.8.1
- dgl 0.7.0 - dgl 0.7.1
- scikit-learn 0.23.2 - scikit-learn 0.23.2
Dataset Dataset
...@@ -48,9 +48,9 @@ The datasets used for node classification are DGL's built-in FraudDataset. The s ...@@ -48,9 +48,9 @@ The datasets used for node classification are DGL's built-in FraudDataset. The s
How to run How to run
-------------------------------- --------------------------------
To run the full graph version, in the care-gnn folder, run To run the full graph version and use early stopping, in the care-gnn folder, run
``` ```
python main.py python main.py --early-stop
``` ```
If want to use a GPU, run If want to use a GPU, run
...@@ -70,50 +70,57 @@ python main_sampling.py ...@@ -70,50 +70,57 @@ python main_sampling.py
Performance Performance
------------------------- -------------------------
The result reported by the paper is the best validation results within 30 epochs, while ours are testing results after the max epoch specified in the table. Early stopping with patience value of 100 is applied. The result reported by the paper is the best validation results within 30 epochs, and the table below reports the val and test results (same setting in the paper except for the random seed, here `seed=717`).
<table> <table>
<tr> <thead>
<th colspan="2">Dataset</th> <tr>
<th>Amazon</th> <th colspan="2">Dataset</th>
<th>Yelp</th> <th>Amazon</th>
</tr > <th>Yelp</th>
<tr> </tr>
<td>Metric</td> </thead>
<td>Max Epoch</td> <tbody>
<td>30 / 1000</td> <tr>
<td>30 / 1000</td> <td>Metric (val / test)</td>
</tr> <td>Max Epoch</td>
<tr > <td>30</td>
<td rowspan="3">AUC</td> <td>30 </td>
<td>paper reported</td> </tr>
<td>89.73 / -</td> <tr>
<td>75.70 / -</td> <td rowspan="3">AUC (val/test)</td>
</tr> <td>paper reported</td>
<tr> <td>0.8973 / -</td>
<td>DGL full graph</td> <td>0.7570 / -</td>
<td>89.50 / 92.35</td> </tr>
<td>69.16 / 79.91</td> <tr>
</tr> <td>DGL full graph</td>
<tr> <td>0.8849 / 0.8922</td>
<td>DGL sampling</td> <td>0.6856 / 0.6867</td>
<td>93.27 / 92.94</td> </tr>
<td>79.38 / 80.53</td> <tr>
</tr> <td>DGL sampling</td>
<tr > <td>0.9350 / 0.9331</td>
<td rowspan="3">Recall</td> <td>0.7857 / 0.7890</td>
<td>paper reported</td> </tr>
<td>88.48 / -</td> <tr>
<td>71.92 / -</td> <td rowspan="3">Recall (val/test)</td>
</tr> <td>paper reported</td>
<tr> <td>0.8848 / -</td>
<td>DGL full graph</td> <td>0.7192 / -</td>
<td>85.54 / 84.47</td> </tr>
<td>69.91 / 73.47</td> <tr>
</tr> <td>DGL full graph</td>
<tr> <td>0.8615 / 0.8544</td>
<td>DGL sampling</td> <td>0.6667/ 0.6619</td>
<td>85.83 / 87.46</td> </tr>
<td>77.26 / 64.34</td> <tr>
</tr> <td>DGL sampling</td>
<td>0.9130 / 0.9045</td>
<td>0.7537 / 0.7540</td>
</tr>
</tbody>
</table> </table>
...@@ -3,9 +3,10 @@ import argparse ...@@ -3,9 +3,10 @@ import argparse
import torch as th import torch as th
from model import CAREGNN from model import CAREGNN
import torch.optim as optim import torch.optim as optim
from utils import EarlyStopping
from sklearn.metrics import recall_score, roc_auc_score from sklearn.metrics import recall_score, roc_auc_score
from utils import EarlyStopping
def main(args): def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= # # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
...@@ -21,10 +22,10 @@ def main(args): ...@@ -21,10 +22,10 @@ def main(args):
device = 'cpu' device = 'cpu'
# retrieve labels of ground truth # retrieve labels of ground truth
labels = graph.ndata['label'].to(device).squeeze().long() labels = graph.ndata['label'].to(device)
# Extract node features # Extract node features
feat = graph.ndata['feature'].to(device).float() feat = graph.ndata['feature'].to(device)
# retrieve masks for train/validation/test # retrieve masks for train/validation/test
train_mask = graph.ndata['train_mask'] train_mask = graph.ndata['train_mask']
...@@ -121,7 +122,7 @@ if __name__ == '__main__': ...@@ -121,7 +122,7 @@ if __name__ == '__main__':
parser.add_argument("--weight_decay", type=float, default=0.001, help="Weight decay. Default: 0.001") parser.add_argument("--weight_decay", type=float, default=0.001, help="Weight decay. Default: 0.001")
parser.add_argument("--step_size", type=float, default=0.02, help="RL action step size (lambda 2). Default: 0.02") parser.add_argument("--step_size", type=float, default=0.02, help="RL action step size (lambda 2). Default: 0.02")
parser.add_argument("--sim_weight", type=float, default=2, help="Similarity loss weight (lambda 1). Default: 2") parser.add_argument("--sim_weight", type=float, default=2, help="Similarity loss weight (lambda 1). Default: 2")
parser.add_argument('--early-stop', action='store_true', default=True, help="indicates whether to use early stop") parser.add_argument('--early-stop', action='store_true', default=False, help="indicates whether to use early stop")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -2,9 +2,10 @@ import dgl ...@@ -2,9 +2,10 @@ import dgl
import argparse import argparse
import torch as th import torch as th
import torch.optim as optim import torch.optim as optim
from sklearn.metrics import roc_auc_score, recall_score
from utils import EarlyStopping from utils import EarlyStopping
from model_sampling import CAREGNN, CARESampler, _l1_dist from model_sampling import CAREGNN, CARESampler, _l1_dist
from sklearn.metrics import roc_auc_score, recall_score
def evaluate(model, loss_fn, dataloader, device='cpu'): def evaluate(model, loss_fn, dataloader, device='cpu'):
...@@ -14,8 +15,8 @@ def evaluate(model, loss_fn, dataloader, device='cpu'): ...@@ -14,8 +15,8 @@ def evaluate(model, loss_fn, dataloader, device='cpu'):
num_blocks = 0 num_blocks = 0
for input_nodes, output_nodes, blocks in dataloader: for input_nodes, output_nodes, blocks in dataloader:
blocks = [b.to(device) for b in blocks] blocks = [b.to(device) for b in blocks]
feature = blocks[0].srcdata['feature'].float() feature = blocks[0].srcdata['feature']
label = blocks[-1].dstdata['label'].squeeze().long() label = blocks[-1].dstdata['label']
logits_gnn, logits_sim = model(blocks, feature) logits_gnn, logits_sim = model(blocks, feature)
# compute loss # compute loss
...@@ -42,10 +43,10 @@ def main(args): ...@@ -42,10 +43,10 @@ def main(args):
device = 'cpu' device = 'cpu'
# retrieve labels of ground truth # retrieve labels of ground truth
labels = graph.ndata['label'].to(device).bool() labels = graph.ndata['label'].to(device)
# Extract node features # Extract node features
feat = graph.ndata['feature'].to(device).float() feat = graph.ndata['feature'].to(device)
layers_feat = feat.expand(args.num_layers, -1, -1) layers_feat = feat.expand(args.num_layers, -1, -1)
# retrieve masks for train/validation/test # retrieve masks for train/validation/test
...@@ -58,7 +59,7 @@ def main(args): ...@@ -58,7 +59,7 @@ def main(args):
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze(1).to(device) test_idx = th.nonzero(test_mask, as_tuple=False).squeeze(1).to(device)
# Reinforcement learning module only for positive training nodes # Reinforcement learning module only for positive training nodes
rl_idx = th.nonzero(train_mask.to(device) & labels, as_tuple=False).squeeze(1) rl_idx = th.nonzero(train_mask.to(device) & labels.bool(), as_tuple=False).squeeze(1)
graph = graph.to(device) graph = graph.to(device)
...@@ -112,8 +113,8 @@ def main(args): ...@@ -112,8 +113,8 @@ def main(args):
for input_nodes, output_nodes, blocks in train_dataloader: for input_nodes, output_nodes, blocks in train_dataloader:
blocks = [b.to(device) for b in blocks] blocks = [b.to(device) for b in blocks]
train_feature = blocks[0].srcdata['feature'].float() train_feature = blocks[0].srcdata['feature']
train_label = blocks[-1].dstdata['label'].squeeze().long() train_label = blocks[-1].dstdata['label']
logits_gnn, logits_sim = model(blocks, train_feature) logits_gnn, logits_sim = model(blocks, train_feature)
# compute loss # compute loss
...@@ -184,8 +185,9 @@ if __name__ == '__main__': ...@@ -184,8 +185,9 @@ if __name__ == '__main__':
parser.add_argument("--step_size", type=float, default=0.02, help="RL action step size (lambda 2). Default: 0.02") parser.add_argument("--step_size", type=float, default=0.02, help="RL action step size (lambda 2). Default: 0.02")
parser.add_argument("--sim_weight", type=float, default=2, help="Similarity loss weight (lambda 1). Default: 0.001") parser.add_argument("--sim_weight", type=float, default=2, help="Similarity loss weight (lambda 1). Default: 0.001")
parser.add_argument("--num_workers", type=int, default=4, help="Number of node dataloader") parser.add_argument("--num_workers", type=int, default=4, help="Number of node dataloader")
parser.add_argument('--early-stop', action='store_true', default=True, help="indicates whether to use early stop") parser.add_argument('--early-stop', action='store_true', default=False, help="indicates whether to use early stop")
args = parser.parse_args() args = parser.parse_args()
th.manual_seed(717)
print(args) print(args)
main(args) main(args)
...@@ -43,9 +43,12 @@ class CAREConv(nn.Module): ...@@ -43,9 +43,12 @@ class CAREConv(nn.Module):
neigh_list = [] neigh_list = []
for node in g.nodes(): for node in g.nodes():
edges = g.in_edges(node, form='eid') edges = g.in_edges(node, form='eid')
num_neigh = int(g.in_degrees(node) * p) num_neigh = th.ceil(g.in_degrees(node) * p).int().item()
neigh_dist = dist[edges] neigh_dist = dist[edges]
neigh_index = np.argpartition(neigh_dist.cpu().detach(), num_neigh)[:num_neigh] if neigh_dist.shape[0] > num_neigh:
neigh_index = np.argpartition(neigh_dist.cpu().detach(), num_neigh)[:num_neigh]
else:
neigh_index = np.arange(num_neigh)
neigh_list.append(edges[neigh_index]) neigh_list.append(edges[neigh_index])
return th.cat(neigh_list) return th.cat(neigh_list)
......
...@@ -25,9 +25,12 @@ class CARESampler(dgl.dataloading.BlockSampler): ...@@ -25,9 +25,12 @@ class CARESampler(dgl.dataloading.BlockSampler):
# extract each node from dict because of single node type # extract each node from dict because of single node type
for node in seed_nodes: for node in seed_nodes:
edges = g.in_edges(node, form='eid', etype=etype) edges = g.in_edges(node, form='eid', etype=etype)
num_neigh = int(g.in_degrees(node, etype=etype) * self.p[block_id][etype]) num_neigh = th.ceil(g.in_degrees(node, etype=etype) * self.p[block_id][etype]).int().item()
neigh_dist = self.dists[block_id][etype][edges] neigh_dist = self.dists[block_id][etype][edges]
neigh_index = np.argpartition(neigh_dist.cpu().detach(), num_neigh)[:num_neigh] if neigh_dist.shape[0] > num_neigh:
neigh_index = np.argpartition(neigh_dist.cpu().detach(), num_neigh)[:num_neigh]
else:
neigh_index = np.arange(num_neigh)
edge_mask[edges[neigh_index]] = 1 edge_mask[edges[neigh_index]] = 1
new_edges_masks[etype] = edge_mask.bool() new_edges_masks[etype] = edge_mask.bool()
...@@ -56,6 +59,7 @@ class CAREConv(nn.Module): ...@@ -56,6 +59,7 @@ class CAREConv(nn.Module):
self.p = {} self.p = {}
self.last_avg_dist = {} self.last_avg_dist = {}
self.f = {} self.f = {}
# indicate whether the RL converges
self.cvg = {} self.cvg = {}
for etype in edges: for etype in edges:
self.p[etype] = 0.5 self.p[etype] = 0.5
...@@ -151,7 +155,7 @@ class CAREGNN(nn.Module): ...@@ -151,7 +155,7 @@ class CAREGNN(nn.Module):
def RLModule(self, graph, epoch, idx, dists): def RLModule(self, graph, epoch, idx, dists):
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
for etype in self.edges: for etype in self.edges:
if not layer.cvg: if not layer.cvg[etype]:
# formula 5 # formula 5
eid = graph.in_edges(idx, form='eid', etype=etype) eid = graph.in_edges(idx, form='eid', etype=etype)
avg_dist = th.mean(dists[i][etype][eid]) avg_dist = th.mean(dists[i][etype][eid])
...@@ -159,11 +163,17 @@ class CAREGNN(nn.Module): ...@@ -159,11 +163,17 @@ class CAREGNN(nn.Module):
# formula 6 # formula 6
if layer.last_avg_dist[etype] < avg_dist: if layer.last_avg_dist[etype] < avg_dist:
layer.p[etype] -= self.step_size layer.p[etype] -= self.step_size
layer.f.append(-1) layer.f[etype].append(-1)
# avoid overflow, follow the author's implement
if layer.p[etype] < 0:
layer.p[etype] = 0.001
else: else:
layer.p[etype] += self.step_size layer.p[etype] += self.step_size
layer.f.append(+1) layer.f[etype].append(+1)
if layer.p[etype] > 1:
layer.p[etype] = 0.999
layer.last_avg_dist[etype] = avg_dist
# formula 7 # formula 7
if epoch >= 10 and sum(layer.f[-10:]) <= 2: if epoch >= 9 and abs(sum(layer.f[etype][-10:])) <= 2:
layer.cvg = True layer.cvg[etype] = True
...@@ -109,7 +109,8 @@ class FraudDataset(DGLBuiltinDataset): ...@@ -109,7 +109,8 @@ class FraudDataset(DGLBuiltinDataset):
data = io.loadmat(file_path) data = io.loadmat(file_path)
node_features = data['features'].todense() node_features = data['features'].todense()
node_labels = data['label'] # remove additional dimension of length 1 in raw .mat file
node_labels = data['label'].squeeze()
graph_data = {} graph_data = {}
for relation in self.relations[self.name]: for relation in self.relations[self.name]:
...@@ -118,8 +119,8 @@ class FraudDataset(DGLBuiltinDataset): ...@@ -118,8 +119,8 @@ class FraudDataset(DGLBuiltinDataset):
graph_data[(self.node_name[self.name], relation, self.node_name[self.name])] = (row, col) graph_data[(self.node_name[self.name], relation, self.node_name[self.name])] = (row, col)
g = heterograph(graph_data) g = heterograph(graph_data)
g.ndata['feature'] = F.tensor(node_features) g.ndata['feature'] = F.tensor(node_features, dtype=F.data_type_dict['float32'])
g.ndata['label'] = F.tensor(node_labels.T) g.ndata['label'] = F.tensor(node_labels, dtype=F.data_type_dict['int64'])
self.graph = g self.graph = g
self._random_split(g.ndata['feature'], self.seed, self.train_size, self.val_size) self._random_split(g.ndata['feature'], self.seed, self.train_size, self.val_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment