Unverified Commit 7c771d0d authored by Yuchen's avatar Yuchen Committed by GitHub
Browse files

[BugFix] fix #3429 and update results of caregnn (#3441)



* squeeze node labels in FraudDataset

* fix RLModule

* update results in README.md

* fix KeyError in full graph training
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent ea8b93f9
......@@ -12,7 +12,7 @@ Dependencies
----------------------
- Python 3.7.10
- PyTorch 1.8.1
- dgl 0.7.0
- dgl 0.7.1
- scikit-learn 0.23.2
Dataset
......@@ -48,9 +48,9 @@ The datasets used for node classification are DGL's built-in FraudDataset. The s
How to run
--------------------------------
To run the full graph version, in the care-gnn folder, run
To run the full graph version and use early stopping, in the care-gnn folder, run
```
python main.py
python main.py --early-stop
```
If want to use a GPU, run
......@@ -70,50 +70,57 @@ python main_sampling.py
Performance
-------------------------
The result reported by the paper is the best validation results within 30 epochs, while ours are testing results after the max epoch specified in the table. Early stopping with patience value of 100 is applied.
The result reported by the paper is the best validation results within 30 epochs, and the table below reports the val and test results (same setting in the paper except for the random seed, here `seed=717`).
<table>
<thead>
<tr>
<th colspan="2">Dataset</th>
<th>Amazon</th>
<th>Yelp</th>
</tr >
</tr>
</thead>
<tbody>
<tr>
<td>Metric</td>
<td>Metric (val / test)</td>
<td>Max Epoch</td>
<td>30 / 1000</td>
<td>30 / 1000</td>
<td>30</td>
<td>30 </td>
</tr>
<tr >
<td rowspan="3">AUC</td>
<tr>
<td rowspan="3">AUC (val/test)</td>
<td>paper reported</td>
<td>89.73 / -</td>
<td>75.70 / -</td>
<td>0.8973 / -</td>
<td>0.7570 / -</td>
</tr>
<tr>
<td>DGL full graph</td>
<td>89.50 / 92.35</td>
<td>69.16 / 79.91</td>
<td>0.8849 / 0.8922</td>
<td>0.6856 / 0.6867</td>
</tr>
<tr>
<td>DGL sampling</td>
<td>93.27 / 92.94</td>
<td>79.38 / 80.53</td>
<td>0.9350 / 0.9331</td>
<td>0.7857 / 0.7890</td>
</tr>
<tr >
<td rowspan="3">Recall</td>
<tr>
<td rowspan="3">Recall (val/test)</td>
<td>paper reported</td>
<td>88.48 / -</td>
<td>71.92 / -</td>
<td>0.8848 / -</td>
<td>0.7192 / -</td>
</tr>
<tr>
<td>DGL full graph</td>
<td>85.54 / 84.47</td>
<td>69.91 / 73.47</td>
<td>0.8615 / 0.8544</td>
<td>0.6667/ 0.6619</td>
</tr>
<tr>
<td>DGL sampling</td>
<td>85.83 / 87.46</td>
<td>77.26 / 64.34</td>
<td>0.9130 / 0.9045</td>
<td>0.7537 / 0.7540</td>
</tr>
</tbody>
</table>
......@@ -3,9 +3,10 @@ import argparse
import torch as th
from model import CAREGNN
import torch.optim as optim
from utils import EarlyStopping
from sklearn.metrics import recall_score, roc_auc_score
from utils import EarlyStopping
def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
......@@ -21,10 +22,10 @@ def main(args):
device = 'cpu'
# retrieve labels of ground truth
labels = graph.ndata['label'].to(device).squeeze().long()
labels = graph.ndata['label'].to(device)
# Extract node features
feat = graph.ndata['feature'].to(device).float()
feat = graph.ndata['feature'].to(device)
# retrieve masks for train/validation/test
train_mask = graph.ndata['train_mask']
......@@ -121,7 +122,7 @@ if __name__ == '__main__':
parser.add_argument("--weight_decay", type=float, default=0.001, help="Weight decay. Default: 0.001")
parser.add_argument("--step_size", type=float, default=0.02, help="RL action step size (lambda 2). Default: 0.02")
parser.add_argument("--sim_weight", type=float, default=2, help="Similarity loss weight (lambda 1). Default: 2")
parser.add_argument('--early-stop', action='store_true', default=True, help="indicates whether to use early stop")
parser.add_argument('--early-stop', action='store_true', default=False, help="indicates whether to use early stop")
args = parser.parse_args()
print(args)
......
......@@ -2,9 +2,10 @@ import dgl
import argparse
import torch as th
import torch.optim as optim
from sklearn.metrics import roc_auc_score, recall_score
from utils import EarlyStopping
from model_sampling import CAREGNN, CARESampler, _l1_dist
from sklearn.metrics import roc_auc_score, recall_score
def evaluate(model, loss_fn, dataloader, device='cpu'):
......@@ -14,8 +15,8 @@ def evaluate(model, loss_fn, dataloader, device='cpu'):
num_blocks = 0
for input_nodes, output_nodes, blocks in dataloader:
blocks = [b.to(device) for b in blocks]
feature = blocks[0].srcdata['feature'].float()
label = blocks[-1].dstdata['label'].squeeze().long()
feature = blocks[0].srcdata['feature']
label = blocks[-1].dstdata['label']
logits_gnn, logits_sim = model(blocks, feature)
# compute loss
......@@ -42,10 +43,10 @@ def main(args):
device = 'cpu'
# retrieve labels of ground truth
labels = graph.ndata['label'].to(device).bool()
labels = graph.ndata['label'].to(device)
# Extract node features
feat = graph.ndata['feature'].to(device).float()
feat = graph.ndata['feature'].to(device)
layers_feat = feat.expand(args.num_layers, -1, -1)
# retrieve masks for train/validation/test
......@@ -58,7 +59,7 @@ def main(args):
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze(1).to(device)
# Reinforcement learning module only for positive training nodes
rl_idx = th.nonzero(train_mask.to(device) & labels, as_tuple=False).squeeze(1)
rl_idx = th.nonzero(train_mask.to(device) & labels.bool(), as_tuple=False).squeeze(1)
graph = graph.to(device)
......@@ -112,8 +113,8 @@ def main(args):
for input_nodes, output_nodes, blocks in train_dataloader:
blocks = [b.to(device) for b in blocks]
train_feature = blocks[0].srcdata['feature'].float()
train_label = blocks[-1].dstdata['label'].squeeze().long()
train_feature = blocks[0].srcdata['feature']
train_label = blocks[-1].dstdata['label']
logits_gnn, logits_sim = model(blocks, train_feature)
# compute loss
......@@ -184,8 +185,9 @@ if __name__ == '__main__':
parser.add_argument("--step_size", type=float, default=0.02, help="RL action step size (lambda 2). Default: 0.02")
parser.add_argument("--sim_weight", type=float, default=2, help="Similarity loss weight (lambda 1). Default: 0.001")
parser.add_argument("--num_workers", type=int, default=4, help="Number of node dataloader")
parser.add_argument('--early-stop', action='store_true', default=True, help="indicates whether to use early stop")
parser.add_argument('--early-stop', action='store_true', default=False, help="indicates whether to use early stop")
args = parser.parse_args()
th.manual_seed(717)
print(args)
main(args)
......@@ -43,9 +43,12 @@ class CAREConv(nn.Module):
neigh_list = []
for node in g.nodes():
edges = g.in_edges(node, form='eid')
num_neigh = int(g.in_degrees(node) * p)
num_neigh = th.ceil(g.in_degrees(node) * p).int().item()
neigh_dist = dist[edges]
if neigh_dist.shape[0] > num_neigh:
neigh_index = np.argpartition(neigh_dist.cpu().detach(), num_neigh)[:num_neigh]
else:
neigh_index = np.arange(num_neigh)
neigh_list.append(edges[neigh_index])
return th.cat(neigh_list)
......
......@@ -25,9 +25,12 @@ class CARESampler(dgl.dataloading.BlockSampler):
# extract each node from dict because of single node type
for node in seed_nodes:
edges = g.in_edges(node, form='eid', etype=etype)
num_neigh = int(g.in_degrees(node, etype=etype) * self.p[block_id][etype])
num_neigh = th.ceil(g.in_degrees(node, etype=etype) * self.p[block_id][etype]).int().item()
neigh_dist = self.dists[block_id][etype][edges]
if neigh_dist.shape[0] > num_neigh:
neigh_index = np.argpartition(neigh_dist.cpu().detach(), num_neigh)[:num_neigh]
else:
neigh_index = np.arange(num_neigh)
edge_mask[edges[neigh_index]] = 1
new_edges_masks[etype] = edge_mask.bool()
......@@ -56,6 +59,7 @@ class CAREConv(nn.Module):
self.p = {}
self.last_avg_dist = {}
self.f = {}
# indicate whether the RL converges
self.cvg = {}
for etype in edges:
self.p[etype] = 0.5
......@@ -151,7 +155,7 @@ class CAREGNN(nn.Module):
def RLModule(self, graph, epoch, idx, dists):
for i, layer in enumerate(self.layers):
for etype in self.edges:
if not layer.cvg:
if not layer.cvg[etype]:
# formula 5
eid = graph.in_edges(idx, form='eid', etype=etype)
avg_dist = th.mean(dists[i][etype][eid])
......@@ -159,11 +163,17 @@ class CAREGNN(nn.Module):
# formula 6
if layer.last_avg_dist[etype] < avg_dist:
layer.p[etype] -= self.step_size
layer.f.append(-1)
layer.f[etype].append(-1)
# avoid overflow, follow the author's implement
if layer.p[etype] < 0:
layer.p[etype] = 0.001
else:
layer.p[etype] += self.step_size
layer.f.append(+1)
layer.f[etype].append(+1)
if layer.p[etype] > 1:
layer.p[etype] = 0.999
layer.last_avg_dist[etype] = avg_dist
# formula 7
if epoch >= 10 and sum(layer.f[-10:]) <= 2:
layer.cvg = True
if epoch >= 9 and abs(sum(layer.f[etype][-10:])) <= 2:
layer.cvg[etype] = True
......@@ -109,7 +109,8 @@ class FraudDataset(DGLBuiltinDataset):
data = io.loadmat(file_path)
node_features = data['features'].todense()
node_labels = data['label']
# remove additional dimension of length 1 in raw .mat file
node_labels = data['label'].squeeze()
graph_data = {}
for relation in self.relations[self.name]:
......@@ -118,8 +119,8 @@ class FraudDataset(DGLBuiltinDataset):
graph_data[(self.node_name[self.name], relation, self.node_name[self.name])] = (row, col)
g = heterograph(graph_data)
g.ndata['feature'] = F.tensor(node_features)
g.ndata['label'] = F.tensor(node_labels.T)
g.ndata['feature'] = F.tensor(node_features, dtype=F.data_type_dict['float32'])
g.ndata['label'] = F.tensor(node_labels, dtype=F.data_type_dict['int64'])
self.graph = g
self._random_split(g.ndata['feature'], self.seed, self.train_size, self.val_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment