"vscode:/vscode.git/clone" did not exist on "89f45180e02cf04b2b044ffc19f5ca8f599fb438"
Commit 5d61a79b authored by mibaumgartner's avatar mibaumgartner
Browse files

utils

parent 44da6e3e
from nndet.utils.tensor import (
make_onehot_batch, to_dtype, to_device, to_numpy, to_tensor, cat,
)
from nndet.utils.info import (
maybe_verbose_iterable, find_name, log_experiment, log_plan, log_git,
get_cls_name, log_error, file_logger,
)
from nndet.utils.timer import Timer
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
This is prototype code ... Use at your own risk
This was initially part of a notebook but I needed to move it into
this scriptish functions to run it in my default pipeline
"""
import pickle
from itertools import product
from pathlib import Path
from typing import Sequence, Optional, Tuple
from collections import defaultdict
from loguru import logger
import numpy as np
import pandas as pd
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import
import matplotlib.pyplot as plt
plt.style.use('seaborn-deep')
from sklearn.metrics import confusion_matrix
from torch import Tensor
import SimpleITK as sitk
from nndet.detection.boxes import box_iou_np, box_size_np
from nndet.io.load import load_pickle, save_json
from nndet.utils.info import maybe_verbose_iterable
def collect_overview(prediction_dir: Path, gt_dir: Path,
iou: float, score: float,
max_num_fp_per_image: int = 5,
top_n: int = 10,
):
results = defaultdict(dict)
for f in prediction_dir.glob("*_boxes.pkl"):
case_id = f.stem.rsplit('_', 1)[0]
gt_data = np.load(str(gt_dir / f"{case_id}_boxes_gt.npz"), allow_pickle=True)
gt_boxes = gt_data["boxes"]
gt_classes = gt_data["classes"]
gt_ignore = [np.zeros(gt_boxes_img.shape[0]).reshape(-1, 1) for gt_boxes_img in [gt_boxes]]
case_result = load_pickle(f)
pred_boxes = case_result["pred_boxes"]
pred_scores = case_result["pred_scores"]
pred_labels = case_result["pred_labels"]
keep = pred_scores > score
pred_boxes, pred_scores, pred_labels = pred_boxes[keep], pred_scores[keep], pred_labels[keep]
# if "properties" in case_data:
# results[case_id]["orig_spacing"] = case_data["properties"]["original_spacing"]
# results[case_id]["crop_shape"] = [c[1] for c in case_data["properties"]["crop_bbox"]]
# else:
# results[case_id]["orig_spacing"] = None
# results[case_id]["crop_shape"] = None
results[case_id]["num_gt"] = len(gt_classes)
# computation stats here
if gt_boxes.size == 0:
idx = np.argsort(pred_scores)[::-1][:5]
results[case_id]["fp_score"] = pred_scores[idx]
results[case_id]["fp_label"] = pred_labels[idx]
results[case_id]["fp_true_label"] = (np.ones(len(pred_labels)) * -1)
results[case_id]["fp_type"] = ["fp_iou"] * len(pred_labels)
results[case_id]["num_fn"] = 0
elif pred_boxes.size == 0:
results[case_id]["num_fn"] = len(gt_classes)
results[case_id]["fn_boxes"] = gt_boxes
else:
match_quality_matrix = box_iou_np(gt_boxes, pred_boxes)
matched_idxs = np.argmax(match_quality_matrix, axis=0)
matched_vals = np.max(match_quality_matrix, axis=0)
matched_idxs[matched_vals < iou] = -1
matched_gt_boxes_per_image = gt_boxes[matched_idxs.clip(min=0)]
target_labels = gt_classes[matched_idxs.clip(min=0)]
target_labels[matched_idxs == -1] = -1
# True positive analysis
tp_keep = target_labels == pred_labels
tp_boxes, tp_scores, tp_labels = pred_boxes[tp_keep], pred_scores[tp_keep], pred_labels[tp_keep]
keep_high = tp_scores > 0.5
tp_high_boxes, tp_high_scores, tp_high_labels = tp_boxes[keep_high], tp_scores[keep_high], tp_labels[
keep_high]
keep_low = tp_scores < 0.5
tp_low_boxes, tp_low_scores, tp_low_labels = tp_boxes[keep_low], tp_scores[keep_low], tp_labels[keep_low]
high_idx = np.argsort(tp_high_scores)[::-1][:3]
low_idx = np.argsort(tp_low_scores)[:3]
results[case_id]["iou_tp"] = int(tp_keep.sum())
results[case_id]["tp_high_boxes"] = tp_high_boxes[high_idx]
results[case_id]["tp_high_score"] = tp_high_scores[high_idx]
results[case_id]["tp_high_label"] = tp_high_labels[high_idx]
results[case_id]["tp_iou"] = matched_vals[tp_keep]
if tp_low_boxes.size > 0:
results[case_id]["tp_low_boxes"] = tp_low_boxes[low_idx]
results[case_id]["tp_low_score"] = tp_low_scores[low_idx]
results[case_id]["tp_low_label"] = tp_low_labels[low_idx]
# False Positive Analysis
fp_keep = (pred_labels != target_labels) * (pred_labels != -1)
fp_boxes, fp_scores, fp_labels, fp_target_labels = pred_boxes[fp_keep], pred_scores[fp_keep], pred_labels[
fp_keep], target_labels[fp_keep]
idx = np.argsort(fp_scores)[::-1][:max_num_fp_per_image]
# results[case_id]["fp_box"] = fp_boxes[idx]
results[case_id]["fp_score"] = fp_scores[idx]
results[case_id]["fp_label"] = fp_labels[idx]
results[case_id]["fp_true_label"] = fp_target_labels[idx]
results[case_id]["fp_type"] = ["fp_iou" if tl == -1 else "fp_cls" for tl in fp_target_labels]
# Misc
unmatched_gt = (match_quality_matrix.max(axis=1) < iou)
false_negatives = unmatched_gt.sum()
results[case_id]["fn_boxes"] = gt_boxes[unmatched_gt]
results[case_id]["num_fn"] = false_negatives
df = pd.DataFrame.from_dict(results, orient='index')
df = df.sort_index()
analysis_ids = {}
if "fp_score" in list(df.columns):
tmp = df["fp_score"].apply(lambda x: np.max(x) if np.any(x) else 0).nlargest(top_n)
analysis_ids["top_scoring_fp"] = tmp.index.values.tolist()
tmp = df["fp_score"].apply(
lambda x: len(x) if isinstance(x, Sequence) or isinstance(x, np.ndarray) else 0).nlargest(top_n)
analysis_ids["top_num_fp"] = tmp.index.values.tolist()
if "fp_score" in list(df.columns):
tmp = df["num_fn"].nlargest(top_n)
analysis_ids["top_num_fn"] = tmp.index.values.tolist()
return df, analysis_ids
def collect_score_iou(prediction_dir: Path, gt_dir: Path, iou: float, score: float):
all_pred = []
all_target = []
all_pred_ious = []
all_pred_scores = []
for f in prediction_dir.glob("*_boxes.pkl"):
case_id = f.stem.rsplit('_', 1)[0]
gt_data = np.load(str(gt_dir / f"{case_id}_boxes_gt.npz"), allow_pickle=True)
gt_boxes = gt_data["boxes"]
gt_classes = gt_data["classes"]
gt_ignore = [np.zeros(gt_boxes_img.shape[0]).reshape(-1, 1) for gt_boxes_img in [gt_boxes]]
case_result = load_pickle(f)
pred_boxes = case_result["pred_boxes"]
pred_scores = case_result["pred_scores"]
pred_labels = case_result["pred_labels"]
keep = pred_scores > score
pred_boxes, pred_scores, pred_labels = pred_boxes[keep], pred_scores[keep], pred_labels[keep]
# computation starts here
if gt_boxes.size == 0:
all_pred.append(pred_labels)
all_target.append(np.ones(len(pred_labels)) * -1)
all_pred_ious.append(np.zeros(len(pred_labels)))
all_pred_scores.append(pred_scores)
elif pred_boxes.size == 0:
all_pred.append(np.ones(len(gt_classes)) * -1)
all_target.append(gt_classes)
else:
match_quality_matrix = box_iou_np(gt_boxes, pred_boxes)
matched_idxs = np.argmax(match_quality_matrix, axis=0)
matched_vals = np.max(match_quality_matrix, axis=0)
matched_idxs[matched_vals < iou] = -1
matched_gt_boxes_per_image = gt_boxes[matched_idxs.clip(min=0)]
target_labels = gt_classes[matched_idxs.clip(min=0)]
target_labels[matched_idxs == -1] = -1
all_pred.append(pred_labels)
all_target.append(target_labels)
all_pred_ious.append(matched_vals)
all_pred_scores.append(pred_scores)
false_negatives = (match_quality_matrix.max(axis=1) < iou).sum()
if false_negatives > 0: # false negatives
all_pred.append(np.ones(false_negatives) * -1)
all_target.append(np.zeros(false_negatives))
return all_pred, all_target, all_pred_ious, all_pred_scores
def plot_confusion_matrix(all_pred, all_target, iou: float, score:float):
if len(all_pred) > 0 and len(all_target) > 0:
cm = confusion_matrix(np.concatenate(all_target), np.concatenate(all_pred))
plt.figure()
ax = sns.heatmap(cm, annot=True, cbar=False)
ax.set_xlabel("Prediction")
ax.set_ylabel("Ground Truth")
ax.set_title(f"Confusion Matrix IoU {iou} and Score Threshold {score}")
return ax
else:
return None
def plot_joint_iou_score(all_pred_ious, all_pred_scores):
if isinstance(all_pred_ious, Sequence):
if len(all_pred_ious) == 0:
return None
all_pred_ious = np.concatenate(all_pred_ious)
if isinstance(all_pred_scores, Sequence):
if len(all_pred_scores) == 0:
return None
all_pred_scores = np.concatenate(all_pred_scores)
plt.figure()
f = sns.jointplot(x=all_pred_ious, y=all_pred_scores,
xlim=(-0.01, 1.01), ylim=(-0.01, 1.01), marginal_kws={"bins": 10},
kind='reg', scatter=True,
)
plt.plot([0, 1], [0, 1], 'g')
f.set_axis_labels("IoU", "Predicted Score")
f.ax_joint.axvline(x=0.1, c='r')
f.ax_joint.axvline(x=0.5, c='r')
f.fig.subplots_adjust(top=0.9)
f.fig.suptitle('Class independent predicted score over IoU plot', fontsize=16)
return f
def collect_boxes(prediction_dir: Path, gt_dir: Path, iou:float, score: float):
all_pred = []
all_target = []
all_boxes = []
i = 0
for f in prediction_dir.glob("*_boxes.pkl"):
case_id = f.stem.rsplit('_', 1)[0]
gt_data = np.load(str(gt_dir / f"{case_id}_boxes_gt.npz"), allow_pickle=True)
gt_boxes = gt_data["boxes"]
gt_classes = gt_data["classes"]
gt_ignore = [np.zeros(gt_boxes_img.shape[0]).reshape(-1, 1) for gt_boxes_img in [gt_boxes]]
case_result = load_pickle(f)
pred_boxes = case_result["pred_boxes"]
pred_scores = case_result["pred_scores"]
pred_labels = case_result["pred_labels"]
keep = pred_scores > score
pred_boxes, pred_scores, pred_labels = pred_boxes[keep], pred_scores[keep], pred_labels[keep]
# computation starts here
if gt_boxes.size == 0:
all_pred.append(pred_labels)
all_target.append(np.ones(len(pred_labels)) * -1)
all_boxes.append(pred_boxes)
elif pred_boxes.size == 0:
all_pred.append(np.ones(len(gt_classes)) * -1)
all_target.append(gt_classes)
all_boxes.append(gt_boxes)
else:
match_quality_matrix = box_iou_np(gt_boxes, pred_boxes)
matched_idxs = np.argmax(match_quality_matrix, axis=0)
matched_vals = np.max(match_quality_matrix, axis=0)
matched_idxs[matched_vals < iou] = -1
matched_gt_boxes_per_image = gt_boxes[matched_idxs.clip(min=0)]
target_labels = gt_classes[matched_idxs.clip(min=0)]
target_labels[matched_idxs == -1] = -1
all_pred.append(pred_labels)
all_target.append(target_labels)
all_boxes.append(pred_boxes)
unmatched_gt = (match_quality_matrix.max(axis=1) < iou)
false_negatives = unmatched_gt.sum()
if false_negatives > 0: # false negatives
all_pred.append(np.ones(false_negatives) * -1)
all_target.append(np.zeros(false_negatives))
all_boxes.append(gt_boxes[np.nonzero(unmatched_gt)[0]])
return all_pred, all_target, all_boxes
def plot_sizes(all_pred, all_target, all_boxes, iou, score):
if len(all_pred) == 0 or len(all_target) == 0:
return None, None
_all_pred = np.concatenate(all_pred)
_all_target = np.concatenate(all_target)
_all_boxes = np.concatenate([ab for ab in all_boxes if ab.size > 0])
dists = box_size_np(_all_boxes)
tp_mask = _all_pred == _all_target
fp_mask = (_all_pred != _all_target) * (_all_pred != -1)
fn_mask = (_all_pred != _all_target) * (_all_pred == -1)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(dists[tp_mask, 0], dists[tp_mask, 1], dists[tp_mask, 2], c='g', marker='o', label="tp")
ax.scatter(dists[fp_mask, 0], dists[fp_mask, 1], dists[fp_mask, 2], c='r', marker='x', label="fp")
ax.scatter(dists[fn_mask, 0], dists[fn_mask, 1], dists[fn_mask, 2], c='b', marker='^', label="fn")
ax.set_title(
f"IoU {iou} and Score Threshold {score}: tp {sum(tp_mask)} fp {sum(fp_mask)} fn {sum(fn_mask)}")
ax.set_xlabel('bounding box size axis 0')
ax.set_ylabel('bounding box size axis 1')
ax.set_zlabel('bounding box size axis 2')
ax.legend()
return fig, ax
def plot_sizes_bar(all_pred, all_target, all_boxes, iou, score,
max_bin: Optional[int] = None ):
if len(all_pred) == 0 or len(all_target) == 0:
return None, None
_all_pred = np.concatenate(all_pred)
_all_target = np.concatenate(all_target)
_all_boxes = np.concatenate([ab for ab in all_boxes if ab.size > 0])
dists = box_size_np(_all_boxes)
tp_mask = _all_pred == _all_target
fp_mask = (_all_pred != _all_target) * (_all_pred != -1)
fn_mask = (_all_pred != _all_target) * (_all_pred == -1)
fig = plt.figure()
# ax = fig.add_subplot(111)
data = {
"tp": dists[tp_mask, 0] + dists[tp_mask, 1] + dists[tp_mask, 2],
"fp": dists[fp_mask, 0] + dists[fp_mask, 1] + dists[fp_mask, 2],
"fn": dists[fn_mask, 0] + dists[fn_mask, 1] + dists[fn_mask, 2],
}
# plt.hist(x=[data["tp"], data["fp"], data["fn"]],
# bins=100, label=["tp", "fp", "fn"], stacked=False,
# histtype="step",
# )
# ax = plt.gca()
kwargs = {}
if max_bin is not None:
kwargs["binrange"] = [0, max_bin]
ax = sns.histplot(data=data, bins=100, element="step",
palette={"tp": "g", "fp": "r", "fn": "b"},
legend=True, fill=False, **kwargs
)
ax.set_title(
f"IoU {iou} and Score Threshold {score}: tp {sum(tp_mask)} fp {sum(fp_mask)} fn {sum(fn_mask)}")
ax.set_xlabel("box width + height ( + depth)")
ax.set_ylabel("Count")
return fig, ax
def run_analysis_suite(prediction_dir: Path, gt_dir: Path, save_dir: Path):
for iou, score in maybe_verbose_iterable(list(product([0.1, 0.5], [0.1, 0.5]))):
_save_dir = save_dir / f"iou_{iou}_score_{score}"
_save_dir.mkdir(parents=True, exist_ok=True)
found_predictions = list(prediction_dir.glob("*_boxes.pkl"))
logger.info(f"Found {len(found_predictions)} predictions for analysis")
df, analysis_ids = collect_overview(prediction_dir, gt_dir,
iou=iou, score=score,
max_num_fp_per_image=5,
top_n=10,
)
df.to_json(_save_dir / "analysis.json", indent=4, orient='index')
df.to_csv(_save_dir / "analysis.csv")
save_json(analysis_ids, _save_dir / "analysis_ids.json")
all_pred, all_target, all_pred_ious, all_pred_scores = collect_score_iou(
prediction_dir, gt_dir, iou=iou, score=score)
confusion_ax = plot_confusion_matrix(all_pred, all_target, iou=iou, score=score)
plt.savefig(_save_dir / "confusion_matrix.png")
plt.close()
iou_score_ax = plot_joint_iou_score(all_pred_ious, all_pred_scores)
plt.savefig(_save_dir / "joint_iou_score.png")
plt.close()
all_pred, all_target, all_boxes = collect_boxes(
prediction_dir, gt_dir, iou=iou, score=score)
sizes_fig, sizes_ax = plot_sizes(all_pred, all_target, all_boxes, iou=iou, score=score)
plt.savefig(_save_dir / "sizes.png")
with open(str(_save_dir / 'sizes.pkl'), "wb") as fp:
pickle.dump(sizes_fig, fp, protocol=4)
plt.close()
sizes_fig, sizes_ax = plot_sizes_bar(all_pred, all_target, all_boxes, iou=iou, score=score)
plt.savefig(_save_dir / "sizes_bar.png")
with open(str(_save_dir / 'sizes_bar.pkl'), "wb") as fp:
pickle.dump(sizes_fig, fp, protocol=4)
plt.close()
sizes_fig, sizes_ax = plot_sizes_bar(all_pred, all_target, all_boxes,
iou=iou, score=score, max_bin=100)
plt.savefig(_save_dir / "sizes_bar_100.png")
with open(str(_save_dir / 'sizes_bar_100.pkl'), "wb") as fp:
pickle.dump(sizes_fig, fp, protocol=4)
plt.close()
def convert_box_to_nii_meta(pred_boxes: Tensor,
pred_scores: Tensor,
pred_labels: Tensor,
props: dict,
) -> Tuple[sitk.Image, dict]:
instance_mask = np.zeros(props["original_size_of_raw_data"], dtype=np.uint8)
for instance_id, pbox in enumerate(pred_boxes, start=1):
mask_slicing = [slice(int(pbox[0]), int(pbox[2])),
slice(int(pbox[1]), int(pbox[3])),
]
if instance_mask.ndim == 3:
mask_slicing.append(slice(int(pbox[4]), int(pbox[5])))
instance_mask[tuple(mask_slicing)] = instance_id
logger.info(f"Created instance mask with {instance_mask.max()} instances.")
instance_mask_itk = sitk.GetImageFromArray(instance_mask)
instance_mask_itk.SetOrigin(props["itk_origin"])
instance_mask_itk.SetDirection(props["itk_direction"])
instance_mask_itk.SetSpacing(props["itk_spacing"])
prediction_meta = {idx: {"score": float(score), "label": int(label)}
for idx, (score, label) in enumerate(zip(pred_scores, pred_labels), start=1)}
return instance_mask_itk, prediction_meta
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from scipy.ndimage import label
from typing import Dict, Sequence, Union, Tuple, Optional
from nndet.io.transforms.instances import get_bbox_np
def seg2instances(seg: np.ndarray,
exclude_background: bool = True,
min_num_voxel: int = 0,
) -> Tuple[np.ndarray, Dict[int, int]]:
"""
Use connected components with ones matrix to created instance from segmentation
Args:
seg: semantic segmentation [spatial dims]
exclude_background: skips background class for the mapping
from instances to classes
min_num_voxel: minimum number of voxels of an instance
Returns:
np.ndarray: instance segmentation
Dict[int, int]: mapping from instances to classes
"""
structure = np.ones([3] * seg.ndim)
instances_temp, _ = label(seg, structure=structure)
instance_ids = np.unique(instances_temp)
if exclude_background:
instance_ids = instance_ids[instance_ids > 0]
instance_classes = {}
instances = np.zeros_like(instances_temp)
i = 1
for iid in instance_ids:
instance_binary_mask = instances_temp == iid
if min_num_voxel > 0:
if instance_binary_mask.sum() < min_num_voxel: # remove small instances
continue
instances[instance_binary_mask] = i # save instance to final mask
single_idx = np.argwhere(instance_binary_mask)[0] # select semantic class
semantic_class = int(seg[tuple(single_idx)])
instance_classes[int(i)] = semantic_class # save class
i = i + 1 # bump instance index
return instances, instance_classes
def remove_classes(seg: np.ndarray, rm_classes: Sequence[int], classes: Dict[int, int] = None,
background: int = 0) -> Union[np.ndarray, Tuple[np.ndarray, Dict[int, int]]]:
"""
Remove classes from segmentation (also works on instances
but instance ids may not be consecutive anymore)
Args:
seg: segmentation [spatial dims]
rm_classes: classes which should be removed
classes: optional mapping from instances from segmentation to classes
background: background value
Returns:
np.ndarray: segmentation where classes are removed
Dict[int, int]: updated mapping from instances to classes
"""
for rc in rm_classes:
seg[seg == rc] = background
if classes is not None:
classes.pop(rc)
if classes is None:
return seg
else:
return seg, classes
def reorder_classes(seg: np.ndarray, class_mapping: Dict[int, int]) -> np.ndarray:
"""
Reorders classes in segmentation
Args:
seg: segmentation
class_mapping: mapping from source id to new id
Returns:
np.ndarray: remapped segmentation
"""
for source_id, target_id in class_mapping.items():
seg[seg == source_id] = target_id
return seg
def compute_score_from_seg(instances: np.ndarray,
instance_classes: Dict[int, int],
probs: np.ndarray,
aggregation: str = "max",
) -> np.ndarray:
"""
Combine scores for each instance given an instance mask and instance logits
Args:
instances: instance segmentation [dims]; dims can be arbitrary dimensions
instance_classes: assign each instance id to a class (id -> class)
probs: predicted probabilities for each class [C, dims];
C = number of classes, dims need to have the same dimensions as
instances
aggregation: defines the aggregation method for the probabilities.
One of 'max', 'mean'
Returns:
Sequence[float]: Probability for each instance
"""
instance_classes = {int(key): int(item) for key, item in instance_classes.items()}
instance_ids = list(instance_classes.keys())
instance_scores = []
for iid in instance_ids:
ic = instance_classes[iid]
instance_mask = instances == iid
instance_probs = probs[ic][instance_mask]
if aggregation == "max":
_score = np.max(instance_probs)
elif aggregation == "mean":
_score = np.mean(instance_probs)
elif aggregation == "median":
_score = np.median(instance_probs)
elif aggregation == "percentile95":
_score = np.percentile(instance_probs, 95)
else:
raise ValueError(f"Aggregation {aggregation} is not aggregation")
instance_scores.append(_score)
return np.asarray(instance_scores)
def instance_results_from_seg(probs: np.ndarray,
aggregation: str,
stuff: Optional[Sequence[int]] = None,
min_num_voxel: int = 0,
min_threshold: Optional[float] = None,
) -> dict:
"""
Compute instance segmentation results from a semantic segmentation
argmax -> remove stuff classes -> connected components ->
aggregate score inside each instance
Args:
probs: Predicted probabilities for each class [C, dims];
C = number of classes, dims can be arbitrary dimensions
aggregation: defines the aggregation method for the probabilities.
One of 'max', 'mean'
stuff: stuff classes to be ignored during conversion.
min_num_voxel: minimum number of voxels of an instance
min_threshold: if None argmax is used. If a threshold is provided
it is used as a probability threshold for the foreground class.
if multiple foreground classes exceed the threshold, the
foreground class with the largest probability is selected.
Returns:
dict: predictions
`pred_instances`: instance segmentation [dims]
`pred_boxes`: predicted bounding boxes [2 * spatial dims]
`pred_labels`: predicted class for each instance/box
`pred_scores`: predicted score for each instance/box
"""
if min_threshold is not None:
if probs.shape[0] > 2:
fg_argmax = np.argmax(probs, axis=0)
fg_mask = np.max(probs[1:], axis=0) > min_threshold
seg = np.zeros_like(probs[0])
seg[fg_mask] = fg_argmax[fg_mask]
else:
seg = probs[1] > min_threshold
else:
seg = np.argmax(probs, axis=0)
if stuff is not None:
for s in stuff:
seg[seg == s] = 0
instances, instance_classes = seg2instances(seg,
exclude_background=True,
min_num_voxel=min_num_voxel,
)
instance_scores = compute_score_from_seg(instances, instance_classes, probs,
aggregation=aggregation)
instance_classes = {int(key): int(item) - 1 for key, item in instance_classes.items()}
tmp = get_bbox_np(instances[None], instance_classes)
instance_boxes = tmp["boxes"]
instance_classes_seq = tmp["classes"]
return {
"pred_instances": instances,
"pred_boxes": instance_boxes,
"pred_labels": instance_classes_seq,
"pred_scores": instance_scores,
}
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
from pathlib import Path
import yaml
from hydra.experimental import compose as hydra_compose
from nndet.io.paths import Pathlike, get_task
def load_dataset_info(task_dir: Pathlike) -> dict:
"""
Load dataset information from a given task directory
Args:
task_dir: path to directory of specific task e.g. ../Task12_LIDC
Returns:
dict: loaded dataset info. Typically includes:
`name` (str): name of dataset
`target_class` (str)
"""
task_dir = Path(task_dir)
yaml_path = task_dir / "dataset.yaml"
json_path = task_dir / "dataset.json"
if yaml_path.is_file():
with open(yaml_path, 'r') as f:
data = yaml.full_load(f)
elif json_path.is_file():
with open(json_path, "r") as f:
data = json.load(f)
else:
raise RuntimeError(f"Did not find dataset.json or dataset.yaml in {task_dir}")
return data
def compose(task, *args, models: bool = False, **kwargs) -> dict:
from omegaconf import OmegaConf
cfg = hydra_compose(*args, **kwargs)
OmegaConf.set_struct(cfg, False)
task_name = get_task(task, name=True, models=models)
cfg["task"] = task_name
cfg["data"] = load_dataset_info(get_task(task_name))
return cfg
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import copy
import pathlib
import warnings
import functools
from collections.abc import MutableMapping
from subprocess import PIPE, run
from omegaconf.omegaconf import OmegaConf
from tqdm import tqdm
from typing import Mapping, Sequence, Union, Callable, Any, Iterable
from loguru import logger
from contextlib import contextmanager
from typing import Union, Optional
from pathlib import Path
from git import Repo, InvalidGitRepositoryError
def env_guard(func):
"""
Contextmanager to check nnDetection environment variables
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
# we use print here because logging might not be initialized yet and
# this is intended as a user warning.
# det_data
if os.environ.get("det_data", None) is None:
raise RuntimeError(
"'det_data' environment variable not set. "
"Please refer to the installation instructions. "
)
# det_models
if os.environ.get("det_models", None) is None:
raise RuntimeError(
"'det_models' environment variable not set. "
"Please refer to the installation instructions. "
)
# OMP_NUM_THREADS
if os.environ.get("OMP_NUM_THREADS", None) is None:
raise RuntimeError(
"'OMP_NUM_THREADS' environment variable not set. "
"Please refer to the installation instructions. "
)
# det_num_threads
if os.environ.get("det_num_threads", None) is None:
warnings.warn(
"Warning: 'det_num_threads' environment variable not set. "
"Please read installation instructions again. "
"Training will not work properly.")
# det_verbose
if os.environ.get("det_verbose", None) is None:
print("'det_verbose' environment variable not set. "
"Continue in verbose mode.")
return func(*args, **kwargs)
return wrapper
def get_requirements():
"""
Get all installed packages from currently active environment
Returns:
str: list with all requirements
"""
command = ['pip', 'list']
result = run(command, stdout=PIPE, stderr=PIPE, universal_newlines=True)
assert not result.stderr, "stderr not empty"
return result.stdout
def write_requirements_to_file(path: Union[str, Path]) -> None:
"""
Write all installed packages from currently active environment to file
Args:
path (str): path to file (including file name and extension)
"""
with open(path, "w+") as f:
f.write(get_requirements())
def get_repo_info(path: Union[str, Path]):
"""
Parse repository information from path
Args:
path (str): path to repo. If path is not a repository it
searches parent folders for a repository
Returns:
dict: contains the current hash, gitdir and active branch
"""
def find_repo(findpath):
p = Path(findpath).absolute()
for p in [p, *p.parents]:
try:
repo = Repo(p)
break
except InvalidGitRepositoryError:
pass
else:
raise InvalidGitRepositoryError
return repo
repo = find_repo(path)
return {"hash": repo.head.commit.hexsha,
"gitdir": repo.git_dir,
"active_branch": repo.active_branch.name}
def maybe_verbose_iterable(data: Iterable, **kwargs) -> Iterable:
"""
If verbose flag of nndet is enabled, uses tqdm to create a
progress bar
Args:
data: iterable to wrap
**kwargs: keyword arguments passed to tqdm
Returns:
Iterable: maybe iterable with progress bar atteched to it
"""
if bool(int(os.getenv("det_verbose", 1))):
return tqdm(data, **kwargs)
else:
return data
def find_name(tdir: Union[str, Path], name: str,
postfix: Optional[str] = None) -> Path:
"""
Generates non exisitng names for files and dirs by adding a counter to
the end
Args:
tdir: target directory where name should be determined for
name: base name for string
postfix: postfix for name+counter. Defaults to None.
Raises:
RuntimeError: this function only works up to the counter of 1000
Returns:
Path: path to generated item
"""
if not isinstance(tdir, Path):
tdir = Path(tdir)
if not tdir.is_dir():
tdir.mkdir(parents=True)
if postfix is None:
postfix = ""
i=0
while True:
output_dir = tdir / f"{name}{i:03d}{postfix}"
if not output_dir.exists():
break
if i > 1000:
raise RuntimeError(f"Was not able to find name for tdir {tdir} and {name}")
i += 1
return output_dir
def log_experiment(cfg, plans: dict, trainer_class: Callable, stage: int,
repo_path: Union[pathlib.Path, str]):
"""
Use python logging module to log important settings
Args:
cfg:
config file
plans (dict):
current plan for training
trainer_class (callable):
class of current trainer
stage (int):
stage
repo_path (Union[pathlib.Path, str]):
path to repo/or file inside a repository to parse additional information
"""
logger.info("##################### Experiment Logging Start ##########################")
logger.info(f"Running in {cfg.mode.mode} mode.")
logger.info(f"I am running the following nnUNet: {cfg.exp.network}")
logger.info(f"My trainer class is: {trainer_class}")
logger.info("For that I will be using the following configuration:")
log_plan(plans)
logger.info(f"I am using stage {stage} from these plans")
logger.info(f"\nI am using data from this folder: {os.path.join(cfg.exp.dataset_dir, plans['data_identifier'])}")
logger.info("##################### Experiment Logging End ##########################")
log_git(repo_path)
def log_plan(plans):
"""
Use python logging module to log current settings
Args:
plans (dict):
current plan for training
"""
logger.info("num_classes: {}".format(plans['num_classes']))
logger.info("modalities: {}".format(plans['modalities']))
logger.info("use_mask_for_norm {}".format(plans['use_mask_for_norm']))
logger.info("keep_only_largest_region {}".format(plans['keep_only_largest_region']))
logger.info("min_region_size_per_class {}".format(plans['min_region_size_per_class']))
logger.info("min_size_per_class {}".format(plans['min_size_per_class']))
logger.info("normalization_schemes {}".format(plans['normalization_schemes']))
logger.info("stages...\n")
def log_git(repo_path: Union[pathlib.Path, str], repo_name: str = None):
"""
Use python logging module to log git information
Args:
repo_path (Union[pathlib.Path, str]): path to repo or file inside repository (repository is recursively searched)
"""
try:
git_info = get_repo_info(repo_path)
# logger.info("##################### GIT INFO Start ##########################")
# if repo_name is not None:
# logger.info(f"Repository name: {repo_name}")
# for key, item in git_info.items():
# logger.info(f"{key}: {item}")
# logger.info("##################### GIT INFO End ##########################")
return git_info
except Exception:
logger.error("Was not able to read git information, trying to continue without.")
return {}
def get_cls_name(obj: Any, package_name: bool = True) -> str:
"""
Get name of class from object
Args:
obj (Any): any object
package_name (bool): append package origin at the beginning
Returns:
str: name of class
"""
cls_name = str(obj.__class__)
# remove class prefix
cls_name = cls_name.split('\'')[1]
# split modules
cls_split = cls_name.split('.')
if len(cls_split) > 1:
cls_name = cls_split[0] + '.' + cls_split[-1] if package_name else cls_split[-1]
else:
cls_name = cls_split[0]
return cls_name
def log_error(fn: Callable) -> Any:
"""
Log error messages in hydra log when they occur
Args:
fn: function to wrap
Returns:
Any
"""
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Exception as e:
logger.error(str(e))
raise e
return wrapper
@contextmanager
def file_logger(path: Union[str, Path], level: str = "DEBUG", overwrite: bool = True):
"""
context manager to automatically clean up file logger
Args:
path: path to output file
level: logging level. Defaults to "Debug".
Yields:
None
"""
path = Path(path)
if overwrite and path.is_file():
os.remove(path)
logger_id = logger.add(path, level=level)
try:
yield None
finally:
logger.remove(logger_id)
def create_debug_plan(plan: dict) -> str:
_plan = copy.deepcopy(plan)
_plan.pop("dataset_properties", None)
_plan.pop("original_spacings", None)
_plan.pop("original_sizes", None)
return stringify_nested_dict(_plan)
def stringify_nested_dict(data: dict):
if isinstance(data, dict):
return {str(key): stringify_nested_dict(item) for key, item in data.items()}
elif isinstance(data, (list, tuple)):
return [stringify_nested_dict(item) for item in data]
else:
return str(data)
def flatten_mapping(
nested_mapping: Mapping,
sep: str = ".",
) -> Mapping[str, Any]:
_mapping = {}
for key, item in nested_mapping.items():
if isinstance(item, MutableMapping):
for _key, _item in flatten_mapping(item, sep=sep).items():
_mapping[str(key) + sep + str(_key)] = _item
else:
_mapping[str(key)] = item
return _mapping
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from pathlib import Path
import numpy as np
import SimpleITK as sitk
from itertools import product
from typing import Sequence, Union, Tuple
def create_circle_mask_itk(image_itk: sitk.Image,
world_centers: Sequence[Sequence[float]],
world_rads: Sequence[float],
ndim: int = 3,
) -> sitk.Image:
"""
Creates an itk image with circles defined by center points and radii
Args:
image_itk: original image (used for the coordinate frame)
world_centers: Sequence of center points in world coordiantes (x, y, z)
world_rads: Sequence of radii to use
ndim: number of spatial dimensions
Returns:
sitk.Image: mask with circles
"""
image_np = sitk.GetArrayFromImage(image_itk)
min_spacing = min(image_itk.GetSpacing())
if image_np.ndim > ndim:
image_np = image_np[0]
mask_np = np.zeros_like(image_np).astype(np.uint8)
for _id, (world_center, world_rad) in enumerate(zip(world_centers, world_rads), start=1):
check_rad = (world_rad / min_spacing) * 1.5 # add some buffer to it
bounds = []
center = image_itk.TransformPhysicalPointToContinuousIndex(world_center)[::-1]
for ax, c in enumerate(center):
bounds.append((
max(0, int(c - check_rad)),
min(mask_np.shape[ax], int(c + check_rad)),
))
coord_box = product(*[list(range(b[0], b[1])) for b in bounds])
# loop over every pixel position
for coord in coord_box:
world_coord = image_itk.TransformIndexToPhysicalPoint(tuple(reversed(coord))) # reverse order to x, y, z for sitk
dist = np.linalg.norm(np.array(world_coord) - np.array(world_center))
if dist <= world_rad:
mask_np[tuple(coord)] = _id
assert mask_np.max() == _id
mask_itk = sitk.GetImageFromArray(mask_np)
return copy_meta_data_itk(image_itk, mask_itk)
def copy_meta_data_itk(source: sitk.Image, target: sitk.Image) -> sitk.Image:
"""
Copy meta data between files
Args:
source: source file
target: target file
Returns:
sitk.Image: target file with copied meta data
"""
# for i in source.GetMetaDataKeys():
# target.SetMetaData(i, source.GetMetaData(i))
target.SetOrigin(source.GetOrigin())
target.SetDirection(source.GetDirection())
target.SetSpacing(source.GetSpacing())
return target
def load_sitk(path: Union[Path, str], **kwargs) -> sitk.Image:
"""
Functional interface to load image with sitk
Args:
path: path to file to load
Returns:
sitk.Image: loaded sitk image
"""
return sitk.ReadImage(str(path), **kwargs)
def load_sitk_as_array(path: Union[Path, str], **kwargs) -> Tuple[np.ndarray, dict]:
"""
Functional interface to load sitk image and convert it to an array
Args:
path: path to file to load
Returns:
np.ndarray: loaded image data
dict: loaded meta data
"""
img_itk = load_sitk(path, **kwargs)
meta = {key: img_itk.GetMetaData(key) for key in img_itk.GetMetaDataKeys()}
return sitk.GetArrayFromImage(img_itk), meta
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import shutil
import json
from itertools import repeat
from multiprocessing import Pool
import SimpleITK as sitk
import numpy as np
from nndet.utils.itk import load_sitk_as_array, load_sitk
from loguru import logger
from pathlib import Path
from typing import Sequence, Union
from nndet.io.load import save_json, load_json
from nndet.io.paths import get_case_ids_from_dir
from nndet.io.prepare import sitk_copy_metadata
from nndet.io.transforms.instances import instances_to_segmentation_np
Pathlike = Union[str, Path]
class Exporter:
"""
Helper to export datasets to nnunet
"""
def __init__(self,
data_info: dict,
tr_image_dir: Pathlike,
label_dir: Pathlike,
target_dir: Pathlike,
ts_image_dir: Pathlike = None,
export_stuff: bool = False,
processes: int = 6,
):
"""
Args:
data_info: dataset information. See :method:`export_dataset_info`.
Required keys: `modality`, `labels`
tr_image_dir: training data dir
label_dir: label data dir
target_dir: target directory
ts_image_dir: test data dir
export_stuff: export stuff segmentations
"""
self.data_info = data_info
self.tr_image_dir = Path(tr_image_dir)
self.label_dir = Path(label_dir)
self.target_dir = Path(target_dir)
self.export_stuff = export_stuff
self.processes = processes
if ts_image_dir is not None:
self.ts_image_dir = Path(ts_image_dir)
else:
self.ts_image_dir = None
def export(self):
"""
Export entire dataset
"""
self.export_images()
self.export_labels()
self.export_dataset_info()
def export_images(self):
"""
Export images
"""
# data can be copied directly
for img_dir in [self.tr_image_dir, self.ts_image_dir]:
if img_dir is None:
continue
image_target_dir = self.target_dir / img_dir.stem
logger.info(f"Copy data from {img_dir} to {image_target_dir}")
shutil.copytree(img_dir, image_target_dir)
def export_labels(self):
"""
Export labels
"""
case_ids = get_case_ids_from_dir(self.label_dir, remove_modality=False, pattern="*.json")
label_target_dir = self.target_dir / self.label_dir.stem
label_target_dir.mkdir(exist_ok=True, parents=True)
num_classes = len(self.data_info.get("labels", {}))
if num_classes == 0:
logger.warning(f"Did not find any fg classes.")
logger.info(f"Found {len(case_ids)} to process.")
logger.info(f"Export stuff: {self.export_stuff}")
if self.processes == 0:
logger.info("Using for loop to export labels")
for cid in case_ids:
self._export_label(cid, num_classes, label_target_dir)
else:
logger.info(f"Using pool with {self.processes} processes to export labels")
with Pool(processes=self.processes) as p:
p.starmap(self._export_label, zip(
case_ids, repeat(num_classes), repeat(label_target_dir)))
assert len(get_case_ids_from_dir(
label_target_dir, remove_modality=False, pattern="*.nii.gz")) == len(case_ids)
def _export_label(self, cid: str, num_classes: int, target_dir: Path):
logger.info(f"Processing {cid}")
meta = load_json(self.label_dir / f"{cid}.json")
instance_seg_itk = sitk.ReadImage(str(self.label_dir / f"{cid}.nii.gz"))
instance_seg = sitk.GetArrayFromImage(instance_seg_itk)
if np.any(np.isnan(instance_seg)):
logger.error(f"FOUND NAN IN {cid} LABEL")
# instance classes start form 0 which is background in nnUNet
seg = instances_to_segmentation_np(instance_seg,
meta["instances"],
add_background=True,
)
if num_classes > 0:
assert seg.max() <= num_classes, "Wrong class id, something went wrong."
if instance_seg.max() > 0:
assert seg.max() > 0, "Instance got lost, something went wrong"
assert np.all((instance_seg > 0) == (seg > 0)), "Something wrong with foreground"
assert np.all((instance_seg == 0) == (seg == 0)), "Something wrong with background"
if self.export_stuff:
# map stuff classes to: max(labels) + stuff_cls
stuff_seg = load_sitk_as_array(self.label_dir / f"{cid}_stuff.nii.gz")[0]
for i in range(1, stuff_seg.max() + 1):
seg[stuff_seg == i] = num_classes + i
seg_itk = sitk.GetImageFromArray(seg)
spacing = instance_seg_itk.GetSpacing()
seg_itk.SetSpacing(spacing)
origin = instance_seg_itk.GetOrigin()
seg_itk.SetOrigin(origin)
direction = instance_seg_itk.GetDirection()
seg_itk.SetDirection(direction)
sitk.WriteImage(seg_itk, str(target_dir / f"{cid}.nii.gz"))
def export_dataset_info(self):
"""
Export dataset settings (dataset.json for nnunet)
"""
self.target_dir.mkdir(exist_ok=True, parents=True)
dataset_info = {}
dataset_info["name"] = self.data_info.get("name", "unknown")
dataset_info["description"] = self.data_info.get("description", "unknown")
dataset_info["reference"] = self.data_info.get("reference", "unknown")
dataset_info["licence"] = self.data_info.get("licence", "unknown")
dataset_info["release"] = self.data_info.get("release", "unknown")
min_size = self.data_info.get("min_size", 0)
min_vol = self.data_info.get("min_vol", 0)
dataset_info["prep_info"] = f"min size: {min_size} ; min vol {min_vol}"
dataset_info["tensorImageSize"] = f"{self.data_info.get('dim', 3)}D"
# dataset_info["tensorImageSize"] = self.data_info.get("tensorImageSize", "4D")
dataset_info["modality"] = self.data_info.get("modalities", {})
if not dataset_info["modality"]:
logger.error("Did not find any modalities for dataset")
# +1 for seg classes because of background
dataset_info["labels"] = {"0": "background"}
instance_classes = self.data_info.get("labels", {})
if not instance_classes:
logger.error("Did not find any labels of dataset")
for _id, _class in instance_classes.items():
seg_id = int(_id) + 1
dataset_info["labels"][str(seg_id)] = _class
if self.export_stuff:
stuff_classes = self.data_info.get("labels_stuff", {})
num_instance_classes = len(instance_classes)
# copy stuff classes into nnuent dataset.json
stuff_classes = {
str(int(key) + num_instance_classes): item
for key, item in stuff_classes.items() if int(key) > 0
}
dataset_info["labels_stuff"] = stuff_classes
dataset_info["labels"].update(stuff_classes)
_case_ids = get_case_ids_from_dir(self.label_dir, remove_modality=False)
case_ids_tr = get_case_ids_from_dir(self.tr_image_dir, remove_modality=True)
assert len(set(_case_ids).union(case_ids_tr)) == len(_case_ids), "All training images need a label"
dataset_info["numTraining"] = len(case_ids_tr)
dataset_info["training"] = [
{"image": f"./imagesTr/{cid}.nii.gz", "label": f"./labelsTr/{cid}.nii.gz"}
for cid in case_ids_tr]
if self.ts_image_dir is not None:
case_ids_ts = get_case_ids_from_dir(self.ts_image_dir, remove_modality=True)
dataset_info["numTest"] = len(case_ids_ts)
dataset_info["test"] = [f"./imagesTs/{cid}.nii.gz" for cid in case_ids_ts]
else:
dataset_info["numTest"] = 0
dataset_info["test"] = []
save_json(dataset_info, self.target_dir / "dataset.json")
import inspect
import shutil
import os
from pathlib import Path
from typing import Callable
class Registry:
def __init__(self):
self.mapping = {}
def __getitem__(self, key):
return self.mapping[key]["fn"]
def register(self, fn: Callable):
self._register(fn.__name__, fn, inspect.getfile(fn))
return fn
def _register(self, name: str, fn: Callable, path: str):
if name in self.mapping:
raise TypeError(f"Name {name} already in registry.")
else:
self.mapping[name] = {"fn": fn, "path": path}
def get(self, name: str):
return self.mapping[name]["fn"]
def copy_registered(self, target: Path):
if not target.is_dir():
target.mkdir(parents=True)
paths = [e["path"] for e in self.mapping.values()]
paths = list(set(paths))
names = [p.split('nndet')[-1] for p in paths]
names = [n.replace(os.sep, '_').rsplit('.', 1)[0] for n in names]
names = [f"{n[1:]}.py" for n in names]
for name, path in zip(names, paths):
shutil.copy(path, str(target / name))
from collections import defaultdict
import torch
import re
import numpy as np
from torch import Tensor
from torch._six import container_abcs, string_classes, int_classes
from typing import Sequence, Union, Any, Mapping, Callable, List
np_str_obj_array_pattern = re.compile(r'[SaUO]')
def make_onehot_batch(labels: torch.Tensor, n_classes: torch.Tensor) -> torch.Tensor:
"""
Create onehot encoding of labels
Args:
labels: label tensor to enode [N, dims]
n_classes: number of classes
Returns:
Tensor: onehot encoded tensor [N, C, dims]; N: batch size,
C: number of classes, dims: spatial dimensions
"""
idx = labels.to(dtype=torch.long)
new_shape = [labels.shape[0], n_classes, *labels.shape[1:]]
labels_onehot = torch.zeros(*new_shape, device=labels.device,
dtype=labels.dtype)
labels_onehot.scatter_(1, idx.unsqueeze(dim=1), 1)
return labels_onehot
def to_dtype(inp: Any, dtype: Callable) -> Any:
"""
helper function to convert a sequence of arguments to a specific type
Args:
inp (Any): any object which can be converted by dtype, if sequence is detected
dtype is applied to individual arguments
dtype (Callable): callable to convert arguments
Returns:
Any: converted input
"""
if isinstance(inp, Sequence):
return type(inp)([dtype(i) for i in inp])
else:
return dtype(inp)
def to_device(inp: Union[Sequence[torch.Tensor], torch.Tensor, Mapping[str, torch.Tensor]],
device: Union[torch.device, str], detach: bool = False,
**kwargs) -> \
Union[Sequence[torch.Tensor], torch.Tensor, Mapping[str, torch.Tensor]]:
"""
Push tensor or sequence of tensors to device
Args:
inp (Union[Sequence[torch.Tensor], torch.Tensor]): tensor or sequence of tensors
device (Union[torch.device, str]): target device
detach: detach tensor before moving it to new device
**kwargs: keyword arguments passed to `to` function
Returns:
Union[Sequence[torch.Tensor], torch.Tensor]: tensor or seq. of tenors at target device
"""
if isinstance(inp, torch.Tensor):
if detach:
return inp.detach().to(device=device, **kwargs)
else:
return inp.to(device=device, **kwargs)
elif isinstance(inp, Sequence):
old_type = type(inp)
return old_type([to_device(i, device=device, detach=detach, **kwargs)
for i in inp])
elif isinstance(inp, Mapping):
old_type = type(inp)
return old_type({key: to_device(item, device=device, detach=detach, **kwargs)
for key, item in inp.items()})
else:
return inp
def to_numpy(inp: Union[Sequence[torch.Tensor], torch.Tensor, Any]) -> \
Union[Sequence[np.ndarray], np.ndarray, Any]:
"""
Convert a tensor or sequence of tensors to numpy array/s
Args:
inp (Union[Sequence[torch.Tensor], torch.Tensor]): tensor or sequence of tensors
Returns:
Union[Sequence[np.ndarray], np.ndarray]: array or seq. of arrays at target device
(non tensor entries are forwarded as they are)
"""
if isinstance(inp, (tuple, list)):
old_type = type(inp)
return old_type([to_numpy(i) for i in inp])
elif isinstance(inp, dict) and not isinstance(inp, defaultdict):
old_type = type(inp)
return old_type({k: to_numpy(i) for k, i in inp.items()})
elif isinstance(inp, torch.Tensor):
return inp.detach().cpu().numpy()
else:
return inp
def to_tensor(inp: Any) -> Any:
"""
Convert arrays, seq, mappings to torch tensor
https://github.com/pytorch/pytorch/blob/f522bde1213fcc46f2857f79be7b1b01ddf302a6/torch/utils/data/_utils/collate.py
Args:
inp: np.ndarrays, mappings, sequences are converted to tensors,
rest is passed through this function
Returns:
Any: converted data
"""
elem_type = type(inp)
if isinstance(inp, torch.Tensor):
return inp
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
# array of string classes and object
if elem_type.__name__ == 'ndarray' \
and np_str_obj_array_pattern.search(inp.dtype.str) is not None:
return inp
return torch.as_tensor(inp)
elif isinstance(inp, container_abcs.Mapping):
return {key: to_tensor(inp[key]) for key in inp}
elif isinstance(inp, tuple) and hasattr(inp, '_fields'): # namedtuple
return elem_type(*(to_tensor(d) for d in inp))
elif isinstance(inp, container_abcs.Sequence) and not isinstance(inp, string_classes):
return [to_tensor(d) for d in inp]
else:
return inp
def cat(t: Union[List[Tensor], Tensor], *args, **kwrags):
if not isinstance(t, (list, Tensor)):
raise ValueError(f"Can only concatenate lists and tensors.")
if isinstance(t, Tensor):
return t
elif len(t) == 1:
return t[0]
else:
return torch.cat(t, *args, **kwrags)
import time
from loguru import logger
class Timer:
def __init__(self, msg: str = "", verbose: bool = True):
self.verbose = verbose
self.msg = msg
self.tic: float = None
self.toc: float = None
self.dif: float = None
def __enter__(self):
self.tic = time.perf_counter()
def __exit__(self, exc_type, exc_val, exc_tb):
self.toc = time.perf_counter()
self.dif = self.toc - self.tic
if self.verbose:
logger.info(f"Operation '{self.msg}' took: {self.toc - self.tic} sec")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment