from typing import Optional, Tuple import torch from torch import Tensor def index2mask(idx: Tensor, size: int) -> Tensor: mask = torch.zeros(size, dtype=torch.bool, device=idx.device) mask[idx] = True return mask def compute_acc(logits: Tensor, y: Tensor, mask: Optional[Tensor] = None) -> float: if mask is not None: logits, y = logits[mask], y[mask] if y.dim() == 1: return int(logits.argmax(dim=-1).eq(y).sum()) / y.size(0) else: y_pred = logits > 0 y_true = y > 0.5 tp = int((y_true & y_pred).sum()) fp = int((~y_true & y_pred).sum()) fn = int((y_true & ~y_pred).sum()) precision = tp / (tp + fp) recall = tp / (tp + fn) return 2 * (precision * recall) / (precision + recall) def gen_masks(y: Tensor, train_per_class: int = 20, val_per_class: int = 30, num_splits: int = 20) -> Tuple[Tensor, Tensor, Tensor]: num_classes = int(y.max()) + 1 train_mask = torch.zeros(y.size(0), num_splits, dtype=torch.bool) val_mask = torch.zeros(y.size(0), num_splits, dtype=torch.bool) for c in range(num_classes): idx = (y == c).nonzero(as_tuple=False).view(-1) perm = torch.stack( [torch.randperm(idx.size(0)) for _ in range(num_splits)], dim=1) idx = idx[perm] train_idx = idx[:train_per_class] train_mask.scatter_(0, train_idx, True) val_idx = idx[train_per_class:train_per_class + val_per_class] val_mask.scatter_(0, val_idx, True) test_mask = ~(train_mask | val_mask) return train_mask, val_mask, test_mask