check.py 4.26 KB
Newer Older
1
2
3
4
import functools
import os
import warnings

mibaumgartner's avatar
mibaumgartner committed
5
6
7
8
from nndet.io.paths import get_task
from nndet.utils.config import load_dataset_info


9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def env_guard(func):
    """
    Contextmanager to check nnDetection environment variables
    """
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        # we use print here because logging might not be initialized yet and
        # this is intended as a user warning.
        
        # det_data
        if os.environ.get("det_data", None) is None:
            raise RuntimeError(
                "'det_data' environment variable not set. "
                "Please refer to the installation instructions. "
                )

        # det_models
        if os.environ.get("det_models", None) is None:
            raise RuntimeError(
                "'det_models' environment variable not set. "
                "Please refer to the installation instructions. "
                )

        # OMP_NUM_THREADS
        if os.environ.get("OMP_NUM_THREADS", None) is None:
            raise RuntimeError(
                "'OMP_NUM_THREADS' environment variable not set. "
                "Please refer to the installation instructions. "
                )

        # det_num_threads
        if os.environ.get("det_num_threads", None) is None:
            warnings.warn(
                "Warning: 'det_num_threads' environment variable not set. "
                "Please read installation instructions again. "
                "Training will not work properly.")

        # det_verbose
        if os.environ.get("det_verbose", None) is None:
            print("'det_verbose' environment variable not set. "
                  "Continue in verbose mode.")

        return func(*args, **kwargs)
    return wrapper


mibaumgartner's avatar
mibaumgartner committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def _check_key_missing(cfg: dict, key: str, ktype=None):
    if key not in cfg:
        raise ValueError(f"Dataset information did not contain "
                        f"'{key}' key, found {list(cfg.keys())}")
    
    if ktype is not None:
        if not isinstance(cfg[key], ktype):
            raise ValueError(f"Found {key} of type {type(cfg[key])} in "
                             f"dataset information but expected type {ktype}")


def check_dataset_file(task_name: str):
    """
    Run a sequence of checks to confirm correct format of dataset information

    Args:
        task_name: task identifier to check info for
    """
    cfg = load_dataset_info(get_task(task_name))
    _check_key_missing(cfg, "task", ktype=str)
    _check_key_missing(cfg, "dim", ktype=int)
    _check_key_missing(cfg, "labels", ktype=dict)
    _check_key_missing(cfg, "modalities", ktype=dict)

    # check dim
    if dim := cfg["dim"] not in [2, 3]:
        raise ValueError(f"Found dim {dim} in dataset info but only support dim=2 or dim=3.")

    # check labels
    for key, item in cfg["labels"].items():
        if not isinstance(key, (str, int)):
            raise ValueError("Expected key of type string in dataset "
                             f"info labels but found {type(key)} : {key}")
        if not isinstance(item, (str, int)):
            raise ValueError("Expected name of type string in dataset "
                             f"info labels but found {type(item)} : {item}")
    found_classes = sorted(list(map(int, cfg["labels"].keys())))
    for ic, idx in enumerate(found_classes):
        if ic != idx:
            raise ValueError("Found wrong order of label classes in dataset info."
                             f"Found {found_classes} but expected {list(range(len(found_classes)))}")

    # check modalities
    for key, item in cfg["modalities"].items():
        if not isinstance(key, (str, int)):
            raise ValueError("Expected key of type string in dataset "
                             f"info labels but found {type(key)} : {key}")
        if not isinstance(item, (str, int)):
            raise ValueError("Expected name of type string in dataset "
                             f"info labels but found {type(item)} : {item}")
    found_mods = sorted(list(map(int, cfg["modalities"].keys())))
    for ic, idx in enumerate(found_classes):
        if ic != idx:
            raise ValueError("Found wrong order of modalities in dataset info."
                             f"Found {found_mods} but expected {list(range(len(found_mods)))}")


def check_data_and_label_splitted():
    pass