Commit 3446e932 authored by Andres Martinez Mora's avatar Andres Martinez Mora
Browse files

Remove model inference summary printout

parent e708c342
......@@ -68,7 +68,7 @@ from nndet.io.transforms import (
Instances2Boxes,
Instances2Segmentation,
FindInstances,
)
)
class RetinaUNetModule(LightningBaseModuleSWA):
......@@ -84,12 +84,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
head_sampler_cls = HardNegativeSamplerBatched
segmenter_cls = DiCESegmenter
def __init__(self,
model_cfg: dict,
trainer_cfg: dict,
plan: dict,
**kwargs
):
def __init__(self, model_cfg: dict, trainer_cfg: dict, plan: dict, **kwargs):
"""
RetinaUNet Lightning Module Skeleton
......@@ -106,7 +101,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
plan=plan,
)
_classes = [f"class{c}" for c in range(plan["architecture"]["classifier_classes"])]
_classes = [
f"class{c}" for c in range(plan["architecture"]["classifier_classes"])
]
self.box_evaluator = BoxEvaluator.create(
classes=_classes,
fast=True,
......@@ -130,7 +127,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
instance_key="target",
map_key="instance_mapping",
present_instances="present_instances",
)
),
)
self.eval_score_key = "mAP_IoU_0.10_0.50_0.05_MaxDet_100"
......@@ -148,7 +145,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
targets={
"target_boxes": batch["boxes"],
"target_classes": batch["classes"],
"target_seg": batch['target'][:, 0] # Remove channel dimension
"target_seg": batch["target"][:, 0], # Remove channel dimension
},
evaluation=False,
batch_num=batch_idx,
......@@ -167,7 +164,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
targets = {
"target_boxes": batch["boxes"],
"target_classes": batch["classes"],
"target_seg": batch['target'][:, 0] # Remove channel dimension
"target_seg": batch["target"][:, 0], # Remove channel dimension
}
losses, prediction = self.model.train_step(
images=batch["data"],
......@@ -178,8 +175,10 @@ class RetinaUNetModule(LightningBaseModuleSWA):
loss = sum(losses.values())
self.evaluation_step(prediction=prediction, targets=targets)
return {"loss": loss.detach().item(),
**{key: l.detach().item() for key, l in losses.items()}}
return {
"loss": loss.detach().item(),
**{key: l.detach().item() for key, l in losses.items()},
}
def evaluation_step(
self,
......@@ -281,9 +280,11 @@ class RetinaUNetModule(LightningBaseModuleSWA):
metric_scores, _ = self.box_evaluator.finish_online_evaluation()
self.box_evaluator.reset()
logger.info(f"mAP@0.1:0.5:0.05: {metric_scores['mAP_IoU_0.10_0.50_0.05_MaxDet_100']:0.3f} "
logger.info(
f"mAP@0.1:0.5:0.05: {metric_scores['mAP_IoU_0.10_0.50_0.05_MaxDet_100']:0.3f} "
f"AP@0.1: {metric_scores['AP_IoU_0.10_MaxDet_100']:0.3f} "
f"AP@0.5: {metric_scores['AP_IoU_0.50_MaxDet_100']:0.3f}")
f"AP@0.5: {metric_scores['AP_IoU_0.50_MaxDet_100']:0.3f}"
)
seg_scores, _ = self.seg_evaluator.finish_online_evaluation()
self.seg_evaluator.reset()
......@@ -292,7 +293,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
logger.info(f"Proxy FG Dice: {seg_scores['seg_dice']:0.3f}")
for key, item in metric_scores.items():
self.log(f'{key}', item, on_step=None, on_epoch=True, prog_bar=False, logger=True)
self.log(
f"{key}", item, on_step=None, on_epoch=True, prog_bar=False, logger=True
)
def configure_optimizers(self):
"""
......@@ -301,11 +304,15 @@ class RetinaUNetModule(LightningBaseModuleSWA):
schedule
"""
# configure optimizer
logger.info(f"Running: initial_lr {self.trainer_cfg['initial_lr']} "
logger.info(
f"Running: initial_lr {self.trainer_cfg['initial_lr']} "
f"weight_decay {self.trainer_cfg['weight_decay']} "
f"SGD with momentum {self.trainer_cfg['sgd_momentum']} and "
f"nesterov {self.trainer_cfg['sgd_nesterov']}")
wd_groups = get_params_no_wd_on_norm(self, weight_decay=self.trainer_cfg['weight_decay'])
f"nesterov {self.trainer_cfg['sgd_nesterov']}"
)
wd_groups = get_params_no_wd_on_norm(
self, weight_decay=self.trainer_cfg["weight_decay"]
)
optimizer = torch.optim.SGD(
wd_groups,
self.trainer_cfg["initial_lr"],
......@@ -315,19 +322,22 @@ class RetinaUNetModule(LightningBaseModuleSWA):
)
# configure lr scheduler
num_iterations = self.trainer_cfg["max_num_epochs"] * \
self.trainer_cfg["num_train_batches_per_epoch"]
num_iterations = (
self.trainer_cfg["max_num_epochs"]
* self.trainer_cfg["num_train_batches_per_epoch"]
)
scheduler = LinearWarmupPolyLR(
optimizer=optimizer,
warm_iterations=self.trainer_cfg["warm_iterations"],
warm_lr=self.trainer_cfg["warm_lr"],
poly_gamma=self.trainer_cfg["poly_gamma"],
num_iterations=num_iterations
num_iterations=num_iterations,
)
return [optimizer], {'scheduler': scheduler, 'interval': 'step'}
return [optimizer], {"scheduler": scheduler, "interval": "step"}
@classmethod
def from_config_plan(cls,
def from_config_plan(
cls,
model_cfg: dict,
plan_arch: dict,
plan_anchors: dict,
......@@ -359,21 +369,32 @@ class RetinaUNetModule(LightningBaseModuleSWA):
will be performed
**kwargs:
"""
logger.info(f"Architecture overwrites: {model_cfg['plan_arch_overwrites']} "
f"Anchor overwrites: {model_cfg['plan_anchors_overwrites']}")
logger.info(f"Building architecture according to plan of {plan_arch.get('arch_name', 'not_found')}")
logger.info(
f"Architecture overwrites: {model_cfg['plan_arch_overwrites']} "
f"Anchor overwrites: {model_cfg['plan_anchors_overwrites']}"
)
logger.info(
f"Building architecture according to plan of {plan_arch.get('arch_name', 'not_found')}"
)
plan_arch.update(model_cfg["plan_arch_overwrites"])
plan_anchors.update(model_cfg["plan_anchors_overwrites"])
logger.info(f"Start channels: {plan_arch['start_channels']}; "
logger.info(
f"Start channels: {plan_arch['start_channels']}; "
f"head channels: {plan_arch['head_channels']}; "
f"fpn channels: {plan_arch['fpn_channels']}")
f"fpn channels: {plan_arch['fpn_channels']}"
)
_plan_anchors = copy.deepcopy(plan_anchors)
coder = BoxCoderND(weights=(1.,) * (plan_arch["dim"] * 2))
s_param = False if ("aspect_ratios" in _plan_anchors) and \
(_plan_anchors["aspect_ratios"] is not None) else True
anchor_generator = get_anchor_generator(
plan_arch["dim"], s_param=s_param)(**_plan_anchors)
coder = BoxCoderND(weights=(1.0,) * (plan_arch["dim"] * 2))
s_param = (
False
if ("aspect_ratios" in _plan_anchors)
and (_plan_anchors["aspect_ratios"] is not None)
else True
)
anchor_generator = get_anchor_generator(plan_arch["dim"], s_param=s_param)(
**_plan_anchors
)
encoder = cls._build_encoder(
plan_arch=plan_arch,
......@@ -404,7 +425,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
model_cfg=model_cfg,
classifier=classifier,
regressor=regressor,
coder=coder
coder=coder,
)
segmenter = cls._build_segmenter(
plan_arch=plan_arch,
......@@ -418,13 +439,13 @@ class RetinaUNetModule(LightningBaseModuleSWA):
remove_small_boxes = plan_arch.get("remove_small_boxes", 0.01)
nms_thresh = plan_arch.get("nms_thresh", 0.6)
logger.info(f"Model Inference Summary: \n"
f"detections_per_img: {detections_per_img} \n"
f"score_thresh: {score_thresh} \n"
f"topk_candidates: {topk_candidates} \n"
f"remove_small_boxes: {remove_small_boxes} \n"
f"nms_thresh: {nms_thresh}",
)
# logger.info(f"Model Inference Summary: \n"
# f"detections_per_img: {detections_per_img} \n"
# f"score_thresh: {score_thresh} \n"
# f"topk_candidates: {topk_candidates} \n"
# f"remove_small_boxes: {remove_small_boxes} \n"
# f"nms_thresh: {nms_thresh}",
# )
return BaseRetinaNet(
dim=plan_arch["dim"],
......@@ -461,7 +482,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
EncoderType: encoder instance
"""
conv = Generator(cls.base_conv_cls, plan_arch["dim"])
logger.info(f"Building:: encoder {cls.encoder_cls.__name__}: {model_cfg['encoder_kwargs']} ")
logger.info(
f"Building:: encoder {cls.encoder_cls.__name__}: {model_cfg['encoder_kwargs']} "
)
encoder = cls.encoder_cls(
conv=conv,
conv_kernels=plan_arch["conv_kernels"],
......@@ -471,7 +494,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
start_channels=plan_arch["start_channels"],
stage_kwargs=None,
max_channels=plan_arch.get("max_channels", 320),
**model_cfg['encoder_kwargs'],
**model_cfg["encoder_kwargs"],
)
return encoder
......@@ -493,7 +516,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
DecoderType: decoder instance
"""
conv = Generator(cls.base_conv_cls, plan_arch["dim"])
logger.info(f"Building:: decoder {cls.decoder_cls.__name__}: {model_cfg['decoder_kwargs']}")
logger.info(
f"Building:: decoder {cls.decoder_cls.__name__}: {model_cfg['decoder_kwargs']}"
)
decoder = cls.decoder_cls(
conv=conv,
conv_kernels=plan_arch["conv_kernels"],
......@@ -501,7 +526,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
in_channels=encoder.get_channels(),
decoder_levels=plan_arch["decoder_levels"],
fixed_out_channels=plan_arch["fpn_channels"],
**model_cfg['decoder_kwargs'],
**model_cfg["decoder_kwargs"],
)
return decoder
......@@ -525,7 +550,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
"""
conv = Generator(cls.head_conv_cls, plan_arch["dim"])
name = cls.head_classifier_cls.__name__
kwargs = model_cfg['head_classifier_kwargs']
kwargs = model_cfg["head_classifier_kwargs"]
logger.info(f"Building:: classifier {name}: {kwargs}")
classifier = cls.head_classifier_cls(
......@@ -559,7 +584,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
"""
conv = Generator(cls.head_conv_cls, plan_arch["dim"])
name = cls.head_regressor_cls.__name__
kwargs = model_cfg['head_regressor_kwargs']
kwargs = model_cfg["head_regressor_kwargs"]
logger.info(f"Building:: regressor {name}: {kwargs}")
regressor = cls.head_regressor_cls(
......@@ -595,12 +620,14 @@ class RetinaUNetModule(LightningBaseModuleSWA):
HeadType: instantiated head
"""
head_name = cls.head_cls.__name__
head_kwargs = model_cfg['head_kwargs']
head_kwargs = model_cfg["head_kwargs"]
sampler_name = cls.head_sampler_cls.__name__
sampler_kwargs = model_cfg['head_sampler_kwargs']
sampler_kwargs = model_cfg["head_sampler_kwargs"]
logger.info(f"Building:: head {head_name}: {head_kwargs} "
f"sampler {sampler_name}: {sampler_kwargs}")
logger.info(
f"Building:: head {head_name}: {head_kwargs} "
f"sampler {sampler_name}: {sampler_kwargs}"
)
sampler = cls.head_sampler_cls(**sampler_kwargs)
head = cls.head_cls(
classifier=classifier,
......@@ -632,7 +659,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
"""
if cls.segmenter_cls is not None:
name = cls.segmenter_cls.__name__
kwargs = model_cfg['segmenter_kwargs']
kwargs = model_cfg["segmenter_kwargs"]
conv = Generator(cls.base_conv_cls, plan_arch["dim"])
logger.info(f"Building:: segmenter {name} {kwargs}")
......@@ -661,14 +688,15 @@ class RetinaUNetModule(LightningBaseModuleSWA):
3: {
"boxes": BoxEnsemblerSelective,
"seg": SegmentationEnsembler,
}
},
}
if dim == 2:
raise NotImplementedError
return _lookup[dim][key]
@classmethod
def get_predictor(cls,
def get_predictor(
cls,
plan: Dict,
models: Sequence[RetinaUNetModule],
num_tta_transforms: int = None,
......@@ -684,14 +712,19 @@ class RetinaUNetModule(LightningBaseModuleSWA):
num_tta_transforms = 8 if plan["network_dim"] == 3 else 4
# setup
tta_transforms, tta_inverse_transforms = \
get_tta_transforms(num_tta_transforms, True)
logger.info(f"Using {len(tta_transforms)} tta transformations for prediction (one dummy trafo).")
tta_transforms, tta_inverse_transforms = get_tta_transforms(
num_tta_transforms, True
)
logger.info(
f"Using {len(tta_transforms)} tta transformations for prediction (one dummy trafo)."
)
ensembler = {"boxes": partial(
ensembler = {
"boxes": partial(
cls.get_ensembler_cls(key="boxes", dim=plan["network_dim"]).from_case,
parameters=inferene_plan,
)}
)
}
if do_seg:
ensembler["seg"] = partial(
cls.get_ensembler_cls(key="seg", dim=plan["network_dim"]).from_case,
......@@ -711,7 +744,8 @@ class RetinaUNetModule(LightningBaseModuleSWA):
predictor.pre_transform = Inference2D(["data"])
return predictor
def sweep(self,
def sweep(
self,
cfg: dict,
save_dir: os.PathLike,
train_data_dir: os.PathLike,
......@@ -767,7 +801,9 @@ class RetinaUNetModule(LightningBaseModuleSWA):
)
logger.info("Start parameter sweep...")
ensembler_cls = self.get_ensembler_cls(key="boxes", dim=self.plan["network_dim"])
ensembler_cls = self.get_ensembler_cls(
key="boxes", dim=self.plan["network_dim"]
)
sweeper = BoxSweeper(
classes=[item for _, item in cfg["data"]["labels"].items()],
pred_dir=prediction_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment