Commit e39241d6 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 398095616
parent b883ceb8
......@@ -14,8 +14,8 @@
# Lint as: python3
"""Video classification configuration definition."""
from typing import Optional, Tuple
import dataclasses
from typing import Optional, Tuple
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
......@@ -121,6 +121,7 @@ class VideoClassificationModel(hyperparams.Config):
use_sync_bn=False)
dropout_rate: float = 0.2
aggregate_endpoints: bool = False
require_endpoints: Optional[Tuple[str, ...]] = None
@dataclasses.dataclass
......@@ -146,6 +147,10 @@ class VideoClassificationTask(cfg.TaskConfig):
metrics: Metrics = Metrics()
init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone
# Spatial Partitioning fields. See go/tf2-spatial-partition-api-examples
# for explanation of the technique.
train_input_partition_dims: Optional[Tuple[int, ...]] = None
eval_input_partition_dims: Optional[Tuple[int, ...]] = None
def add_trainer(experiment: cfg.ExperimentConfig,
......
......@@ -98,5 +98,6 @@ def build_video_classification_model(
input_specs=input_specs_dict,
dropout_rate=model_config.dropout_rate,
aggregate_endpoints=model_config.aggregate_endpoints,
kernel_regularizer=l2_regularizer)
kernel_regularizer=l2_regularizer,
require_endpoints=model_config.require_endpoints)
return model
......@@ -13,7 +13,8 @@
# limitations under the License.
"""Build video classification models."""
from typing import Any, Mapping, Optional, Union
from typing import Any, Mapping, Optional, Union, List, Text
import tensorflow as tf
layers = tf.keras.layers
......@@ -33,6 +34,7 @@ class VideoClassificationModel(tf.keras.Model):
kernel_initializer: str = 'random_uniform',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
require_endpoints: Optional[List[Text]] = None,
**kwargs):
"""Video Classification initialization function.
......@@ -48,6 +50,8 @@ class VideoClassificationModel(tf.keras.Model):
None.
bias_regularizer: tf.keras.regularizers.Regularizer object. Default to
None.
require_endpoints: the required endpoints for prediction. If None or
empty, then only uses the final endpoint.
**kwargs: keyword arguments to be passed.
"""
if not input_specs:
......@@ -64,6 +68,7 @@ class VideoClassificationModel(tf.keras.Model):
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
'require_endpoints': require_endpoints,
}
self._input_specs = input_specs
self._kernel_regularizer = kernel_regularizer
......@@ -82,8 +87,18 @@ class VideoClassificationModel(tf.keras.Model):
pooled_feats.append(x_pool)
x = tf.concat(pooled_feats, axis=1)
else:
x = endpoints[max(endpoints.keys())]
x = tf.keras.layers.GlobalAveragePooling3D()(x)
if not require_endpoints:
# Uses the last endpoint for prediction.
x = endpoints[max(endpoints.keys())]
x = tf.keras.layers.GlobalAveragePooling3D()(x)
else:
# Concats all the required endpoints for prediction.
outputs = []
for name in require_endpoints:
x = endpoints[name]
x = tf.keras.layers.GlobalAveragePooling3D()(x)
outputs.append(x)
x = tf.concat(outputs, axis=1)
x = tf.keras.layers.Dropout(dropout_rate)(x)
x = tf.keras.layers.Dense(
......
......@@ -255,6 +255,11 @@ class VideoClassificationTask(base_task.Task):
A dictionary of logs.
"""
features, labels = inputs
input_partition_dims = self.task_config.train_input_partition_dims
if input_partition_dims:
strategy = tf.distribute.get_strategy()
features['image'] = strategy.experimental_split_to_logical_devices(
features['image'], input_partition_dims)
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
......@@ -314,6 +319,11 @@ class VideoClassificationTask(base_task.Task):
A dictionary of logs.
"""
features, labels = inputs
input_partition_dims = self.task_config.eval_input_partition_dims
if input_partition_dims:
strategy = tf.distribute.get_strategy()
features['image'] = strategy.experimental_split_to_logical_devices(
features['image'], input_partition_dims)
outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment