# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Centernet task definition.""" from typing import Any, List, Optional, Tuple from absl import logging import tensorflow as tf from official.core import base_task from official.core import input_reader from official.core import task_factory from official.projects.centernet.configs import centernet as exp_cfg from official.projects.centernet.dataloaders import centernet_input from official.projects.centernet.losses import centernet_losses from official.projects.centernet.modeling import centernet_model from official.projects.centernet.modeling.heads import centernet_head from official.projects.centernet.modeling.layers import detection_generator from official.projects.centernet.ops import loss_ops from official.projects.centernet.ops import target_assigner from official.vision.dataloaders import tf_example_decoder from official.vision.dataloaders import tfds_factory from official.vision.dataloaders import tf_example_label_map_decoder from official.vision.evaluation import coco_evaluator from official.vision.modeling.backbones import factory @task_factory.register_task_cls(exp_cfg.CenterNetTask) class CenterNetTask(base_task.Task): """Task definition for centernet.""" def build_inputs(self, params: exp_cfg.DataConfig, input_context: Optional[tf.distribute.InputContext] = None): """Build input dataset.""" if params.tfds_name: decoder = tfds_factory.get_detection_decoder(params.tfds_name) else: decoder_cfg = params.decoder.get() if params.decoder.type == 'simple_decoder': decoder = tf_example_decoder.TfExampleDecoder( regenerate_source_id=decoder_cfg.regenerate_source_id) elif params.decoder.type == 'label_map_decoder': decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap( label_map=decoder_cfg.label_map, regenerate_source_id=decoder_cfg.regenerate_source_id) else: raise ValueError('Unknown decoder type: {}!'.format( params.decoder.type)) parser = centernet_input.CenterNetParser( output_height=self.task_config.model.input_size[0], output_width=self.task_config.model.input_size[1], max_num_instances=self.task_config.model.max_num_instances, bgr_ordering=params.parser.bgr_ordering, channel_means=params.parser.channel_means, channel_stds=params.parser.channel_stds, aug_rand_hflip=params.parser.aug_rand_hflip, aug_scale_min=params.parser.aug_scale_min, aug_scale_max=params.parser.aug_scale_max, aug_rand_hue=params.parser.aug_rand_hue, aug_rand_brightness=params.parser.aug_rand_brightness, aug_rand_contrast=params.parser.aug_rand_contrast, aug_rand_saturation=params.parser.aug_rand_saturation, odapi_augmentation=params.parser.odapi_augmentation, dtype=params.dtype) reader = input_reader.InputReader( params, dataset_fn=tf.data.TFRecordDataset, decoder_fn=decoder.decode, parser_fn=parser.parse_fn(params.is_training)) dataset = reader.read(input_context=input_context) return dataset def build_model(self): """get an instance of CenterNet.""" model_config = self.task_config.model input_specs = tf.keras.layers.InputSpec( shape=[None] + model_config.input_size) l2_weight_decay = self.task_config.weight_decay # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss. # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2) # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss) l2_regularizer = (tf.keras.regularizers.l2( l2_weight_decay / 2.0) if l2_weight_decay else None) backbone = factory.build_backbone( input_specs=input_specs, backbone_config=model_config.backbone, norm_activation_config=model_config.norm_activation, l2_regularizer=l2_regularizer) task_outputs = self.task_config.get_output_length_dict() head_config = model_config.head head = centernet_head.CenterNetHead( input_specs=backbone.output_specs, task_outputs=task_outputs, input_levels=head_config.input_levels, heatmap_bias=head_config.heatmap_bias) # output_specs is a dict backbone_output_spec = backbone.output_specs[head_config.input_levels[-1]] if len(backbone_output_spec) == 4: bb_output_height = backbone_output_spec[1] elif len(backbone_output_spec) == 3: bb_output_height = backbone_output_spec[0] else: raise ValueError self._net_down_scale = int(model_config.input_size[0] / bb_output_height) dg_config = model_config.detection_generator detect_generator_obj = detection_generator.CenterNetDetectionGenerator( max_detections=dg_config.max_detections, peak_error=dg_config.peak_error, peak_extract_kernel_size=dg_config.peak_extract_kernel_size, class_offset=dg_config.class_offset, net_down_scale=self._net_down_scale, input_image_dims=model_config.input_size[0], use_nms=dg_config.use_nms, nms_pre_thresh=dg_config.nms_pre_thresh, nms_thresh=dg_config.nms_thresh) model = centernet_model.CenterNetModel( backbone=backbone, head=head, detection_generator=detect_generator_obj) return model def initialize(self, model: tf.keras.Model): """Loading pretrained checkpoint.""" if not self.task_config.init_checkpoint: return ckpt_dir_or_file = self.task_config.init_checkpoint # Restoring checkpoint. if tf.io.gfile.isdir(ckpt_dir_or_file): ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file) if self.task_config.init_checkpoint_modules == 'all': ckpt = tf.train.Checkpoint(**model.checkpoint_items) status = ckpt.restore(ckpt_dir_or_file) status.assert_consumed() elif self.task_config.init_checkpoint_modules == 'backbone': ckpt = tf.train.Checkpoint(backbone=model.backbone) status = ckpt.restore(ckpt_dir_or_file) status.expect_partial().assert_existing_objects_matched() else: raise ValueError( "Only 'all' or 'backbone' can be used to initialize the model.") logging.info('Finished loading pretrained checkpoint from %s', ckpt_dir_or_file) def build_losses(self, outputs, labels, aux_losses=None): """Build losses.""" input_size = self.task_config.model.input_size[0:2] output_size = outputs['ct_heatmaps'][0].get_shape().as_list()[1:3] gt_label = tf.map_fn( # pylint: disable=g-long-lambda fn=lambda x: target_assigner.assign_centernet_targets( labels=x, input_size=input_size, output_size=output_size, num_classes=self.task_config.model.num_classes, max_num_instances=self.task_config.model.max_num_instances, gaussian_iou=self.task_config.losses.gaussian_iou, class_offset=self.task_config.losses.class_offset), elems=labels, fn_output_signature={ 'ct_heatmaps': tf.TensorSpec( shape=[output_size[0], output_size[1], self.task_config.model.num_classes], dtype=tf.float32), 'ct_offset': tf.TensorSpec( shape=[self.task_config.model.max_num_instances, 2], dtype=tf.float32), 'size': tf.TensorSpec( shape=[self.task_config.model.max_num_instances, 2], dtype=tf.float32), 'box_mask': tf.TensorSpec( shape=[self.task_config.model.max_num_instances], dtype=tf.int32), 'box_indices': tf.TensorSpec( shape=[self.task_config.model.max_num_instances, 2], dtype=tf.int32), } ) losses = {} # Create loss functions object_center_loss_fn = centernet_losses.PenaltyReducedLogisticFocalLoss() localization_loss_fn = centernet_losses.L1LocalizationLoss() # Set up box indices so that they have a batch element as well box_indices = loss_ops.add_batch_to_indices(gt_label['box_indices']) box_mask = tf.cast(gt_label['box_mask'], dtype=tf.float32) num_boxes = tf.cast( loss_ops.get_num_instances_from_weights(gt_label['box_mask']), dtype=tf.float32) # Calculate center heatmap loss output_unpad_image_shapes = tf.math.ceil( tf.cast(labels['unpad_image_shapes'], tf.float32) / self._net_down_scale) valid_anchor_weights = loss_ops.get_valid_anchor_weights_in_flattened_image( output_unpad_image_shapes, output_size[0], output_size[1]) valid_anchor_weights = tf.expand_dims(valid_anchor_weights, 2) pred_ct_heatmap_list = outputs['ct_heatmaps'] true_flattened_ct_heatmap = loss_ops.flatten_spatial_dimensions( gt_label['ct_heatmaps']) true_flattened_ct_heatmap = tf.cast(true_flattened_ct_heatmap, tf.float32) total_center_loss = 0.0 for ct_heatmap in pred_ct_heatmap_list: pred_flattened_ct_heatmap = loss_ops.flatten_spatial_dimensions( ct_heatmap) pred_flattened_ct_heatmap = tf.cast(pred_flattened_ct_heatmap, tf.float32) total_center_loss += object_center_loss_fn( target_tensor=true_flattened_ct_heatmap, prediction_tensor=pred_flattened_ct_heatmap, weights=valid_anchor_weights) center_loss = tf.reduce_sum(total_center_loss) / float( len(pred_ct_heatmap_list) * num_boxes) losses['ct_loss'] = center_loss # Calculate scale loss pred_scale_list = outputs['ct_size'] true_scale = tf.cast(gt_label['size'], tf.float32) total_scale_loss = 0.0 for scale_map in pred_scale_list: pred_scale = loss_ops.get_batch_predictions_from_indices(scale_map, box_indices) pred_scale = tf.cast(pred_scale, tf.float32) # Only apply loss for boxes that appear in the ground truth total_scale_loss += tf.reduce_sum( localization_loss_fn(target_tensor=true_scale, prediction_tensor=pred_scale), axis=-1) * box_mask scale_loss = tf.reduce_sum(total_scale_loss) / float( len(pred_scale_list) * num_boxes) losses['scale_loss'] = scale_loss # Calculate offset loss pred_offset_list = outputs['ct_offset'] true_offset = tf.cast(gt_label['ct_offset'], tf.float32) total_offset_loss = 0.0 for offset_map in pred_offset_list: pred_offset = loss_ops.get_batch_predictions_from_indices(offset_map, box_indices) pred_offset = tf.cast(pred_offset, tf.float32) # Only apply loss for boxes that appear in the ground truth total_offset_loss += tf.reduce_sum( localization_loss_fn(target_tensor=true_offset, prediction_tensor=pred_offset), axis=-1) * box_mask offset_loss = tf.reduce_sum(total_offset_loss) / float( len(pred_offset_list) * num_boxes) losses['ct_offset_loss'] = offset_loss # Aggregate and finalize loss loss_weights = self.task_config.losses.detection total_loss = (loss_weights.object_center_weight * center_loss + loss_weights.scale_weight * scale_loss + loss_weights.offset_weight * offset_loss) if aux_losses: total_loss += tf.add_n(aux_losses) losses['total_loss'] = total_loss return losses def build_metrics(self, training=True): metrics = [] metric_names = ['total_loss', 'ct_loss', 'scale_loss', 'ct_offset_loss'] for name in metric_names: metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32)) if not training: if (self.task_config.validation_data.tfds_name and self.task_config.annotation_file): raise ValueError( "Can't evaluate using annotation file when TFDS is used.") self.coco_metric = coco_evaluator.COCOEvaluator( annotation_file=self.task_config.annotation_file, include_mask=False, per_category_metrics=self.task_config.per_category_metrics) return metrics def train_step(self, inputs: Tuple[Any, Any], model: tf.keras.Model, optimizer: tf.keras.optimizers.Optimizer, metrics: Optional[List[Any]] = None): """Does forward and backward. Args: inputs: a dictionary of input tensors. model: the model, forward pass definition. optimizer: the optimizer for this training step. metrics: a nested structure of metrics objects. Returns: A dictionary of logs. """ features, labels = inputs num_replicas = tf.distribute.get_strategy().num_replicas_in_sync with tf.GradientTape() as tape: outputs = model(features, training=True) # Casting output layer as float32 is necessary when mixed_precision is # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32. outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) losses = self.build_losses(outputs['raw_output'], labels) scaled_loss = losses['total_loss'] / num_replicas # For mixed_precision policy, when LossScaleOptimizer is used, loss is # scaled for numerical stability. if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer): scaled_loss = optimizer.get_scaled_loss(scaled_loss) # compute the gradient tvars = model.trainable_variables gradients = tape.gradient(scaled_loss, tvars) # get unscaled loss if the scaled loss was used if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer): gradients = optimizer.get_unscaled_gradients(gradients) if self.task_config.gradient_clip_norm > 0.0: gradients, _ = tf.clip_by_global_norm(gradients, self.task_config.gradient_clip_norm) optimizer.apply_gradients(list(zip(gradients, tvars))) logs = {self.loss: losses['total_loss']} if metrics: for m in metrics: m.update_state(losses[m.name]) logs.update({m.name: m.result()}) return logs def validation_step(self, inputs: Tuple[Any, Any], model: tf.keras.Model, metrics: Optional[List[Any]] = None): """Validation step. Args: inputs: a dictionary of input tensors. model: the keras.Model. metrics: a nested structure of metrics objects. Returns: A dictionary of logs. """ features, labels = inputs outputs = model(features, training=False) outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) losses = self.build_losses(outputs['raw_output'], labels) logs = {self.loss: losses['total_loss']} coco_model_outputs = { 'detection_boxes': outputs['boxes'], 'detection_scores': outputs['confidence'], 'detection_classes': outputs['classes'], 'num_detections': outputs['num_detections'], 'source_id': labels['groundtruths']['source_id'], 'image_info': labels['image_info'] } logs.update({self.coco_metric.name: (labels['groundtruths'], coco_model_outputs)}) if metrics: for m in metrics: m.update_state(losses[m.name]) logs.update({m.name: m.result()}) return logs def aggregate_logs(self, state=None, step_outputs=None): if state is None: self.coco_metric.reset_states() state = self.coco_metric self.coco_metric.update_state(step_outputs[self.coco_metric.name][0], step_outputs[self.coco_metric.name][1]) return state def reduce_aggregated_logs(self, aggregated_logs, global_step=None): return self.coco_metric.result()