Commit b37e01f3 authored by mibaumgartner's avatar mibaumgartner
Browse files

checks prototype

parent 42eeb024
from nndet.io.load import load_json, load_pickle, save_json, save_pickle, npy_dataset, save_yaml
from nndet.io.paths import get_case_id_from_file, get_case_id_from_path, \
get_case_ids_from_dir, get_paths_from_splitted_dir, get_paths_raw_to_split, \
get_task, get_training_dir
from nndet.io.load import (
load_json,
load_pickle,
save_json,
save_pickle,
npy_dataset,
save_yaml,
)
from nndet.io.paths import (
get_case_id_from_file,
get_case_id_from_path,
get_case_ids_from_dir,
get_paths_from_splitted_dir,
get_paths_raw_to_split,
get_task, get_training_dir,
)
from nndet.io.itk import (
load_sitk,
load_sitk_as_array,
)
import functools
import os
import warnings
from pathlib import Path
from typing import List, Sequence, Optional
from nndet.io.paths import get_task
import numpy as np
import SimpleITK as sitk
from nndet.io import load_json, load_sitk
from nndet.io.paths import get_task, get_paths_from_splitted_dir
from nndet.utils.config import load_dataset_info
......@@ -109,5 +115,128 @@ def check_dataset_file(task_name: str):
f"Found {found_mods} but expected {list(range(len(found_mods)))}")
def check_data_and_label_splitted():
pass
def check_data_and_label_splitted(
task_name: str,
test: bool = False,
labels: bool = True,
full_check: bool = True,
):
"""
Perform checks of data and label in raw splitted format
Args:
task_name: name of task to check
test: check test data
labels: check labels
full_check: Per default a full check will be performed which needs to
load all files. If this is disabled, a computationall light check
will be performed
Raises:
ValueError: if not all raw splitted files were found
ValueError: missing label info file
ValueError: instances in label info file need to start at 1
ValueError: instances in label info file need to be consecutive
"""
cfg = load_dataset_info(get_task(task_name))
splitted_paths = get_paths_from_splitted_dir(
num_modalities=len(cfg["modalities"]),
splitted_4d_output_dir=Path(os.getenv('det_data')) / task_name / "raw_splitted",
labels=labels,
test=test,
)
for case_paths in splitted_paths:
# check all files exist
for cp in case_paths:
if not Path(cp).is_file():
raise ValueError(f"Expected {cp} to be a raw splitted "
"data path but it does not exist.")
if labels:
# check label info (json files)
mask_path = case_paths[-1]
mask_info_path = mask_path.parent / f"{mask_path.stem.split('.')[0]}.json"
if not Path(mask_info_path).is_file():
raise ValueError(f"Expected {mask_info_path} to be a raw splitted "
"mask info path but it does not exist.")
mask_info = load_json(mask_info_path)
mask_info_instances = list(map(int, mask_info["instances"].keys()))
if j := not min(mask_info_instances) == 1:
raise ValueError(f"Instance IDs need to start at 1, found {j} in {mask_info_path}")
for i in range(len(mask_info_instances)):
if i not in mask_info_instances:
raise ValueError(f"Exptected {i} to be an Instance ID in "
f"{mask_info_path} but only found {mask_info_instances}")
else:
mask_info_path = None
if full_check:
_full_check(case_paths, mask_info_path)
def _full_check(case_paths: List[Path], mask_info_path: Optional[Path] = None) -> None:
"""
Performas itk and instance chekcs on provided paths
Args:
case_paths: paths to all itk images to check properties
if label is provided it needs to be at the last position
mask_info_path: optionally check label properties. If None, no
check of label properties will be performed.
Raises:
ValueError: Inconsistent instances in label info and label image
See also:
:func:`_check_itk_params`
"""
img_itk_seq = [load_sitk(cp) for cp in case_paths]
_check_itk_params(img_itk_seq, case_paths)
if mask_info_path is not None:
mask_itk = img_itk_seq[-1]
mask_info = load_json(mask_info_path)
info_instances = list(map(int, mask_info["instances"].keys()))
mask_instances = np.unique(sitk.GetArrayViewFromImage(mask_itk))
mask_instances = mask_instances[mask_instances > 0]
for mi in mask_instances:
if not mi in info_instances:
raise ValueError(f"Found instance ID {mi} in mask which is "
f"not present in info {info_instances}")
if not len(info_instances) == len(mask_instances):
raise ValueError("Found instances in info which are not present in mask: "
f"mask: {mask_instances} info {info_instances}")
def _check_itk_params(img_seq: Sequence[sitk.Image], paths: Sequence[Path]) -> None:
"""
Check Dimension, Origin, Direction and Spacing of a Sequence of images
Args:
img_seq: sequence of images to check
paths: correcponding paths of images (for error msg)
Raises:
ValueError: raised if dimensions do not match
ValueError: raised if origin does not match
ValueError: raised if direction does not match
ValueError: raised if spacing does not match
"""
for idx, img in enumerate(img_seq[1:], start=1):
if not (np.asarray(img_seq[0].GetDimension()) == \
np.asarray(img.GetDimension())).all():
raise ValueError(f"Expected {paths[idx]} and {paths[0]} to have same dimensions!")
if not (np.asarray(img_seq[0].GetOrigin()) == \
np.asarray(img.GetOrigin())).all():
raise ValueError(f"Expected {paths[idx]} and {paths[0]} to have same origin!")
if not (np.asarray(img_seq[0].GetDirection()) == \
np.asarray(img.GetDirection())).all():
raise ValueError(f"Expected {paths[idx]} and {paths[0]} to have same direction!")
if not (np.asarray(img_seq[0].GetSpacing()) == \
np.asarray(img.GetSpacing())).all():
raise ValueError(f"Expected {paths[idx]} and {paths[0]} to have same spacing!")
......@@ -15,6 +15,7 @@ limitations under the License.
"""
import argparse
from multiprocessing import Value
import os
import sys
from typing import Any, Mapping, Type, TypeVar
......@@ -29,6 +30,7 @@ from nndet.io import get_task, get_training_dir
from nndet.io.load import load_pickle
from nndet.inference.loading import load_all_models
from nndet.inference.helper import predict_dir
from nndet.utils.check import check_data_and_label_splitted
def run(cfg: dict,
......@@ -164,6 +166,10 @@ def main():
"The 'test' split needs to be located in fold 0 "
"of a manually created split file."),
)
parser.add_argument('--check',
help="Run check of the test data before predicting",
action='store_true',
)
args = parser.parse_args()
model = args.model
......@@ -174,6 +180,7 @@ def main():
ov = args.overwrites
force_args = args.force_args
test_split = args.test_split
check = args.check
task_name = get_task(task, name=True)
task_model_dir = Path(os.getenv("det_models"))
......@@ -196,6 +203,16 @@ def main():
overwrites.append("host.parent_results=${env:det_models}")
cfg.merge_with_dotlist(overwrites)
if check:
if test_split:
raise ValueError("Check is not supported for test split option.")
check_data_and_label_splitted(
task_name=cfg["task"],
test=True,
labels=False,
full_check=True
)
run(OmegaConf.to_container(cfg, resolve=True),
training_dir,
process=process,
......
......@@ -41,6 +41,7 @@ from nndet.io.load import load_pickle, load_npz_looped
from nndet.io.prepare import maybe_split_4d_nifti, instances_from_segmentation
from nndet.io.paths import get_paths_raw_to_split, get_paths_from_splitted_dir, subfiles, get_case_id_from_path
from nndet.preprocessing import ImageCropper
from nndet.utils.check import check_dataset_file, check_data_and_label_splitted
def run_splitting_4d(data_dir: Path, output_dir: Path, num_processes: int) -> None:
......@@ -423,11 +424,36 @@ def main():
parser.add_argument('-o', '--overwrites', type=str, nargs='+',
help="overwrites for config file", default=[],
required=False)
parser.add_argument('--full_check',
help="Run a full check of the data.",
action='store_true',
)
args = parser.parse_args()
tasks = args.tasks
ov = args.overwrites
full_check = args.full_check
initialize_config_module(config_module="nndet.conf")
# perform preprocessing checks first
for task in tasks:
_ov = copy.deepcopy(ov) if ov is not None else []
cfg = compose(task, "config.yaml", overrides=_ov)
check_dataset_file(cfg["task"])
check_data_and_label_splitted(
cfg["task"],
test=False,
labels=True,
full_check=full_check,
)
if cfg["data"]["test_labels"]:
check_data_and_label_splitted(
cfg["task"],
test=True,
labels=True,
full_check=full_check,
)
# start preprocessing
for task in tasks:
_ov = copy.deepcopy(ov) if ov is not None else []
cfg = compose(task, "config.yaml", overrides=_ov)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment