Commit 06eec91c authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 406973172
parent c4ebfef2
......@@ -17,10 +17,12 @@
The trainer derives from the Orbit `StandardTrainer` class.
"""
from typing import Union
import gin
import orbit
import tensorflow as tf
from official.modeling import optimization
from official.modeling.multitask import base_model
from official.modeling.multitask import multitask
......@@ -45,6 +47,11 @@ class MultiTaskBaseTrainer(orbit.StandardTrainer):
self._training_metrics = None
self._global_step = orbit.utils.create_global_step()
# Creates a shadow copy of the weights to store weights moving average.
if isinstance(self._optimizer, optimization.ExponentialMovingAverage
) and not self._optimizer.has_shadow_copy:
self._optimizer.shadow_copy(multi_task_model)
if hasattr(self.multi_task_model, "checkpoint_items"):
checkpoint_items = self.multi_task_model.checkpoint_items
else:
......
......@@ -70,6 +70,7 @@ class ImageClassificationModel(hyperparams.Config):
@dataclasses.dataclass
class Losses(hyperparams.Config):
loss_weight: float = 1.0
one_hot: bool = True
label_smoothing: float = 0.0
l2_weight_decay: float = 0.0
......
......@@ -185,6 +185,7 @@ class MaskRCNN(hyperparams.Config):
@dataclasses.dataclass
class Losses(hyperparams.Config):
loss_weight: float = 1.0
rpn_huber_loss_delta: float = 1. / 9.
frcnn_huber_loss_delta: float = 1.
l2_weight_decay: float = 0.0
......
......@@ -83,6 +83,7 @@ class Anchor(hyperparams.Config):
@dataclasses.dataclass
class Losses(hyperparams.Config):
loss_weight: float = 1.0
focal_loss_alpha: float = 0.25
focal_loss_gamma: float = 1.5
huber_loss_delta: float = 0.1
......
......@@ -92,6 +92,7 @@ class SemanticSegmentationModel(hyperparams.Config):
@dataclasses.dataclass
class Losses(hyperparams.Config):
loss_weight: float = 1.0
label_smoothing: float = 0.0
ignore_label: int = 255
class_weights: List[float] = dataclasses.field(default_factory=list)
......
......@@ -14,7 +14,7 @@
"""Factory methods to build models."""
# Import libraries
from typing import Optional
import tensorflow as tf
......@@ -41,10 +41,12 @@ from official.vision.beta.modeling.layers import roi_sampler
def build_classification_model(
input_specs: tf.keras.layers.InputSpec,
model_config: classification_cfg.ImageClassificationModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None,
skip_logits_layer: bool = False) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
skip_logits_layer: bool = False,
backbone: Optional[tf.keras.Model] = None) -> tf.keras.Model:
"""Builds the classification model."""
norm_activation_config = model_config.norm_activation
if not backbone:
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=model_config.backbone,
......@@ -66,12 +68,15 @@ def build_classification_model(
return model
def build_maskrcnn(
input_specs: tf.keras.layers.InputSpec,
def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
model_config: maskrcnn_cfg.MaskRCNN,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
l2_regularizer: Optional[
tf.keras.regularizers.Regularizer] = None,
backbone: Optional[tf.keras.Model] = None,
decoder: Optional[tf.keras.Model] = None) -> tf.keras.Model:
"""Builds Mask R-CNN model."""
norm_activation_config = model_config.norm_activation
if not backbone:
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=model_config.backbone,
......@@ -79,6 +84,7 @@ def build_maskrcnn(
l2_regularizer=l2_regularizer)
backbone_features = backbone(tf.keras.Input(input_specs.shape[1:]))
if not decoder:
decoder = decoders.factory.build_decoder(
input_specs=backbone.output_specs,
model_config=model_config,
......@@ -121,7 +127,6 @@ def build_maskrcnn(
kernel_regularizer=l2_regularizer,
name='detection_head')
# Builds decoder and region proposal network:
if decoder:
decoder_features = decoder(backbone_features)
rpn_head(decoder_features)
......@@ -253,9 +258,13 @@ def build_maskrcnn(
def build_retinanet(
input_specs: tf.keras.layers.InputSpec,
model_config: retinanet_cfg.RetinaNet,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
backbone: Optional[tf.keras.Model] = None,
decoder: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
"""Builds RetinaNet model."""
norm_activation_config = model_config.norm_activation
if not backbone:
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=model_config.backbone,
......@@ -263,6 +272,7 @@ def build_retinanet(
l2_regularizer=l2_regularizer)
backbone_features = backbone(tf.keras.Input(input_specs.shape[1:]))
if not decoder:
decoder = decoders.factory.build_decoder(
input_specs=backbone.output_specs,
model_config=model_config,
......@@ -321,15 +331,20 @@ def build_retinanet(
def build_segmentation_model(
input_specs: tf.keras.layers.InputSpec,
model_config: segmentation_cfg.SemanticSegmentationModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
backbone: Optional[tf.keras.regularizers.Regularizer] = None,
decoder: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
"""Builds Segmentation model."""
norm_activation_config = model_config.norm_activation
if not backbone:
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer)
if not decoder:
decoder = decoders.factory.build_decoder(
input_specs=backbone.output_specs,
model_config=model_config,
......
......@@ -49,6 +49,7 @@ class ImageClassificationModel(hyperparams.Config):
@dataclasses.dataclass
class Losses(hyperparams.Config):
loss_weight: float = 1.0
one_hot: bool = True
label_smoothing: float = 0.0
l2_weight_decay: float = 0.0
......
......@@ -169,6 +169,7 @@ class ImageClassificationTask(base_task.Task):
if aux_losses:
total_loss += tf.add_n(aux_losses)
total_loss = losses_config.loss_weight * total_loss
return total_loss
def build_metrics(self,
......
......@@ -236,6 +236,7 @@ class MaskRCNNTask(base_task.Task):
reg_loss = tf.reduce_sum(aux_losses)
total_loss = model_loss + reg_loss
total_loss = params.losses.loss_weight * total_loss
losses = {
'total_loss': total_loss,
'rpn_score_loss': rpn_score_loss,
......
......@@ -220,6 +220,8 @@ class RetinaNetTask(base_task.Task):
reg_loss = tf.reduce_sum(aux_losses)
total_loss = model_loss + reg_loss
total_loss = params.losses.loss_weight * total_loss
return total_loss, cls_loss, box_loss, model_loss
def build_metrics(self, training: bool = True):
......
......@@ -140,6 +140,8 @@ class SemanticSegmentationTask(base_task.Task):
if aux_losses:
total_loss += tf.add_n(aux_losses)
total_loss = loss_params.loss_weight * total_loss
return total_loss
def build_metrics(self, training: bool = True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment