Commit cf66b327 authored by rusty1s's avatar rusty1s
Browse files

compute_f1_score

parent 38f62d97
......@@ -5,7 +5,7 @@ from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric_autoscale.models import GCN
from torch_geometric_autoscale import metis, permute, SubgraphLoader
from torch_geometric_autoscale import get_data, compute_acc
from torch_geometric_autoscale import get_data, compute_micro_f1
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
......@@ -67,9 +67,9 @@ def test(model, data):
# Full-batch inference since the graph is small
out = model(data.x.to(model.device), data.adj_t.to(model.device)).cpu()
train_acc = compute_acc(out, data.y, data.train_mask)
val_acc = compute_acc(out, data.y, data.val_mask)
test_acc = compute_acc(out, data.y, data.test_mask)
train_acc = compute_micro_f1(out, data.y, data.train_mask)
val_acc = compute_micro_f1(out, data.y, data.val_mask)
test_acc = compute_micro_f1(out, data.y, data.test_mask)
return train_acc, val_acc, test_acc
......
......@@ -5,7 +5,7 @@ from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric_autoscale.models import GCN2
from torch_geometric_autoscale import metis, permute, SubgraphLoader
from torch_geometric_autoscale import get_data, compute_acc
from torch_geometric_autoscale import get_data, compute_micro_f1
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
......@@ -69,9 +69,9 @@ def test(model, data):
# Full-batch inference since the graph is small
out = model(data.x.to(model.device), data.adj_t.to(model.device)).cpu()
train_acc = compute_acc(out, data.y, data.train_mask)
val_acc = compute_acc(out, data.y, data.val_mask)
test_acc = compute_acc(out, data.y, data.test_mask)
train_acc = compute_micro_f1(out, data.y, data.train_mask)
val_acc = compute_micro_f1(out, data.y, data.val_mask)
test_acc = compute_micro_f1(out, data.y, data.test_mask)
return train_acc, val_acc, test_acc
......
......@@ -7,7 +7,7 @@ from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric_autoscale import (get_data, metis, permute,
SubgraphLoader, EvalSubgraphLoader,
models, compute_acc)
models, compute_micro_f1)
from torch_geometric_autoscale.data import get_ppi
torch.manual_seed(123)
......@@ -144,17 +144,17 @@ def main(conf):
loss = mini_train(model, train_loader, criterion, optimizer,
params.max_steps, grad_norm)
out = mini_test(model, eval_loader)
train_acc = compute_acc(out, data.y, data.train_mask)
train_acc = compute_micro_f1(out, data.y, data.train_mask)
if conf.dataset.name != 'ppi':
val_acc = compute_acc(out, data.y, data.val_mask)
tmp_test_acc = compute_acc(out, data.y, data.test_mask)
val_acc = compute_micro_f1(out, data.y, data.val_mask)
tmp_test_acc = compute_micro_f1(out, data.y, data.test_mask)
else:
# We need to perform inference on a different graph as PPI is an
# inductive dataset.
val_acc = compute_acc(full_test(model, val_data), val_data.y)
tmp_test_acc = compute_acc(full_test(model, test_data),
test_data.y)
val_acc = compute_micro_f1(full_test(model, val_data), val_data.y)
tmp_test_acc = compute_micro_f1(full_test(model, test_data),
test_data.y)
if val_acc > best_val_acc:
best_val_acc = val_acc
......
......@@ -5,8 +5,8 @@ from omegaconf import OmegaConf
import torch
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric_autoscale import (get_data, metis, permute,
SubgraphLoader, models, compute_acc)
from torch_geometric_autoscale import (get_data, metis, permute, models,
SubgraphLoader, compute_micro_f1)
torch.manual_seed(123)
criterion = torch.nn.CrossEntropyLoss()
......@@ -50,8 +50,8 @@ def test(run, model, data):
test_mask = test_mask[:, run] if test_mask.dim() == 2 else test_mask
out = model(data.x, data.adj_t)
val_acc = compute_acc(out, data.y, val_mask)
test_acc = compute_acc(out, data.y, test_mask)
val_acc = compute_micro_f1(out, data.y, val_mask)
test_acc = compute_micro_f1(out, data.y, test_mask)
return val_acc, test_acc
......
......@@ -13,7 +13,7 @@ from .data import get_data # noqa
from .history import History # noqa
from .pool import AsyncIOPool # noqa
from .metis import metis, permute # noqa
from .utils import compute_acc # noqa
from .utils import compute_micro_f1 # noqa
from .loader import SubgraphLoader, EvalSubgraphLoader # noqa
__all__ = [
......@@ -22,7 +22,7 @@ __all__ = [
'AsyncIOPool',
'metis',
'permute',
'compute_acc',
'compute_micro_f1',
'SubgraphLoader',
'EvalSubgraphLoader',
'__version__',
......
......@@ -10,8 +10,8 @@ def index2mask(idx: Tensor, size: int) -> Tensor:
return mask
def compute_acc(logits: Tensor, y: Tensor,
mask: Optional[Tensor] = None) -> float:
def compute_micro_f1(logits: Tensor, y: Tensor,
mask: Optional[Tensor] = None) -> float:
if mask is not None:
logits, y = logits[mask], y[mask]
......@@ -24,9 +24,12 @@ def compute_acc(logits: Tensor, y: Tensor,
tp = int((y_true & y_pred).sum())
fp = int((~y_true & y_pred).sum())
fn = int((y_true & ~y_pred).sum())
precision = tp / (tp + fp)
recall = tp / (tp + fn)
return 2 * (precision * recall) / (precision + recall)
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.
if precision + recall > 0:
return 2 * (precision * recall) / (precision + recall)
else:
return 0.
def gen_masks(y: Tensor, train_per_class: int = 20, val_per_class: int = 30,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment