eval.py 2.67 KB
Newer Older
1
"""
2
3
Code adapted from https://github.com/CRIPAC-DIG/GRACE
Linear evaluation on learned node embeddings
4
"""
5
6
7

import functools

8
import numpy as np
9
from sklearn.linear_model import LogisticRegression
10
11
from sklearn.metrics import f1_score
from sklearn.model_selection import GridSearchCV, train_test_split
12
from sklearn.multiclass import OneVsRestClassifier
13
from sklearn.preprocessing import OneHotEncoder, normalize
14
15
16
17
18
19
20
21
22
23
24


def repeat(n_times):
    def decorator(f):
        @functools.wraps(f)
        def wrapper(*args, **kwargs):
            results = [f(*args, **kwargs) for _ in range(n_times)]
            statistics = {}
            for key in results[0].keys():
                values = [r[key] for r in results]
                statistics[key] = {
25
26
27
                    "mean": np.mean(values),
                    "std": np.std(values),
                }
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
            print_statistics(statistics, f.__name__)
            return statistics

        return wrapper

    return decorator


def prob_to_one_hot(y_pred):
    ret = np.zeros(y_pred.shape, np.bool)
    indices = np.argmax(y_pred, axis=1)
    for i in range(y_pred.shape[0]):
        ret[i][indices[i]] = True
    return ret


def print_statistics(statistics, function_name):
45
    print(f"(E) | {function_name}:", end=" ")
46
    for i, key in enumerate(statistics.keys()):
47
48
49
        mean = statistics[key]["mean"]
        std = statistics[key]["std"]
        print(f"{key}={mean:.4f}+-{std:.4f}", end="")
50
        if i != len(statistics.keys()) - 1:
51
            print(",", end=" ")
52
53
54
55
56
        else:
            print()


@repeat(3)
57
58
59
def label_classification(
    embeddings, y, train_mask, test_mask, split="random", ratio=0.1
):
60
61
62
    X = embeddings.detach().cpu().numpy()
    Y = y.detach().cpu().numpy()
    Y = Y.reshape(-1, 1)
63
    onehot_encoder = OneHotEncoder(categories="auto").fit(Y)
64
65
    Y = onehot_encoder.transform(Y).toarray().astype(np.bool)

66
    X = normalize(X, norm="l2")
67

68
69
70
71
72
    if split == "random":
        X_train, X_test, y_train, y_test = train_test_split(
            X, Y, test_size=1 - ratio
        )
    elif split == "public":
73
74
75
76
77
        X_train = X[train_mask]
        X_test = X[test_mask]
        y_train = Y[train_mask]
        y_test = Y[test_mask]

78
    logreg = LogisticRegression(solver="liblinear")
79
80
    c = 2.0 ** np.arange(-10, 10)

81
82
83
84
85
86
87
    clf = GridSearchCV(
        estimator=OneVsRestClassifier(logreg),
        param_grid=dict(estimator__C=c),
        n_jobs=8,
        cv=5,
        verbose=0,
    )
88
89
90
91
92
93
94
95
    clf.fit(X_train, y_train)

    y_pred = clf.predict_proba(X_test)
    y_pred = prob_to_one_hot(y_pred)

    micro = f1_score(y_test, y_pred, average="micro")
    macro = f1_score(y_test, y_pred, average="macro")

96
    return {"F1Mi": micro, "F1Ma": macro}