prepare_mic.py 4.67 KB
Newer Older
mibaumgartner's avatar
mibaumgartner committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import sys
import os
from itertools import repeat
from multiprocessing.pool import Pool

import pandas as pd
import numpy as np
import numpy.testing as npt
import SimpleITK as sitk
from pathlib import Path
from loguru import logger
from tqdm import tqdm
from pathlib import Path

from nndet.io.load import save_json, load_json
from nndet.io.paths import subfiles
from nndet.utils.check import env_guard


def prepare_case(case_dir: Path, target_dir: Path, df: pd.DataFrame):
    target_data_dir = target_dir / "imagesTr"
    target_label_dir = target_dir / "labelsTr"

    case_id = str(case_dir).split('/')[-1]
    logger.info(f"Processing case {case_id}")
    df = df[df.PatientID == case_id]

    # process data
    img = sitk.ReadImage(str(case_dir / f"{case_id}_ct_scan.nrrd"))
    sitk.WriteImage(img, str(target_data_dir / f"{case_id}.nii.gz"))
    img_arr = sitk.GetArrayFromImage(img)

    # process mask
    final_rois = np.zeros_like(img_arr, dtype=np.uint8)
    mal_labels = {}
    roi_ids = set([ii.split('.')[0].split('_')[-1]
                   for ii in os.listdir(case_dir) if '.nii.gz' in ii])

    rix = 1
    for rid in roi_ids:
        roi_id_paths = [ii for ii in os.listdir(case_dir) if '{}.nii'.format(rid) in ii]
        nodule_ids = [ii.split('_')[2].lstrip("0") for ii in roi_id_paths]
        rater_labels = [df[df.NoduleID == int(ii)].Malignancy.values[0] for ii in nodule_ids]
        rater_labels.extend([0] * (4-len(rater_labels)))
        mal_label = np.mean([ii for ii in rater_labels if ii > -1])

        roi_rater_list = []
        for rp in roi_id_paths:
            roi = sitk.ReadImage(str(case_dir / rp))
            roi_arr = sitk.GetArrayFromImage(roi).astype(np.uint8)
            assert roi_arr.shape == img_arr.shape, [
                roi_arr.shape, img_arr.shape, case_id, roi.GetSpacing()]
            for ix in range(len(img_arr.shape)):
                npt.assert_almost_equal(roi.GetSpacing()[ix], img.GetSpacing()[ix])
            roi_rater_list.append(roi_arr)

        roi_rater_list.extend([np.zeros_like(roi_rater_list[-1])]*(4-len(roi_id_paths)))
        roi_raters = np.array(roi_rater_list)
        roi_raters = np.mean(roi_raters, axis=0)
        roi_raters[roi_raters < 0.5] = 0
        if np.sum(roi_raters) > 0:
            mal_labels[rix] = mal_label
            final_rois[roi_raters >= 0.5] = rix
            rix += 1
        else:
            # indicate rois suppressed by majority voting of raters
            logger.warning(f'suppressed roi! {roi_id_paths}')

    mask_itk = sitk.GetImageFromArray(final_rois)
    sitk.WriteImage(mask_itk, str(target_label_dir / f"{case_id}.nii.gz"))
    instance_classes = {key: int(item >= 3) for key, item in mal_labels}
    save_json({"instances": instance_classes, "scores": mal_labels},
              target_label_dir / f"{case_id}")


def reformat_labels(target: Path):
    for p in subfiles(target, identifier="*json", join=True):
        label = load_json(Path(p))
        mal_labels = label["scores"]
        instance_classes = {key: int(item >= 3) for key, item in mal_labels.items()}
        save_json({"instances": instance_classes, "scores": mal_labels}, Path(p))


def delete_without_label(target: Path):
    for p in subfiles(target, identifier="*.npz", join=True):
        _p = str(p).rsplit('.', 1)[0] + '.pkl'
        if not os.path.isfile(_p):
            os.remove(p)


def check_data_load(target: Path):
    for p in tqdm(subfiles(target, identifier="*.npy", join=True)):
        try:
            data = np.load(p)
        except Exception as e:
            print(f"Failed to load: {p} with {e}")


@env_guard
def main():
    det_data_dir = Path(os.getenv('det_data'))
    task_data_dir = det_data_dir / "Task012_LIDC"
    source_data_dir = task_data_dir / "raw"
    
    if not (p := source_data_dir / "data_nrrd").is_dir():
        raise ValueError(f"Expted {p} to contain LIDC data")
    if not (p := source_data_dir / 'characteristics.csv').is_file():
        raise ValueError(f"Expted {p} to contain exist")

    target_dir = task_data_dir / "raw_splitted"
    target_data_dir = task_data_dir / "raw_splitted" / "imagesTr"
    target_data_dir.mkdir(exist_ok=True, parents=True)
    target_label_dir = task_data_dir / "raw_splitted" / "labelsTr"
    target_label_dir.mkdir(exist_ok=True, parents=True)

    logger.remove()
    logger.add(sys.stdout, level="INFO")
    logger.add(task_data_dir / "prepare.log", level="DEBUG")

    data_dir = source_data_dir / "data_nrrd"
    case_dirs = [x for x in data_dir.iterdir() if x.is_dir()]
    df = pd.read_csv(source_data_dir / 'characteristics.csv', sep=';')

    for cd in case_dirs:
        prepare_case(cd, target_dir, df)

    # TODO download custom split file


if __name__ == '__main__':
    main()