"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "f0c7efa9a44d137f202e6f14263800d630e641a0"
Unverified Commit e58eeebf authored by Tingzhang Zhao's avatar Tingzhang Zhao Committed by GitHub
Browse files

[Model] Add RECT example (#2813)



* Update README.md

Add description of RECT

* [Example] Add implementation of RECT

[Example] Add implementation of RECT

* Update classify.py

Modify the class names and the function names mentioned above

* Update main.py

Modify the function names mentioned above

* Update label_utils.py

Adjust the comments

* Update README.md

Add the github information
Co-authored-by: default avatarTianjun Xiao <xiaotj1990327@gmail.com>
parent bbebde46
...@@ -8,6 +8,7 @@ The folder contains example implementations of selected research papers related ...@@ -8,6 +8,7 @@ The folder contains example implementations of selected research papers related
| Paper | node classification | link prediction / classification | graph property prediction | sampling | OGB | | Paper | node classification | link prediction / classification | graph property prediction | sampling | OGB |
| ------------------------------------------------------------ | ------------------- | -------------------------------- | ------------------------- | ------------------ | ------------------ | | ------------------------------------------------------------ | ------------------- | -------------------------------- | ------------------------- | ------------------ | ------------------ |
| [Network Embedding with Completely-imbalanced Labels](#rect) | :heavy_check_mark: | | | | |
| [Boost then Convolve: Gradient Boosting Meets Graph Neural Networks](#bgnn) | :heavy_check_mark: | | | | | | [Boost then Convolve: Gradient Boosting Meets Graph Neural Networks](#bgnn) | :heavy_check_mark: | | | | |
| [Contrastive Multi-View Representation Learning on Graphs](#mvgrl) | :heavy_check_mark: | | :heavy_check_mark: | | | | [Contrastive Multi-View Representation Learning on Graphs](#mvgrl) | :heavy_check_mark: | | :heavy_check_mark: | | |
| [Graph Random Neural Network for Semi-Supervised Learning on Graphs](#grand) | :heavy_check_mark: | | | | | | [Graph Random Neural Network for Semi-Supervised Learning on Graphs](#grand) | :heavy_check_mark: | | | | |
...@@ -99,6 +100,9 @@ The folder contains example implementations of selected research papers related ...@@ -99,6 +100,9 @@ The folder contains example implementations of selected research papers related
## 2020 ## 2020
- <a name="rect"></a> Wang et al. Network Embedding with Completely-imbalanced Labels. [Paper link](https://ieeexplore.ieee.org/document/8979355).
- Example code: [PyTorch](../examples/pytorch/rect)
- Tags: node classification, network embedding, completely-imbalanced labels
- <a name="mvgrl"></a> Hassani and Khasahmadi. Contrastive Multi-View Representation Learning on Graphs. [Paper link](https://arxiv.org/abs/2006.05582). - <a name="mvgrl"></a> Hassani and Khasahmadi. Contrastive Multi-View Representation Learning on Graphs. [Paper link](https://arxiv.org/abs/2006.05582).
- Example code: [PyTorch](../examples/pytorch/mvgrl) - Example code: [PyTorch](../examples/pytorch/mvgrl)
- Tags: graph diffusion, self-supervised learning on graphs. - Tags: graph diffusion, self-supervised learning on graphs.
......
# **DGL Implementation of RECT (TKDE20)**
This DGL example implements the GNN model **RECT** (or more specifically its supervised part **RECT-L**) proposed in the paper [Network Embedding with Completely-imbalanced Labels](https://ieeexplore.ieee.org/document/8979355). The authors' original implementation can be found [here](https://github.com/zhengwang100/RECT).
## Example Implementor
This example was implemented by [Tingzhang Zhao](https://github.com/Fizyhsp) when he was an undergraduate at USTB.
## **Dataset and experimental setting**
Two DGL's build-in datasets (Cora and Citeseer) with their default train/val/test settings are used in this example. In addition, as this paper considers the zero-shot (i.e., completely-imbalanced) label setting, those "unseen" classes should be removed from the training set, as suggested in the paper. In this example, in each dataset, we simply remove the 2-3 classes (i.e., these 2-3 classes are unseen classes) from the labeled training set. Then, we obtain graph embedding results by different models. Finally, with the obtained embedding results and the original balanced labels, we train a logistic regression classifier to evaluate the model performance.
## **Usage**
`python main.py --dataset cora --gpu 0 --model-opt RECT-L --removed-class 0 1 2` #reproducing the RECT-L on "cora" datasets in the zero-shot label setting using GPU
`python main.py --dataset cora --gpu 0 --model-opt GCN --removed-class 0 1 2` #reproducing the GCN on "cora" datasets in the zero-shot label setting using GPU
`python main.py --dataset cora --gpu 0 --model-opt NodeFeats --removed-class 0 1 2` # evaluating the original node features using GPU
## **Performance**
The performance results are are as follows:
| **Datasets/Models** | **NodeFeats** | **GCN** | **RECT-L** |
| :-----------------: | :-----------: | :-----: | :--------: |
| **Cora** | 47.56 | 51.26 | **68.60** |
| **Citeseer** | 42.04 | 37.55 | **56.32** |
<center>Table 1:node classification results with the first three classes as "unseen"</center>
| **Datasets/Models** | **NodeFeats** | **GCN** | **RECT-L** |
| :-----------------: | :-----------: | :-----: | :--------: |
| **Cora** | 47.56 | 56.91 | **69.30** |
| **Citeseer** | 42.04 | 45.69 | **61.85** |
<center>Table 2:node classification results with the last two classes as "unseen"</center>
import torch
import torch.nn as nn
import torch.nn.functional as F
from statistics import mean
class LogisticRegressionClassifier(nn.Module):
''' Define a logistic regression classifier to evaluate the quality of embedding results
'''
def __init__(self, nfeat, nclass):
super(LogisticRegressionClassifier, self).__init__()
self.lrc = nn.Linear(nfeat, nclass)
def forward(self, x):
preds = self.lrc(x)
return preds
def _evaluate(model, features, labels, test_mask):
model.eval()
with torch.no_grad():
logits = model(features)
logits = logits[test_mask]
labels = labels[test_mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
def _train_test_with_lrc(model, features, labels, train_mask, test_mask):
''' Under the pre-defined balanced train/test label setting, train a lrc to evaluate the embedding results. '''
optimizer = torch.optim.Adam(model.parameters(), lr=0.2, weight_decay=5e-06)
for _ in range(100):
model.train()
optimizer.zero_grad()
output = model(features)
loss_train = F.cross_entropy(output[train_mask], labels[train_mask])
loss_train.backward()
optimizer.step()
return _evaluate(model=model, features=features, labels=labels, test_mask=test_mask)
def evaluate_embeds(features, labels, train_mask, test_mask, n_classes, cuda, test_times=10):
print("Training a logistic regression classifier with the pre-defined train/test split setting ...")
res_list = []
for _ in range(test_times):
model = LogisticRegressionClassifier(nfeat=features.shape[1], nclass=n_classes)
if cuda:
model.cuda()
res = _train_test_with_lrc(model=model, features=features, labels=labels, train_mask=train_mask, test_mask=test_mask)
res_list.append(res)
return mean(res_list)
import torch
import numpy as np
from collections import defaultdict
def remove_unseen_classes_from_training(train_mask, labels, removed_class):
''' Remove the unseen classes (the first three classes by default) to get the zero-shot (i.e., completely imbalanced) label setting
Input: train_mask, labels, removed_class
Output: train_mask_zs: the bool list only containing seen classes
'''
train_mask_zs = train_mask.clone()
for i in range(train_mask_zs.numel()):
if train_mask_zs[i]==1 and (labels[i].item() in removed_class):
train_mask_zs[i]=0
return train_mask_zs
def get_class_set(labels):
''' Get the class set.
Input: labels [l, [c1, c2, ..]]
Output:the labeled class set dict_keys([k1, k2, ..])
'''
mydict = {}
for y in labels:
for label in y:
mydict[int(label)] = 1
return mydict.keys()
def get_label_attributes(train_mask_zs, nodeids, labellist, features):
''' Get the class-center (semanic knowledge) of each seen class.
Suppose a node i is labeled as c, then attribute[c] += node_i_attribute, finally mean(attribute[c])
Input: train_mask_zs, nodeids, labellist, features
Output: label_attribute{}: label -> average_labeled_node_features (class centers)
'''
_, feat_num = features.shape
labels = get_class_set(labellist)
label_attribute_nodes = defaultdict(list)
for nodeid, labels in zip(nodeids, labellist):
for label in labels:
label_attribute_nodes[int(label)].append(int(nodeid))
label_attribute = {}
for label in label_attribute_nodes.keys():
nodes = label_attribute_nodes[int(label)]
selected_features = features[nodes, :]
label_attribute[int(label)] = np.mean(selected_features, axis=0)
return label_attribute
def get_labeled_nodes_label_attribute(train_mask_zs, labels, features, cuda):
''' Replace the original labels by their class-centers.
For each label c in the training set, the following operations will be performed:
Get label_attribute{} through function get_label_attributes, then res[i, :] = label_attribute[c]
Input: train_mask_zs, labels, features
Output: Y_{semantic} [l, ft]: tensor
'''
X = torch.LongTensor(range(features.shape[0]))
nodeids = []
labellist = []
for i in X[train_mask_zs].numpy().tolist():
nodeids.append(str(i))
for i in labels[train_mask_zs].cpu().numpy().tolist():
labellist.append([str(i)])
# 1. get the semantic knowledge (class centers) of all seen classes
label_attribute = get_label_attributes(train_mask_zs=train_mask_zs, nodeids=nodeids, labellist=labellist, features=features.cpu().numpy())
# 2. replace original labels by their class centers (semantic knowledge)
res = np.zeros([len(nodeids), features.shape[1]])
for i, labels in enumerate(labellist):
# support mutiple labels
c = len(labels)
temp = np.zeros([c, features.shape[1]])
for ii, label in enumerate(labels):
temp[ii, :] = label_attribute[int(label)]
temp = np.mean(temp, axis=0)
res[i, :] = temp
if cuda:
res = torch.FloatTensor(res).cuda()
else:
res = torch.FloatTensor(res)
return res
import torch
import torch.nn as nn
from classify import evaluate_embeds
from label_utils import remove_unseen_classes_from_training, get_labeled_nodes_label_attribute
from utils import load_data, svd_feature, process_classids
from model import GCN, RECT_L
def main(args):
g, features, labels, train_mask, test_mask, n_classes, cuda= load_data(args)
# adopt any number of classes as the unseen classes (the first three classes by default)
removed_class=args.removed_class
if(len(removed_class)>n_classes):
raise ValueError('unseen number is greater than the number of classes: {}'.format(len(removed_class)))
for i in removed_class:
if i not in labels:
raise ValueError('class out of bounds: {}'.format(i))
# remove these unseen classes from the training set, to construct the zero-shot label setting
train_mask_zs = remove_unseen_classes_from_training(train_mask=train_mask, labels=labels, removed_class=removed_class)
print('after removing the unseen classes, seen class labeled node num:', sum(train_mask_zs).item())
if args.model_opt == 'RECT-L':
model = RECT_L(g=g, in_feats=args.n_hidden, n_hidden=args.n_hidden, activation=nn.PReLU())
if cuda:
model.cuda()
features = svd_feature(features=features, d=args.n_hidden)
attribute_labels = get_labeled_nodes_label_attribute(train_mask_zs=train_mask_zs, labels=labels, features=features, cuda=cuda)
loss_fcn = nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
for epoch in range(args.n_epochs):
model.train()
optimizer.zero_grad()
logits = model(features)
loss_train = loss_fcn(attribute_labels, logits[train_mask_zs])
print('Epoch {:d} | Train Loss {:.5f}'.format(epoch + 1, loss_train.item()))
loss_train.backward()
optimizer.step()
model.eval()
embeds = model.embed(features)
elif args.model_opt == 'GCN':
model = GCN(g=g, in_feats=features.shape[1],
n_hidden=args.n_hidden, n_classes=n_classes-len(removed_class),
activation=nn.PReLU(), dropout=args.dropout)
if cuda:
model.cuda()
loss_fcn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
for epoch in range(args.n_epochs):
model.train()
logits = model(features)
labels_train = process_classids(labels_temp=labels[train_mask_zs])
loss_train = loss_fcn(logits[train_mask_zs], labels_train)
optimizer.zero_grad()
print('Epoch {:d} | Train Loss {:.5f}'.format(epoch + 1, loss_train.item()))
loss_train.backward()
optimizer.step()
model.eval()
embeds = model.embed(features)
elif args.model_opt == 'NodeFeats':
embeds = svd_feature(features)
# evaluate the quality of embedding results with the original balanced labels, to assess the model performance (as suggested in the paper)
res = evaluate_embeds(features=embeds, labels=labels, train_mask=train_mask, test_mask=test_mask, n_classes=n_classes, cuda=cuda)
print("Test Accuracy of {:s}: {:.4f}".format(args.model_opt, res))
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='MODEL')
parser.add_argument("--model-opt", type=str, default='RECT-L',
choices=['RECT-L', 'GCN', 'NodeFeats'],
help="model option")
parser.add_argument("--dataset", type=str, default='cora',
choices=['cora', 'citeseer'],
help="dataset")
parser.add_argument("--dropout", type=float, default=0.0,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=0,
help="gpu")
parser.add_argument("--removed-class", type=int, nargs='*', default=[0, 1, 2],
help="remove the unseen classes")
parser.add_argument("--lr", type=float, default=1e-3,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=200,
help="number of hidden gcn units")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
args = parser.parse_args()
main(args)
import torch.nn as nn
from dgl.nn.pytorch import GraphConv
import torch.nn.functional as F
class GCN(nn.Module):
def __init__(self, g, in_feats, n_hidden, n_classes, activation, dropout):
super(GCN, self).__init__()
self.g = g
self.gcn_1 = GraphConv(in_feats, n_hidden, activation=activation)
self.gcn_2 = GraphConv(n_hidden, n_classes)
self.dropout = nn.Dropout(p=dropout)
def forward(self, features):
h = self.gcn_1(self.g, features)
h = self.dropout(h)
preds = self.gcn_2(self.g, h)
return preds
def embed(self, inputs):
h_1 = self.gcn_1(self.g, inputs)
return h_1.detach()
class RECT_L(nn.Module):
def __init__(self, g, in_feats, n_hidden, activation, dropout=0.0):
super(RECT_L, self).__init__()
self.g = g
self.gcn_1 = GraphConv(in_feats, n_hidden, activation=activation)
self.fc = nn.Linear(n_hidden, in_feats)
self.dropout = dropout
nn.init.xavier_uniform_(self.fc.weight.data)
def forward(self, inputs):
h_1 = self.gcn_1(self.g, inputs)
h_1 = F.dropout(h_1, p=self.dropout, training=self.training)
preds = self.fc(h_1)
return preds
# Detach the return variables
def embed(self, inputs):
h_1 = self.gcn_1(self.g, inputs)
return h_1.detach()
\ No newline at end of file
import torch
import dgl
from dgl.data import CoraGraphDataset, CiteseerGraphDataset
def load_data(args):
if args.dataset == 'cora':
data = CoraGraphDataset()
elif args.dataset == 'citeseer':
data = CiteseerGraphDataset()
else:
raise ValueError('Unknown dataset: {}'.format(args.dataset))
g = data[0]
if args.gpu < 0:
cuda = False
else:
cuda = True
g = g.int().to(args.gpu)
features = g.ndata['feat']
labels = g.ndata['label']
train_mask = g.ndata['train_mask']
test_mask = g.ndata['test_mask']
g = dgl.add_self_loop(g)
return g, features, labels, train_mask, test_mask, data.num_classes, cuda
def svd_feature(features, d=200):
''' Get 200-dimensional node features, to avoid curse of dimensionality
'''
if( features.shape[1] <= d ): return features
U, S, VT = torch.svd(features)
res = torch.mm(U[:, 0:d], torch.diag(S[0:d]))
return res
def process_classids(labels_temp):
''' Reorder the remaining classes with unseen classes removed.
Input: the label only removing unseen classes
Output: the label with reordered classes
'''
labeldict = {}
num=0
for i in labels_temp:
labeldict[int(i)]=1
labellist=sorted(labeldict)
for label in labellist:
labeldict[int(label)]=num
num=num+1
for i in range(labels_temp.numel()):
labels_temp[i]=labeldict[int(labels_temp[i])]
return labels_temp
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment