Commit 8607cb0f authored by Baumgartner, Michael's avatar Baumgartner, Michael
Browse files

Merge remote-tracking branch 'origin/0000_project' into main

parents 1044ace5 ca7e0f11
......@@ -15,11 +15,16 @@ limitations under the License.
"""
import torch
from loguru import logger
from torch import Tensor
from torch.cuda.amp import autocast
from torchvision.ops.boxes import nms as nms_2d
from nndet._C import nms as nms_gpu
try:
from nndet._C import nms as nms_gpu
except ImportError:
logger.warning("nnDetection was not build with GPU support!")
nms_gpu = None
from nndet.core.boxes.ops import box_iou
......
......@@ -35,6 +35,7 @@ class DetectionEvaluator(AbstractEvaluator):
def __init__(self,
metrics: Sequence[DetectionMetric],
iou_fn: Callable[[np.ndarray, np.ndarray], np.ndarray] = box_iou_np,
match_fn: Callable = matching_batch,
max_detections: int = 100,
):
"""
......@@ -46,6 +47,7 @@ class DetectionEvaluator(AbstractEvaluator):
max_detections (int): number of maximum detections per image (reduces computation)
"""
self.iou_fn = iou_fn
self.match_fn = match_fn
self.max_detections = max_detections
self.metrics = metrics
self.results_list = [] # store results of each image
......@@ -99,7 +101,7 @@ class DetectionEvaluator(AbstractEvaluator):
n = [0 if gt_boxes_img.size == 0 else gt_boxes_img.shape[0] for gt_boxes_img in gt_boxes]
gt_ignore = [np.zeros(_n).reshape(-1) for _n in n]
self.results_list.extend(matching_batch(
self.results_list.extend(self.match_fn(
self.iou_fn, self.iou_thresholds, pred_boxes=pred_boxes, pred_classes=pred_classes,
pred_scores=pred_scores, gt_boxes=gt_boxes, gt_classes=gt_classes, gt_ignore=gt_ignore,
max_detections=self.max_detections))
......
......@@ -22,7 +22,6 @@ from typing import Sequence, Optional
import torch
from loguru import logger
from nndet.ptmodule import MODULE_REGISTRY
from nndet.io.paths import Pathlike
......@@ -80,6 +79,8 @@ def load_final_model(
`model`: loaded model
`rank`: rank is always 0
"""
from nndet.ptmodule import MODULE_REGISTRY
assert num_models == 1, f"load_final_model only supports num_models=1, found {num_models}"
logger.info(f"Loading {identifier} model")
......@@ -123,6 +124,8 @@ def load_all_models(
`model`: loaded model
`rank`: rank of model
"""
from nndet.ptmodule import MODULE_REGISTRY
model_names = list(source_models.glob('*.ckpt'))
if not model_names:
raise RuntimeError(f"Did not find any models in {source_models}")
......
......@@ -57,8 +57,13 @@ class BaseModule(pl.LightningDataModule):
self.fold = fold
self.preprocessed_dir = self.data_dir.parent.parent
self.splits_file = self.augment_cfg.get(
"splits_final", "splits_final.pkl")
if "splits" in self.augment_cfg:
self.splits_file = self.augment_cfg["splits"]
elif "splits_final" in self.augment_cfg:
self.splits_file = self.augment_cfg["splits_final"]
else:
self.splits_file = "splits_final"
self.dataset_tr = {}
self.dataset_val = {}
......
......@@ -171,7 +171,11 @@ def get_case_id_from_file(file_name: str, remove_modality: bool = True) -> str:
Returns:
str: name of file without ending
"""
file_name = file_name.split('.')[0]
if file_name.endswith(".nii.gz"):
file_name = file_name.rsplit(".", 2)[0]
else:
file_name = file_name.rsplit(".", 1)[0]
if remove_modality:
file_name = file_name[:-5]
return file_name
......
......@@ -18,6 +18,7 @@ import os
import sys
import socket
import argparse
import importlib
from pathlib import Path
from datetime import datetime
from typing import List
......@@ -347,6 +348,10 @@ def _sweep(
cfg = OmegaConf.load(str(train_dir / "config.yaml"))
os.chdir(str(train_dir))
for imp in cfg.get("additional_imports", []):
print(f"Additional import found {imp}")
importlib.import_module(imp)
logger.remove()
logger.add(sys.stdout, format="{level} {message}", level="INFO")
log_file = Path(os.getcwd()) / "sweep.log"
......
......@@ -186,6 +186,19 @@ def unpack():
unpack_dataset(p, num_processes, False)
def hydra_searchpath():
from hydra import compose as hydra_compose
from hydra import initialize_config_module
initialize_config_module(config_module="nndet.conf")
cfg = hydra_compose("config.yaml", return_hydra_config=True)
print("Found config sources::")
print("----------------------")
for s in cfg.hydra.runtime.config_sources:
print(s)
def env():
import os
import torch
......
......@@ -129,6 +129,7 @@ setup(
'nndet_seg2nii = scripts.utils:seg2nii',
'nndet_unpack = scripts.utils:unpack',
'nndet_env = scripts.utils:env',
'nndet_searchpath = scripts.utils:hydra_searchpath'
]
},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment