"git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "6186146d599d33cfd82c78a948f3b06858e5a7b9"
preprocess.py 18.2 KB
Newer Older
mibaumgartner's avatar
mibaumgartner committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import argparse
import shutil
import os
import copy
import sys
import traceback

import numpy as np

from loguru import logger
from itertools import repeat
from typing import Dict, Sequence, Tuple, List
from pathlib import Path
from multiprocessing import Pool
from hydra.experimental import initialize_config_module
from omegaconf import OmegaConf

from nndet.utils.config import compose
from nndet.utils.info import env_guard
from nndet.planning import DatasetAnalyzer
from nndet.planning.plan_experiment import PLANNER_REGISTRY
from nndet.planning.plan_experiment import create_labels
from nndet.planning.properties.registry import medical_instance_props
from nndet.io.load import load_pickle, load_npz_looped
from nndet.io.prepare import maybe_split_4d_nifti, instances_from_segmentation
from nndet.io.paths import get_paths_raw_to_split, get_paths_from_splitted_dir, subfiles, get_case_id_from_path
from nndet.preprocessing import ImageCropper


def run_splitting_4d(data_dir: Path, output_dir: Path, num_processes: int) -> None:
    """
    Due to historical reasons this framework uses 3D niftis instead of 4D niftis
    This function splits present 4D niftis into 3D niftis per channel

    Args:
        data_dir (str): top directory where data is located
        output_dir (str): output directory for splitted data
        num_processes (int): number of processes to use to split data
        rm_classes: classes to remove from segmentation
        ro_classes: reorder classes in segmentation
        subtract_one_from_classes: subtract one from all classes in mapping
        instances_from_seg: converts semantic segmentations to instance
            segmentations via connected components
    """
    if output_dir.is_dir():
        shutil.rmtree(output_dir)
    output_dir.mkdir(parents=True)

    source_files, target_folders = get_paths_raw_to_split(data_dir, output_dir)
    with Pool(processes=num_processes) as p:
        p.starmap(maybe_split_4d_nifti, zip(source_files, target_folders))


def prepare_labels(data_dir: Path,
                   output_dir: Path,
                   num_processes: int,
                   rm_classes: Sequence[int],
                   ro_classes: Dict[int, int],
                   subtract_one_from_classes: bool,
                   instances_from_seg: bool = True):
    """
    Copy labels to splitted dir.
    Optionally, runs connected components and removes classes from
    semantic segmentations

    Args:
        data_dir: path to task base dir
        output_dir: base dir of prepared labels
        num_processes: number of processes to use
        rm_classes: remove specific classes from semantic segmentation.
            Can only be used with `instances_from_seg`
        ro_classes: reorder classes in semantic segmentation for
            connected components. Can only be used with `instances_from_seg`
        subtract_one_from_classes: class indices for detection start from 0.
            Subtracts 1 from classes extracted from segmentation
        instances_from_seg: Run connected components. Defaults to True.
    """
    for labels_subdir in ("labelsTr", "labelsTs"):
        if not (data_dir / labels_subdir).is_dir():
            continue
        labels_output_dir = output_dir / labels_subdir

        if instances_from_seg:
            if not labels_output_dir.is_dir():
                labels_output_dir.mkdir(parents=True)

            with Pool(processes=num_processes) as p:
                paths = list(map(Path, subfiles(data_dir / labels_subdir,
                                                identifier="*.nii.gz", join=True)))
                paths = [path for path in paths if not path.name.startswith('.')]
                p.starmap(instances_from_segmentation, zip(
                    paths, repeat(labels_output_dir), repeat(rm_classes),
                    repeat(ro_classes), repeat(subtract_one_from_classes)))
        else:
            shutil.copytree(data_dir / labels_subdir, labels_output_dir)


def run_cropping_and_convert(cropped_output_dir: Path,
                             splitted_4d_output_dir: Path,
                             data_info: dict,
                             overwrite: bool,
                             num_processes: int,
                             ):
    """
    First preparation step data:
        - stack data and segementation to a single sample (segmentation is the last channel)
        - save data as npz (format: case_id.npz)
        - save additional properties as pkl file (format: case_id.pkl)
        - crop data to nonzeor region; crop segmentation; fill segmentation with -1 where in nonzero regions

    Args:
        cropped_output_dir (Path): path to directory where cropped images should be saved
        splitted_4d_output_dir (Path): path to splitted data
        data_info: information about data set (here `modalities` is needed)
        overwrite (bool): overwrite existing cropped data
        num_processes (int): number of processes used to crop image data
    """
    num_modalities = len(data_info["modalities"].keys())

    if overwrite and cropped_output_dir.is_dir():
        shutil.rmtree(str(cropped_output_dir))
    if not cropped_output_dir.is_dir():
        cropped_output_dir.mkdir(parents=True)

    case_files = get_paths_from_splitted_dir(num_modalities, splitted_4d_output_dir)

    logger.info(f"Running cropping with overwrite {overwrite}.")
    imgcrop = ImageCropper(num_processes, cropped_output_dir)
    imgcrop.run_cropping(case_files, overwrite_existing=overwrite)

    case_ids_failed, result_check = run_check(cropped_output_dir / "imagesTr",
                                              remove=True,
                                              processes=num_processes,
                                              keys=("data",)
                                              )
    if not result_check:
        logger.warning(
            f"Crop check failed: There are corrupted files!!!! {case_ids_failed}"
            f"Try to crop corrupted files again.",
        )
        imgcrop = ImageCropper(0, cropped_output_dir)
        imgcrop.run_cropping(case_files, overwrite_existing=False)
        case_ids_failed, result_check = run_check(cropped_output_dir / "imagesTr",
                                                  remove=False,
                                                  processes=num_processes,
                                                  keys=("data",)
                                                  )
        if not result_check:
            logger.error(f"Found corrupted files: {case_ids_failed}.")
            raise RuntimeError("Corrupted files")
    else:
        logger.info(f"Crop check successful: Loading check completed")


def run_dataset_analysis(cropped_output_dir: Path,
                         preprocessed_output_dir: Path,
                         data_info: dict,
                         num_processes: int,
                         intensity_properties: bool = True,
                         overwrite: bool = True,
                         ):
    """
    Analyse dataset

    Args:
        cropped_output_dir: path to base cropped dir
        preprocessed_output_dir: path to base preprocessed output dir
        data_info: additional information about dataset (`modalities` and `labels` needed)
        num_processes: number of processes to use
        intensity_properties: analyze intensity values of foreground
        overwrite: overwrite existing properties
    """
    analyzer = DatasetAnalyzer(
        cropped_output_dir,
        preprocessed_output_dir=preprocessed_output_dir,
        data_info=data_info,
        num_processes=num_processes,
        overwrite=overwrite,
        )
    properties = medical_instance_props(intensity_properties=intensity_properties)
    _ = analyzer.analyze_dataset(properties)


def run_planning_and_process(
    splitted_4d_output_dir: Path,
    cropped_output_dir: Path,
    preprocessed_output_dir: Path,
    planners: Dict[str, Sequence[str]],
    dim: int,
    model_name: str,
    model_cfg: Dict,
    num_processes: int,
    run_preprocessing: bool = True,
    ):
    """
    Run planning and preprocessing

    Args:
        splitted_4d_output_dir: base dir of splitted data
        cropped_output_dir: base dir of cropped data
        preprocessed_output_dir: base dir of preprocessed data
        planners: define planners for
            the needed dimension
        dim: number of spatial dimensions
        model_name: name of model to run planning for
        model_cfg: hyperparameters of model (used during planning to
            instantiate model)
        num_processes: number of processes to use for preprocessing
        run_preprocessing: Preprocess and check data. Defaults to True.
    """
    selected_planners = planners[f"{dim}d"]
    for planner_name in selected_planners:
        planner_cls = PLANNER_REGISTRY.get(planner_name)
        planner = planner_cls(
            preprocessed_output_dir=preprocessed_output_dir
        )
        plan_identifiers = planner.plan_experiment(
            model_name=model_name,
            model_cfg=model_cfg,
        )
        if run_preprocessing:
            for plan_id in plan_identifiers:
                plan = load_pickle(preprocessed_output_dir / plan_id)
                planner.run_preprocessing(
                    cropped_data_dir=cropped_output_dir / "imagesTr",
                    plan=plan,
                    num_processes=num_processes,
                    )
                case_ids_failed, result_check = run_check(
                    data_dir=preprocessed_output_dir / plan["data_identifier"] / "imagesTr",
                    remove=True,
                    processes=num_processes
                )

                # delete and rerun corrupted cases
                if not result_check:
                    logger.warning(f"{plan_id} check failed: There are corrupted files {case_ids_failed}!!!!"
                                   f"Running preprocessing of those cases without multiprocessing.")
                    planner.run_preprocessing(
                        cropped_data_dir=cropped_output_dir / "imagesTr",
                        plan=plan,
                        num_processes=0,
                    )
                    case_ids_failed, result_check = run_check(
                        data_dir=preprocessed_output_dir / plan["data_identifier"] / "imagesTr",
                        remove=False,
                        processes=0
                    )
                    if not result_check:
                        logger.error(f"Could not fix corrupted files {case_ids_failed}!")
                        raise RuntimeError("Found corrupted files, check logs!")
                    else:
                        logger.info("Fixed corrupted files.")
                else:
                    logger.info(f"{plan_id} check successful: Loading check completed")

    if run_preprocessing:
        create_labels(
            preprocessed_output_dir=preprocessed_output_dir,
            source_dir=splitted_4d_output_dir,
            num_processes=num_processes,
        )


def run_check(data_dir: Path,
              remove: bool = False,
              processes: int = 8,
              keys: Sequence[str] = ("data", "seg"),
              ) -> Tuple[List[str], bool]:
    """
    Check if files from preprocessed dir are loadable

    Args:
        data_dir (Path): path to preprocessed data
        remove (bool, optional): if loading fails the file is the npz and pkl
            file are removed automatically. Defaults to False.
        processes (int, optional): number of processes to use. If
            0 processes are specified it uses a normal for loop. Defaults to 8.
        keys: keys to load and check

    Returns:
        True if all cases were loadable, False otherwise
    """
    cases_npz = list(data_dir.glob("*.npz"))
    cases_npz.sort()
    cases_pkl = [case.parent / f"{(case.name).rsplit('.', 1)[0]}.pkl"
                 for case in cases_npz]

    if processes == 0:
        result = [check_case(case_npz, case_pkl, remove=remove)
                  for case_npz, case_pkl in zip(cases_npz, cases_pkl)]
    else:
        with Pool(processes=processes) as p:
            result = p.starmap(check_case,
                               zip(cases_npz, cases_pkl, repeat(remove), repeat(keys)))
    failed_cases = [fc[0] for fc in result if not fc[1]]
    logger.info(f"Checked {len(result)} cases in {data_dir}")
    return failed_cases, len(failed_cases) == 0


def check_case(case_npz: Path,
               case_pkl: Path = None,
               remove: bool = False,
               keys: Sequence[str] = ("data", "seg"),
               ) -> Tuple[str, bool]:
    """
    Check if a single cases loadable

    Args:
        case_npz (Path): path to npz file
        case_pkl (Path, optional): path to pkl file. Defaults to None.
        remove (bool, optional): if loading fails the file is the npz and pkl
            file are removed automatically. Defaults to False.

    Returns:
        str: case id
        bool: true if case was loaded correctly, false otherwise
    """
    logger.info(f"Checking {case_npz}")
    case_id = get_case_id_from_path(case_npz, remove_modality=False)
    try:
        case_dict = load_npz_looped(str(case_npz), keys=keys, num_tries=3)
        if "seg" in keys and case_pkl is not None:
            properties = load_pickle(case_pkl)
            seg = case_dict["seg"]
            seg_instances = np.unique(seg)  # automatically sorted
            seg_instances = seg_instances[seg_instances > 0]
            
            instances_properties = properties["instances"].keys()
            props_instances = np.sort(np.array(list(map(int, instances_properties))))
            
            if (len(seg_instances) != len(props_instances)) or any(seg_instances != props_instances):
                logger.warning(f"Inconsistent instances {case_npz} from "
                                f"properties {props_instances} from seg {seg_instances}. "
                                f"Very small instances can get lost in resampling "
                                f"but larger instances should not disappear!")       
            for i in seg_instances:
                if str(i) not in instances_properties:
                    raise RuntimeError(f"Found instance {seg_instances} in segmentation "
                                       f"which is not in properties {instances_properties}."
                                       f"Delete labels manually and rerun prepare label!")
    except Exception as e:
        logger.error(f"Failed to load {case_npz} with {e}")
        logger.error(f"{traceback.format_exc()}")
        if remove:
            os.remove(case_npz)
            if case_pkl is not None:
                os.remove(case_pkl)
        return case_id, False
    return case_id, True


def run(cfg, instances_from_seg):
    """
    Python interface for script

    Args:
        cfg: dict with config
        instances_from_seg: convert semantic segmentation to instance segmentation
    """
    logger.remove()
    logger.add(sys.stdout, level="INFO")
    logger.add(Path(cfg["host"]["data_dir"]) / "logging.log", level="DEBUG")
    logger.info(f"Running instances_from_seg: {instances_from_seg}")
    data_info = cfg["data"]

    if cfg["prep"]["crop"]:
        # crop data to nonzero area
        run_cropping_and_convert(cropped_output_dir=Path(cfg["host"]["cropped_output_dir"]),
                                 splitted_4d_output_dir=Path(cfg["host"]["splitted_4d_output_dir"]),
                                 data_info=data_info,
                                 overwrite=cfg["prep"]["overwrite"],
                                 num_processes=cfg["prep"]["num_processes"],
                                 )

    if cfg["prep"]["analyze"]:
        # compute statistics over data and segmentation(e.g. physical volume of individual classes)
        run_dataset_analysis(cropped_output_dir=Path(cfg["host"]["cropped_output_dir"]),
                             preprocessed_output_dir=Path(cfg["host"]["preprocessed_output_dir"]),
                             data_info=data_info,
                             num_processes=cfg["prep"]["num_processes"],
                             intensity_properties=True,
                             overwrite=cfg["prep"]["overwrite"],
                             )

    if cfg["prep"]["plan"] or cfg["prep"]["process"]:
        # plan future training
        run_planning_and_process(
            splitted_4d_output_dir=Path(cfg["host"]["splitted_4d_output_dir"]),
            cropped_output_dir=Path(cfg["host"]["cropped_output_dir"]),
            preprocessed_output_dir=Path(cfg["host"]["preprocessed_output_dir"]),
            planners=cfg["planners"],
            dim=data_info["dim"],
            model_name=cfg["module"],
            model_cfg=cfg["model_cfg"],
            num_processes=cfg["prep"]["num_processes_processing"],
            run_preprocessing=cfg["prep"]["process"],
        )


@env_guard
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('tasks', type=str, nargs='+',
                        help="Single or multiple task identifiers to process consecutively",
                        )
    parser.add_argument('-o', '--overwrites', type=str, nargs='+',
                        help="overwrites for config file", default=[],
                        required=False)
    args = parser.parse_args()
    tasks = args.tasks
    ov = args.overwrites

    initialize_config_module(config_module="nndet.conf")
    for task in tasks:
        _ov = copy.deepcopy(ov) if ov is not None else []
        cfg = compose(task, "config.yaml", overrides=_ov)
        instances_from_seg = cfg.data.get("instances_from_seg", False)
        run(OmegaConf.to_container(cfg, resolve=True), instances_from_seg)


if __name__ == '__main__':
    main()