preprocess.py 16.5 KB
Newer Older
mibaumgartner's avatar
mibaumgartner committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import argparse
import shutil
import os
import copy
import sys
import traceback

import numpy as np

from loguru import logger
from itertools import repeat
from typing import Dict, Sequence, Tuple, List
from pathlib import Path
from multiprocessing import Pool
31
from hydra import initialize_config_module
mibaumgartner's avatar
mibaumgartner committed
32
33
34
from omegaconf import OmegaConf

from nndet.utils.config import compose
35
from nndet.utils.check import env_guard
mibaumgartner's avatar
mibaumgartner committed
36
from nndet.planning import DatasetAnalyzer
mibaumgartner's avatar
mibaumgartner committed
37
38
from nndet.planning import PLANNER_REGISTRY
from nndet.planning.experiment.utils import create_labels
mibaumgartner's avatar
mibaumgartner committed
39
40
41
42
from nndet.planning.properties.registry import medical_instance_props
from nndet.io.load import load_pickle, load_npz_looped
from nndet.io.paths import get_paths_raw_to_split, get_paths_from_splitted_dir, subfiles, get_case_id_from_path
from nndet.preprocessing import ImageCropper
mibaumgartner's avatar
mibaumgartner committed
43
from nndet.utils.check import check_dataset_file, check_data_and_label_splitted
mibaumgartner's avatar
mibaumgartner committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135


def run_cropping_and_convert(cropped_output_dir: Path,
                             splitted_4d_output_dir: Path,
                             data_info: dict,
                             overwrite: bool,
                             num_processes: int,
                             ):
    """
    First preparation step data:
        - stack data and segementation to a single sample (segmentation is the last channel)
        - save data as npz (format: case_id.npz)
        - save additional properties as pkl file (format: case_id.pkl)
        - crop data to nonzeor region; crop segmentation; fill segmentation with -1 where in nonzero regions

    Args:
        cropped_output_dir (Path): path to directory where cropped images should be saved
        splitted_4d_output_dir (Path): path to splitted data
        data_info: information about data set (here `modalities` is needed)
        overwrite (bool): overwrite existing cropped data
        num_processes (int): number of processes used to crop image data
    """
    num_modalities = len(data_info["modalities"].keys())

    if overwrite and cropped_output_dir.is_dir():
        shutil.rmtree(str(cropped_output_dir))
    if not cropped_output_dir.is_dir():
        cropped_output_dir.mkdir(parents=True)

    case_files = get_paths_from_splitted_dir(num_modalities, splitted_4d_output_dir)

    logger.info(f"Running cropping with overwrite {overwrite}.")
    imgcrop = ImageCropper(num_processes, cropped_output_dir)
    imgcrop.run_cropping(case_files, overwrite_existing=overwrite)

    case_ids_failed, result_check = run_check(cropped_output_dir / "imagesTr",
                                              remove=True,
                                              processes=num_processes,
                                              keys=("data",)
                                              )
    if not result_check:
        logger.warning(
            f"Crop check failed: There are corrupted files!!!! {case_ids_failed}"
            f"Try to crop corrupted files again.",
        )
        imgcrop = ImageCropper(0, cropped_output_dir)
        imgcrop.run_cropping(case_files, overwrite_existing=False)
        case_ids_failed, result_check = run_check(cropped_output_dir / "imagesTr",
                                                  remove=False,
                                                  processes=num_processes,
                                                  keys=("data",)
                                                  )
        if not result_check:
            logger.error(f"Found corrupted files: {case_ids_failed}.")
            raise RuntimeError("Corrupted files")
    else:
        logger.info(f"Crop check successful: Loading check completed")


def run_dataset_analysis(cropped_output_dir: Path,
                         preprocessed_output_dir: Path,
                         data_info: dict,
                         num_processes: int,
                         intensity_properties: bool = True,
                         overwrite: bool = True,
                         ):
    """
    Analyse dataset

    Args:
        cropped_output_dir: path to base cropped dir
        preprocessed_output_dir: path to base preprocessed output dir
        data_info: additional information about dataset (`modalities` and `labels` needed)
        num_processes: number of processes to use
        intensity_properties: analyze intensity values of foreground
        overwrite: overwrite existing properties
    """
    analyzer = DatasetAnalyzer(
        cropped_output_dir,
        preprocessed_output_dir=preprocessed_output_dir,
        data_info=data_info,
        num_processes=num_processes,
        overwrite=overwrite,
        )
    properties = medical_instance_props(intensity_properties=intensity_properties)
    _ = analyzer.analyze_dataset(properties)


