eval.py 2.67 KB
Newer Older
1
"""
2
3
Code adapted from https://github.com/CRIPAC-DIG/GRACE
Linear evaluation on learned node embeddings
4
"""
5
6
7

import functools

8
import numpy as np
9
from sklearn.linear_model import LogisticRegression
10
11
from sklearn.metrics import f1_score
from sklearn.model_selection import GridSearchCV, train_test_split
12
from sklearn.multiclass import OneVsRestClassifier
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
13
from sklearn.preprocessing import normalize, OneHotEncoder
14
15
16
17
18
19
20
21
22
23
24


def repeat(n_times):
    def decorator(f):
        @functools.wraps(f)
        def wrapper(*args, **kwargs):
            results = [f(*args, **kwargs) for _ in range(n_times)]
            statistics = {}
            for key in results[0].keys():
                values = [r[key] for r in results]
                statistics[key] = {
25
26
27
                    "mean": np.mean(values),
                    "std": np.std(values),
                }
28
29
30
31
32
33
34
35
36
            print_statistics(statistics, f.__name__)
            return statistics

        return wrapper

    return decorator


def prob_to_one_hot(y_pred):
37
    ret = np.zeros(y_pred.shape, np.bool_)
38
39
40
41
42
43
44
    indices = np.argmax(y_pred, axis=1)
    for i in range(y_pred.shape[0]):
        ret[i][indices[i]] = True
    return ret


def print_statistics(statistics, function_name):
45
    print(f"(E) | {function_name}:", end=" ")
46
    for i, key in enumerate(statistics.keys()):
47
48
49
        mean = statistics[key]["mean"]
        std = statistics[key]["std"]
        print(f"{key}={mean:.4f}+-{std:.4f}", end="")
50
        if i != len(statistics.keys()) - 1:
51
            print(",", end=" ")
52
53
54
55
56
        else:
            print()


@repeat(3)
57
58
59
def label_classification(
    embeddings, y, train_mask, test_mask, split="random", ratio=0.1
):
60
61
62
    X = embeddings.detach().cpu().numpy()
    Y = y.detach().cpu().numpy()
    Y = Y.reshape(-1, 1)
63
    onehot_encoder = OneHotEncoder(categories="auto").fit(Y)
64
    Y = onehot_encoder.transform(Y).toarray().astype(np.bool_)
65

66
    X = normalize(X, norm="l2")
67

68
69
70
71
72
    if split == "random":
        X_train, X_test, y_train, y_test = train_test_split(
            X, Y, test_size=1 - ratio
        )
    elif split == "public":
73
74
75
76
77
        X_train = X[train_mask]
        X_test = X[test_mask]
        y_train = Y[train_mask]
        y_test = Y[test_mask]

78
    logreg = LogisticRegression(solver="liblinear")
79
80
    c = 2.0 ** np.arange(-10, 10)

81
82
83
84
85
86
87
    clf = GridSearchCV(
        estimator=OneVsRestClassifier(logreg),
        param_grid=dict(estimator__C=c),
        n_jobs=8,
        cv=5,
        verbose=0,
    )
88
89
90
91
92
93
94
95
    clf.fit(X_train, y_train)

    y_pred = clf.predict_proba(X_test)
    y_pred = prob_to_one_hot(y_pred)

    micro = f1_score(y_test, y_pred, average="micro")
    macro = f1_score(y_test, y_pred, average="macro")

96
    return {"F1Mi": micro, "F1Ma": macro}