Commit b37e01f3 authored by mibaumgartner's avatar mibaumgartner
Browse files

checks prototype

parent 42eeb024
from nndet.io.load import load_json, load_pickle, save_json, save_pickle, npy_dataset, save_yaml from nndet.io.load import (
from nndet.io.paths import get_case_id_from_file, get_case_id_from_path, \ load_json,
get_case_ids_from_dir, get_paths_from_splitted_dir, get_paths_raw_to_split, \ load_pickle,
get_task, get_training_dir save_json,
save_pickle,
npy_dataset,
save_yaml,
)
from nndet.io.paths import (
get_case_id_from_file,
get_case_id_from_path,
get_case_ids_from_dir,
get_paths_from_splitted_dir,
get_paths_raw_to_split,
get_task, get_training_dir,
)
from nndet.io.itk import (
load_sitk,
load_sitk_as_array,
)
import functools import functools
import os import os
import warnings import warnings
from pathlib import Path
from typing import List, Sequence, Optional
from nndet.io.paths import get_task import numpy as np
import SimpleITK as sitk
from nndet.io import load_json, load_sitk
from nndet.io.paths import get_task, get_paths_from_splitted_dir
from nndet.utils.config import load_dataset_info from nndet.utils.config import load_dataset_info
...@@ -109,5 +115,128 @@ def check_dataset_file(task_name: str): ...@@ -109,5 +115,128 @@ def check_dataset_file(task_name: str):
f"Found {found_mods} but expected {list(range(len(found_mods)))}") f"Found {found_mods} but expected {list(range(len(found_mods)))}")
def check_data_and_label_splitted(): def check_data_and_label_splitted(
pass task_name: str,
test: bool = False,
labels: bool = True,
full_check: bool = True,
):
"""
Perform checks of data and label in raw splitted format
Args:
task_name: name of task to check
test: check test data
labels: check labels
full_check: Per default a full check will be performed which needs to
load all files. If this is disabled, a computationall light check
will be performed
Raises:
ValueError: if not all raw splitted files were found
ValueError: missing label info file
ValueError: instances in label info file need to start at 1
ValueError: instances in label info file need to be consecutive
"""
cfg = load_dataset_info(get_task(task_name))
splitted_paths = get_paths_from_splitted_dir(
num_modalities=len(cfg["modalities"]),
splitted_4d_output_dir=Path(os.getenv('det_data')) / task_name / "raw_splitted",
labels=labels,
test=test,
)
for case_paths in splitted_paths:
# check all files exist
for cp in case_paths:
if not Path(cp).is_file():
raise ValueError(f"Expected {cp} to be a raw splitted "
"data path but it does not exist.")
if labels:
# check label info (json files)
mask_path = case_paths[-1]
mask_info_path = mask_path.parent / f"{mask_path.stem.split('.')[0]}.json"
if not Path(mask_info_path).is_file():
raise ValueError(f"Expected {mask_info_path} to be a raw splitted "
"mask info path but it does not exist.")
mask_info = load_json(mask_info_path)
mask_info_instances = list(map(int, mask_info["instances"].keys()))
if j := not min(mask_info_instances) == 1:
raise ValueError(f"Instance IDs need to start at 1, found {j} in {mask_info_path}")
for i in range(len(mask_info_instances)):
if i not in mask_info_instances:
raise ValueError(f"Exptected {i} to be an Instance ID in "
f"{mask_info_path} but only found {mask_info_instances}")
else:
mask_info_path = None
if full_check:
_full_check(case_paths, mask_info_path)
def _full_check(case_paths: List[Path], mask_info_path: Optional[Path] = None) -> None:
"""
Performas itk and instance chekcs on provided paths
Args:
case_paths: paths to all itk images to check properties
if label is provided it needs to be at the last position
mask_info_path: optionally check label properties. If None, no
check of label properties will be performed.
Raises:
ValueError: Inconsistent instances in label info and label image
See also:
:func:`_check_itk_params`
"""
img_itk_seq = [load_sitk(cp) for cp in case_paths]
_check_itk_params(img_itk_seq, case_paths)
if mask_info_path is not None:
mask_itk = img_itk_seq[-1]
mask_info = load_json(mask_info_path)
info_instances = list(map(int, mask_info["instances"].keys()))
mask_instances = np.unique(sitk.GetArrayViewFromImage(mask_itk))
mask_instances = mask_instances[mask_instances > 0]
for mi in mask_instances:
if not mi in info_instances:
raise ValueError(f"Found instance ID {mi} in mask which is "
f"not present in info {info_instances}")
if not len(info_instances) == len(mask_instances):
raise ValueError("Found instances in info which are not present in mask: "
f"mask: {mask_instances} info {info_instances}")
def _check_itk_params(img_seq: Sequence[sitk.Image], paths: Sequence[Path]) -> None:
"""
Check Dimension, Origin, Direction and Spacing of a Sequence of images
Args:
img_seq: sequence of images to check
paths: correcponding paths of images (for error msg)
Raises:
ValueError: raised if dimensions do not match
ValueError: raised if origin does not match
ValueError: raised if direction does not match
ValueError: raised if spacing does not match
"""
for idx, img in enumerate(img_seq[1:], start=1):
if not (np.asarray(img_seq[0].GetDimension()) == \
np.asarray(img.GetDimension())).all():
raise ValueError(f"Expected {paths[idx]} and {paths[0]} to have same dimensions!")
if not (np.asarray(img_seq[0].GetOrigin()) == \
np.asarray(img.GetOrigin())).all():
raise ValueError(f"Expected {paths[idx]} and {paths[0]} to have same origin!")
if not (np.asarray(img_seq[0].GetDirection()) == \
np.asarray(img.GetDirection())).all():
raise ValueError(f"Expected {paths[idx]} and {paths[0]} to have same direction!")
if not (np.asarray(img_seq[0].GetSpacing()) == \
np.asarray(img.GetSpacing())).all():
raise ValueError(f"Expected {paths[idx]} and {paths[0]} to have same spacing!")
...@@ -15,6 +15,7 @@ limitations under the License. ...@@ -15,6 +15,7 @@ limitations under the License.
""" """
import argparse import argparse
from multiprocessing import Value
import os import os
import sys import sys
from typing import Any, Mapping, Type, TypeVar from typing import Any, Mapping, Type, TypeVar
...@@ -29,6 +30,7 @@ from nndet.io import get_task, get_training_dir ...@@ -29,6 +30,7 @@ from nndet.io import get_task, get_training_dir
from nndet.io.load import load_pickle from nndet.io.load import load_pickle
from nndet.inference.loading import load_all_models from nndet.inference.loading import load_all_models
from nndet.inference.helper import predict_dir from nndet.inference.helper import predict_dir
from nndet.utils.check import check_data_and_label_splitted
def run(cfg: dict, def run(cfg: dict,
...@@ -164,6 +166,10 @@ def main(): ...@@ -164,6 +166,10 @@ def main():
"The 'test' split needs to be located in fold 0 " "The 'test' split needs to be located in fold 0 "
"of a manually created split file."), "of a manually created split file."),
) )
parser.add_argument('--check',
help="Run check of the test data before predicting",
action='store_true',
)
args = parser.parse_args() args = parser.parse_args()
model = args.model model = args.model
...@@ -174,6 +180,7 @@ def main(): ...@@ -174,6 +180,7 @@ def main():
ov = args.overwrites ov = args.overwrites
force_args = args.force_args force_args = args.force_args
test_split = args.test_split test_split = args.test_split
check = args.check
task_name = get_task(task, name=True) task_name = get_task(task, name=True)
task_model_dir = Path(os.getenv("det_models")) task_model_dir = Path(os.getenv("det_models"))
...@@ -196,6 +203,16 @@ def main(): ...@@ -196,6 +203,16 @@ def main():
overwrites.append("host.parent_results=${env:det_models}") overwrites.append("host.parent_results=${env:det_models}")
cfg.merge_with_dotlist(overwrites) cfg.merge_with_dotlist(overwrites)
if check:
if test_split:
raise ValueError("Check is not supported for test split option.")
check_data_and_label_splitted(
task_name=cfg["task"],
test=True,
labels=False,
full_check=True
)
run(OmegaConf.to_container(cfg, resolve=True), run(OmegaConf.to_container(cfg, resolve=True),
training_dir, training_dir,
process=process, process=process,
......
...@@ -41,6 +41,7 @@ from nndet.io.load import load_pickle, load_npz_looped ...@@ -41,6 +41,7 @@ from nndet.io.load import load_pickle, load_npz_looped
from nndet.io.prepare import maybe_split_4d_nifti, instances_from_segmentation from nndet.io.prepare import maybe_split_4d_nifti, instances_from_segmentation
from nndet.io.paths import get_paths_raw_to_split, get_paths_from_splitted_dir, subfiles, get_case_id_from_path from nndet.io.paths import get_paths_raw_to_split, get_paths_from_splitted_dir, subfiles, get_case_id_from_path
from nndet.preprocessing import ImageCropper from nndet.preprocessing import ImageCropper
from nndet.utils.check import check_dataset_file, check_data_and_label_splitted
def run_splitting_4d(data_dir: Path, output_dir: Path, num_processes: int) -> None: def run_splitting_4d(data_dir: Path, output_dir: Path, num_processes: int) -> None:
...@@ -423,11 +424,36 @@ def main(): ...@@ -423,11 +424,36 @@ def main():
parser.add_argument('-o', '--overwrites', type=str, nargs='+', parser.add_argument('-o', '--overwrites', type=str, nargs='+',
help="overwrites for config file", default=[], help="overwrites for config file", default=[],
required=False) required=False)
parser.add_argument('--full_check',
help="Run a full check of the data.",
action='store_true',
)
args = parser.parse_args() args = parser.parse_args()
tasks = args.tasks tasks = args.tasks
ov = args.overwrites ov = args.overwrites
full_check = args.full_check
initialize_config_module(config_module="nndet.conf") initialize_config_module(config_module="nndet.conf")
# perform preprocessing checks first
for task in tasks:
_ov = copy.deepcopy(ov) if ov is not None else []
cfg = compose(task, "config.yaml", overrides=_ov)
check_dataset_file(cfg["task"])
check_data_and_label_splitted(
cfg["task"],
test=False,
labels=True,
full_check=full_check,
)
if cfg["data"]["test_labels"]:
check_data_and_label_splitted(
cfg["task"],
test=True,
labels=True,
full_check=full_check,
)
# start preprocessing
for task in tasks: for task in tasks:
_ov = copy.deepcopy(ov) if ov is not None else [] _ov = copy.deepcopy(ov) if ov is not None else []
cfg = compose(task, "config.yaml", overrides=_ov) cfg = compose(task, "config.yaml", overrides=_ov)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment