Commit e5fbe328 authored by Rebecca Chen's avatar Rebecca Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 398612420
parent 58644b96
......@@ -121,7 +121,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
Returns:
A model instance.
"""
""" # pytype: disable=bad-return-type # typed-keras
@abc.abstractmethod
def build_inputs(self,
......
......@@ -69,7 +69,7 @@ class ProgressivePolicy:
shape=[])
self._volatiles.reassign_trackable(
optimizer=self.get_optimizer(stage_id),
model=self.get_model(stage_id, old_model=None))
model=self.get_model(stage_id, old_model=None)) # pytype: disable=wrong-arg-types # typed-keras
streamz_counters.progressive_policy_creation_counter.get_cell(
).increase_by(1)
......@@ -96,7 +96,7 @@ class ProgressivePolicy:
@abc.abstractmethod
def get_model(self,
stage_id: int,
old_model: tf.keras.Model = None) -> tf.keras.Model:
old_model: tf.keras.Model = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Return model for this stage. For initialization, `old_model` = None."""
pass
......
......@@ -252,6 +252,6 @@ def run_experiment_with_multitask_eval(
if run_post_eval:
return model, evaluator.evaluate(
tf.convert_to_tensor(params.trainer.validation_steps))
tf.convert_to_tensor(params.trainer.validation_steps)) # pytype: disable=bad-return-type # typed-keras
else:
return model, {}
return model, {} # pytype: disable=bad-return-type # typed-keras
......@@ -181,7 +181,7 @@ class XLNetClassifier(tf.keras.Model):
initializer: tf.keras.initializers.Initializer = 'random_normal',
summary_type: str = 'last',
dropout_rate: float = 0.1,
head_name: str = 'sentence_prediction',
head_name: str = 'sentence_prediction', # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs):
super().__init__(**kwargs)
self._network = network
......@@ -271,7 +271,7 @@ class XLNetSpanLabeler(tf.keras.Model):
end_n_top: int = 5,
dropout_rate: float = 0.1,
span_labeling_activation: tf.keras.initializers.Initializer = 'tanh',
initializer: tf.keras.initializers.Initializer = 'glorot_uniform',
initializer: tf.keras.initializers.Initializer = 'glorot_uniform', # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs):
super().__init__(**kwargs)
self._config = {
......
......@@ -232,7 +232,7 @@ class EmbeddingPostprocessor(tf.keras.layers.Layer):
def __call__(self, word_embeddings, token_type_ids=None, **kwargs):
inputs = tf_utils.pack_inputs([word_embeddings, token_type_ids])
return super(EmbeddingPostprocessor, self).__call__(inputs, **kwargs)
return super(EmbeddingPostprocessor, self).__call__(inputs, **kwargs) # pytype: disable=attribute-error # typed-keras
def call(self, inputs):
"""Implements call() for the layer."""
......
......@@ -147,7 +147,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
config, custom_objects=custom_objects)
def _prepare_local(self, var_device, var_dtype, apply_state):
super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype,
super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype, # pytype: disable=attribute-error # typed-keras
apply_state)
apply_state[(var_device, var_dtype)]['weight_decay_rate'] = tf.constant(
self.weight_decay_rate, name='adam_weight_decay_rate')
......@@ -197,14 +197,14 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
return super(AdamWeightDecay,
self)._resource_apply_dense(grad, var, **kwargs)
self)._resource_apply_dense(grad, var, **kwargs) # pytype: disable=attribute-error # typed-keras
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
return super(AdamWeightDecay,
self)._resource_apply_sparse(grad, var, indices, **kwargs)
self)._resource_apply_sparse(grad, var, indices, **kwargs) # pytype: disable=attribute-error # typed-keras
def get_config(self):
config = super(AdamWeightDecay, self).get_config()
......
......@@ -265,7 +265,7 @@ class BASNetEncoder(tf.keras.Model):
def build_basnet_encoder(
input_specs: tf.keras.layers.InputSpec,
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds BASNet Encoder backbone from a config."""
backbone_type = model_config.backbone.type
norm_activation_config = model_config.norm_activation
......
......@@ -106,7 +106,7 @@ def build_decoder(
input_specs: Mapping[str, tf.TensorShape],
model_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None,
**kwargs) -> Union[None, tf.keras.Model, tf.keras.layers.Layer]:
**kwargs) -> Union[None, tf.keras.Model, tf.keras.layers.Layer]: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds decoder from a config.
A decoder can be a keras.Model, a keras.layers.Layer, or None. If it is not
......
......@@ -42,7 +42,7 @@ def build_classification_model(
input_specs: tf.keras.layers.InputSpec,
model_config: classification_cfg.ImageClassificationModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None,
skip_logits_layer: bool = False) -> tf.keras.Model:
skip_logits_layer: bool = False) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds the classification model."""
norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone(
......@@ -69,7 +69,7 @@ def build_classification_model(
def build_maskrcnn(
input_specs: tf.keras.layers.InputSpec,
model_config: maskrcnn_cfg.MaskRCNN,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds Mask R-CNN model."""
norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone(
......@@ -252,7 +252,7 @@ def build_maskrcnn(
def build_retinanet(
input_specs: tf.keras.layers.InputSpec,
model_config: retinanet_cfg.RetinaNet,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds RetinaNet model."""
norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone(
......@@ -319,7 +319,7 @@ def build_retinanet(
def build_segmentation_model(
input_specs: tf.keras.layers.InputSpec,
model_config: segmentation_cfg.SemanticSegmentationModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds Segmentation model."""
norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone(
......
......@@ -35,7 +35,7 @@ from official.vision.beta.tasks import maskrcnn
# Taken from modeling/factory.py
def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
model_config: deep_mask_head_rcnn_config.DeepMaskHeadRCNN,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
l2_regularizer: tf.keras.regularizers.Regularizer = None): # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds Mask R-CNN model."""
norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone(
......
......@@ -46,7 +46,7 @@ class UNet3D(tf.keras.Model):
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
use_sync_bn: bool = False,
use_batch_normalization: bool = False,
use_batch_normalization: bool = False, # type: ignore # typed-keras
**kwargs):
"""3D UNet backbone initialization function.
......@@ -156,7 +156,7 @@ def build_unet3d(
input_specs: tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds 3D UNet backbone from a config."""
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment