Commit 06eec91c authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 406973172
parent c4ebfef2
...@@ -17,10 +17,12 @@ ...@@ -17,10 +17,12 @@
The trainer derives from the Orbit `StandardTrainer` class. The trainer derives from the Orbit `StandardTrainer` class.
""" """
from typing import Union from typing import Union
import gin import gin
import orbit import orbit
import tensorflow as tf import tensorflow as tf
from official.modeling import optimization
from official.modeling.multitask import base_model from official.modeling.multitask import base_model
from official.modeling.multitask import multitask from official.modeling.multitask import multitask
...@@ -45,6 +47,11 @@ class MultiTaskBaseTrainer(orbit.StandardTrainer): ...@@ -45,6 +47,11 @@ class MultiTaskBaseTrainer(orbit.StandardTrainer):
self._training_metrics = None self._training_metrics = None
self._global_step = orbit.utils.create_global_step() self._global_step = orbit.utils.create_global_step()
# Creates a shadow copy of the weights to store weights moving average.
if isinstance(self._optimizer, optimization.ExponentialMovingAverage
) and not self._optimizer.has_shadow_copy:
self._optimizer.shadow_copy(multi_task_model)
if hasattr(self.multi_task_model, "checkpoint_items"): if hasattr(self.multi_task_model, "checkpoint_items"):
checkpoint_items = self.multi_task_model.checkpoint_items checkpoint_items = self.multi_task_model.checkpoint_items
else: else:
......
...@@ -70,6 +70,7 @@ class ImageClassificationModel(hyperparams.Config): ...@@ -70,6 +70,7 @@ class ImageClassificationModel(hyperparams.Config):
@dataclasses.dataclass @dataclasses.dataclass
class Losses(hyperparams.Config): class Losses(hyperparams.Config):
loss_weight: float = 1.0
one_hot: bool = True one_hot: bool = True
label_smoothing: float = 0.0 label_smoothing: float = 0.0
l2_weight_decay: float = 0.0 l2_weight_decay: float = 0.0
......
...@@ -185,6 +185,7 @@ class MaskRCNN(hyperparams.Config): ...@@ -185,6 +185,7 @@ class MaskRCNN(hyperparams.Config):
@dataclasses.dataclass @dataclasses.dataclass
class Losses(hyperparams.Config): class Losses(hyperparams.Config):
loss_weight: float = 1.0
rpn_huber_loss_delta: float = 1. / 9. rpn_huber_loss_delta: float = 1. / 9.
frcnn_huber_loss_delta: float = 1. frcnn_huber_loss_delta: float = 1.
l2_weight_decay: float = 0.0 l2_weight_decay: float = 0.0
......
...@@ -83,6 +83,7 @@ class Anchor(hyperparams.Config): ...@@ -83,6 +83,7 @@ class Anchor(hyperparams.Config):
@dataclasses.dataclass @dataclasses.dataclass
class Losses(hyperparams.Config): class Losses(hyperparams.Config):
loss_weight: float = 1.0
focal_loss_alpha: float = 0.25 focal_loss_alpha: float = 0.25
focal_loss_gamma: float = 1.5 focal_loss_gamma: float = 1.5
huber_loss_delta: float = 0.1 huber_loss_delta: float = 0.1
......
...@@ -92,6 +92,7 @@ class SemanticSegmentationModel(hyperparams.Config): ...@@ -92,6 +92,7 @@ class SemanticSegmentationModel(hyperparams.Config):
@dataclasses.dataclass @dataclasses.dataclass
class Losses(hyperparams.Config): class Losses(hyperparams.Config):
loss_weight: float = 1.0
label_smoothing: float = 0.0 label_smoothing: float = 0.0
ignore_label: int = 255 ignore_label: int = 255
class_weights: List[float] = dataclasses.field(default_factory=list) class_weights: List[float] = dataclasses.field(default_factory=list)
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""Factory methods to build models.""" """Factory methods to build models."""
# Import libraries from typing import Optional
import tensorflow as tf import tensorflow as tf
...@@ -41,10 +41,12 @@ from official.vision.beta.modeling.layers import roi_sampler ...@@ -41,10 +41,12 @@ from official.vision.beta.modeling.layers import roi_sampler
def build_classification_model( def build_classification_model(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: classification_cfg.ImageClassificationModel, model_config: classification_cfg.ImageClassificationModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None, l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
skip_logits_layer: bool = False) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras skip_logits_layer: bool = False,
backbone: Optional[tf.keras.Model] = None) -> tf.keras.Model:
"""Builds the classification model.""" """Builds the classification model."""
norm_activation_config = model_config.norm_activation norm_activation_config = model_config.norm_activation
if not backbone:
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
backbone_config=model_config.backbone, backbone_config=model_config.backbone,
...@@ -66,12 +68,15 @@ def build_classification_model( ...@@ -66,12 +68,15 @@ def build_classification_model(
return model return model
def build_maskrcnn( def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
input_specs: tf.keras.layers.InputSpec,
model_config: maskrcnn_cfg.MaskRCNN, model_config: maskrcnn_cfg.MaskRCNN,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras l2_regularizer: Optional[
tf.keras.regularizers.Regularizer] = None,
backbone: Optional[tf.keras.Model] = None,
decoder: Optional[tf.keras.Model] = None) -> tf.keras.Model:
"""Builds Mask R-CNN model.""" """Builds Mask R-CNN model."""
norm_activation_config = model_config.norm_activation norm_activation_config = model_config.norm_activation
if not backbone:
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
backbone_config=model_config.backbone, backbone_config=model_config.backbone,
...@@ -79,6 +84,7 @@ def build_maskrcnn( ...@@ -79,6 +84,7 @@ def build_maskrcnn(
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
backbone_features = backbone(tf.keras.Input(input_specs.shape[1:])) backbone_features = backbone(tf.keras.Input(input_specs.shape[1:]))
if not decoder:
decoder = decoders.factory.build_decoder( decoder = decoders.factory.build_decoder(
input_specs=backbone.output_specs, input_specs=backbone.output_specs,
model_config=model_config, model_config=model_config,
...@@ -121,7 +127,6 @@ def build_maskrcnn( ...@@ -121,7 +127,6 @@ def build_maskrcnn(
kernel_regularizer=l2_regularizer, kernel_regularizer=l2_regularizer,
name='detection_head') name='detection_head')
# Builds decoder and region proposal network:
if decoder: if decoder:
decoder_features = decoder(backbone_features) decoder_features = decoder(backbone_features)
rpn_head(decoder_features) rpn_head(decoder_features)
...@@ -253,9 +258,13 @@ def build_maskrcnn( ...@@ -253,9 +258,13 @@ def build_maskrcnn(
def build_retinanet( def build_retinanet(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: retinanet_cfg.RetinaNet, model_config: retinanet_cfg.RetinaNet,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
backbone: Optional[tf.keras.Model] = None,
decoder: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
"""Builds RetinaNet model.""" """Builds RetinaNet model."""
norm_activation_config = model_config.norm_activation norm_activation_config = model_config.norm_activation
if not backbone:
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
backbone_config=model_config.backbone, backbone_config=model_config.backbone,
...@@ -263,6 +272,7 @@ def build_retinanet( ...@@ -263,6 +272,7 @@ def build_retinanet(
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
backbone_features = backbone(tf.keras.Input(input_specs.shape[1:])) backbone_features = backbone(tf.keras.Input(input_specs.shape[1:]))
if not decoder:
decoder = decoders.factory.build_decoder( decoder = decoders.factory.build_decoder(
input_specs=backbone.output_specs, input_specs=backbone.output_specs,
model_config=model_config, model_config=model_config,
...@@ -321,15 +331,20 @@ def build_retinanet( ...@@ -321,15 +331,20 @@ def build_retinanet(
def build_segmentation_model( def build_segmentation_model(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: segmentation_cfg.SemanticSegmentationModel, model_config: segmentation_cfg.SemanticSegmentationModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
backbone: Optional[tf.keras.regularizers.Regularizer] = None,
decoder: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
"""Builds Segmentation model.""" """Builds Segmentation model."""
norm_activation_config = model_config.norm_activation norm_activation_config = model_config.norm_activation
if not backbone:
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
backbone_config=model_config.backbone, backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config, norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
if not decoder:
decoder = decoders.factory.build_decoder( decoder = decoders.factory.build_decoder(
input_specs=backbone.output_specs, input_specs=backbone.output_specs,
model_config=model_config, model_config=model_config,
......
...@@ -49,6 +49,7 @@ class ImageClassificationModel(hyperparams.Config): ...@@ -49,6 +49,7 @@ class ImageClassificationModel(hyperparams.Config):
@dataclasses.dataclass @dataclasses.dataclass
class Losses(hyperparams.Config): class Losses(hyperparams.Config):
loss_weight: float = 1.0
one_hot: bool = True one_hot: bool = True
label_smoothing: float = 0.0 label_smoothing: float = 0.0
l2_weight_decay: float = 0.0 l2_weight_decay: float = 0.0
......
...@@ -169,6 +169,7 @@ class ImageClassificationTask(base_task.Task): ...@@ -169,6 +169,7 @@ class ImageClassificationTask(base_task.Task):
if aux_losses: if aux_losses:
total_loss += tf.add_n(aux_losses) total_loss += tf.add_n(aux_losses)
total_loss = losses_config.loss_weight * total_loss
return total_loss return total_loss
def build_metrics(self, def build_metrics(self,
......
...@@ -236,6 +236,7 @@ class MaskRCNNTask(base_task.Task): ...@@ -236,6 +236,7 @@ class MaskRCNNTask(base_task.Task):
reg_loss = tf.reduce_sum(aux_losses) reg_loss = tf.reduce_sum(aux_losses)
total_loss = model_loss + reg_loss total_loss = model_loss + reg_loss
total_loss = params.losses.loss_weight * total_loss
losses = { losses = {
'total_loss': total_loss, 'total_loss': total_loss,
'rpn_score_loss': rpn_score_loss, 'rpn_score_loss': rpn_score_loss,
......
...@@ -220,6 +220,8 @@ class RetinaNetTask(base_task.Task): ...@@ -220,6 +220,8 @@ class RetinaNetTask(base_task.Task):
reg_loss = tf.reduce_sum(aux_losses) reg_loss = tf.reduce_sum(aux_losses)
total_loss = model_loss + reg_loss total_loss = model_loss + reg_loss
total_loss = params.losses.loss_weight * total_loss
return total_loss, cls_loss, box_loss, model_loss return total_loss, cls_loss, box_loss, model_loss
def build_metrics(self, training: bool = True): def build_metrics(self, training: bool = True):
......
...@@ -140,6 +140,8 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -140,6 +140,8 @@ class SemanticSegmentationTask(base_task.Task):
if aux_losses: if aux_losses:
total_loss += tf.add_n(aux_losses) total_loss += tf.add_n(aux_losses)
total_loss = loss_params.loss_weight * total_loss
return total_loss return total_loss
def build_metrics(self, training: bool = True): def build_metrics(self, training: bool = True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment