utils.py 2.39 KB
Newer Older
1
""" Code adapted from https://github.com/fanyun-sun/InfoGraph """
2
3
import math

4
5
6
import numpy as np
import torch as th
import torch.nn.functional as F
7
8
from sklearn.metrics import accuracy_score
from sklearn.model_selection import GridSearchCV, StratifiedKFold
9
10
from sklearn.svm import LinearSVC

11
12
13
14

def linearsvc(embeds, labels):
    x = embeds.cpu().numpy()
    y = labels.cpu().numpy()
15
    params = {"C": [0.001, 0.01, 0.1, 1, 10, 100, 1000]}
16
17
18
19
20
    kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None)
    accuracies = []
    for train_index, test_index in kf.split(x, y):
        x_train, x_test = x[train_index], x[test_index]
        y_train, y_test = y[train_index], y[test_index]
21
22
23
        classifier = GridSearchCV(
            LinearSVC(), params, cv=5, scoring="accuracy", verbose=0
        )
24
25
26
27
        classifier.fit(x_train, y_train)
        accuracies.append(accuracy_score(y_test, classifier.predict(x_test)))
    return np.mean(accuracies), np.std(accuracies)

28

29
30
31
32
33
34
35
36
def get_positive_expectation(p_samples, average=True):
    """Computes the positive part of a JS Divergence.
    Args:
        p_samples: Positive samples.
        average: Average the result over samples.
    Returns:
        th.Tensor
    """
37
38
    log_2 = math.log(2.0)
    Ep = log_2 - F.softplus(-p_samples)
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

    if average:
        return Ep.mean()
    else:
        return Ep


def get_negative_expectation(q_samples, average=True):
    """Computes the negative part of a JS Divergence.
    Args:
        q_samples: Negative samples.
        average: Average the result over samples.
    Returns:
        th.Tensor
    """
54
    log_2 = math.log(2.0)
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    Eq = F.softplus(-q_samples) + q_samples - log_2

    if average:
        return Eq.mean()
    else:
        return Eq


def local_global_loss_(l_enc, g_enc, graph_id):

    num_graphs = g_enc.shape[0]
    num_nodes = l_enc.shape[0]

    device = g_enc.device

    pos_mask = th.zeros((num_nodes, num_graphs)).to(device)
    neg_mask = th.ones((num_nodes, num_graphs)).to(device)

    for nodeidx, graphidx in enumerate(graph_id):

75
76
        pos_mask[nodeidx][graphidx] = 1.0
        neg_mask[nodeidx][graphidx] = 0.0
77
78
79
80
81
82
83
84
85

    res = th.mm(l_enc, g_enc.t())

    E_pos = get_positive_expectation(res * pos_mask, average=False).sum()
    E_pos = E_pos / num_nodes
    E_neg = get_negative_expectation(res * neg_mask, average=False).sum()
    E_neg = E_neg / (num_nodes * (num_graphs - 1))

    return E_neg - E_pos