test_utils.py 1.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
import torch


def print_red_warning(message):
    print(f"\033[31mWARNING: {message}\033[0m")


def calc_sim(x, y, name="tensor"):
    x, y = x.data.double(), y.data.double()
    denominator = (x * x + y * y).sum()
    if denominator == 0:
12
        print_red_warning(f"{name} all zero")
13
14
15
16
17
18
19
20
21
        return 1
    sim = 2 * (x * y).sum() / denominator
    return sim


def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
    x_mask = torch.isfinite(x)
    y_mask = torch.isfinite(y)
    if not torch.all(x_mask == y_mask):
22
        print_red_warning(f"{name} Error: isfinite mask mismatch")
23
24
        if raise_assert:
            raise AssertionError
25
26
    if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all():
        print_red_warning(f"{name} Error: nonfinite value mismatch")
27
28
29
30
31
        if raise_assert:
            raise AssertionError
    x = x.masked_fill(~x_mask, 0)
    y = y.masked_fill(~y_mask, 0)
    sim = calc_sim(x, y, name)
32
    diff = 1.0 - sim
33
    if not (0 <= diff <= eps):
34
        print_red_warning(f"{name} Error: {diff}")
35
36
37
        if raise_assert:
            raise AssertionError
    else:
38
        print(f"{name} {data} passed")