metrics.py 1.9 KB
Newer Older
Ponku's avatar
Ponku committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from typing import Dict, List, Optional, Tuple

from torch import Tensor

AVAILABLE_METRICS = ["mae", "rmse", "epe", "bad1", "bad2", "epe", "1px", "3px", "5px", "fl-all", "relepe"]


def compute_metrics(
    flow_pred: Tensor, flow_gt: Tensor, valid_flow_mask: Optional[Tensor], metrics: List[str]
) -> Tuple[Dict[str, float], int]:
    for m in metrics:
        if m not in AVAILABLE_METRICS:
            raise ValueError(f"Invalid metric: {m}. Valid metrics are: {AVAILABLE_METRICS}")

    metrics_dict = {}

    pixels_diffs = (flow_pred - flow_gt).abs()
    # there is no Y flow in Stereo Matching, therefor flow.abs() = flow.pow(2).sum(dim=1).sqrt()
    flow_norm = flow_gt.abs()

    if valid_flow_mask is not None:
        valid_flow_mask = valid_flow_mask.unsqueeze(1)
        pixels_diffs = pixels_diffs[valid_flow_mask]
        flow_norm = flow_norm[valid_flow_mask]

    num_pixels = pixels_diffs.numel()
    if "bad1" in metrics:
        metrics_dict["bad1"] = (pixels_diffs > 1).float().mean().item()
    if "bad2" in metrics:
        metrics_dict["bad2"] = (pixels_diffs > 2).float().mean().item()

    if "mae" in metrics:
        metrics_dict["mae"] = pixels_diffs.mean().item()
    if "rmse" in metrics:
        metrics_dict["rmse"] = pixels_diffs.pow(2).mean().sqrt().item()
    if "epe" in metrics:
        metrics_dict["epe"] = pixels_diffs.mean().item()
    if "1px" in metrics:
        metrics_dict["1px"] = (pixels_diffs < 1).float().mean().item()
    if "3px" in metrics:
        metrics_dict["3px"] = (pixels_diffs < 3).float().mean().item()
    if "5px" in metrics:
        metrics_dict["5px"] = (pixels_diffs < 5).float().mean().item()
    if "fl-all" in metrics:
        metrics_dict["fl-all"] = ((pixels_diffs < 3) & ((pixels_diffs / flow_norm) < 0.05)).float().mean().item() * 100
    if "relepe" in metrics:
        metrics_dict["relepe"] = (pixels_diffs / flow_norm).mean().item()

    return metrics_dict, num_pixels