utils.py 2.18 KB
Newer Older
rusty1s's avatar
update  
rusty1s committed
1
from typing import Optional, Tuple
rusty1s's avatar
rusty1s committed
2
3
4

import torch
from torch import Tensor
rusty1s's avatar
rusty1s committed
5
6
import torch.nn.functional as F
from torch_sparse import SparseTensor
rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14


def index2mask(idx: Tensor, size: int) -> Tensor:
    mask = torch.zeros(size, dtype=torch.bool, device=idx.device)
    mask[idx] = True
    return mask


rusty1s's avatar
rusty1s committed
15
16
def compute_micro_f1(logits: Tensor, y: Tensor,
                     mask: Optional[Tensor] = None) -> float:
rusty1s's avatar
rusty1s committed
17
18
19
20
21
22
23
24
25
26
27
28
    if mask is not None:
        logits, y = logits[mask], y[mask]

    if y.dim() == 1:
        return int(logits.argmax(dim=-1).eq(y).sum()) / y.size(0)
    else:
        y_pred = logits > 0
        y_true = y > 0.5

        tp = int((y_true & y_pred).sum())
        fp = int((~y_true & y_pred).sum())
        fn = int((y_true & ~y_pred).sum())
rusty1s's avatar
rusty1s committed
29
30
31
32

        try:
            precision = tp / (tp + fp)
            recall = tp / (tp + fn)
rusty1s's avatar
rusty1s committed
33
            return 2 * (precision * recall) / (precision + recall)
rusty1s's avatar
rusty1s committed
34
        except ZeroDivisionError:
rusty1s's avatar
rusty1s committed
35
            return 0.
rusty1s's avatar
rusty1s committed
36
37
38


def gen_masks(y: Tensor, train_per_class: int = 20, val_per_class: int = 30,
rusty1s's avatar
update  
rusty1s committed
39
              num_splits: int = 20) -> Tuple[Tensor, Tensor, Tensor]:
rusty1s's avatar
rusty1s committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    num_classes = int(y.max()) + 1

    train_mask = torch.zeros(y.size(0), num_splits, dtype=torch.bool)
    val_mask = torch.zeros(y.size(0), num_splits, dtype=torch.bool)

    for c in range(num_classes):
        idx = (y == c).nonzero(as_tuple=False).view(-1)
        perm = torch.stack(
            [torch.randperm(idx.size(0)) for _ in range(num_splits)], dim=1)
        idx = idx[perm]

        train_idx = idx[:train_per_class]
        train_mask.scatter_(0, train_idx, True)
        val_idx = idx[train_per_class:train_per_class + val_per_class]
        val_mask.scatter_(0, val_idx, True)

    test_mask = ~(train_mask | val_mask)

    return train_mask, val_mask, test_mask
rusty1s's avatar
rusty1s committed
59
60
61
62
63
64
65
66
67
68


def dropout(adj_t: SparseTensor, p: float, training: bool = True):
    if not training:
        return adj_t

    if adj_t.storage.value() is not None:
        value = F.dropout(adj_t.storage.value(), p=p)
        adj_t = adj_t.set_value(value, layout='coo')
    else:
rusty1s's avatar
typo  
rusty1s committed
69
        mask = torch.rand(adj_t.nnz(), device=adj_t.storage.row().device) > p
rusty1s's avatar
rusty1s committed
70
71
72
        adj_t = adj_t.masked_select_nnz(mask, layout='coo')

    return adj_t