split.py 1.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
from collections import OrderedDict
from pathlib import Path

import numpy as np
from sklearn.model_selection import GroupKFold

from nndet.utils.check import env_guard
from nndet.io import get_case_ids_from_dir, save_pickle


@env_guard
def main():
    det_data_dir = Path(os.getenv('det_data'))
    task_data_dir = det_data_dir / "Task019FG_ADAM"

    target_label_dir = task_data_dir / "raw_splitted" / "labelsTr"
    splits_file_dir = task_data_dir / "preprocessed"
    splits_file_dir.mkdir(parents=True, exist_ok=True)
    splits_file = splits_file_dir / "splits_final.pkl"

    case_ids = sorted(get_case_ids_from_dir(target_label_dir, remove_modality=False))
    case_ids_pat = [c if c.isdigit() else c[:-1] for c in case_ids]
    case_ids_pat_unique = list(set(case_ids_pat))
    print(f"Found {len(case_ids_pat_unique)} unique patient ids.")

    splits = []
    kfold = GroupKFold(n_splits=5)
    for i, (train_idx, test_idx) in enumerate(kfold.split(case_ids, groups=case_ids_pat)):
        train_keys = np.array(case_ids)[train_idx]
        test_keys = np.array(case_ids)[test_idx]

        splits.append(OrderedDict())
        splits[-1]['train'] = train_keys
        splits[-1]['val'] = test_keys
        print(f"Generated split: {splits[-1]}")
    save_pickle(splits, splits_file)
   

if __name__ == '__main__':
    main()