Commit 3446e932 authored by Andres Martinez Mora's avatar Andres Martinez Mora
Browse files

Remove model inference summary printout

parent e708c342
...@@ -68,7 +68,7 @@ from nndet.io.transforms import ( ...@@ -68,7 +68,7 @@ from nndet.io.transforms import (
Instances2Boxes, Instances2Boxes,
Instances2Segmentation, Instances2Segmentation,
FindInstances, FindInstances,
) )
class RetinaUNetModule(LightningBaseModuleSWA): class RetinaUNetModule(LightningBaseModuleSWA):
...@@ -84,15 +84,10 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -84,15 +84,10 @@ class RetinaUNetModule(LightningBaseModuleSWA):
head_sampler_cls = HardNegativeSamplerBatched head_sampler_cls = HardNegativeSamplerBatched
segmenter_cls = DiCESegmenter segmenter_cls = DiCESegmenter
def __init__(self, def __init__(self, model_cfg: dict, trainer_cfg: dict, plan: dict, **kwargs):
model_cfg: dict,
trainer_cfg: dict,
plan: dict,
**kwargs
):
""" """
RetinaUNet Lightning Module Skeleton RetinaUNet Lightning Module Skeleton
Args: Args:
model_cfg: model configuration. Check :method:`from_config_plan` model_cfg: model configuration. Check :method:`from_config_plan`
for more information for more information
...@@ -106,32 +101,34 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -106,32 +101,34 @@ class RetinaUNetModule(LightningBaseModuleSWA):
plan=plan, plan=plan,
) )
_classes = [f"class{c}" for c in range(plan["architecture"]["classifier_classes"])] _classes = [
f"class{c}" for c in range(plan["architecture"]["classifier_classes"])
]
self.box_evaluator = BoxEvaluator.create( self.box_evaluator = BoxEvaluator.create(
classes=_classes, classes=_classes,
fast=True, fast=True,
save_dir=None, save_dir=None,
) )
self.seg_evaluator = SegmentationEvaluator.create() self.seg_evaluator = SegmentationEvaluator.create()
self.pre_trafo = Compose( self.pre_trafo = Compose(
FindInstances( FindInstances(
instance_key="target", instance_key="target",
save_key="present_instances", save_key="present_instances",
), ),
Instances2Boxes( Instances2Boxes(
instance_key="target", instance_key="target",
map_key="instance_mapping", map_key="instance_mapping",
box_key="boxes", box_key="boxes",
class_key="classes", class_key="classes",
present_instances="present_instances", present_instances="present_instances",
), ),
Instances2Segmentation( Instances2Segmentation(
instance_key="target", instance_key="target",
map_key="instance_mapping", map_key="instance_mapping",
present_instances="present_instances", present_instances="present_instances",
) ),
) )
self.eval_score_key = "mAP_IoU_0.10_0.50_0.05_MaxDet_100" self.eval_score_key = "mAP_IoU_0.10_0.50_0.05_MaxDet_100"
...@@ -148,8 +145,8 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -148,8 +145,8 @@ class RetinaUNetModule(LightningBaseModuleSWA):
targets={ targets={
"target_boxes": batch["boxes"], "target_boxes": batch["boxes"],
"target_classes": batch["classes"], "target_classes": batch["classes"],
"target_seg": batch['target'][:, 0] # Remove channel dimension "target_seg": batch["target"][:, 0], # Remove channel dimension
}, },
evaluation=False, evaluation=False,
batch_num=batch_idx, batch_num=batch_idx,
) )
...@@ -165,10 +162,10 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -165,10 +162,10 @@ class RetinaUNetModule(LightningBaseModuleSWA):
with torch.no_grad(): with torch.no_grad():
batch = self.pre_trafo(**batch) batch = self.pre_trafo(**batch)
targets = { targets = {
"target_boxes": batch["boxes"], "target_boxes": batch["boxes"],
"target_classes": batch["classes"], "target_classes": batch["classes"],
"target_seg": batch['target'][:, 0] # Remove channel dimension "target_seg": batch["target"][:, 0], # Remove channel dimension
} }
losses, prediction = self.model.train_step( losses, prediction = self.model.train_step(
images=batch["data"], images=batch["data"],
targets=targets, targets=targets,
...@@ -178,8 +175,10 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -178,8 +175,10 @@ class RetinaUNetModule(LightningBaseModuleSWA):
loss = sum(losses.values()) loss = sum(losses.values())
self.evaluation_step(prediction=prediction, targets=targets) self.evaluation_step(prediction=prediction, targets=targets)
return {"loss": loss.detach().item(), return {
**{key: l.detach().item() for key, l in losses.items()}} "loss": loss.detach().item(),
**{key: l.detach().item() for key, l in losses.items()},
}
def evaluation_step( def evaluation_step(
self, self,
...@@ -223,7 +222,7 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -223,7 +222,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
gt_boxes=gt_boxes, gt_boxes=gt_boxes,
gt_classes=gt_classes, gt_classes=gt_classes,
gt_ignore=gt_ignore, gt_ignore=gt_ignore,
) )
pred_seg = to_numpy(prediction["pred_seg"]) pred_seg = to_numpy(prediction["pred_seg"])
gt_seg = to_numpy(targets["target_seg"]) gt_seg = to_numpy(targets["target_seg"])
...@@ -231,7 +230,7 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -231,7 +230,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
self.seg_evaluator.run_online_evaluation( self.seg_evaluator.run_online_evaluation(
seg_probs=pred_seg, seg_probs=pred_seg,
target=gt_seg, target=gt_seg,
) )
def training_epoch_end(self, training_step_outputs): def training_epoch_end(self, training_step_outputs):
""" """
...@@ -281,9 +280,11 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -281,9 +280,11 @@ class RetinaUNetModule(LightningBaseModuleSWA):
metric_scores, _ = self.box_evaluator.finish_online_evaluation() metric_scores, _ = self.box_evaluator.finish_online_evaluation()
self.box_evaluator.reset() self.box_evaluator.reset()
logger.info(f"mAP@0.1:0.5:0.05: {metric_scores['mAP_IoU_0.10_0.50_0.05_MaxDet_100']:0.3f} " logger.info(
f"AP@0.1: {metric_scores['AP_IoU_0.10_MaxDet_100']:0.3f} " f"mAP@0.1:0.5:0.05: {metric_scores['mAP_IoU_0.10_0.50_0.05_MaxDet_100']:0.3f} "
f"AP@0.5: {metric_scores['AP_IoU_0.50_MaxDet_100']:0.3f}") f"AP@0.1: {metric_scores['AP_IoU_0.10_MaxDet_100']:0.3f} "
f"AP@0.5: {metric_scores['AP_IoU_0.50_MaxDet_100']:0.3f}"
)
seg_scores, _ = self.seg_evaluator.finish_online_evaluation() seg_scores, _ = self.seg_evaluator.finish_online_evaluation()
self.seg_evaluator.reset() self.seg_evaluator.reset()
...@@ -292,7 +293,9 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -292,7 +293,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
logger.info(f"Proxy FG Dice: {seg_scores['seg_dice']:0.3f}") logger.info(f"Proxy FG Dice: {seg_scores['seg_dice']:0.3f}")
for key, item in metric_scores.items(): for key, item in metric_scores.items():
self.log(f'{key}', item, on_step=None, on_epoch=True, prog_bar=False, logger=True) self.log(
f"{key}", item, on_step=None, on_epoch=True, prog_bar=False, logger=True
)
def configure_optimizers(self): def configure_optimizers(self):
""" """
...@@ -301,39 +304,46 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -301,39 +304,46 @@ class RetinaUNetModule(LightningBaseModuleSWA):
schedule schedule
""" """
# configure optimizer # configure optimizer
logger.info(f"Running: initial_lr {self.trainer_cfg['initial_lr']} " logger.info(
f"weight_decay {self.trainer_cfg['weight_decay']} " f"Running: initial_lr {self.trainer_cfg['initial_lr']} "
f"SGD with momentum {self.trainer_cfg['sgd_momentum']} and " f"weight_decay {self.trainer_cfg['weight_decay']} "
f"nesterov {self.trainer_cfg['sgd_nesterov']}") f"SGD with momentum {self.trainer_cfg['sgd_momentum']} and "
wd_groups = get_params_no_wd_on_norm(self, weight_decay=self.trainer_cfg['weight_decay']) f"nesterov {self.trainer_cfg['sgd_nesterov']}"
)
wd_groups = get_params_no_wd_on_norm(
self, weight_decay=self.trainer_cfg["weight_decay"]
)
optimizer = torch.optim.SGD( optimizer = torch.optim.SGD(
wd_groups, wd_groups,
self.trainer_cfg["initial_lr"], self.trainer_cfg["initial_lr"],
weight_decay=self.trainer_cfg["weight_decay"], weight_decay=self.trainer_cfg["weight_decay"],
momentum=self.trainer_cfg["sgd_momentum"], momentum=self.trainer_cfg["sgd_momentum"],
nesterov=self.trainer_cfg["sgd_nesterov"], nesterov=self.trainer_cfg["sgd_nesterov"],
) )
# configure lr scheduler # configure lr scheduler
num_iterations = self.trainer_cfg["max_num_epochs"] * \ num_iterations = (
self.trainer_cfg["num_train_batches_per_epoch"] self.trainer_cfg["max_num_epochs"]
* self.trainer_cfg["num_train_batches_per_epoch"]
)
scheduler = LinearWarmupPolyLR( scheduler = LinearWarmupPolyLR(
optimizer=optimizer, optimizer=optimizer,
warm_iterations=self.trainer_cfg["warm_iterations"], warm_iterations=self.trainer_cfg["warm_iterations"],
warm_lr=self.trainer_cfg["warm_lr"], warm_lr=self.trainer_cfg["warm_lr"],
poly_gamma=self.trainer_cfg["poly_gamma"], poly_gamma=self.trainer_cfg["poly_gamma"],
num_iterations=num_iterations num_iterations=num_iterations,
) )
return [optimizer], {'scheduler': scheduler, 'interval': 'step'} return [optimizer], {"scheduler": scheduler, "interval": "step"}
@classmethod @classmethod
def from_config_plan(cls, def from_config_plan(
model_cfg: dict, cls,
plan_arch: dict, model_cfg: dict,
plan_anchors: dict, plan_arch: dict,
log_num_anchors: str = None, plan_anchors: dict,
**kwargs, log_num_anchors: str = None,
): **kwargs,
):
""" """
Create Configurable RetinaUNet Create Configurable RetinaUNet
...@@ -359,35 +369,46 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -359,35 +369,46 @@ class RetinaUNetModule(LightningBaseModuleSWA):
will be performed will be performed
**kwargs: **kwargs:
""" """
logger.info(f"Architecture overwrites: {model_cfg['plan_arch_overwrites']} " logger.info(
f"Anchor overwrites: {model_cfg['plan_anchors_overwrites']}") f"Architecture overwrites: {model_cfg['plan_arch_overwrites']} "
logger.info(f"Building architecture according to plan of {plan_arch.get('arch_name', 'not_found')}") f"Anchor overwrites: {model_cfg['plan_anchors_overwrites']}"
)
logger.info(
f"Building architecture according to plan of {plan_arch.get('arch_name', 'not_found')}"
)
plan_arch.update(model_cfg["plan_arch_overwrites"]) plan_arch.update(model_cfg["plan_arch_overwrites"])
plan_anchors.update(model_cfg["plan_anchors_overwrites"]) plan_anchors.update(model_cfg["plan_anchors_overwrites"])
logger.info(f"Start channels: {plan_arch['start_channels']}; " logger.info(
f"head channels: {plan_arch['head_channels']}; " f"Start channels: {plan_arch['start_channels']}; "
f"fpn channels: {plan_arch['fpn_channels']}") f"head channels: {plan_arch['head_channels']}; "
f"fpn channels: {plan_arch['fpn_channels']}"
)
_plan_anchors = copy.deepcopy(plan_anchors) _plan_anchors = copy.deepcopy(plan_anchors)
coder = BoxCoderND(weights=(1.,) * (plan_arch["dim"] * 2)) coder = BoxCoderND(weights=(1.0,) * (plan_arch["dim"] * 2))
s_param = False if ("aspect_ratios" in _plan_anchors) and \ s_param = (
(_plan_anchors["aspect_ratios"] is not None) else True False
anchor_generator = get_anchor_generator( if ("aspect_ratios" in _plan_anchors)
plan_arch["dim"], s_param=s_param)(**_plan_anchors) and (_plan_anchors["aspect_ratios"] is not None)
else True
)
anchor_generator = get_anchor_generator(plan_arch["dim"], s_param=s_param)(
**_plan_anchors
)
encoder = cls._build_encoder( encoder = cls._build_encoder(
plan_arch=plan_arch, plan_arch=plan_arch,
model_cfg=model_cfg, model_cfg=model_cfg,
) )
decoder = cls._build_decoder( decoder = cls._build_decoder(
encoder=encoder, encoder=encoder,
plan_arch=plan_arch, plan_arch=plan_arch,
model_cfg=model_cfg, model_cfg=model_cfg,
) )
matcher = cls.matcher_cls( matcher = cls.matcher_cls(
similarity_fn=box_iou, similarity_fn=box_iou,
**model_cfg["matcher_kwargs"], **model_cfg["matcher_kwargs"],
) )
classifier = cls._build_head_classifier( classifier = cls._build_head_classifier(
plan_arch=plan_arch, plan_arch=plan_arch,
...@@ -404,7 +425,7 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -404,7 +425,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
model_cfg=model_cfg, model_cfg=model_cfg,
classifier=classifier, classifier=classifier,
regressor=regressor, regressor=regressor,
coder=coder coder=coder,
) )
segmenter = cls._build_segmenter( segmenter = cls._build_segmenter(
plan_arch=plan_arch, plan_arch=plan_arch,
...@@ -418,13 +439,13 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -418,13 +439,13 @@ class RetinaUNetModule(LightningBaseModuleSWA):
remove_small_boxes = plan_arch.get("remove_small_boxes", 0.01) remove_small_boxes = plan_arch.get("remove_small_boxes", 0.01)
nms_thresh = plan_arch.get("nms_thresh", 0.6) nms_thresh = plan_arch.get("nms_thresh", 0.6)
logger.info(f"Model Inference Summary: \n" # logger.info(f"Model Inference Summary: \n"
f"detections_per_img: {detections_per_img} \n" # f"detections_per_img: {detections_per_img} \n"
f"score_thresh: {score_thresh} \n" # f"score_thresh: {score_thresh} \n"
f"topk_candidates: {topk_candidates} \n" # f"topk_candidates: {topk_candidates} \n"
f"remove_small_boxes: {remove_small_boxes} \n" # f"remove_small_boxes: {remove_small_boxes} \n"
f"nms_thresh: {nms_thresh}", # f"nms_thresh: {nms_thresh}",
) # )
return BaseRetinaNet( return BaseRetinaNet(
dim=plan_arch["dim"], dim=plan_arch["dim"],
...@@ -461,7 +482,9 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -461,7 +482,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
EncoderType: encoder instance EncoderType: encoder instance
""" """
conv = Generator(cls.base_conv_cls, plan_arch["dim"]) conv = Generator(cls.base_conv_cls, plan_arch["dim"])
logger.info(f"Building:: encoder {cls.encoder_cls.__name__}: {model_cfg['encoder_kwargs']} ") logger.info(
f"Building:: encoder {cls.encoder_cls.__name__}: {model_cfg['encoder_kwargs']} "
)
encoder = cls.encoder_cls( encoder = cls.encoder_cls(
conv=conv, conv=conv,
conv_kernels=plan_arch["conv_kernels"], conv_kernels=plan_arch["conv_kernels"],
...@@ -471,7 +494,7 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -471,7 +494,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
start_channels=plan_arch["start_channels"], start_channels=plan_arch["start_channels"],
stage_kwargs=None, stage_kwargs=None,
max_channels=plan_arch.get("max_channels", 320), max_channels=plan_arch.get("max_channels", 320),
**model_cfg['encoder_kwargs'], **model_cfg["encoder_kwargs"],
) )
return encoder return encoder
...@@ -493,7 +516,9 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -493,7 +516,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
DecoderType: decoder instance DecoderType: decoder instance
""" """
conv = Generator(cls.base_conv_cls, plan_arch["dim"]) conv = Generator(cls.base_conv_cls, plan_arch["dim"])
logger.info(f"Building:: decoder {cls.decoder_cls.__name__}: {model_cfg['decoder_kwargs']}") logger.info(
f"Building:: decoder {cls.decoder_cls.__name__}: {model_cfg['decoder_kwargs']}"
)
decoder = cls.decoder_cls( decoder = cls.decoder_cls(
conv=conv, conv=conv,
conv_kernels=plan_arch["conv_kernels"], conv_kernels=plan_arch["conv_kernels"],
...@@ -501,7 +526,7 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -501,7 +526,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
in_channels=encoder.get_channels(), in_channels=encoder.get_channels(),
decoder_levels=plan_arch["decoder_levels"], decoder_levels=plan_arch["decoder_levels"],
fixed_out_channels=plan_arch["fpn_channels"], fixed_out_channels=plan_arch["fpn_channels"],
**model_cfg['decoder_kwargs'], **model_cfg["decoder_kwargs"],
) )
return decoder return decoder
...@@ -525,7 +550,7 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -525,7 +550,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
""" """
conv = Generator(cls.head_conv_cls, plan_arch["dim"]) conv = Generator(cls.head_conv_cls, plan_arch["dim"])
name = cls.head_classifier_cls.__name__ name = cls.head_classifier_cls.__name__
kwargs = model_cfg['head_classifier_kwargs'] kwargs = model_cfg["head_classifier_kwargs"]
logger.info(f"Building:: classifier {name}: {kwargs}") logger.info(f"Building:: classifier {name}: {kwargs}")
classifier = cls.head_classifier_cls( classifier = cls.head_classifier_cls(
...@@ -559,7 +584,7 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -559,7 +584,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
""" """
conv = Generator(cls.head_conv_cls, plan_arch["dim"]) conv = Generator(cls.head_conv_cls, plan_arch["dim"])
name = cls.head_regressor_cls.__name__ name = cls.head_regressor_cls.__name__
kwargs = model_cfg['head_regressor_kwargs'] kwargs = model_cfg["head_regressor_kwargs"]
logger.info(f"Building:: regressor {name}: {kwargs}") logger.info(f"Building:: regressor {name}: {kwargs}")
regressor = cls.head_regressor_cls( regressor = cls.head_regressor_cls(
...@@ -595,12 +620,14 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -595,12 +620,14 @@ class RetinaUNetModule(LightningBaseModuleSWA):
HeadType: instantiated head HeadType: instantiated head
""" """
head_name = cls.head_cls.__name__ head_name = cls.head_cls.__name__
head_kwargs = model_cfg['head_kwargs'] head_kwargs = model_cfg["head_kwargs"]
sampler_name = cls.head_sampler_cls.__name__ sampler_name = cls.head_sampler_cls.__name__
sampler_kwargs = model_cfg['head_sampler_kwargs'] sampler_kwargs = model_cfg["head_sampler_kwargs"]
logger.info(f"Building:: head {head_name}: {head_kwargs} " logger.info(
f"sampler {sampler_name}: {sampler_kwargs}") f"Building:: head {head_name}: {head_kwargs} "
f"sampler {sampler_name}: {sampler_kwargs}"
)
sampler = cls.head_sampler_cls(**sampler_kwargs) sampler = cls.head_sampler_cls(**sampler_kwargs)
head = cls.head_cls( head = cls.head_cls(
classifier=classifier, classifier=classifier,
...@@ -632,7 +659,7 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -632,7 +659,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
""" """
if cls.segmenter_cls is not None: if cls.segmenter_cls is not None:
name = cls.segmenter_cls.__name__ name = cls.segmenter_cls.__name__
kwargs = model_cfg['segmenter_kwargs'] kwargs = model_cfg["segmenter_kwargs"]
conv = Generator(cls.base_conv_cls, plan_arch["dim"]) conv = Generator(cls.base_conv_cls, plan_arch["dim"])
logger.info(f"Building:: segmenter {name} {kwargs}") logger.info(f"Building:: segmenter {name} {kwargs}")
...@@ -661,20 +688,21 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -661,20 +688,21 @@ class RetinaUNetModule(LightningBaseModuleSWA):
3: { 3: {
"boxes": BoxEnsemblerSelective, "boxes": BoxEnsemblerSelective,
"seg": SegmentationEnsembler, "seg": SegmentationEnsembler,
} },
} }
if dim == 2: if dim == 2:
raise NotImplementedError raise NotImplementedError
return _lookup[dim][key] return _lookup[dim][key]
@classmethod @classmethod
def get_predictor(cls, def get_predictor(
plan: Dict, cls,
models: Sequence[RetinaUNetModule], plan: Dict,
num_tta_transforms: int = None, models: Sequence[RetinaUNetModule],
do_seg: bool = False, num_tta_transforms: int = None,
**kwargs, do_seg: bool = False,
) -> Predictor: **kwargs,
) -> Predictor:
# process plan # process plan
crop_size = plan["patch_size"] crop_size = plan["patch_size"]
batch_size = plan["batch_size"] batch_size = plan["batch_size"]
...@@ -684,14 +712,19 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -684,14 +712,19 @@ class RetinaUNetModule(LightningBaseModuleSWA):
num_tta_transforms = 8 if plan["network_dim"] == 3 else 4 num_tta_transforms = 8 if plan["network_dim"] == 3 else 4
# setup # setup
tta_transforms, tta_inverse_transforms = \ tta_transforms, tta_inverse_transforms = get_tta_transforms(
get_tta_transforms(num_tta_transforms, True) num_tta_transforms, True
logger.info(f"Using {len(tta_transforms)} tta transformations for prediction (one dummy trafo).") )
logger.info(
ensembler = {"boxes": partial( f"Using {len(tta_transforms)} tta transformations for prediction (one dummy trafo)."
cls.get_ensembler_cls(key="boxes", dim=plan["network_dim"]).from_case, )
parameters=inferene_plan,
)} ensembler = {
"boxes": partial(
cls.get_ensembler_cls(key="boxes", dim=plan["network_dim"]).from_case,
parameters=inferene_plan,
)
}
if do_seg: if do_seg:
ensembler["seg"] = partial( ensembler["seg"] = partial(
cls.get_ensembler_cls(key="seg", dim=plan["network_dim"]).from_case, cls.get_ensembler_cls(key="seg", dim=plan["network_dim"]).from_case,
...@@ -705,20 +738,21 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -705,20 +738,21 @@ class RetinaUNetModule(LightningBaseModuleSWA):
tta_inverse_transforms=tta_inverse_transforms, tta_inverse_transforms=tta_inverse_transforms,
batch_size=batch_size, batch_size=batch_size,
**kwargs, **kwargs,
) )
if plan["network_dim"] == 2: if plan["network_dim"] == 2:
raise NotImplementedError raise NotImplementedError
predictor.pre_transform = Inference2D(["data"]) predictor.pre_transform = Inference2D(["data"])
return predictor return predictor
def sweep(self, def sweep(
cfg: dict, self,
save_dir: os.PathLike, cfg: dict,
train_data_dir: os.PathLike, save_dir: os.PathLike,
case_ids: Sequence[str], train_data_dir: os.PathLike,
run_prediction: bool = True, case_ids: Sequence[str],
**kwargs, run_prediction: bool = True,
) -> Dict[str, Any]: **kwargs,
) -> Dict[str, Any]:
""" """
Sweep detection parameters to find the best predictions Sweep detection parameters to find the best predictions
...@@ -764,10 +798,12 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -764,10 +798,12 @@ class RetinaUNetModule(LightningBaseModuleSWA):
save_state=True, save_state=True,
model_fn=get_loader_fn(mode=self.trainer_cfg.get("sweep_ckpt", "last")), model_fn=get_loader_fn(mode=self.trainer_cfg.get("sweep_ckpt", "last")),
**kwargs, **kwargs,
) )
logger.info("Start parameter sweep...") logger.info("Start parameter sweep...")
ensembler_cls = self.get_ensembler_cls(key="boxes", dim=self.plan["network_dim"]) ensembler_cls = self.get_ensembler_cls(
key="boxes", dim=self.plan["network_dim"]
)
sweeper = BoxSweeper( sweeper = BoxSweeper(
classes=[item for _, item in cfg["data"]["labels"].items()], classes=[item for _, item in cfg["data"]["labels"].items()],
pred_dir=prediction_dir, pred_dir=prediction_dir,
...@@ -775,6 +811,6 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -775,6 +811,6 @@ class RetinaUNetModule(LightningBaseModuleSWA):
target_metric=self.eval_score_key, target_metric=self.eval_score_key,
ensembler_cls=ensembler_cls, ensembler_cls=ensembler_cls,
save_dir=_save_dir, save_dir=_save_dir,
) )
inference_plan = sweeper.run_postprocessing_sweep() inference_plan = sweeper.run_postprocessing_sweep()
return inference_plan return inference_plan
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment