utils.py 818 Bytes
Newer Older
lt610's avatar
lt610 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np
import random
from torch.nn import functional as F
import torch


def evaluate(model, graph, feats, labels, idxs):
    model.eval()
    with torch.no_grad():
        logits = model(graph, feats)
        results = ()
        for idx in idxs:
            loss = F.cross_entropy(logits[idx], labels[idx])
            acc = torch.sum(logits[idx].argmax(dim=1) == labels[idx]).item() / len(idx)
            results += (loss, acc)
    return results


def generate_random_seeds(seed, nums):
    random.seed(seed)
    return [random.randint(1, 999999999) for _ in range(nums)]


def set_random_state(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True