Commit 0cdad98d authored by mibaumgartner's avatar mibaumgartner
Browse files

utils

parent 9853d3e4
...@@ -35,6 +35,7 @@ class DetectionEvaluator(AbstractEvaluator): ...@@ -35,6 +35,7 @@ class DetectionEvaluator(AbstractEvaluator):
def __init__(self, def __init__(self,
metrics: Sequence[DetectionMetric], metrics: Sequence[DetectionMetric],
iou_fn: Callable[[np.ndarray, np.ndarray], np.ndarray] = box_iou_np, iou_fn: Callable[[np.ndarray, np.ndarray], np.ndarray] = box_iou_np,
match_fn: Callable = matching_batch,
max_detections: int = 100, max_detections: int = 100,
): ):
""" """
...@@ -46,6 +47,7 @@ class DetectionEvaluator(AbstractEvaluator): ...@@ -46,6 +47,7 @@ class DetectionEvaluator(AbstractEvaluator):
max_detections (int): number of maximum detections per image (reduces computation) max_detections (int): number of maximum detections per image (reduces computation)
""" """
self.iou_fn = iou_fn self.iou_fn = iou_fn
self.match_fn = match_fn
self.max_detections = max_detections self.max_detections = max_detections
self.metrics = metrics self.metrics = metrics
self.results_list = [] # store results of each image self.results_list = [] # store results of each image
...@@ -99,7 +101,7 @@ class DetectionEvaluator(AbstractEvaluator): ...@@ -99,7 +101,7 @@ class DetectionEvaluator(AbstractEvaluator):
n = [0 if gt_boxes_img.size == 0 else gt_boxes_img.shape[0] for gt_boxes_img in gt_boxes] n = [0 if gt_boxes_img.size == 0 else gt_boxes_img.shape[0] for gt_boxes_img in gt_boxes]
gt_ignore = [np.zeros(_n).reshape(-1) for _n in n] gt_ignore = [np.zeros(_n).reshape(-1) for _n in n]
self.results_list.extend(matching_batch( self.results_list.extend(self.match_fn(
self.iou_fn, self.iou_thresholds, pred_boxes=pred_boxes, pred_classes=pred_classes, self.iou_fn, self.iou_thresholds, pred_boxes=pred_boxes, pred_classes=pred_classes,
pred_scores=pred_scores, gt_boxes=gt_boxes, gt_classes=gt_classes, gt_ignore=gt_ignore, pred_scores=pred_scores, gt_boxes=gt_boxes, gt_classes=gt_classes, gt_ignore=gt_ignore,
max_detections=self.max_detections)) max_detections=self.max_detections))
......
...@@ -22,7 +22,6 @@ from typing import Sequence, Optional ...@@ -22,7 +22,6 @@ from typing import Sequence, Optional
import torch import torch
from loguru import logger from loguru import logger
from nndet.ptmodule import MODULE_REGISTRY
from nndet.io.paths import Pathlike from nndet.io.paths import Pathlike
...@@ -80,6 +79,8 @@ def load_final_model( ...@@ -80,6 +79,8 @@ def load_final_model(
`model`: loaded model `model`: loaded model
`rank`: rank is always 0 `rank`: rank is always 0
""" """
from nndet.ptmodule import MODULE_REGISTRY
assert num_models == 1, f"load_final_model only supports num_models=1, found {num_models}" assert num_models == 1, f"load_final_model only supports num_models=1, found {num_models}"
logger.info(f"Loading {identifier} model") logger.info(f"Loading {identifier} model")
...@@ -123,6 +124,8 @@ def load_all_models( ...@@ -123,6 +124,8 @@ def load_all_models(
`model`: loaded model `model`: loaded model
`rank`: rank of model `rank`: rank of model
""" """
from nndet.ptmodule import MODULE_REGISTRY
model_names = list(source_models.glob('*.ckpt')) model_names = list(source_models.glob('*.ckpt'))
if not model_names: if not model_names:
raise RuntimeError(f"Did not find any models in {source_models}") raise RuntimeError(f"Did not find any models in {source_models}")
......
...@@ -18,6 +18,7 @@ import os ...@@ -18,6 +18,7 @@ import os
import sys import sys
import socket import socket
import argparse import argparse
import importlib
from pathlib import Path from pathlib import Path
from datetime import datetime from datetime import datetime
from typing import List from typing import List
...@@ -347,6 +348,10 @@ def _sweep( ...@@ -347,6 +348,10 @@ def _sweep(
cfg = OmegaConf.load(str(train_dir / "config.yaml")) cfg = OmegaConf.load(str(train_dir / "config.yaml"))
os.chdir(str(train_dir)) os.chdir(str(train_dir))
for imp in cfg.get("additional_imports", []):
print(f"Additional import found {imp}")
importlib.import_module(imp)
logger.remove() logger.remove()
logger.add(sys.stdout, format="{level} {message}", level="INFO") logger.add(sys.stdout, format="{level} {message}", level="INFO")
log_file = Path(os.getcwd()) / "sweep.log" log_file = Path(os.getcwd()) / "sweep.log"
......
...@@ -186,6 +186,19 @@ def unpack(): ...@@ -186,6 +186,19 @@ def unpack():
unpack_dataset(p, num_processes, False) unpack_dataset(p, num_processes, False)
def hydra_searchpath():
from hydra import compose as hydra_compose
from hydra import initialize_config_module
initialize_config_module(config_module="nndet.conf")
cfg = hydra_compose("config.yaml", return_hydra_config=True)
print("Found config sources::")
print("----------------------")
for s in cfg.hydra.runtime.config_sources:
print(s)
def env(): def env():
import os import os
import torch import torch
......
...@@ -129,6 +129,7 @@ setup( ...@@ -129,6 +129,7 @@ setup(
'nndet_seg2nii = scripts.utils:seg2nii', 'nndet_seg2nii = scripts.utils:seg2nii',
'nndet_unpack = scripts.utils:unpack', 'nndet_unpack = scripts.utils:unpack',
'nndet_env = scripts.utils:env', 'nndet_env = scripts.utils:env',
'nndet_searchpath = scripts.utils:hydra_searchpath'
] ]
}, },
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment