Commit 71ab0c31 authored by Rebecca Chen's avatar Rebecca Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 398612564
parent b63376e6
......@@ -302,7 +302,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
return attention_output
def _build_from_signature(self, query, value, key=None):
super()._build_from_signature(query=query, value=value, key=key)
super()._build_from_signature(query=query, value=value, key=key) # pytype: disable=attribute-error # typed-keras
if self._begin_kernel > 0:
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
......
......@@ -120,7 +120,7 @@ class MultiChannelAttention(tf.keras.layers.MultiHeadAttention):
"""
def _build_attention(self, rank):
super(MultiChannelAttention, self)._build_attention(rank)
super(MultiChannelAttention, self)._build_attention(rank) # pytype: disable=attribute-error # typed-keras
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2])
def call(self,
......
......@@ -114,7 +114,7 @@ class EfficientNet(tf.keras.Model):
activation: str = 'relu',
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
norm_epsilon: float = 0.001, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs):
"""Initializes an EfficientNet model.
......@@ -299,7 +299,7 @@ def build_efficientnet(
input_specs: tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds EfficientNet backbone from a config."""
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
......
......@@ -88,7 +88,7 @@ def build_backbone(input_specs: Union[tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None,
**kwargs) -> tf.keras.Model:
**kwargs) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds backbone from a config.
Args:
......
......@@ -407,7 +407,7 @@ def build_resnet(
input_specs: tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds ResNet backbone from a config."""
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
......
......@@ -343,7 +343,7 @@ def build_dilated_resnet(
input_specs: tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds ResNet backbone from a config."""
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
......
......@@ -216,7 +216,7 @@ def build_revnet(
input_specs: tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds RevNet backbone from a config."""
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
......
......@@ -705,7 +705,7 @@ def build_movinet(
input_specs: tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds MoViNet backbone from a config."""
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
......
......@@ -103,7 +103,7 @@ class MobileConv2D(tf.keras.layers.Layer):
bias_constraint: Optional[tf.keras.constraints.Constraint] = None,
use_depthwise: bool = False,
use_temporal: bool = False,
use_buffered_input: bool = False,
use_buffered_input: bool = False, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs): # pylint: disable=g-doc-args
"""Initializes mobile conv2d.
......@@ -270,7 +270,7 @@ class ConvBlock(tf.keras.layers.Layer):
batch_norm_epsilon: float = 1e-3,
activation: Optional[Any] = None,
conv_type: str = '3d',
use_buffered_input: bool = False,
use_buffered_input: bool = False, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs):
"""Initializes a conv block.
......@@ -553,7 +553,7 @@ class StreamConvBlock(ConvBlock):
batch_norm_epsilon: float = 1e-3,
activation: Optional[Any] = None,
conv_type: str = '3d',
state_prefix: Optional[str] = None,
state_prefix: Optional[str] = None, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs):
"""Initializes a stream conv block.
......@@ -678,7 +678,7 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = tf.keras
.regularizers.L2(KERNEL_WEIGHT_DECAY),
use_positional_encoding: bool = False,
state_prefix: Optional[str] = None,
state_prefix: Optional[str] = None, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs):
"""Implementation for squeeze and excitation.
......@@ -917,7 +917,7 @@ class SkipBlock(tf.keras.layers.Layer):
batch_norm_layer: tf.keras.layers.Layer =
tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3,
batch_norm_epsilon: float = 1e-3, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs):
"""Implementation for skip block.
......@@ -1035,7 +1035,7 @@ class MovinetBlock(tf.keras.layers.Layer):
tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3,
state_prefix: Optional[str] = None,
state_prefix: Optional[str] = None, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs):
"""Implementation for MoViNet block.
......@@ -1235,7 +1235,7 @@ class Stem(tf.keras.layers.Layer):
tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3,
state_prefix: Optional[str] = None,
state_prefix: Optional[str] = None, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs):
"""Implementation for video model stem.
......@@ -1343,7 +1343,7 @@ class Head(tf.keras.layers.Layer):
tf.keras.layers.BatchNormalization,
batch_norm_momentum: float = 0.99,
batch_norm_epsilon: float = 1e-3,
state_prefix: Optional[str] = None,
state_prefix: Optional[str] = None, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs):
"""Implementation for video model head.
......@@ -1442,7 +1442,7 @@ class ClassifierHead(tf.keras.layers.Layer):
max_pool_predictions: bool = False,
kernel_initializer: tf.keras.initializers.Initializer = 'HeNormal',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] =
tf.keras.regularizers.L2(KERNEL_WEIGHT_DECAY),
tf.keras.regularizers.L2(KERNEL_WEIGHT_DECAY), # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs):
"""Implementation for video model classifier head.
......
......@@ -93,7 +93,7 @@ class MovinetClassifier(tf.keras.Model):
backbone: tf.keras.Model,
input_specs: Mapping[str, tf.keras.layers.InputSpec],
state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
) -> Tuple[Mapping[str, tf.keras.Input], Union[Tuple[Mapping[
) -> Tuple[Mapping[str, tf.keras.Input], Union[Tuple[Mapping[ # pytype: disable=invalid-annotation # typed-keras
str, tf.Tensor], Mapping[str, tf.Tensor]], Mapping[str, tf.Tensor]]]:
"""Builds the model network.
......
......@@ -27,7 +27,7 @@ from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_ma
def build_panoptic_maskrcnn(
input_specs: tf.keras.layers.InputSpec,
model_config: panoptic_maskrcnn_cfg.PanopticMaskRCNN,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds Panoptic Mask R-CNN model.
This factory function builds the mask rcnn first, builds the non-shared
......
......@@ -48,7 +48,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
max_level: Optional[int] = None,
num_scales: Optional[int] = None,
aspect_ratios: Optional[List[float]] = None,
anchor_size: Optional[float] = None,
anchor_size: Optional[float] = None, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs):
"""Initializes the Panoptic Mask R-CNN model.
......
......@@ -92,7 +92,7 @@ def build_decoder(
input_specs: Mapping[str, tf.TensorShape],
model_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None,
**kwargs) -> Union[None, tf.keras.Model, tf.keras.layers.Layer]:
**kwargs) -> Union[None, tf.keras.Model, tf.keras.layers.Layer]: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds decoder from a config.
Args:
......
......@@ -45,7 +45,7 @@ class UNet3DDecoder(tf.keras.Model):
norm_epsilon: float = 0.001,
use_sync_bn: bool = False,
use_batch_normalization: bool = False,
use_deconvolution: bool = False,
use_deconvolution: bool = False, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs):
"""3D UNet decoder initialization function.
......
......@@ -28,7 +28,7 @@ from official.vision.beta.projects.volumetric_models.modeling.heads import segme
def build_segmentation_model_3d(
input_specs: tf.keras.layers.InputSpec,
model_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds Segmentation model."""
norm_activation_config = model_config.norm_activation
backbone = backbone_factory.build_backbone(
......
......@@ -37,7 +37,7 @@ class SegmentationHead3D(tf.keras.layers.Layer):
use_batch_normalization: bool = False,
kernel_regularizer: tf.keras.regularizers.Regularizer = None,
bias_regularizer: tf.keras.regularizers.Regularizer = None,
output_logits: bool = True,
output_logits: bool = True, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs):
"""Initialize params to build segmentation head.
......
......@@ -38,7 +38,7 @@ class BasicBlock3DVolume(tf.keras.layers.Layer):
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
use_batch_normalization: bool = False,
use_batch_normalization: bool = False, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs):
"""Creates a basic 3d convolution block applying one or more convolutions.
......
......@@ -664,7 +664,7 @@ def build_darknet(
input_specs: tf.keras.layers.InputSpec,
backbone_cfg: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds darknet."""
backbone_cfg = backbone_cfg.get()
......
......@@ -244,7 +244,7 @@ class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
if self.update_weights:
self.model.optimizer.assign_average_vars(self.model.variables)
return super()._save_model(epoch, logs)
return super()._save_model(epoch, logs) # pytype: disable=attribute-error # typed-keras
else:
# Note: `model.get_weights()` gives us the weights (non-ref)
# whereas `model.variables` returns references to the variables.
......@@ -252,6 +252,6 @@ class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
self.model.optimizer.assign_average_vars(self.model.variables)
# result is currently None, since `super._save_model` doesn't
# return anything, but this may change in the future.
result = super()._save_model(epoch, logs)
result = super()._save_model(epoch, logs) # pytype: disable=attribute-error # typed-keras
self.model.set_weights(non_avg_weights)
return result
......@@ -87,7 +87,7 @@ def get_batch_norm(batch_norm_type: Text) -> tf.keras.layers.BatchNormalization:
if batch_norm_type == 'tpu':
return TpuBatchNormalization
return tf.keras.layers.BatchNormalization
return tf.keras.layers.BatchNormalization # pytype: disable=bad-return-type # typed-keras
def count_params(model, trainable_only=True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment