errors.py 3.46 KB
Newer Older
maming's avatar
maming committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import dataclasses
import itertools
import warnings
from functools import wraps
from typing import Any, Type, TypeVar, Union

import numpy as np
import torch


def compact_str(
    value: Union[dict, list, str, int, bool, None],
    depth: int = 3,
    max_items: int = 10,
    max_str_len: int = 50,
) -> str:
    """
    Compact representation of a value as a string.

    Args:
        value: The value to compact
        depth: The maximum depth to compact
        max_items: The maximum number of items to show in a list or dict
        max_str_len: The maximum string length to show

    Returns: The printable string
    """
    if isinstance(value, dict):
        if depth <= 0:
            return "{...}"
        return (
            "{"
            + ", ".join(
                (
                    f"{k}: {v!r}"
                    if isinstance(k, str) and k.startswith("__")
                    else f"{k}: {compact_str(v, depth - 1, max_items, max_str_len)}"
                )
                for k, v in itertools.islice(value.items(), max_items)
            )
            + "}"
        )
    elif isinstance(value, list):
        if depth <= 0:
            return "[...]"
        return (
            "["
            + ", ".join(
                compact_str(v, depth - 1, max_items, max_str_len) for v in value[:max_items]
            )
            + "]"
        )
    elif isinstance(value, tuple):
        if depth <= 0:
            return "(...)"
        return (
            "("
            + ", ".join(
                compact_str(v, depth - 1, max_items, max_str_len) for v in value[:max_items]
            )
            + ")"
        )
    elif isinstance(value, str):
        if len(value) > max_str_len:
            return repr(value[:max_str_len] + "...")
        return repr(value)
    elif isinstance(value, torch.Tensor):
        return f"Tensor(shape={value.shape}, dtype={value.dtype}, device={value.device})"
    elif isinstance(value, np.ndarray):
        return f"np.ndarray(shape={value.shape}, dtype={value.dtype})"
    elif dataclasses.is_dataclass(value):
        return f"{value.__class__.__name__}({', '.join(f'{field.name}={compact_str(getattr(value, field.name))}' for field in dataclasses.fields(value))})"
    else:
        return compact_str(repr(value), depth, max_items, max_str_len)


T = TypeVar("T")


class SampleException(ValueError):
    @classmethod
    def from_sample_key(cls: Type[T], sample_key: str) -> T:
        return cls(f"Sample {sample_key} failed")

    @classmethod
    def from_sample(cls: Type[T], sample: Any, message: str = "") -> T:
        if message:
            message = f": {message}"
        return cls(f"Sample {compact_str(sample)} failed{message}")


class FatalSampleError(SampleException):
    # This will not be handled by the error handler
    pass


def warn_deprecated(reason, stacklevel=2):
    warnings.warn(reason, FutureWarning, stacklevel=stacklevel)


def deprecated(reason):
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            warn_deprecated(f"{func.__name__} is deprecated: {reason}", stacklevel=3)
            return func(*args, **kwargs)

        return wrapper

    return decorator


SYSTEM_EXCEPTIONS = (
    SystemError,
    SyntaxError,
    ImportError,
    StopIteration,
    StopAsyncIteration,
    MemoryError,
    RecursionError,
    ReferenceError,
    NameError,
    UnboundLocalError,
    FatalSampleError,
)