def run_planning_and_process(
    splitted_4d_output_dir: Path,
    cropped_output_dir: Path,
    preprocessed_output_dir: Path,
mibaumgartner's avatar
mibaumgartner committed
136
    planner_name: str,
mibaumgartner's avatar
mibaumgartner committed
137
138
139
140
141
142
143
144
145
146
147
148
149
    dim: int,
    model_name: str,
    model_cfg: Dict,
    num_processes: int,
    run_preprocessing: bool = True,
    ):
    """
    Run planning and preprocessing

    Args:
        splitted_4d_output_dir: base dir of splitted data
        cropped_output_dir: base dir of cropped data
        preprocessed_output_dir: base dir of preprocessed data
mibaumgartner's avatar
mibaumgartner committed
150
        planner_name: planner name
mibaumgartner's avatar
mibaumgartner committed
151
152
153
154
155
156
157
        dim: number of spatial dimensions
        model_name: name of model to run planning for
        model_cfg: hyperparameters of model (used during planning to
            instantiate model)
        num_processes: number of processes to use for preprocessing
        run_preprocessing: Preprocess and check data. Defaults to True.
    """
mibaumgartner's avatar
mibaumgartner committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    planner_cls = PLANNER_REGISTRY.get(planner_name)
    planner = planner_cls(
        preprocessed_output_dir=preprocessed_output_dir
    )
    plan_identifiers = planner.plan_experiment(
        model_name=model_name,
        model_cfg=model_cfg,
    )
    if run_preprocessing:
        for plan_id in plan_identifiers:
            plan = load_pickle(preprocessed_output_dir / plan_id)
            planner.run_preprocessing(
                cropped_data_dir=cropped_output_dir / "imagesTr",
                plan=plan,
                num_processes=num_processes,
                )
            case_ids_failed, result_check = run_check(
                data_dir=preprocessed_output_dir / plan["data_identifier"] / "imagesTr",
                remove=True,
                processes=num_processes
            )

            # delete and rerun corrupted cases
            if not result_check:
                logger.warning(f"{plan_id} check failed: There are corrupted files {case_ids_failed}!!!!"
                                f"Running preprocessing of those cases without multiprocessing.")
mibaumgartner's avatar
mibaumgartner committed
184
185
186
                planner.run_preprocessing(
                    cropped_data_dir=cropped_output_dir / "imagesTr",
                    plan=plan,
mibaumgartner's avatar
mibaumgartner committed
187
188
                    num_processes=0,
                )
mibaumgartner's avatar
mibaumgartner committed
189
190
                case_ids_failed, result_check = run_check(
                    data_dir=preprocessed_output_dir / plan["data_identifier"] / "imagesTr",
mibaumgartner's avatar
mibaumgartner committed
191
192
                    remove=False,
                    processes=0
mibaumgartner's avatar
mibaumgartner committed
193
194
                )
                if not result_check:
mibaumgartner's avatar
mibaumgartner committed
195
196
                    logger.error(f"Could not fix corrupted files {case_ids_failed}!")
                    raise RuntimeError("Found corrupted files, check logs!")
mibaumgartner's avatar
mibaumgartner committed
197
                else:
mibaumgartner's avatar
mibaumgartner committed
198
199
200
                    logger.info("Fixed corrupted files.")
            else:
                logger.info(f"{plan_id} check successful: Loading check completed")
mibaumgartner's avatar
mibaumgartner committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

    if run_preprocessing:
        create_labels(
            preprocessed_output_dir=preprocessed_output_dir,
            source_dir=splitted_4d_output_dir,
            num_processes=num_processes,
        )


def run_check(data_dir: Path,
              remove: bool = False,
              processes: int = 8,
              keys: Sequence[str] = ("data", "seg"),
              ) -> Tuple[List[str], bool]:
    """
    Check if files from preprocessed dir are loadable

    Args:
        data_dir (Path): path to preprocessed data
        remove (bool, optional): if loading fails the file is the npz and pkl
            file are removed automatically. Defaults to False.
        processes (int, optional): number of processes to use. If
            0 processes are specified it uses a normal for loop. Defaults to 8.
        keys: keys to load and check

    Returns:
        True if all cases were loadable, False otherwise
    """
    cases_npz = list(data_dir.glob("*.npz"))
    cases_npz.sort()
    cases_pkl = [case.parent / f"{(case.name).rsplit('.', 1)[0]}.pkl"
                 for case in cases_npz]

    if processes == 0:
235
        result = [check_case(case_npz, case_pkl, remove=remove, keys=keys)
mibaumgartner's avatar
mibaumgartner committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
                  for case_npz, case_pkl in zip(cases_npz, cases_pkl)]
    else:
        with Pool(processes=processes) as p:
            result = p.starmap(check_case,
                               zip(cases_npz, cases_pkl, repeat(remove), repeat(keys)))
    failed_cases = [fc[0] for fc in result if not fc[1]]
    logger.info(f"Checked {len(result)} cases in {data_dir}")
    return failed_cases, len(failed_cases) == 0


def check_case(case_npz: Path,
               case_pkl: Path = None,
               remove: bool = False,
               keys: Sequence[str] = ("data", "seg"),
               ) -> Tuple[str, bool]:
    """
    Check if a single cases loadable

    Args:
        case_npz (Path): path to npz file
        case_pkl (Path, optional): path to pkl file. Defaults to None.
        remove (bool, optional): if loading fails the file is the npz and pkl
            file are removed automatically. Defaults to False.

    Returns:
        str: case id
        bool: true if case was loaded correctly, false otherwise
    """
    logger.info(f"Checking {case_npz}")
    case_id = get_case_id_from_path(case_npz, remove_modality=False)
    try:
        case_dict = load_npz_looped(str(case_npz), keys=keys, num_tries=3)
        if "seg" in keys and case_pkl is not None:
            properties = load_pickle(case_pkl)
            seg = case_dict["seg"]
            seg_instances = np.unique(seg)  # automatically sorted
            seg_instances = seg_instances[seg_instances > 0]
            
            instances_properties = properties["instances"].keys()
            props_instances = np.sort(np.array(list(map(int, instances_properties))))
            
            if (len(seg_instances) != len(props_instances)) or any(seg_instances != props_instances):
                logger.warning(f"Inconsistent instances {case_npz} from "
                                f"properties {props_instances} from seg {seg_instances}. "
                                f"Very small instances can get lost in resampling "
                                f"but larger instances should not disappear!")       
            for i in seg_instances:
                if str(i) not in instances_properties:
                    raise RuntimeError(f"Found instance {seg_instances} in segmentation "
                                       f"which is not in properties {instances_properties}."
                                       f"Delete labels manually and rerun prepare label!")
    except Exception as e:
        logger.error(f"Failed to load {case_npz} with {e}")
        logger.error(f"{traceback.format_exc()}")
        if remove:
            os.remove(case_npz)
            if case_pkl is not None:
                os.remove(case_pkl)
        return case_id, False
    return case_id, True


298
299
300
301
def run(cfg,
        num_processes: int,
        num_processes_preprocessing: int,
        ):
