main.py 4.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.nn as nn
from classify import evaluate_embeds
from label_utils import remove_unseen_classes_from_training, get_labeled_nodes_label_attribute
from utils import load_data, svd_feature, process_classids
from model import GCN, RECT_L

def main(args):
    g, features, labels, train_mask, test_mask, n_classes, cuda= load_data(args)
    # adopt any number of classes as the unseen classes (the first three classes by default)
    removed_class=args.removed_class
    if(len(removed_class)>n_classes):
        raise ValueError('unseen number is greater than the number of classes: {}'.format(len(removed_class)))
    for i in removed_class:
        if i not in labels:
            raise ValueError('class out of bounds: {}'.format(i))

    # remove these unseen classes from the training set, to construct the zero-shot label setting
    train_mask_zs = remove_unseen_classes_from_training(train_mask=train_mask, labels=labels, removed_class=removed_class)
    print('after removing the unseen classes, seen class labeled node num:', sum(train_mask_zs).item())
    
    if args.model_opt == 'RECT-L':
        model = RECT_L(g=g, in_feats=args.n_hidden, n_hidden=args.n_hidden, activation=nn.PReLU())

        if cuda:
            model.cuda()
        features = svd_feature(features=features, d=args.n_hidden)
        attribute_labels = get_labeled_nodes_label_attribute(train_mask_zs=train_mask_zs, labels=labels, features=features, cuda=cuda)
        loss_fcn = nn.MSELoss(reduction='sum')
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        
        for epoch in range(args.n_epochs):
            model.train()
            optimizer.zero_grad()
            logits = model(features)
            loss_train = loss_fcn(attribute_labels, logits[train_mask_zs])
            print('Epoch {:d} | Train Loss {:.5f}'.format(epoch + 1, loss_train.item()))
            loss_train.backward()
            optimizer.step()
        model.eval()
        embeds = model.embed(features)
        
    elif args.model_opt == 'GCN':
        model = GCN(g=g, in_feats=features.shape[1],
                    n_hidden=args.n_hidden, n_classes=n_classes-len(removed_class),
                    activation=nn.PReLU(), dropout=args.dropout)

        if cuda:
            model.cuda()
        loss_fcn = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    
        for epoch in range(args.n_epochs):
            model.train()
            logits = model(features)
            labels_train = process_classids(labels_temp=labels[train_mask_zs])
            loss_train = loss_fcn(logits[train_mask_zs], labels_train)
            optimizer.zero_grad()
            print('Epoch {:d} | Train Loss {:.5f}'.format(epoch + 1, loss_train.item()))
            loss_train.backward()
            optimizer.step()
        model.eval()
        embeds = model.embed(features)
        
    elif args.model_opt == 'NodeFeats':
        embeds = svd_feature(features)
        
    # evaluate the quality of embedding results with the original balanced labels, to assess the model performance (as suggested in the paper)
    res = evaluate_embeds(features=embeds, labels=labels, train_mask=train_mask, test_mask=test_mask, n_classes=n_classes, cuda=cuda)
    print("Test Accuracy of {:s}: {:.4f}".format(args.model_opt, res))
    
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='MODEL')
    parser.add_argument("--model-opt", type=str, default='RECT-L',
                        choices=['RECT-L', 'GCN', 'NodeFeats'],
                        help="model option")
    parser.add_argument("--dataset", type=str, default='cora',
                        choices=['cora', 'citeseer'],
                        help="dataset")
    parser.add_argument("--dropout", type=float, default=0.0,
                        help="dropout probability")
    parser.add_argument("--gpu", type=int, default=0,
                        help="gpu")
    parser.add_argument("--removed-class", type=int, nargs='*', default=[0, 1, 2],
                        help="remove the unseen classes")
    parser.add_argument("--lr", type=float, default=1e-3,
                        help="learning rate")
    parser.add_argument("--n-epochs", type=int, default=200,
                        help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=200,
                        help="number of hidden gcn units")
    parser.add_argument("--weight-decay", type=float, default=5e-4,
                        help="Weight for L2 loss")
    args = parser.parse_args()

    main(args)