Commit 3446e932 authored by Andres Martinez Mora's avatar Andres Martinez Mora
Browse files

Remove model inference summary printout

parent e708c342
...@@ -68,7 +68,7 @@ from nndet.io.transforms import ( ...@@ -68,7 +68,7 @@ from nndet.io.transforms import (
Instances2Boxes, Instances2Boxes,
Instances2Segmentation, Instances2Segmentation,
FindInstances, FindInstances,
) )
class RetinaUNetModule(LightningBaseModuleSWA): class RetinaUNetModule(LightningBaseModuleSWA):
...@@ -84,12 +84,7 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -84,12 +84,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
head_sampler_cls = HardNegativeSamplerBatched head_sampler_cls = HardNegativeSamplerBatched
segmenter_cls = DiCESegmenter segmenter_cls = DiCESegmenter
def __init__(self, def __init__(self, model_cfg: dict, trainer_cfg: dict, plan: dict, **kwargs):
model_cfg: dict,
trainer_cfg: dict,
plan: dict,
**kwargs
):
""" """
RetinaUNet Lightning Module Skeleton RetinaUNet Lightning Module Skeleton
...@@ -106,7 +101,9 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -106,7 +101,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
plan=plan, plan=plan,
) )
_classes = [f"class{c}" for c in range(plan["architecture"]["classifier_classes"])] _classes = [
f"class{c}" for c in range(plan["architecture"]["classifier_classes"])
]
self.box_evaluator = BoxEvaluator.create( self.box_evaluator = BoxEvaluator.create(
classes=_classes, classes=_classes,
fast=True, fast=True,
...@@ -130,7 +127,7 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -130,7 +127,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
instance_key="target", instance_key="target",
map_key="instance_mapping", map_key="instance_mapping",
present_instances="present_instances", present_instances="present_instances",
) ),
) )
self.eval_score_key = "mAP_IoU_0.10_0.50_0.05_MaxDet_100" self.eval_score_key = "mAP_IoU_0.10_0.50_0.05_MaxDet_100"
...@@ -148,7 +145,7 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -148,7 +145,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
targets={ targets={
"target_boxes": batch["boxes"], "target_boxes": batch["boxes"],
"target_classes": batch["classes"], "target_classes": batch["classes"],
"target_seg": batch['target'][:, 0] # Remove channel dimension "target_seg": batch["target"][:, 0], # Remove channel dimension
}, },
evaluation=False, evaluation=False,
batch_num=batch_idx, batch_num=batch_idx,
...@@ -167,7 +164,7 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -167,7 +164,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
targets = { targets = {
"target_boxes": batch["boxes"], "target_boxes": batch["boxes"],
"target_classes": batch["classes"], "target_classes": batch["classes"],
"target_seg": batch['target'][:, 0] # Remove channel dimension "target_seg": batch["target"][:, 0], # Remove channel dimension
} }
losses, prediction = self.model.train_step( losses, prediction = self.model.train_step(
images=batch["data"], images=batch["data"],
...@@ -178,8 +175,10 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -178,8 +175,10 @@ class RetinaUNetModule(LightningBaseModuleSWA):
loss = sum(losses.values()) loss = sum(losses.values())
self.evaluation_step(prediction=prediction, targets=targets) self.evaluation_step(prediction=prediction, targets=targets)
return {"loss": loss.detach().item(), return {
**{key: l.detach().item() for key, l in losses.items()}} "loss": loss.detach().item(),
**{key: l.detach().item() for key, l in losses.items()},
}
def evaluation_step( def evaluation_step(
self, self,
...@@ -281,9 +280,11 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -281,9 +280,11 @@ class RetinaUNetModule(LightningBaseModuleSWA):
metric_scores, _ = self.box_evaluator.finish_online_evaluation() metric_scores, _ = self.box_evaluator.finish_online_evaluation()
self.box_evaluator.reset() self.box_evaluator.reset()
logger.info(f"mAP@0.1:0.5:0.05: {metric_scores['mAP_IoU_0.10_0.50_0.05_MaxDet_100']:0.3f} " logger.info(
f"mAP@0.1:0.5:0.05: {metric_scores['mAP_IoU_0.10_0.50_0.05_MaxDet_100']:0.3f} "
f"AP@0.1: {metric_scores['AP_IoU_0.10_MaxDet_100']:0.3f} " f"AP@0.1: {metric_scores['AP_IoU_0.10_MaxDet_100']:0.3f} "
f"AP@0.5: {metric_scores['AP_IoU_0.50_MaxDet_100']:0.3f}") f"AP@0.5: {metric_scores['AP_IoU_0.50_MaxDet_100']:0.3f}"
)
seg_scores, _ = self.seg_evaluator.finish_online_evaluation() seg_scores, _ = self.seg_evaluator.finish_online_evaluation()
self.seg_evaluator.reset() self.seg_evaluator.reset()
...@@ -292,7 +293,9 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -292,7 +293,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
logger.info(f"Proxy FG Dice: {seg_scores['seg_dice']:0.3f}") logger.info(f"Proxy FG Dice: {seg_scores['seg_dice']:0.3f}")
for key, item in metric_scores.items(): for key, item in metric_scores.items():
self.log(f'{key}', item, on_step=None, on_epoch=True, prog_bar=False, logger=True) self.log(
f"{key}", item, on_step=None, on_epoch=True, prog_bar=False, logger=True
)
def configure_optimizers(self): def configure_optimizers(self):
""" """
...@@ -301,11 +304,15 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -301,11 +304,15 @@ class RetinaUNetModule(LightningBaseModuleSWA):
schedule schedule
""" """
# configure optimizer # configure optimizer
logger.info(f"Running: initial_lr {self.trainer_cfg['initial_lr']} " logger.info(
f"Running: initial_lr {self.trainer_cfg['initial_lr']} "
f"weight_decay {self.trainer_cfg['weight_decay']} " f"weight_decay {self.trainer_cfg['weight_decay']} "
f"SGD with momentum {self.trainer_cfg['sgd_momentum']} and " f"SGD with momentum {self.trainer_cfg['sgd_momentum']} and "
f"nesterov {self.trainer_cfg['sgd_nesterov']}") f"nesterov {self.trainer_cfg['sgd_nesterov']}"
wd_groups = get_params_no_wd_on_norm(self, weight_decay=self.trainer_cfg['weight_decay']) )
wd_groups = get_params_no_wd_on_norm(
self, weight_decay=self.trainer_cfg["weight_decay"]
)
optimizer = torch.optim.SGD( optimizer = torch.optim.SGD(
wd_groups, wd_groups,
self.trainer_cfg["initial_lr"], self.trainer_cfg["initial_lr"],
...@@ -315,19 +322,22 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -315,19 +322,22 @@ class RetinaUNetModule(LightningBaseModuleSWA):
) )
# configure lr scheduler # configure lr scheduler
num_iterations = self.trainer_cfg["max_num_epochs"] * \ num_iterations = (
self.trainer_cfg["num_train_batches_per_epoch"] self.trainer_cfg["max_num_epochs"]
* self.trainer_cfg["num_train_batches_per_epoch"]
)
scheduler = LinearWarmupPolyLR( scheduler = LinearWarmupPolyLR(
optimizer=optimizer, optimizer=optimizer,
warm_iterations=self.trainer_cfg["warm_iterations"], warm_iterations=self.trainer_cfg["warm_iterations"],
warm_lr=self.trainer_cfg["warm_lr"], warm_lr=self.trainer_cfg["warm_lr"],
poly_gamma=self.trainer_cfg["poly_gamma"], poly_gamma=self.trainer_cfg["poly_gamma"],
num_iterations=num_iterations num_iterations=num_iterations,
) )
return [optimizer], {'scheduler': scheduler, 'interval': 'step'} return [optimizer], {"scheduler": scheduler, "interval": "step"}
@classmethod @classmethod
def from_config_plan(cls, def from_config_plan(
cls,
model_cfg: dict, model_cfg: dict,
plan_arch: dict, plan_arch: dict,
plan_anchors: dict, plan_anchors: dict,
...@@ -359,21 +369,32 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -359,21 +369,32 @@ class RetinaUNetModule(LightningBaseModuleSWA):
will be performed will be performed
**kwargs: **kwargs:
""" """
logger.info(f"Architecture overwrites: {model_cfg['plan_arch_overwrites']} " logger.info(
f"Anchor overwrites: {model_cfg['plan_anchors_overwrites']}") f"Architecture overwrites: {model_cfg['plan_arch_overwrites']} "
logger.info(f"Building architecture according to plan of {plan_arch.get('arch_name', 'not_found')}") f"Anchor overwrites: {model_cfg['plan_anchors_overwrites']}"
)
logger.info(
f"Building architecture according to plan of {plan_arch.get('arch_name', 'not_found')}"
)
plan_arch.update(model_cfg["plan_arch_overwrites"]) plan_arch.update(model_cfg["plan_arch_overwrites"])
plan_anchors.update(model_cfg["plan_anchors_overwrites"]) plan_anchors.update(model_cfg["plan_anchors_overwrites"])
logger.info(f"Start channels: {plan_arch['start_channels']}; " logger.info(
f"Start channels: {plan_arch['start_channels']}; "
f"head channels: {plan_arch['head_channels']}; " f"head channels: {plan_arch['head_channels']}; "
f"fpn channels: {plan_arch['fpn_channels']}") f"fpn channels: {plan_arch['fpn_channels']}"
)
_plan_anchors = copy.deepcopy(plan_anchors) _plan_anchors = copy.deepcopy(plan_anchors)
coder = BoxCoderND(weights=(1.,) * (plan_arch["dim"] * 2)) coder = BoxCoderND(weights=(1.0,) * (plan_arch["dim"] * 2))
s_param = False if ("aspect_ratios" in _plan_anchors) and \ s_param = (
(_plan_anchors["aspect_ratios"] is not None) else True False
anchor_generator = get_anchor_generator( if ("aspect_ratios" in _plan_anchors)
plan_arch["dim"], s_param=s_param)(**_plan_anchors) and (_plan_anchors["aspect_ratios"] is not None)
else True
)
anchor_generator = get_anchor_generator(plan_arch["dim"], s_param=s_param)(
**_plan_anchors
)
encoder = cls._build_encoder( encoder = cls._build_encoder(
plan_arch=plan_arch, plan_arch=plan_arch,
...@@ -404,7 +425,7 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -404,7 +425,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
model_cfg=model_cfg, model_cfg=model_cfg,
classifier=classifier, classifier=classifier,
regressor=regressor, regressor=regressor,
coder=coder coder=coder,
) )
segmenter = cls._build_segmenter( segmenter = cls._build_segmenter(
plan_arch=plan_arch, plan_arch=plan_arch,
...@@ -418,13 +439,13 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -418,13 +439,13 @@ class RetinaUNetModule(LightningBaseModuleSWA):
remove_small_boxes = plan_arch.get("remove_small_boxes", 0.01) remove_small_boxes = plan_arch.get("remove_small_boxes", 0.01)
nms_thresh = plan_arch.get("nms_thresh", 0.6) nms_thresh = plan_arch.get("nms_thresh", 0.6)
logger.info(f"Model Inference Summary: \n" # logger.info(f"Model Inference Summary: \n"
f"detections_per_img: {detections_per_img} \n" # f"detections_per_img: {detections_per_img} \n"
f"score_thresh: {score_thresh} \n" # f"score_thresh: {score_thresh} \n"
f"topk_candidates: {topk_candidates} \n" # f"topk_candidates: {topk_candidates} \n"
f"remove_small_boxes: {remove_small_boxes} \n" # f"remove_small_boxes: {remove_small_boxes} \n"
f"nms_thresh: {nms_thresh}", # f"nms_thresh: {nms_thresh}",
) # )
return BaseRetinaNet( return BaseRetinaNet(
dim=plan_arch["dim"], dim=plan_arch["dim"],
...@@ -461,7 +482,9 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -461,7 +482,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
EncoderType: encoder instance EncoderType: encoder instance
""" """
conv = Generator(cls.base_conv_cls, plan_arch["dim"]) conv = Generator(cls.base_conv_cls, plan_arch["dim"])
logger.info(f"Building:: encoder {cls.encoder_cls.__name__}: {model_cfg['encoder_kwargs']} ") logger.info(
f"Building:: encoder {cls.encoder_cls.__name__}: {model_cfg['encoder_kwargs']} "
)
encoder = cls.encoder_cls( encoder = cls.encoder_cls(
conv=conv, conv=conv,
conv_kernels=plan_arch["conv_kernels"], conv_kernels=plan_arch["conv_kernels"],
...@@ -471,7 +494,7 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -471,7 +494,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
start_channels=plan_arch["start_channels"], start_channels=plan_arch["start_channels"],
stage_kwargs=None, stage_kwargs=None,
max_channels=plan_arch.get("max_channels", 320), max_channels=plan_arch.get("max_channels", 320),
**model_cfg['encoder_kwargs'], **model_cfg["encoder_kwargs"],
) )
return encoder return encoder
...@@ -493,7 +516,9 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -493,7 +516,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
DecoderType: decoder instance DecoderType: decoder instance
""" """
conv = Generator(cls.base_conv_cls, plan_arch["dim"]) conv = Generator(cls.base_conv_cls, plan_arch["dim"])
logger.info(f"Building:: decoder {cls.decoder_cls.__name__}: {model_cfg['decoder_kwargs']}") logger.info(
f"Building:: decoder {cls.decoder_cls.__name__}: {model_cfg['decoder_kwargs']}"
)
decoder = cls.decoder_cls( decoder = cls.decoder_cls(
conv=conv, conv=conv,
conv_kernels=plan_arch["conv_kernels"], conv_kernels=plan_arch["conv_kernels"],
...@@ -501,7 +526,7 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -501,7 +526,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
in_channels=encoder.get_channels(), in_channels=encoder.get_channels(),
decoder_levels=plan_arch["decoder_levels"], decoder_levels=plan_arch["decoder_levels"],
fixed_out_channels=plan_arch["fpn_channels"], fixed_out_channels=plan_arch["fpn_channels"],
**model_cfg['decoder_kwargs'], **model_cfg["decoder_kwargs"],
) )
return decoder return decoder
...@@ -525,7 +550,7 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -525,7 +550,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
""" """
conv = Generator(cls.head_conv_cls, plan_arch["dim"]) conv = Generator(cls.head_conv_cls, plan_arch["dim"])
name = cls.head_classifier_cls.__name__ name = cls.head_classifier_cls.__name__
kwargs = model_cfg['head_classifier_kwargs'] kwargs = model_cfg["head_classifier_kwargs"]
logger.info(f"Building:: classifier {name}: {kwargs}") logger.info(f"Building:: classifier {name}: {kwargs}")
classifier = cls.head_classifier_cls( classifier = cls.head_classifier_cls(
...@@ -559,7 +584,7 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -559,7 +584,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
""" """
conv = Generator(cls.head_conv_cls, plan_arch["dim"]) conv = Generator(cls.head_conv_cls, plan_arch["dim"])
name = cls.head_regressor_cls.__name__ name = cls.head_regressor_cls.__name__
kwargs = model_cfg['head_regressor_kwargs'] kwargs = model_cfg["head_regressor_kwargs"]
logger.info(f"Building:: regressor {name}: {kwargs}") logger.info(f"Building:: regressor {name}: {kwargs}")
regressor = cls.head_regressor_cls( regressor = cls.head_regressor_cls(
...@@ -595,12 +620,14 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -595,12 +620,14 @@ class RetinaUNetModule(LightningBaseModuleSWA):
HeadType: instantiated head HeadType: instantiated head
""" """
head_name = cls.head_cls.__name__ head_name = cls.head_cls.__name__
head_kwargs = model_cfg['head_kwargs'] head_kwargs = model_cfg["head_kwargs"]
sampler_name = cls.head_sampler_cls.__name__ sampler_name = cls.head_sampler_cls.__name__
sampler_kwargs = model_cfg['head_sampler_kwargs'] sampler_kwargs = model_cfg["head_sampler_kwargs"]
logger.info(f"Building:: head {head_name}: {head_kwargs} " logger.info(
f"sampler {sampler_name}: {sampler_kwargs}") f"Building:: head {head_name}: {head_kwargs} "
f"sampler {sampler_name}: {sampler_kwargs}"
)
sampler = cls.head_sampler_cls(**sampler_kwargs) sampler = cls.head_sampler_cls(**sampler_kwargs)
head = cls.head_cls( head = cls.head_cls(
classifier=classifier, classifier=classifier,
...@@ -632,7 +659,7 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -632,7 +659,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
""" """
if cls.segmenter_cls is not None: if cls.segmenter_cls is not None:
name = cls.segmenter_cls.__name__ name = cls.segmenter_cls.__name__
kwargs = model_cfg['segmenter_kwargs'] kwargs = model_cfg["segmenter_kwargs"]
conv = Generator(cls.base_conv_cls, plan_arch["dim"]) conv = Generator(cls.base_conv_cls, plan_arch["dim"])
logger.info(f"Building:: segmenter {name} {kwargs}") logger.info(f"Building:: segmenter {name} {kwargs}")
...@@ -661,14 +688,15 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -661,14 +688,15 @@ class RetinaUNetModule(LightningBaseModuleSWA):
3: { 3: {
"boxes": BoxEnsemblerSelective, "boxes": BoxEnsemblerSelective,
"seg": SegmentationEnsembler, "seg": SegmentationEnsembler,
} },
} }
if dim == 2: if dim == 2:
raise NotImplementedError raise NotImplementedError
return _lookup[dim][key] return _lookup[dim][key]
@classmethod @classmethod
def get_predictor(cls, def get_predictor(
cls,
plan: Dict, plan: Dict,
models: Sequence[RetinaUNetModule], models: Sequence[RetinaUNetModule],
num_tta_transforms: int = None, num_tta_transforms: int = None,
...@@ -684,14 +712,19 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -684,14 +712,19 @@ class RetinaUNetModule(LightningBaseModuleSWA):
num_tta_transforms = 8 if plan["network_dim"] == 3 else 4 num_tta_transforms = 8 if plan["network_dim"] == 3 else 4
# setup # setup
tta_transforms, tta_inverse_transforms = \ tta_transforms, tta_inverse_transforms = get_tta_transforms(
get_tta_transforms(num_tta_transforms, True) num_tta_transforms, True
logger.info(f"Using {len(tta_transforms)} tta transformations for prediction (one dummy trafo).") )
logger.info(
f"Using {len(tta_transforms)} tta transformations for prediction (one dummy trafo)."
)
ensembler = {"boxes": partial( ensembler = {
"boxes": partial(
cls.get_ensembler_cls(key="boxes", dim=plan["network_dim"]).from_case, cls.get_ensembler_cls(key="boxes", dim=plan["network_dim"]).from_case,
parameters=inferene_plan, parameters=inferene_plan,
)} )
}
if do_seg: if do_seg:
ensembler["seg"] = partial( ensembler["seg"] = partial(
cls.get_ensembler_cls(key="seg", dim=plan["network_dim"]).from_case, cls.get_ensembler_cls(key="seg", dim=plan["network_dim"]).from_case,
...@@ -711,7 +744,8 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -711,7 +744,8 @@ class RetinaUNetModule(LightningBaseModuleSWA):
predictor.pre_transform = Inference2D(["data"]) predictor.pre_transform = Inference2D(["data"])
return predictor return predictor
def sweep(self, def sweep(
self,
cfg: dict, cfg: dict,
save_dir: os.PathLike, save_dir: os.PathLike,
train_data_dir: os.PathLike, train_data_dir: os.PathLike,
...@@ -767,7 +801,9 @@ class RetinaUNetModule(LightningBaseModuleSWA): ...@@ -767,7 +801,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
) )
logger.info("Start parameter sweep...") logger.info("Start parameter sweep...")
ensembler_cls = self.get_ensembler_cls(key="boxes", dim=self.plan["network_dim"]) ensembler_cls = self.get_ensembler_cls(
key="boxes", dim=self.plan["network_dim"]
)
sweeper = BoxSweeper( sweeper = BoxSweeper(
classes=[item for _, item in cfg["data"]["labels"].items()], classes=[item for _, item in cfg["data"]["labels"].items()],
pred_dir=prediction_dir, pred_dir=prediction_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment