Commit 8607cb0f authored by Baumgartner, Michael's avatar Baumgartner, Michael
Browse files

Merge remote-tracking branch 'origin/0000_project' into main

parents 1044ace5 ca7e0f11
...@@ -15,11 +15,16 @@ limitations under the License. ...@@ -15,11 +15,16 @@ limitations under the License.
""" """
import torch import torch
from loguru import logger
from torch import Tensor from torch import Tensor
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from torchvision.ops.boxes import nms as nms_2d from torchvision.ops.boxes import nms as nms_2d
from nndet._C import nms as nms_gpu try:
from nndet._C import nms as nms_gpu
except ImportError:
logger.warning("nnDetection was not build with GPU support!")
nms_gpu = None
from nndet.core.boxes.ops import box_iou from nndet.core.boxes.ops import box_iou
......
...@@ -35,6 +35,7 @@ class DetectionEvaluator(AbstractEvaluator): ...@@ -35,6 +35,7 @@ class DetectionEvaluator(AbstractEvaluator):
def __init__(self, def __init__(self,
metrics: Sequence[DetectionMetric], metrics: Sequence[DetectionMetric],
iou_fn: Callable[[np.ndarray, np.ndarray], np.ndarray] = box_iou_np, iou_fn: Callable[[np.ndarray, np.ndarray], np.ndarray] = box_iou_np,
match_fn: Callable = matching_batch,
max_detections: int = 100, max_detections: int = 100,
): ):
""" """
...@@ -46,6 +47,7 @@ class DetectionEvaluator(AbstractEvaluator): ...@@ -46,6 +47,7 @@ class DetectionEvaluator(AbstractEvaluator):
max_detections (int): number of maximum detections per image (reduces computation) max_detections (int): number of maximum detections per image (reduces computation)
""" """
self.iou_fn = iou_fn self.iou_fn = iou_fn
self.match_fn = match_fn
self.max_detections = max_detections self.max_detections = max_detections
self.metrics = metrics self.metrics = metrics
self.results_list = [] # store results of each image self.results_list = [] # store results of each image
...@@ -99,7 +101,7 @@ class DetectionEvaluator(AbstractEvaluator): ...@@ -99,7 +101,7 @@ class DetectionEvaluator(AbstractEvaluator):
n = [0 if gt_boxes_img.size == 0 else gt_boxes_img.shape[0] for gt_boxes_img in gt_boxes] n = [0 if gt_boxes_img.size == 0 else gt_boxes_img.shape[0] for gt_boxes_img in gt_boxes]
gt_ignore = [np.zeros(_n).reshape(-1) for _n in n] gt_ignore = [np.zeros(_n).reshape(-1) for _n in n]
self.results_list.extend(matching_batch( self.results_list.extend(self.match_fn(
self.iou_fn, self.iou_thresholds, pred_boxes=pred_boxes, pred_classes=pred_classes, self.iou_fn, self.iou_thresholds, pred_boxes=pred_boxes, pred_classes=pred_classes,
pred_scores=pred_scores, gt_boxes=gt_boxes, gt_classes=gt_classes, gt_ignore=gt_ignore, pred_scores=pred_scores, gt_boxes=gt_boxes, gt_classes=gt_classes, gt_ignore=gt_ignore,
max_detections=self.max_detections)) max_detections=self.max_detections))
......
...@@ -22,7 +22,6 @@ from typing import Sequence, Optional ...@@ -22,7 +22,6 @@ from typing import Sequence, Optional
import torch import torch
from loguru import logger from loguru import logger
from nndet.ptmodule import MODULE_REGISTRY
from nndet.io.paths import Pathlike from nndet.io.paths import Pathlike
...@@ -80,6 +79,8 @@ def load_final_model( ...@@ -80,6 +79,8 @@ def load_final_model(
`model`: loaded model `model`: loaded model
`rank`: rank is always 0 `rank`: rank is always 0
""" """
from nndet.ptmodule import MODULE_REGISTRY
assert num_models == 1, f"load_final_model only supports num_models=1, found {num_models}" assert num_models == 1, f"load_final_model only supports num_models=1, found {num_models}"
logger.info(f"Loading {identifier} model") logger.info(f"Loading {identifier} model")
...@@ -123,6 +124,8 @@ def load_all_models( ...@@ -123,6 +124,8 @@ def load_all_models(
`model`: loaded model `model`: loaded model
`rank`: rank of model `rank`: rank of model
""" """
from nndet.ptmodule import MODULE_REGISTRY
model_names = list(source_models.glob('*.ckpt')) model_names = list(source_models.glob('*.ckpt'))
if not model_names: if not model_names:
raise RuntimeError(f"Did not find any models in {source_models}") raise RuntimeError(f"Did not find any models in {source_models}")
......
...@@ -57,8 +57,13 @@ class BaseModule(pl.LightningDataModule): ...@@ -57,8 +57,13 @@ class BaseModule(pl.LightningDataModule):
self.fold = fold self.fold = fold
self.preprocessed_dir = self.data_dir.parent.parent self.preprocessed_dir = self.data_dir.parent.parent
self.splits_file = self.augment_cfg.get(
"splits_final", "splits_final.pkl") if "splits" in self.augment_cfg:
self.splits_file = self.augment_cfg["splits"]
elif "splits_final" in self.augment_cfg:
self.splits_file = self.augment_cfg["splits_final"]
else:
self.splits_file = "splits_final"
self.dataset_tr = {} self.dataset_tr = {}
self.dataset_val = {} self.dataset_val = {}
......
...@@ -171,7 +171,11 @@ def get_case_id_from_file(file_name: str, remove_modality: bool = True) -> str: ...@@ -171,7 +171,11 @@ def get_case_id_from_file(file_name: str, remove_modality: bool = True) -> str:
Returns: Returns:
str: name of file without ending str: name of file without ending
""" """
file_name = file_name.split('.')[0] if file_name.endswith(".nii.gz"):
file_name = file_name.rsplit(".", 2)[0]
else:
file_name = file_name.rsplit(".", 1)[0]
if remove_modality: if remove_modality:
file_name = file_name[:-5] file_name = file_name[:-5]
return file_name return file_name
......
...@@ -18,6 +18,7 @@ import os ...@@ -18,6 +18,7 @@ import os
import sys import sys
import socket import socket
import argparse import argparse
import importlib
from pathlib import Path from pathlib import Path
from datetime import datetime from datetime import datetime
from typing import List from typing import List
...@@ -347,6 +348,10 @@ def _sweep( ...@@ -347,6 +348,10 @@ def _sweep(
cfg = OmegaConf.load(str(train_dir / "config.yaml")) cfg = OmegaConf.load(str(train_dir / "config.yaml"))
os.chdir(str(train_dir)) os.chdir(str(train_dir))
for imp in cfg.get("additional_imports", []):
print(f"Additional import found {imp}")
importlib.import_module(imp)
logger.remove() logger.remove()
logger.add(sys.stdout, format="{level} {message}", level="INFO") logger.add(sys.stdout, format="{level} {message}", level="INFO")
log_file = Path(os.getcwd()) / "sweep.log" log_file = Path(os.getcwd()) / "sweep.log"
......
...@@ -186,6 +186,19 @@ def unpack(): ...@@ -186,6 +186,19 @@ def unpack():
unpack_dataset(p, num_processes, False) unpack_dataset(p, num_processes, False)
def hydra_searchpath():
from hydra import compose as hydra_compose
from hydra import initialize_config_module
initialize_config_module(config_module="nndet.conf")
cfg = hydra_compose("config.yaml", return_hydra_config=True)
print("Found config sources::")
print("----------------------")
for s in cfg.hydra.runtime.config_sources:
print(s)
def env(): def env():
import os import os
import torch import torch
......
...@@ -129,6 +129,7 @@ setup( ...@@ -129,6 +129,7 @@ setup(
'nndet_seg2nii = scripts.utils:seg2nii', 'nndet_seg2nii = scripts.utils:seg2nii',
'nndet_unpack = scripts.utils:unpack', 'nndet_unpack = scripts.utils:unpack',
'nndet_env = scripts.utils:env', 'nndet_env = scripts.utils:env',
'nndet_searchpath = scripts.utils:hydra_searchpath'
] ]
}, },
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment