exception.py 1.92 KB
Newer Older
chenzk's avatar
v1.0.8  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import contextlib
import signal
from typing import Optional

from nanotron import distributed as dist


@contextlib.contextmanager
def assert_fail_with(exception_class, error_msg: Optional[str] = None):
    try:
        yield
    except exception_class as e:
        if error_msg is None:
            return
        if error_msg == str(e):
            return
        else:
            raise AssertionError(f'Expected message to be "{error_msg}", but got "{str(e)}" instead.')
    except Exception as e:
        raise AssertionError(f"Expected {exception_class} to be raised, but got: {type(e)} instead:\n{e}")
    raise AssertionError(f"Expected {exception_class} to be raised, but no exception was raised.")


@contextlib.contextmanager
def assert_fail_except_rank_with(
    exception_class, rank_exception: int, pg: dist.ProcessGroup, error_msg: Optional[str] = None
):
    try:
        yield
    except exception_class as e:
        if rank_exception == dist.get_rank(pg):
            raise AssertionError(f"Expected rank {rank_exception} to not raise {exception_class}.")
        else:
            if error_msg is None:
                return
            if error_msg == str(e):
                return
            else:
                raise AssertionError(f'Expected message to be "{error_msg}", but got "{str(e)}" instead.')

    except Exception as e:
        raise AssertionError(f"Expected {exception_class} to be raised, but got: {type(e)} instead:\n{e}")
    if dist.get_rank(pg) != rank_exception:
        raise AssertionError(f"Expected {exception_class} to be raised, but no exception was raised.")


@contextlib.contextmanager
def timeout_after(ms=500):
    """Timeout context manager."""

    def signal_handler(signum, frame):
        raise TimeoutError(f"Timed out after {ms} ms.")

    signal.signal(signal.SIGALRM, signal_handler)
    signal.setitimer(signal.ITIMER_REAL, ms / 1000)
    try:
        yield
    finally:
        signal.alarm(0)