Commit 499b61bf authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by A. Unique TensorFlower
Browse files

internal change.

PiperOrigin-RevId: 422665603
parent 06d2681c
...@@ -88,14 +88,13 @@ class MovinetClassifier(tf.keras.Model): ...@@ -88,14 +88,13 @@ class MovinetClassifier(tf.keras.Model):
# Move backbone after super() call so Keras is happy # Move backbone after super() call so Keras is happy
self._backbone = backbone self._backbone = backbone
def _build_network( def _build_backbone(
self, self,
backbone: tf.keras.Model, backbone: tf.keras.Model,
input_specs: Mapping[str, tf.keras.layers.InputSpec], input_specs: Mapping[str, tf.keras.layers.InputSpec],
state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None, state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
) -> Tuple[Mapping[str, tf.keras.Input], Union[Tuple[Mapping[ # pytype: disable=invalid-annotation # typed-keras ) -> Tuple[Mapping[str, Any], Any, Any]:
str, tf.Tensor], Mapping[str, tf.Tensor]], Mapping[str, tf.Tensor]]]: """Builds the backbone network and gets states and endpoints.
"""Builds the model network.
Args: Args:
backbone: the model backbone. backbone: the model backbone.
...@@ -104,9 +103,9 @@ class MovinetClassifier(tf.keras.Model): ...@@ -104,9 +103,9 @@ class MovinetClassifier(tf.keras.Model):
layer, will overwrite the contents of the buffer(s). layer, will overwrite the contents of the buffer(s).
Returns: Returns:
Inputs and outputs as a tuple. Inputs are expected to be a dict with inputs: a dict of input specs.
base input and states. Outputs are expected to be a dict of endpoints endpoints: a dict of model endpoints.
and (optionally) output states. states: a dict of model states.
""" """
state_specs = state_specs if state_specs is not None else {} state_specs = state_specs if state_specs is not None else {}
...@@ -145,7 +144,30 @@ class MovinetClassifier(tf.keras.Model): ...@@ -145,7 +144,30 @@ class MovinetClassifier(tf.keras.Model):
mismatched_shapes)) mismatched_shapes))
else: else:
endpoints, states = backbone(inputs) endpoints, states = backbone(inputs)
return inputs, endpoints, states
def _build_network(
self,
backbone: tf.keras.Model,
input_specs: Mapping[str, tf.keras.layers.InputSpec],
state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
) -> Tuple[Mapping[str, tf.keras.Input], Union[Tuple[Mapping[ # pytype: disable=invalid-annotation # typed-keras
str, tf.Tensor], Mapping[str, tf.Tensor]], Mapping[str, tf.Tensor]]]:
"""Builds the model network.
Args:
backbone: the model backbone.
input_specs: the model input spec to use.
state_specs: a dict of states such that, if any of the keys match for a
layer, will overwrite the contents of the buffer(s).
Returns:
Inputs and outputs as a tuple. Inputs are expected to be a dict with
base input and states. Outputs are expected to be a dict of endpoints
and (optionally) output states.
"""
inputs, endpoints, states = self._build_backbone(
backbone=backbone, input_specs=input_specs, state_specs=state_specs)
x = endpoints['head'] x = endpoints['head']
x = movinet_layers.ClassifierHead( x = movinet_layers.ClassifierHead(
......
...@@ -46,7 +46,8 @@ from official.modeling import performance ...@@ -46,7 +46,8 @@ from official.modeling import performance
# Import movinet libraries to register the backbone and model into tf.vision # Import movinet libraries to register the backbone and model into tf.vision
# model garden factory. # model garden factory.
# pylint: disable=unused-import # pylint: disable=unused-import
# the followings are the necessary imports. from official.projects.movinet.google.configs import movinet_google
from official.projects.movinet.google.modeling import movinet_model_google
from official.projects.movinet.modeling import movinet from official.projects.movinet.modeling import movinet
from official.projects.movinet.modeling import movinet_model from official.projects.movinet.modeling import movinet_model
# pylint: enable=unused-import # pylint: enable=unused-import
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment