c_and_s.py 3.8 KB
Newer Older
Mufei Li's avatar
Mufei Li committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
[Combining Label Propagation and Simple Models Out-performs
Graph Neural Networks](https://arxiv.org/abs/2010.13993)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.data import CoraGraphDataset
from dgl.mock_sparse import create_from_coo, diag, identity
from torch.optim import Adam


###############################################################################
# (HIGHLIGHT) Compute Label Propagation with Sparse Matrix API
###############################################################################
@torch.no_grad()
def label_propagation(A_hat, label, num_layers=20, alpha=0.9):
    Y = label
    for _ in range(num_layers):
        Y = alpha * A_hat @ Y + (1 - alpha) * label
        Y = Y.clamp_(0.0, 1.0)
    return Y


def correct(A_hat, label, soft_label, mask):
    # Compute error.
    error = torch.zeros_like(soft_label)
    error[mask] = label[mask] - soft_label[mask]

    # Smooth error.
    smoothed_error = label_propagation(A_hat, error)

    # Autoscale.
    sigma = error[mask].abs()
    sigma = sigma.sum() / sigma.shape[0]
    scale = sigma / smoothed_error.abs().sum(dim=1, keepdim=True)
    scale[scale.isinf() | (scale > 1000)] = 1.0

    # Correct.
    result = soft_label + scale * smoothed_error
    return result


def smooth(A_hat, label, soft_label, mask):
    soft_label[mask] = label[mask].float()
    return label_propagation(A_hat, soft_label)


def evaluate(g, pred):
    label = g.ndata["label"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]

    # Compute accuracy on validation/test set.
    val_acc = (pred[val_mask] == label[val_mask]).float().mean()
    test_acc = (pred[test_mask] == label[test_mask]).float().mean()
    return val_acc, test_acc


def train(base_model, g, X):
    label = g.ndata["label"]
    train_mask = g.ndata["train_mask"]

    optimizer = Adam(base_model.parameters(), lr=0.01)

    for epoch in range(10):
        # Forward.
        base_model.train()
        logits = base_model(X)

        # Compute loss with nodes in training set.
        loss = F.cross_entropy(logits[train_mask], label[train_mask])

        # Backward.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Compute prediction.
        base_model.eval()
        logits = base_model(X)
        pred = logits.argmax(dim=1)

        # Evaluate the prediction.
        val_acc, test_acc = evaluate(g, pred)
        print(
            f"Base model, In epoch {epoch}, loss: {loss:.3f}, "
            f"val acc: {val_acc:.3f}, test acc: {test_acc:.3f}"
        )
    return logits


if __name__ == "__main__":
    # If CUDA is available, use GPU to accelerate the training, use CPU
    # otherwise.
    dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Load graph from the existing dataset.
    dataset = CoraGraphDataset()
    g = dataset[0].to(dev)

    # Create the sparse adjacency matrix A.
    src, dst = g.edges()
    N = g.num_nodes()
    A = create_from_coo(dst, src, shape=(N, N))

    # Calculate the symmetrically normalized adjacency matrix.
    I = identity(A.shape, device=dev)
    A_hat = A + I
    D_hat = diag(A_hat.sum(dim=1)) ** -0.5
    A_hat = D_hat @ A_hat @ D_hat

    # Create models.
    X = g.ndata["feat"]
    in_size = X.shape[1]
    out_size = dataset.num_classes
    base_model = nn.Linear(in_size, out_size).to(dev)

    # Stage1: Train the base model.
    logits = train(base_model, g, X)

    # Stage2: Correct and Smooth.
    soft_label = F.softmax(logits, dim=1)
    label = F.one_hot(g.ndata["label"])
    soft_label = correct(A_hat, label, soft_label, g.ndata["train_mask"])
    soft_label = smooth(A_hat, label, soft_label, g.ndata["train_mask"])
    pred = soft_label.argmax(dim=1)
    val_acc, test_acc = evaluate(g, pred)
    print(f"val acc: {val_acc:.3f}, test acc: {test_acc:.3f}")