Commit e708c342 authored by Andres Martinez Mora's avatar Andres Martinez Mora
Browse files

Allow processing input instance segmentation labels with float values

parent d26ba9a9
...@@ -34,14 +34,23 @@ from loguru import logger ...@@ -34,14 +34,23 @@ from loguru import logger
from nndet.io.paths import subfiles, Pathlike from nndet.io.paths import subfiles, Pathlike
__all__ = ["load_case_cropped", "load_case_from_list", __all__ = [
"load_properties_of_cropped", "npy_dataset", "load_case_cropped",
"load_pickle", "load_json", "save_json", "save_pickle", "load_case_from_list",
"save_yaml", "load_npz_looped", "load_properties_of_cropped",
] "npy_dataset",
"load_pickle",
"load_json",
def load_case_from_list(data_files, seg_file=None) -> Tuple[np.ndarray, np.ndarray, dict]: "save_json",
"save_pickle",
"save_yaml",
"load_npz_looped",
]
def load_case_from_list(
data_files, seg_file=None
) -> Tuple[np.ndarray, np.ndarray, dict]:
""" """
Load data and label of one case from list of paths Load data and label of one case from list of paths
...@@ -88,7 +97,9 @@ def load_case_from_list(data_files, seg_file=None) -> Tuple[np.ndarray, np.ndarr ...@@ -88,7 +97,9 @@ def load_case_from_list(data_files, seg_file=None) -> Tuple[np.ndarray, np.ndarr
# cast instances to correct type # cast instances to correct type
properties_json["instances"] = { properties_json["instances"] = {
str(key): int(item) for key, item in properties_json["instances"].items()} str(key): int(item)
for key, item in properties_json["instances"].items()
}
properties.update(properties_json) properties.update(properties_json)
else: else:
...@@ -100,52 +111,62 @@ def load_properties_of_cropped(path: Path): ...@@ -100,52 +111,62 @@ def load_properties_of_cropped(path: Path):
""" """
Load property file of after cropping was performed Load property file of after cropping was performed
(files are name after case id and .pkl ending) (files are name after case id and .pkl ending)
Args: Args:
path (Path): path to file (if .pkl is missing, it will be added automatically) path (Path): path to file (if .pkl is missing, it will be added automatically)
Returns: Returns:
Dict: loaded properties Dict: loaded properties
""" """
if not path.suffix == '.pkl': if not path.suffix == ".pkl":
path = Path(str(path) + '.pkl') path = Path(str(path) + ".pkl")
with open(path, 'rb') as f: with open(path, "rb") as f:
properties = pickle.load(f) properties = pickle.load(f)
return properties return properties
def load_case_cropped(folder: Path, case_id: str) -> Tuple[np.ndarray, np.ndarray, dict]: def load_case_cropped(
folder: Path, case_id: str
) -> Tuple[np.ndarray, np.ndarray, dict]:
""" """
Load single case after cropping Load single case after cropping
Args: Args:
folder (Path): path to folder where cases are located folder (Path): path to folder where cases are located
case_id (str): case identifier case_id (str): case identifier
Returns: Returns:
np.ndarray: data np.ndarray: data
np.ndarray: segmentation np.ndarray: segmentation
dict: additional properties dict: additional properties
""" """
stack = load_npz_looped(os.path.join(folder, case_id) + ".npz", stack = load_npz_looped(
keys=["data"], num_tries=3, os.path.join(folder, case_id) + ".npz",
)["data"] keys=["data"],
num_tries=3,
)["data"]
data = stack[:-1] data = stack[:-1]
seg = stack[-1] seg = stack[-1]
with open(os.path.join(folder, case_id) + ".pkl", "rb") as f: with open(os.path.join(folder, case_id) + ".pkl", "rb") as f:
props = pickle.load(f) props = pickle.load(f)
assert data.shape[1:] == seg.shape, (f"Data and segmentation need to have same dim (except first). " assert data.shape[1:] == seg.shape, (
f"Found data {data.shape} and " f"Data and segmentation need to have same dim (except first). "
f"mask {seg.shape} for case {case_id}") f"Found data {data.shape} and "
return data.astype(np.float32), seg.astype(np.int32), props f"mask {seg.shape} for case {case_id}"
)
return data.astype(np.float32), np.rint(seg).astype(np.int32), props
@contextmanager @contextmanager
def npy_dataset(folder: str, processes: int, def npy_dataset(
unpack: bool = True, delete_npy: bool = True, folder: str,
delete_npz: bool = False): processes: int,
unpack: bool = True,
delete_npy: bool = True,
delete_npz: bool = False,
):
""" """
Automatically unpacks the npz dataset and deletes npy data after completion Automatically unpacks the npz dataset and deletes npy data after completion
...@@ -165,9 +186,7 @@ def npy_dataset(folder: str, processes: int, ...@@ -165,9 +186,7 @@ def npy_dataset(folder: str, processes: int,
del_npy(Path(folder)) del_npy(Path(folder))
def unpack_dataset(folder: Pathlike, def unpack_dataset(folder: Pathlike, processes: int, delete_npz: bool = False):
processes: int,
delete_npz: bool = False):
""" """
unpacks all npz files in a folder to npy unpacks all npz files in a folder to npy
(whatever you want to have unpacked must be saved under key) (whatever you want to have unpacked must be saved under key)
...@@ -181,7 +200,7 @@ def unpack_dataset(folder: Pathlike, ...@@ -181,7 +200,7 @@ def unpack_dataset(folder: Pathlike,
logger.info("Unpacking dataset") logger.info("Unpacking dataset")
npz_files = subfiles(Path(folder), identifier="*.npz", join=True) npz_files = subfiles(Path(folder), identifier="*.npz", join=True)
if not npz_files: if not npz_files:
logger.warning(f'No paths found in {Path(folder)} matching *.npz') logger.warning(f"No paths found in {Path(folder)} matching *.npz")
return return
with Pool(processes) as p: with Pool(processes) as p:
p.starmap(npz2npy, zip(npz_files, repeat(delete_npz))) p.starmap(npz2npy, zip(npz_files, repeat(delete_npz)))
...@@ -255,7 +274,7 @@ def load_json(path: Path, **kwargs) -> Any: ...@@ -255,7 +274,7 @@ def load_json(path: Path, **kwargs) -> Any:
""" """
if isinstance(path, str): if isinstance(path, str):
path = Path(path) path = Path(path)
if not(".json" == path.suffix): if not (".json" == path.suffix):
path = str(path) + ".json" path = str(path) + ".json"
with open(path, "r") as f: with open(path, "r") as f:
...@@ -275,7 +294,7 @@ def save_json(data: Any, path: Pathlike, indent: int = 4, **kwargs): ...@@ -275,7 +294,7 @@ def save_json(data: Any, path: Pathlike, indent: int = 4, **kwargs):
""" """
if isinstance(path, str): if isinstance(path, str):
path = Path(path) path = Path(path)
if not(".json" == path.suffix): if not (".json" == path.suffix):
path = Path(str(path) + ".json") path = Path(str(path) + ".json")
with open(path, "w") as f: with open(path, "w") as f:
...@@ -333,7 +352,7 @@ def save_yaml(data: Any, path: Path, **kwargs): ...@@ -333,7 +352,7 @@ def save_yaml(data: Any, path: Path, **kwargs):
""" """
if isinstance(path, str): if isinstance(path, str):
path = Path(path) path = Path(path)
if not(".yaml" == path.suffix): if not (".yaml" == path.suffix):
path = str(path) + ".yaml" path = str(path) + ".yaml"
with open(path, "w") as f: with open(path, "w") as f:
...@@ -351,7 +370,7 @@ def save_txt(data: str, path: Path, **kwargs): ...@@ -351,7 +370,7 @@ def save_txt(data: str, path: Path, **kwargs):
""" """
if isinstance(path, str): if isinstance(path, str):
path = Path(path) path = Path(path)
if not(".txt" == path.suffix): if not (".txt" == path.suffix):
path = str(path) + ".txt" path = str(path) + ".txt"
with open(path, "a") as f: with open(path, "a") as f:
...@@ -359,12 +378,12 @@ def save_txt(data: str, path: Path, **kwargs): ...@@ -359,12 +378,12 @@ def save_txt(data: str, path: Path, **kwargs):
def load_npz_looped( def load_npz_looped(
p: Pathlike, p: Pathlike,
keys: Sequence[str], keys: Sequence[str],
*args, *args,
num_tries: int = 3, num_tries: int = 3,
**kwargs, **kwargs,
) -> Union[np.ndarray, dict]: ) -> Union[np.ndarray, dict]:
""" """
Try | Except loop to load numpy files Try | Except loop to load numpy files
(especially large numpy files can fail with BadZipFile Errors) (especially large numpy files can fail with BadZipFile Errors)
...@@ -380,7 +399,9 @@ def load_npz_looped( ...@@ -380,7 +399,9 @@ def load_npz_looped(
dict: loaded data dict: loaded data
""" """
if num_tries <= 0: if num_tries <= 0:
raise ValueError(f"Num tires needs to be larger than 0, found {num_tries} tries.") raise ValueError(
f"Num tires needs to be larger than 0, found {num_tries} tries."
)
for i in range(num_tries): # try reading the file 3 times for i in range(num_tries): # try reading the file 3 times
try: try:
...@@ -391,5 +412,5 @@ def load_npz_looped( ...@@ -391,5 +412,5 @@ def load_npz_looped(
if i == num_tries - 1: if i == num_tries - 1:
logger.error(f"Could not unpack {p}") logger.error(f"Could not unpack {p}")
return None return None
time.sleep(5.) time.sleep(5.0)
return data return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment