Commit 0cdad98d authored by mibaumgartner's avatar mibaumgartner
Browse files

utils

parent 9853d3e4
......@@ -35,6 +35,7 @@ class DetectionEvaluator(AbstractEvaluator):
def __init__(self,
metrics: Sequence[DetectionMetric],
iou_fn: Callable[[np.ndarray, np.ndarray], np.ndarray] = box_iou_np,
match_fn: Callable = matching_batch,
max_detections: int = 100,
):
"""
......@@ -46,6 +47,7 @@ class DetectionEvaluator(AbstractEvaluator):
max_detections (int): number of maximum detections per image (reduces computation)
"""
self.iou_fn = iou_fn
self.match_fn = match_fn
self.max_detections = max_detections
self.metrics = metrics
self.results_list = [] # store results of each image
......@@ -99,7 +101,7 @@ class DetectionEvaluator(AbstractEvaluator):
n = [0 if gt_boxes_img.size == 0 else gt_boxes_img.shape[0] for gt_boxes_img in gt_boxes]
gt_ignore = [np.zeros(_n).reshape(-1) for _n in n]
self.results_list.extend(matching_batch(
self.results_list.extend(self.match_fn(
self.iou_fn, self.iou_thresholds, pred_boxes=pred_boxes, pred_classes=pred_classes,
pred_scores=pred_scores, gt_boxes=gt_boxes, gt_classes=gt_classes, gt_ignore=gt_ignore,
max_detections=self.max_detections))
......
......@@ -22,7 +22,6 @@ from typing import Sequence, Optional
import torch
from loguru import logger
from nndet.ptmodule import MODULE_REGISTRY
from nndet.io.paths import Pathlike
......@@ -80,6 +79,8 @@ def load_final_model(
`model`: loaded model
`rank`: rank is always 0
"""
from nndet.ptmodule import MODULE_REGISTRY
assert num_models == 1, f"load_final_model only supports num_models=1, found {num_models}"
logger.info(f"Loading {identifier} model")
......@@ -123,6 +124,8 @@ def load_all_models(
`model`: loaded model
`rank`: rank of model
"""
from nndet.ptmodule import MODULE_REGISTRY
model_names = list(source_models.glob('*.ckpt'))
if not model_names:
raise RuntimeError(f"Did not find any models in {source_models}")
......
......@@ -18,6 +18,7 @@ import os
import sys
import socket
import argparse
import importlib
from pathlib import Path
from datetime import datetime
from typing import List
......@@ -347,6 +348,10 @@ def _sweep(
cfg = OmegaConf.load(str(train_dir / "config.yaml"))
os.chdir(str(train_dir))
for imp in cfg.get("additional_imports", []):
print(f"Additional import found {imp}")
importlib.import_module(imp)
logger.remove()
logger.add(sys.stdout, format="{level} {message}", level="INFO")
log_file = Path(os.getcwd()) / "sweep.log"
......
......@@ -186,6 +186,19 @@ def unpack():
unpack_dataset(p, num_processes, False)
def hydra_searchpath():
from hydra import compose as hydra_compose
from hydra import initialize_config_module
initialize_config_module(config_module="nndet.conf")
cfg = hydra_compose("config.yaml", return_hydra_config=True)
print("Found config sources::")
print("----------------------")
for s in cfg.hydra.runtime.config_sources:
print(s)
def env():
import os
import torch
......
......@@ -129,6 +129,7 @@ setup(
'nndet_seg2nii = scripts.utils:seg2nii',
'nndet_unpack = scripts.utils:unpack',
'nndet_env = scripts.utils:env',
'nndet_searchpath = scripts.utils:hydra_searchpath'
]
},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment