preprocess.py 19.1 KB
Newer Older
mibaumgartner's avatar
mibaumgartner committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import argparse
import shutil
import os
import copy
import sys
import traceback

import numpy as np

from loguru import logger
from itertools import repeat
from typing import Dict, Sequence, Tuple, List
from pathlib import Path
from multiprocessing import Pool
from hydra.experimental import initialize_config_module
from omegaconf import OmegaConf

from nndet.utils.config import compose
35
from nndet.utils.check import env_guard
mibaumgartner's avatar
mibaumgartner committed
36
from nndet.planning import DatasetAnalyzer
mibaumgartner's avatar
mibaumgartner committed
37
38
from nndet.planning import PLANNER_REGISTRY
from nndet.planning.experiment.utils import create_labels
mibaumgartner's avatar
mibaumgartner committed
39
40
41
42
43
from nndet.planning.properties.registry import medical_instance_props
from nndet.io.load import load_pickle, load_npz_looped
from nndet.io.prepare import maybe_split_4d_nifti, instances_from_segmentation
from nndet.io.paths import get_paths_raw_to_split, get_paths_from_splitted_dir, subfiles, get_case_id_from_path
from nndet.preprocessing import ImageCropper
mibaumgartner's avatar
mibaumgartner committed
44
from nndet.utils.check import check_dataset_file, check_data_and_label_splitted
mibaumgartner's avatar
mibaumgartner committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204


def run_splitting_4d(data_dir: Path, output_dir: Path, num_processes: int) -> None:
    """
    Due to historical reasons this framework uses 3D niftis instead of 4D niftis
    This function splits present 4D niftis into 3D niftis per channel

    Args:
        data_dir (str): top directory where data is located
        output_dir (str): output directory for splitted data
        num_processes (int): number of processes to use to split data
        rm_classes: classes to remove from segmentation
        ro_classes: reorder classes in segmentation
        subtract_one_from_classes: subtract one from all classes in mapping
        instances_from_seg: converts semantic segmentations to instance
            segmentations via connected components
    """
    if output_dir.is_dir():
        shutil.rmtree(output_dir)
    output_dir.mkdir(parents=True)

    source_files, target_folders = get_paths_raw_to_split(data_dir, output_dir)
    with Pool(processes=num_processes) as p:
        p.starmap(maybe_split_4d_nifti, zip(source_files, target_folders))


def prepare_labels(data_dir: Path,
                   output_dir: Path,
                   num_processes: int,
                   rm_classes: Sequence[int],
                   ro_classes: Dict[int, int],
                   subtract_one_from_classes: bool,
                   instances_from_seg: bool = True):
    """
    Copy labels to splitted dir.
    Optionally, runs connected components and removes classes from
    semantic segmentations

    Args:
        data_dir: path to task base dir
        output_dir: base dir of prepared labels
        num_processes: number of processes to use
        rm_classes: remove specific classes from semantic segmentation.
            Can only be used with `instances_from_seg`
        ro_classes: reorder classes in semantic segmentation for
            connected components. Can only be used with `instances_from_seg`
        subtract_one_from_classes: class indices for detection start from 0.
            Subtracts 1 from classes extracted from segmentation
        instances_from_seg: Run connected components. Defaults to True.
    """
    for labels_subdir in ("labelsTr", "labelsTs"):
        if not (data_dir / labels_subdir).is_dir():
            continue
        labels_output_dir = output_dir / labels_subdir

        if instances_from_seg:
            if not labels_output_dir.is_dir():
                labels_output_dir.mkdir(parents=True)

            with Pool(processes=num_processes) as p:
                paths = list(map(Path, subfiles(data_dir / labels_subdir,
                                                identifier="*.nii.gz", join=True)))
                paths = [path for path in paths if not path.name.startswith('.')]
                p.starmap(instances_from_segmentation, zip(
                    paths, repeat(labels_output_dir), repeat(rm_classes),
                    repeat(ro_classes), repeat(subtract_one_from_classes)))
        else:
            shutil.copytree(data_dir / labels_subdir, labels_output_dir)


def run_cropping_and_convert(cropped_output_dir: Path,
                             splitted_4d_output_dir: Path,
                             data_info: dict,
                             overwrite: bool,
                             num_processes: int,
                             ):
    """
    First preparation step data:
        - stack data and segementation to a single sample (segmentation is the last channel)
        - save data as npz (format: case_id.npz)
        - save additional properties as pkl file (format: case_id.pkl)
        - crop data to nonzeor region; crop segmentation; fill segmentation with -1 where in nonzero regions

    Args:
        cropped_output_dir (Path): path to directory where cropped images should be saved
        splitted_4d_output_dir (Path): path to splitted data
        data_info: information about data set (here `modalities` is needed)
        overwrite (bool): overwrite existing cropped data
        num_processes (int): number of processes used to crop image data
    """
    num_modalities = len(data_info["modalities"].keys())

    if overwrite and cropped_output_dir.is_dir():
        shutil.rmtree(str(cropped_output_dir))
    if not cropped_output_dir.is_dir():
        cropped_output_dir.mkdir(parents=True)

    case_files = get_paths_from_splitted_dir(num_modalities, splitted_4d_output_dir)

    logger.info(f"Running cropping with overwrite {overwrite}.")
    imgcrop = ImageCropper(num_processes, cropped_output_dir)
    imgcrop.run_cropping(case_files, overwrite_existing=overwrite)

    case_ids_failed, result_check = run_check(cropped_output_dir / "imagesTr",
                                              remove=True,
                                              processes=num_processes,
                                              keys=("data",)
                                              )
    if not result_check:
        logger.warning(
            f"Crop check failed: There are corrupted files!!!! {case_ids_failed}"
            f"Try to crop corrupted files again.",
        )
        imgcrop = ImageCropper(0, cropped_output_dir)
        imgcrop.run_cropping(case_files, overwrite_existing=False)
        case_ids_failed, result_check = run_check(cropped_output_dir / "imagesTr",
                                                  remove=False,
                                                  processes=num_processes,
                                                  keys=("data",)
                                                  )
        if not result_check:
            logger.error(f"Found corrupted files: {case_ids_failed}.")
            raise RuntimeError("Corrupted files")
    else:
        logger.info(f"Crop check successful: Loading check completed")


def run_dataset_analysis(cropped_output_dir: Path,
                         preprocessed_output_dir: Path,
                         data_info: dict,
                         num_processes: int,
                         intensity_properties: bool = True,
                         overwrite: bool = True,
                         ):
    """
    Analyse dataset

    Args:
        cropped_output_dir: path to base cropped dir
        preprocessed_output_dir: path to base preprocessed output dir
        data_info: additional information about dataset (`modalities` and `labels` needed)
        num_processes: number of processes to use
        intensity_properties: analyze intensity values of foreground
        overwrite: overwrite existing properties
    """
    analyzer = DatasetAnalyzer(
        cropped_output_dir,
        preprocessed_output_dir=preprocessed_output_dir,
        data_info=data_info,
        num_processes=num_processes,
        overwrite=overwrite,
        )
    properties = medical_instance_props(intensity_properties=intensity_properties)
    _ = analyzer.analyze_dataset(properties)


def run_planning_and_process(
    splitted_4d_output_dir: Path,
    cropped_output_dir: Path,
    preprocessed_output_dir: Path,
mibaumgartner's avatar
mibaumgartner committed
205
    planner_name: str,
mibaumgartner's avatar
mibaumgartner committed
206
207
208
209
210
211
212
213
214
215
216
217
218
    dim: int,
    model_name: str,
    model_cfg: Dict,
    num_processes: int,
    run_preprocessing: bool = True,
    ):
    """
    Run planning and preprocessing

    Args:
        splitted_4d_output_dir: base dir of splitted data
        cropped_output_dir: base dir of cropped data
        preprocessed_output_dir: base dir of preprocessed data
mibaumgartner's avatar
mibaumgartner committed
219
        planner_name: planner name
mibaumgartner's avatar
mibaumgartner committed
220
221
222
223
224
225
226
        dim: number of spatial dimensions
        model_name: name of model to run planning for
        model_cfg: hyperparameters of model (used during planning to
            instantiate model)
        num_processes: number of processes to use for preprocessing
        run_preprocessing: Preprocess and check data. Defaults to True.
    """
mibaumgartner's avatar
mibaumgartner committed
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    planner_cls = PLANNER_REGISTRY.get(planner_name)
    planner = planner_cls(
        preprocessed_output_dir=preprocessed_output_dir
    )
    plan_identifiers = planner.plan_experiment(
        model_name=model_name,
        model_cfg=model_cfg,
    )
    if run_preprocessing:
        for plan_id in plan_identifiers:
            plan = load_pickle(preprocessed_output_dir / plan_id)
            planner.run_preprocessing(
                cropped_data_dir=cropped_output_dir / "imagesTr",
                plan=plan,
                num_processes=num_processes,
                )
            case_ids_failed, result_check = run_check(
                data_dir=preprocessed_output_dir / plan["data_identifier"] / "imagesTr",
                remove=True,
                processes=num_processes
            )

            # delete and rerun corrupted cases
            if not result_check:
                logger.warning(f"{plan_id} check failed: There are corrupted files {case_ids_failed}!!!!"
                                f"Running preprocessing of those cases without multiprocessing.")
mibaumgartner's avatar
mibaumgartner committed
253
254
255
                planner.run_preprocessing(
                    cropped_data_dir=cropped_output_dir / "imagesTr",
                    plan=plan,
mibaumgartner's avatar
mibaumgartner committed
256
257
                    num_processes=0,
                )
mibaumgartner's avatar
mibaumgartner committed
258
259
                case_ids_failed, result_check = run_check(
                    data_dir=preprocessed_output_dir / plan["data_identifier"] / "imagesTr",
mibaumgartner's avatar
mibaumgartner committed
260
261
                    remove=False,
                    processes=0
mibaumgartner's avatar
mibaumgartner committed
262
263
                )
                if not result_check:
mibaumgartner's avatar
mibaumgartner committed
264
265
                    logger.error(f"Could not fix corrupted files {case_ids_failed}!")
                    raise RuntimeError("Found corrupted files, check logs!")
mibaumgartner's avatar
mibaumgartner committed
266
                else:
mibaumgartner's avatar
mibaumgartner committed
267
268
269
                    logger.info("Fixed corrupted files.")
            else:
                logger.info(f"{plan_id} check successful: Loading check completed")
mibaumgartner's avatar
mibaumgartner committed
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405

    if run_preprocessing:
        create_labels(
            preprocessed_output_dir=preprocessed_output_dir,
            source_dir=splitted_4d_output_dir,
            num_processes=num_processes,
        )


def run_check(data_dir: Path,
              remove: bool = False,
              processes: int = 8,
              keys: Sequence[str] = ("data", "seg"),
              ) -> Tuple[List[str], bool]:
    """
    Check if files from preprocessed dir are loadable

    Args:
        data_dir (Path): path to preprocessed data
        remove (bool, optional): if loading fails the file is the npz and pkl
            file are removed automatically. Defaults to False.
        processes (int, optional): number of processes to use. If
            0 processes are specified it uses a normal for loop. Defaults to 8.
        keys: keys to load and check

    Returns:
        True if all cases were loadable, False otherwise
    """
    cases_npz = list(data_dir.glob("*.npz"))
    cases_npz.sort()
    cases_pkl = [case.parent / f"{(case.name).rsplit('.', 1)[0]}.pkl"
                 for case in cases_npz]

    if processes == 0:
        result = [check_case(case_npz, case_pkl, remove=remove)
                  for case_npz, case_pkl in zip(cases_npz, cases_pkl)]
    else:
        with Pool(processes=processes) as p:
            result = p.starmap(check_case,
                               zip(cases_npz, cases_pkl, repeat(remove), repeat(keys)))
    failed_cases = [fc[0] for fc in result if not fc[1]]
    logger.info(f"Checked {len(result)} cases in {data_dir}")
    return failed_cases, len(failed_cases) == 0


def check_case(case_npz: Path,
               case_pkl: Path = None,
               remove: bool = False,
               keys: Sequence[str] = ("data", "seg"),
               ) -> Tuple[str, bool]:
    """
    Check if a single cases loadable

    Args:
        case_npz (Path): path to npz file
        case_pkl (Path, optional): path to pkl file. Defaults to None.
        remove (bool, optional): if loading fails the file is the npz and pkl
            file are removed automatically. Defaults to False.

    Returns:
        str: case id
        bool: true if case was loaded correctly, false otherwise
    """
    logger.info(f"Checking {case_npz}")
    case_id = get_case_id_from_path(case_npz, remove_modality=False)
    try:
        case_dict = load_npz_looped(str(case_npz), keys=keys, num_tries=3)
        if "seg" in keys and case_pkl is not None:
            properties = load_pickle(case_pkl)
            seg = case_dict["seg"]
            seg_instances = np.unique(seg)  # automatically sorted
            seg_instances = seg_instances[seg_instances > 0]
            
            instances_properties = properties["instances"].keys()
            props_instances = np.sort(np.array(list(map(int, instances_properties))))
            
            if (len(seg_instances) != len(props_instances)) or any(seg_instances != props_instances):
                logger.warning(f"Inconsistent instances {case_npz} from "
                                f"properties {props_instances} from seg {seg_instances}. "
                                f"Very small instances can get lost in resampling "
                                f"but larger instances should not disappear!")       
            for i in seg_instances:
                if str(i) not in instances_properties:
                    raise RuntimeError(f"Found instance {seg_instances} in segmentation "
                                       f"which is not in properties {instances_properties}."
                                       f"Delete labels manually and rerun prepare label!")
    except Exception as e:
        logger.error(f"Failed to load {case_npz} with {e}")
        logger.error(f"{traceback.format_exc()}")
        if remove:
            os.remove(case_npz)
            if case_pkl is not None:
                os.remove(case_pkl)
        return case_id, False
    return case_id, True


def run(cfg, instances_from_seg):
    """
    Python interface for script

    Args:
        cfg: dict with config
        instances_from_seg: convert semantic segmentation to instance segmentation
    """
    logger.remove()
    logger.add(sys.stdout, level="INFO")
    logger.add(Path(cfg["host"]["data_dir"]) / "logging.log", level="DEBUG")
    logger.info(f"Running instances_from_seg: {instances_from_seg}")
    data_info = cfg["data"]

    if cfg["prep"]["crop"]:
        # crop data to nonzero area
        run_cropping_and_convert(cropped_output_dir=Path(cfg["host"]["cropped_output_dir"]),
                                 splitted_4d_output_dir=Path(cfg["host"]["splitted_4d_output_dir"]),
                                 data_info=data_info,
                                 overwrite=cfg["prep"]["overwrite"],
                                 num_processes=cfg["prep"]["num_processes"],
                                 )

    if cfg["prep"]["analyze"]:
        # compute statistics over data and segmentation(e.g. physical volume of individual classes)
        run_dataset_analysis(cropped_output_dir=Path(cfg["host"]["cropped_output_dir"]),
                             preprocessed_output_dir=Path(cfg["host"]["preprocessed_output_dir"]),
                             data_info=data_info,
                             num_processes=cfg["prep"]["num_processes"],
                             intensity_properties=True,
                             overwrite=cfg["prep"]["overwrite"],
                             )

    if cfg["prep"]["plan"] or cfg["prep"]["process"]:
        # plan future training
        run_planning_and_process(
            splitted_4d_output_dir=Path(cfg["host"]["splitted_4d_output_dir"]),
            cropped_output_dir=Path(cfg["host"]["cropped_output_dir"]),
            preprocessed_output_dir=Path(cfg["host"]["preprocessed_output_dir"]),
mibaumgartner's avatar
mibaumgartner committed
406
            planner_name=cfg["planner"],
mibaumgartner's avatar
mibaumgartner committed
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
            dim=data_info["dim"],
            model_name=cfg["module"],
            model_cfg=cfg["model_cfg"],
            num_processes=cfg["prep"]["num_processes_processing"],
            run_preprocessing=cfg["prep"]["process"],
        )


@env_guard
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('tasks', type=str, nargs='+',
                        help="Single or multiple task identifiers to process consecutively",
                        )
    parser.add_argument('-o', '--overwrites', type=str, nargs='+',
                        help="overwrites for config file", default=[],
                        required=False)
mibaumgartner's avatar
mibaumgartner committed
424
425
426
427
    parser.add_argument('--full_check',
                        help="Run a full check of the data.",
                        action='store_true',
                        )
428
429
430
431
    parser.add_argument('--no_check',
                        help="Skip basic check.",
                        action='store_true',
                        )
mibaumgartner's avatar
mibaumgartner committed
432
433
434
    args = parser.parse_args()
    tasks = args.tasks
    ov = args.overwrites
mibaumgartner's avatar
mibaumgartner committed
435
    full_check = args.full_check
436
    no_check = args.no_check
mibaumgartner's avatar
mibaumgartner committed
437
438

    initialize_config_module(config_module="nndet.conf")
mibaumgartner's avatar
mibaumgartner committed
439
    # perform preprocessing checks first
440
441
442
443
444
    if not no_check:
        for task in tasks:
            _ov = copy.deepcopy(ov) if ov is not None else []
            cfg = compose(task, "config.yaml", overrides=_ov)
            check_dataset_file(cfg["task"])
mibaumgartner's avatar
mibaumgartner committed
445
446
            check_data_and_label_splitted(
                cfg["task"],
447
                test=False,
mibaumgartner's avatar
mibaumgartner committed
448
449
450
                labels=True,
                full_check=full_check,
                )
451
452
453
454
455
456
457
            if cfg["data"]["test_labels"]:
                check_data_and_label_splitted(
                    cfg["task"],
                    test=True,
                    labels=True,
                    full_check=full_check,
                    )
mibaumgartner's avatar
mibaumgartner committed
458
459

    # start preprocessing
mibaumgartner's avatar
mibaumgartner committed
460
461
462
463
464
465
466
467
468
    for task in tasks:
        _ov = copy.deepcopy(ov) if ov is not None else []
        cfg = compose(task, "config.yaml", overrides=_ov)
        instances_from_seg = cfg.data.get("instances_from_seg", False)
        run(OmegaConf.to_container(cfg, resolve=True), instances_from_seg)


if __name__ == '__main__':
    main()