utils.py 219 Bytes
Newer Older
1
2
3
4
import torch as th

def allclose(a, b):
    return th.allclose(a, b, rtol=1e-4, atol=1e-4)
5
6
7
8
9
10
11

def check_fail(fn, *args, **kwargs):
    try:
        fn(*args, **kwargs)
        return False
    except:
        return True