Commit b63376e6 authored by Rebecca Chen's avatar Rebecca Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 398612420
parent c5ae4110
...@@ -121,7 +121,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -121,7 +121,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
Returns: Returns:
A model instance. A model instance.
""" """ # pytype: disable=bad-return-type # typed-keras
@abc.abstractmethod @abc.abstractmethod
def build_inputs(self, def build_inputs(self,
......
...@@ -69,7 +69,7 @@ class ProgressivePolicy: ...@@ -69,7 +69,7 @@ class ProgressivePolicy:
shape=[]) shape=[])
self._volatiles.reassign_trackable( self._volatiles.reassign_trackable(
optimizer=self.get_optimizer(stage_id), optimizer=self.get_optimizer(stage_id),
model=self.get_model(stage_id, old_model=None)) model=self.get_model(stage_id, old_model=None)) # pytype: disable=wrong-arg-types # typed-keras
streamz_counters.progressive_policy_creation_counter.get_cell( streamz_counters.progressive_policy_creation_counter.get_cell(
).increase_by(1) ).increase_by(1)
...@@ -96,7 +96,7 @@ class ProgressivePolicy: ...@@ -96,7 +96,7 @@ class ProgressivePolicy:
@abc.abstractmethod @abc.abstractmethod
def get_model(self, def get_model(self,
stage_id: int, stage_id: int,
old_model: tf.keras.Model = None) -> tf.keras.Model: old_model: tf.keras.Model = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Return model for this stage. For initialization, `old_model` = None.""" """Return model for this stage. For initialization, `old_model` = None."""
pass pass
......
...@@ -252,6 +252,6 @@ def run_experiment_with_multitask_eval( ...@@ -252,6 +252,6 @@ def run_experiment_with_multitask_eval(
if run_post_eval: if run_post_eval:
return model, evaluator.evaluate( return model, evaluator.evaluate(
tf.convert_to_tensor(params.trainer.validation_steps)) tf.convert_to_tensor(params.trainer.validation_steps)) # pytype: disable=bad-return-type # typed-keras
else: else:
return model, {} return model, {} # pytype: disable=bad-return-type # typed-keras
...@@ -181,7 +181,7 @@ class XLNetClassifier(tf.keras.Model): ...@@ -181,7 +181,7 @@ class XLNetClassifier(tf.keras.Model):
initializer: tf.keras.initializers.Initializer = 'random_normal', initializer: tf.keras.initializers.Initializer = 'random_normal',
summary_type: str = 'last', summary_type: str = 'last',
dropout_rate: float = 0.1, dropout_rate: float = 0.1,
head_name: str = 'sentence_prediction', head_name: str = 'sentence_prediction', # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self._network = network self._network = network
...@@ -271,7 +271,7 @@ class XLNetSpanLabeler(tf.keras.Model): ...@@ -271,7 +271,7 @@ class XLNetSpanLabeler(tf.keras.Model):
end_n_top: int = 5, end_n_top: int = 5,
dropout_rate: float = 0.1, dropout_rate: float = 0.1,
span_labeling_activation: tf.keras.initializers.Initializer = 'tanh', span_labeling_activation: tf.keras.initializers.Initializer = 'tanh',
initializer: tf.keras.initializers.Initializer = 'glorot_uniform', initializer: tf.keras.initializers.Initializer = 'glorot_uniform', # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self._config = { self._config = {
......
...@@ -232,7 +232,7 @@ class EmbeddingPostprocessor(tf.keras.layers.Layer): ...@@ -232,7 +232,7 @@ class EmbeddingPostprocessor(tf.keras.layers.Layer):
def __call__(self, word_embeddings, token_type_ids=None, **kwargs): def __call__(self, word_embeddings, token_type_ids=None, **kwargs):
inputs = tf_utils.pack_inputs([word_embeddings, token_type_ids]) inputs = tf_utils.pack_inputs([word_embeddings, token_type_ids])
return super(EmbeddingPostprocessor, self).__call__(inputs, **kwargs) return super(EmbeddingPostprocessor, self).__call__(inputs, **kwargs) # pytype: disable=attribute-error # typed-keras
def call(self, inputs): def call(self, inputs):
"""Implements call() for the layer.""" """Implements call() for the layer."""
......
...@@ -147,7 +147,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -147,7 +147,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
config, custom_objects=custom_objects) config, custom_objects=custom_objects)
def _prepare_local(self, var_device, var_dtype, apply_state): def _prepare_local(self, var_device, var_dtype, apply_state):
super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype, super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype, # pytype: disable=attribute-error # typed-keras
apply_state) apply_state)
apply_state[(var_device, var_dtype)]['weight_decay_rate'] = tf.constant( apply_state[(var_device, var_dtype)]['weight_decay_rate'] = tf.constant(
self.weight_decay_rate, name='adam_weight_decay_rate') self.weight_decay_rate, name='adam_weight_decay_rate')
...@@ -197,14 +197,14 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -197,14 +197,14 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
decay = self._decay_weights_op(var, lr_t, apply_state) decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]): with tf.control_dependencies([decay]):
return super(AdamWeightDecay, return super(AdamWeightDecay,
self)._resource_apply_dense(grad, var, **kwargs) self)._resource_apply_dense(grad, var, **kwargs) # pytype: disable=attribute-error # typed-keras
def _resource_apply_sparse(self, grad, var, indices, apply_state=None): def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state) lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state) decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]): with tf.control_dependencies([decay]):
return super(AdamWeightDecay, return super(AdamWeightDecay,
self)._resource_apply_sparse(grad, var, indices, **kwargs) self)._resource_apply_sparse(grad, var, indices, **kwargs) # pytype: disable=attribute-error # typed-keras
def get_config(self): def get_config(self):
config = super(AdamWeightDecay, self).get_config() config = super(AdamWeightDecay, self).get_config()
......
...@@ -265,7 +265,7 @@ class BASNetEncoder(tf.keras.Model): ...@@ -265,7 +265,7 @@ class BASNetEncoder(tf.keras.Model):
def build_basnet_encoder( def build_basnet_encoder(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config, model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds BASNet Encoder backbone from a config.""" """Builds BASNet Encoder backbone from a config."""
backbone_type = model_config.backbone.type backbone_type = model_config.backbone.type
norm_activation_config = model_config.norm_activation norm_activation_config = model_config.norm_activation
......
...@@ -106,7 +106,7 @@ def build_decoder( ...@@ -106,7 +106,7 @@ def build_decoder(
input_specs: Mapping[str, tf.TensorShape], input_specs: Mapping[str, tf.TensorShape],
model_config: hyperparams.Config, model_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None, l2_regularizer: tf.keras.regularizers.Regularizer = None,
**kwargs) -> Union[None, tf.keras.Model, tf.keras.layers.Layer]: **kwargs) -> Union[None, tf.keras.Model, tf.keras.layers.Layer]: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds decoder from a config. """Builds decoder from a config.
A decoder can be a keras.Model, a keras.layers.Layer, or None. If it is not A decoder can be a keras.Model, a keras.layers.Layer, or None. If it is not
......
...@@ -42,7 +42,7 @@ def build_classification_model( ...@@ -42,7 +42,7 @@ def build_classification_model(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: classification_cfg.ImageClassificationModel, model_config: classification_cfg.ImageClassificationModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None, l2_regularizer: tf.keras.regularizers.Regularizer = None,
skip_logits_layer: bool = False) -> tf.keras.Model: skip_logits_layer: bool = False) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds the classification model.""" """Builds the classification model."""
norm_activation_config = model_config.norm_activation norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
...@@ -69,7 +69,7 @@ def build_classification_model( ...@@ -69,7 +69,7 @@ def build_classification_model(
def build_maskrcnn( def build_maskrcnn(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: maskrcnn_cfg.MaskRCNN, model_config: maskrcnn_cfg.MaskRCNN,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds Mask R-CNN model.""" """Builds Mask R-CNN model."""
norm_activation_config = model_config.norm_activation norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
...@@ -252,7 +252,7 @@ def build_maskrcnn( ...@@ -252,7 +252,7 @@ def build_maskrcnn(
def build_retinanet( def build_retinanet(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: retinanet_cfg.RetinaNet, model_config: retinanet_cfg.RetinaNet,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds RetinaNet model.""" """Builds RetinaNet model."""
norm_activation_config = model_config.norm_activation norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
...@@ -319,7 +319,7 @@ def build_retinanet( ...@@ -319,7 +319,7 @@ def build_retinanet(
def build_segmentation_model( def build_segmentation_model(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: segmentation_cfg.SemanticSegmentationModel, model_config: segmentation_cfg.SemanticSegmentationModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds Segmentation model.""" """Builds Segmentation model."""
norm_activation_config = model_config.norm_activation norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
......
...@@ -35,7 +35,7 @@ from official.vision.beta.tasks import maskrcnn ...@@ -35,7 +35,7 @@ from official.vision.beta.tasks import maskrcnn
# Taken from modeling/factory.py # Taken from modeling/factory.py
def build_maskrcnn(input_specs: tf.keras.layers.InputSpec, def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
model_config: deep_mask_head_rcnn_config.DeepMaskHeadRCNN, model_config: deep_mask_head_rcnn_config.DeepMaskHeadRCNN,
l2_regularizer: tf.keras.regularizers.Regularizer = None): l2_regularizer: tf.keras.regularizers.Regularizer = None): # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds Mask R-CNN model.""" """Builds Mask R-CNN model."""
norm_activation_config = model_config.norm_activation norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
......
...@@ -46,7 +46,7 @@ class UNet3D(tf.keras.Model): ...@@ -46,7 +46,7 @@ class UNet3D(tf.keras.Model):
norm_momentum: float = 0.99, norm_momentum: float = 0.99,
norm_epsilon: float = 0.001, norm_epsilon: float = 0.001,
use_sync_bn: bool = False, use_sync_bn: bool = False,
use_batch_normalization: bool = False, use_batch_normalization: bool = False, # type: ignore # typed-keras
**kwargs): **kwargs):
"""3D UNet backbone initialization function. """3D UNet backbone initialization function.
...@@ -156,7 +156,7 @@ def build_unet3d( ...@@ -156,7 +156,7 @@ def build_unet3d(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config, backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config, norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds 3D UNet backbone from a config.""" """Builds 3D UNet backbone from a config."""
backbone_type = backbone_config.type backbone_type = backbone_config.type
backbone_cfg = backbone_config.get() backbone_cfg = backbone_config.get()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment