Commit 24b2ccd4 authored by Rebecca Chen's avatar Rebecca Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 398612564
parent e5fbe328
...@@ -302,7 +302,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -302,7 +302,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
return attention_output return attention_output
def _build_from_signature(self, query, value, key=None): def _build_from_signature(self, query, value, key=None):
super()._build_from_signature(query=query, value=value, key=key) super()._build_from_signature(query=query, value=value, key=key) # pytype: disable=attribute-error # typed-keras
if self._begin_kernel > 0: if self._begin_kernel > 0:
common_kwargs = dict( common_kwargs = dict(
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
......
...@@ -120,7 +120,7 @@ class MultiChannelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -120,7 +120,7 @@ class MultiChannelAttention(tf.keras.layers.MultiHeadAttention):
""" """
def _build_attention(self, rank): def _build_attention(self, rank):
super(MultiChannelAttention, self)._build_attention(rank) super(MultiChannelAttention, self)._build_attention(rank) # pytype: disable=attribute-error # typed-keras
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2]) self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2])
def call(self, def call(self,
......
...@@ -114,7 +114,7 @@ class EfficientNet(tf.keras.Model): ...@@ -114,7 +114,7 @@ class EfficientNet(tf.keras.Model):
activation: str = 'relu', activation: str = 'relu',
use_sync_bn: bool = False, use_sync_bn: bool = False,
norm_momentum: float = 0.99, norm_momentum: float = 0.99,
norm_epsilon: float = 0.001, norm_epsilon: float = 0.001, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs): **kwargs):
"""Initializes an EfficientNet model. """Initializes an EfficientNet model.
...@@ -299,7 +299,7 @@ def build_efficientnet( ...@@ -299,7 +299,7 @@ def build_efficientnet(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config, norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds EfficientNet backbone from a config.""" """Builds EfficientNet backbone from a config."""
backbone_type = backbone_config.type backbone_type = backbone_config.type
backbone_cfg = backbone_config.get() backbone_cfg = backbone_config.get()
......
...@@ -88,7 +88,7 @@ def build_backbone(input_specs: Union[tf.keras.layers.InputSpec, ...@@ -88,7 +88,7 @@ def build_backbone(input_specs: Union[tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config, norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None, l2_regularizer: tf.keras.regularizers.Regularizer = None,
**kwargs) -> tf.keras.Model: **kwargs) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds backbone from a config. """Builds backbone from a config.
Args: Args:
......
...@@ -407,7 +407,7 @@ def build_resnet( ...@@ -407,7 +407,7 @@ def build_resnet(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config, norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds ResNet backbone from a config.""" """Builds ResNet backbone from a config."""
backbone_type = backbone_config.type backbone_type = backbone_config.type
backbone_cfg = backbone_config.get() backbone_cfg = backbone_config.get()
......
...@@ -343,7 +343,7 @@ def build_dilated_resnet( ...@@ -343,7 +343,7 @@ def build_dilated_resnet(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config, norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds ResNet backbone from a config.""" """Builds ResNet backbone from a config."""
backbone_type = backbone_config.type backbone_type = backbone_config.type
backbone_cfg = backbone_config.get() backbone_cfg = backbone_config.get()
......
...@@ -216,7 +216,7 @@ def build_revnet( ...@@ -216,7 +216,7 @@ def build_revnet(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config, norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds RevNet backbone from a config.""" """Builds RevNet backbone from a config."""
backbone_type = backbone_config.type backbone_type = backbone_config.type
backbone_cfg = backbone_config.get() backbone_cfg = backbone_config.get()
......
...@@ -705,7 +705,7 @@ def build_movinet( ...@@ -705,7 +705,7 @@ def build_movinet(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config, norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds MoViNet backbone from a config.""" """Builds MoViNet backbone from a config."""
backbone_type = backbone_config.type backbone_type = backbone_config.type
backbone_cfg = backbone_config.get() backbone_cfg = backbone_config.get()
......
...@@ -103,7 +103,7 @@ class MobileConv2D(tf.keras.layers.Layer): ...@@ -103,7 +103,7 @@ class MobileConv2D(tf.keras.layers.Layer):
bias_constraint: Optional[tf.keras.constraints.Constraint] = None, bias_constraint: Optional[tf.keras.constraints.Constraint] = None,
use_depthwise: bool = False, use_depthwise: bool = False,
use_temporal: bool = False, use_temporal: bool = False,
use_buffered_input: bool = False, use_buffered_input: bool = False, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs): # pylint: disable=g-doc-args **kwargs): # pylint: disable=g-doc-args
"""Initializes mobile conv2d. """Initializes mobile conv2d.
...@@ -270,7 +270,7 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -270,7 +270,7 @@ class ConvBlock(tf.keras.layers.Layer):
batch_norm_epsilon: float = 1e-3, batch_norm_epsilon: float = 1e-3,
activation: Optional[Any] = None, activation: Optional[Any] = None,
conv_type: str = '3d', conv_type: str = '3d',
use_buffered_input: bool = False, use_buffered_input: bool = False, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs): **kwargs):
"""Initializes a conv block. """Initializes a conv block.
...@@ -553,7 +553,7 @@ class StreamConvBlock(ConvBlock): ...@@ -553,7 +553,7 @@ class StreamConvBlock(ConvBlock):
batch_norm_epsilon: float = 1e-3, batch_norm_epsilon: float = 1e-3,
activation: Optional[Any] = None, activation: Optional[Any] = None,
conv_type: str = '3d', conv_type: str = '3d',
state_prefix: Optional[str] = None, state_prefix: Optional[str] = None, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs): **kwargs):
"""Initializes a stream conv block. """Initializes a stream conv block.
...@@ -678,7 +678,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer): ...@@ -678,7 +678,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras
.regularizers.L2(KERNEL_WEIGHT_DECAY), .regularizers.L2(KERNEL_WEIGHT_DECAY),
use_positional_encoding: bool = False, use_positional_encoding: bool = False,
state_prefix: Optional[str] = None, state_prefix: Optional[str] = None, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs): **kwargs):
"""Implementation for squeeze and excitation. """Implementation for squeeze and excitation.
...@@ -917,7 +917,7 @@ class SkipBlock(tf.keras.layers.Layer): ...@@ -917,7 +917,7 @@ class SkipBlock(tf.keras.layers.Layer):
batch_norm_layer: tf.keras.layers.Layer = batch_norm_layer: tf.keras.layers.Layer =
tf.keras.layers.BatchNormalization, tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99, batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3, batch_norm_epsilon: float = 1e-3, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs): **kwargs):
"""Implementation for skip block. """Implementation for skip block.
...@@ -1035,7 +1035,7 @@ class MovinetBlock(tf.keras.layers.Layer): ...@@ -1035,7 +1035,7 @@ class MovinetBlock(tf.keras.layers.Layer):
tf.keras.layers.BatchNormalization, tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99, batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3, batch_norm_epsilon: float = 1e-3,
state_prefix: Optional[str] = None, state_prefix: Optional[str] = None, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs): **kwargs):
"""Implementation for MoViNet block. """Implementation for MoViNet block.
...@@ -1235,7 +1235,7 @@ class Stem(tf.keras.layers.Layer): ...@@ -1235,7 +1235,7 @@ class Stem(tf.keras.layers.Layer):
tf.keras.layers.BatchNormalization, tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99, batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3, batch_norm_epsilon: float = 1e-3,
state_prefix: Optional[str] = None, state_prefix: Optional[str] = None, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs): **kwargs):
"""Implementation for video model stem. """Implementation for video model stem.
...@@ -1343,7 +1343,7 @@ class Head(tf.keras.layers.Layer): ...@@ -1343,7 +1343,7 @@ class Head(tf.keras.layers.Layer):
tf.keras.layers.BatchNormalization, tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99, batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3, batch_norm_epsilon: float = 1e-3,
state_prefix: Optional[str] = None, state_prefix: Optional[str] = None, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs): **kwargs):
"""Implementation for video model head. """Implementation for video model head.
...@@ -1442,7 +1442,7 @@ class ClassifierHead(tf.keras.layers.Layer): ...@@ -1442,7 +1442,7 @@ class ClassifierHead(tf.keras.layers.Layer):
max_pool_predictions: bool = False, max_pool_predictions: bool = False,
kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal', kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] =
tf.keras.regularizers.L2(KERNEL_WEIGHT_DECAY), tf.keras.regularizers.L2(KERNEL_WEIGHT_DECAY), # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs): **kwargs):
"""Implementation for video model classifier head. """Implementation for video model classifier head.
......
...@@ -93,7 +93,7 @@ class MovinetClassifier(tf.keras.Model): ...@@ -93,7 +93,7 @@ class MovinetClassifier(tf.keras.Model):
backbone: tf.keras.Model, backbone: tf.keras.Model,
input_specs: Mapping[str, tf.keras.layers.InputSpec], input_specs: Mapping[str, tf.keras.layers.InputSpec],
state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None, state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
) -> Tuple[Mapping[str, tf.keras.Input], Union[Tuple[Mapping[ ) -> Tuple[Mapping[str, tf.keras.Input], Union[Tuple[Mapping[ # pytype: disable=invalid-annotation # typed-keras
str, tf.Tensor], Mapping[str, tf.Tensor]], Mapping[str, tf.Tensor]]]: str, tf.Tensor], Mapping[str, tf.Tensor]], Mapping[str, tf.Tensor]]]:
"""Builds the model network. """Builds the model network.
......
...@@ -27,7 +27,7 @@ from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_ma ...@@ -27,7 +27,7 @@ from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_ma
def build_panoptic_maskrcnn( def build_panoptic_maskrcnn(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: panoptic_maskrcnn_cfg.PanopticMaskRCNN, model_config: panoptic_maskrcnn_cfg.PanopticMaskRCNN,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds Panoptic Mask R-CNN model. """Builds Panoptic Mask R-CNN model.
This factory function builds the mask rcnn first, builds the non-shared This factory function builds the mask rcnn first, builds the non-shared
......
...@@ -48,7 +48,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -48,7 +48,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
max_level: Optional[int] = None, max_level: Optional[int] = None,
num_scales: Optional[int] = None, num_scales: Optional[int] = None,
aspect_ratios: Optional[List[float]] = None, aspect_ratios: Optional[List[float]] = None,
anchor_size: Optional[float] = None, anchor_size: Optional[float] = None, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs): **kwargs):
"""Initializes the Panoptic Mask R-CNN model. """Initializes the Panoptic Mask R-CNN model.
......
...@@ -92,7 +92,7 @@ def build_decoder( ...@@ -92,7 +92,7 @@ def build_decoder(
input_specs: Mapping[str, tf.TensorShape], input_specs: Mapping[str, tf.TensorShape],
model_config: hyperparams.Config, model_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None, l2_regularizer: tf.keras.regularizers.Regularizer = None,
**kwargs) -> Union[None, tf.keras.Model, tf.keras.layers.Layer]: **kwargs) -> Union[None, tf.keras.Model, tf.keras.layers.Layer]: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds decoder from a config. """Builds decoder from a config.
Args: Args:
......
...@@ -45,7 +45,7 @@ class UNet3DDecoder(tf.keras.Model): ...@@ -45,7 +45,7 @@ class UNet3DDecoder(tf.keras.Model):
norm_epsilon: float = 0.001, norm_epsilon: float = 0.001,
use_sync_bn: bool = False, use_sync_bn: bool = False,
use_batch_normalization: bool = False, use_batch_normalization: bool = False,
use_deconvolution: bool = False, use_deconvolution: bool = False, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs): **kwargs):
"""3D UNet decoder initialization function. """3D UNet decoder initialization function.
......
...@@ -28,7 +28,7 @@ from official.vision.beta.projects.volumetric_models.modeling.heads import segme ...@@ -28,7 +28,7 @@ from official.vision.beta.projects.volumetric_models.modeling.heads import segme
def build_segmentation_model_3d( def build_segmentation_model_3d(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: hyperparams.Config, model_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds Segmentation model.""" """Builds Segmentation model."""
norm_activation_config = model_config.norm_activation norm_activation_config = model_config.norm_activation
backbone = backbone_factory.build_backbone( backbone = backbone_factory.build_backbone(
......
...@@ -37,7 +37,7 @@ class SegmentationHead3D(tf.keras.layers.Layer): ...@@ -37,7 +37,7 @@ class SegmentationHead3D(tf.keras.layers.Layer):
use_batch_normalization: bool = False, use_batch_normalization: bool = False,
kernel_regularizer: tf.keras.regularizers.Regularizer = None, kernel_regularizer: tf.keras.regularizers.Regularizer = None,
bias_regularizer: tf.keras.regularizers.Regularizer = None, bias_regularizer: tf.keras.regularizers.Regularizer = None,
output_logits: bool = True, output_logits: bool = True, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs): **kwargs):
"""Initialize params to build segmentation head. """Initialize params to build segmentation head.
......
...@@ -38,7 +38,7 @@ class BasicBlock3DVolume(tf.keras.layers.Layer): ...@@ -38,7 +38,7 @@ class BasicBlock3DVolume(tf.keras.layers.Layer):
use_sync_bn: bool = False, use_sync_bn: bool = False,
norm_momentum: float = 0.99, norm_momentum: float = 0.99,
norm_epsilon: float = 0.001, norm_epsilon: float = 0.001,
use_batch_normalization: bool = False, use_batch_normalization: bool = False, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs): **kwargs):
"""Creates a basic 3d convolution block applying one or more convolutions. """Creates a basic 3d convolution block applying one or more convolutions.
......
...@@ -664,7 +664,7 @@ def build_darknet( ...@@ -664,7 +664,7 @@ def build_darknet(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
backbone_cfg: hyperparams.Config, backbone_cfg: hyperparams.Config,
norm_activation_config: hyperparams.Config, norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds darknet.""" """Builds darknet."""
backbone_cfg = backbone_cfg.get() backbone_cfg = backbone_cfg.get()
......
...@@ -244,7 +244,7 @@ class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint): ...@@ -244,7 +244,7 @@ class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
if self.update_weights: if self.update_weights:
self.model.optimizer.assign_average_vars(self.model.variables) self.model.optimizer.assign_average_vars(self.model.variables)
return super()._save_model(epoch, logs) return super()._save_model(epoch, logs) # pytype: disable=attribute-error # typed-keras
else: else:
# Note: `model.get_weights()` gives us the weights (non-ref) # Note: `model.get_weights()` gives us the weights (non-ref)
# whereas `model.variables` returns references to the variables. # whereas `model.variables` returns references to the variables.
...@@ -252,6 +252,6 @@ class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint): ...@@ -252,6 +252,6 @@ class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
self.model.optimizer.assign_average_vars(self.model.variables) self.model.optimizer.assign_average_vars(self.model.variables)
# result is currently None, since `super._save_model` doesn't # result is currently None, since `super._save_model` doesn't
# return anything, but this may change in the future. # return anything, but this may change in the future.
result = super()._save_model(epoch, logs) result = super()._save_model(epoch, logs) # pytype: disable=attribute-error # typed-keras
self.model.set_weights(non_avg_weights) self.model.set_weights(non_avg_weights)
return result return result
...@@ -87,7 +87,7 @@ def get_batch_norm(batch_norm_type: Text) -> tf.keras.layers.BatchNormalization: ...@@ -87,7 +87,7 @@ def get_batch_norm(batch_norm_type: Text) -> tf.keras.layers.BatchNormalization:
if batch_norm_type == 'tpu': if batch_norm_type == 'tpu':
return TpuBatchNormalization return TpuBatchNormalization
return tf.keras.layers.BatchNormalization return tf.keras.layers.BatchNormalization # pytype: disable=bad-return-type # typed-keras
def count_params(model, trainable_only=True): def count_params(model, trainable_only=True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment