utils.py 1.6 KB
Newer Older
rusty1s's avatar
update  
rusty1s committed
1
from typing import Optional, Tuple
rusty1s's avatar
rusty1s committed
2
3
4
5
6
7
8
9
10
11
12

import torch
from torch import Tensor


def index2mask(idx: Tensor, size: int) -> Tensor:
    mask = torch.zeros(size, dtype=torch.bool, device=idx.device)
    mask[idx] = True
    return mask


rusty1s's avatar
update  
rusty1s committed
13
14
def compute_acc(logits: Tensor, y: Tensor,
                mask: Optional[Tensor] = None) -> float:
rusty1s's avatar
rusty1s committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    if mask is not None:
        logits, y = logits[mask], y[mask]

    if y.dim() == 1:
        return int(logits.argmax(dim=-1).eq(y).sum()) / y.size(0)
    else:
        y_pred = logits > 0
        y_true = y > 0.5

        tp = int((y_true & y_pred).sum())
        fp = int((~y_true & y_pred).sum())
        fn = int((y_true & ~y_pred).sum())
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        return 2 * (precision * recall) / (precision + recall)


def gen_masks(y: Tensor, train_per_class: int = 20, val_per_class: int = 30,
rusty1s's avatar
update  
rusty1s committed
33
              num_splits: int = 20) -> Tuple[Tensor, Tensor, Tensor]:
rusty1s's avatar
rusty1s committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    num_classes = int(y.max()) + 1

    train_mask = torch.zeros(y.size(0), num_splits, dtype=torch.bool)
    val_mask = torch.zeros(y.size(0), num_splits, dtype=torch.bool)

    for c in range(num_classes):
        idx = (y == c).nonzero(as_tuple=False).view(-1)
        perm = torch.stack(
            [torch.randperm(idx.size(0)) for _ in range(num_splits)], dim=1)
        idx = idx[perm]

        train_idx = idx[:train_per_class]
        train_mask.scatter_(0, train_idx, True)
        val_idx = idx[train_per_class:train_per_class + val_per_class]
        val_mask.scatter_(0, val_idx, True)

    test_mask = ~(train_mask | val_mask)

    return train_mask, val_mask, test_mask