mibaumgartner's avatar
mibaumgartner committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
    """
    Python interface for script

    Args:
        cfg: dict with config
        instances_from_seg: convert semantic segmentation to instance segmentation
    """
    logger.remove()
    logger.add(sys.stdout, level="INFO")
    logger.add(Path(cfg["host"]["data_dir"]) / "logging.log", level="DEBUG")
    data_info = cfg["data"]

    if cfg["prep"]["crop"]:
        # crop data to nonzero area
        run_cropping_and_convert(cropped_output_dir=Path(cfg["host"]["cropped_output_dir"]),
                                 splitted_4d_output_dir=Path(cfg["host"]["splitted_4d_output_dir"]),
                                 data_info=data_info,
                                 overwrite=cfg["prep"]["overwrite"],
320
                                 num_processes=num_processes,
mibaumgartner's avatar
mibaumgartner committed
321
322
323
324
325
326
327
                                 )

    if cfg["prep"]["analyze"]:
        # compute statistics over data and segmentation(e.g. physical volume of individual classes)
        run_dataset_analysis(cropped_output_dir=Path(cfg["host"]["cropped_output_dir"]),
                             preprocessed_output_dir=Path(cfg["host"]["preprocessed_output_dir"]),
                             data_info=data_info,
328
                             num_processes=num_processes,
mibaumgartner's avatar
mibaumgartner committed
329
330
331
332
333
334
335
336
337
338
                             intensity_properties=True,
                             overwrite=cfg["prep"]["overwrite"],
                             )

    if cfg["prep"]["plan"] or cfg["prep"]["process"]:
        # plan future training
        run_planning_and_process(
            splitted_4d_output_dir=Path(cfg["host"]["splitted_4d_output_dir"]),
            cropped_output_dir=Path(cfg["host"]["cropped_output_dir"]),
            preprocessed_output_dir=Path(cfg["host"]["preprocessed_output_dir"]),
mibaumgartner's avatar
mibaumgartner committed
339
            planner_name=cfg["planner"],
mibaumgartner's avatar
mibaumgartner committed
340
341
342
            dim=data_info["dim"],
            model_name=cfg["module"],
            model_cfg=cfg["model_cfg"],
343
            num_processes=num_processes_preprocessing,
mibaumgartner's avatar
mibaumgartner committed
344
345
346
347
348
349
350
351
352
353
354
355
356
            run_preprocessing=cfg["prep"]["process"],
        )


@env_guard
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('tasks', type=str, nargs='+',
                        help="Single or multiple task identifiers to process consecutively",
                        )
    parser.add_argument('-o', '--overwrites', type=str, nargs='+',
                        help="overwrites for config file", default=[],
                        required=False)
mibaumgartner's avatar
mibaumgartner committed
357
358
359
360
    parser.add_argument('--full_check',
                        help="Run a full check of the data.",
                        action='store_true',
                        )
361
362
363
364
    parser.add_argument('--no_check',
                        help="Skip basic check.",
                        action='store_true',
                        )
365
366
367
368
369
370
371
372
    parser.add_argument('-np', '--num_processes',
                        type=int, default=4, required=False,
                        help="Number of processes to use for croppping.",
                        )
    parser.add_argument('-npp', '--num_processes_preprocessing',
                        type=int, default=3, required=False,
                        help="Number of processes to use for resampling.",
                        )
mibaumgartner's avatar
mibaumgartner committed
373
374
375
    args = parser.parse_args()
    tasks = args.tasks
    ov = args.overwrites
mibaumgartner's avatar
mibaumgartner committed
376
    full_check = args.full_check
377
    no_check = args.no_check
378
379
    num_processes = args.num_processes
    num_processes_preprocessing = args.num_processes_preprocessing
mibaumgartner's avatar
mibaumgartner committed
380
381

    initialize_config_module(config_module="nndet.conf")
mibaumgartner's avatar
mibaumgartner committed
382
    # perform preprocessing checks first
383
384
385
386
387
    if not no_check:
        for task in tasks:
            _ov = copy.deepcopy(ov) if ov is not None else []
            cfg = compose(task, "config.yaml", overrides=_ov)
            check_dataset_file(cfg["task"])
mibaumgartner's avatar
mibaumgartner committed
388
389
            check_data_and_label_splitted(
                cfg["task"],
390
                test=False,
mibaumgartner's avatar
mibaumgartner committed
391
392
393
                labels=True,
                full_check=full_check,
                )
394
395
396
397
398
399
400
            if cfg["data"]["test_labels"]:
                check_data_and_label_splitted(
                    cfg["task"],
                    test=True,
                    labels=True,
                    full_check=full_check,
                    )
mibaumgartner's avatar
mibaumgartner committed
401
402

    # start preprocessing
mibaumgartner's avatar
mibaumgartner committed
403
404
405
    for task in tasks:
        _ov = copy.deepcopy(ov) if ov is not None else []
        cfg = compose(task, "config.yaml", overrides=_ov)
406
407
408
409
        run(OmegaConf.to_container(cfg, resolve=True),
            num_processes=num_processes,
            num_processes_preprocessing=num_processes_preprocessing,
            )
mibaumgartner's avatar
mibaumgartner committed
410
411
412
413


if __name__ == '__main__':
    main()