"git@developer.sourcefind.cn:OpenDAS/torchani.git" did not exist on "b59551d88d7181c93825ce7224afcddf70c89e23"
Commit e39241d6 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 398095616
parent b883ceb8
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# Lint as: python3 # Lint as: python3
"""Video classification configuration definition.""" """Video classification configuration definition."""
from typing import Optional, Tuple
import dataclasses import dataclasses
from typing import Optional, Tuple
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
...@@ -121,6 +121,7 @@ class VideoClassificationModel(hyperparams.Config): ...@@ -121,6 +121,7 @@ class VideoClassificationModel(hyperparams.Config):
use_sync_bn=False) use_sync_bn=False)
dropout_rate: float = 0.2 dropout_rate: float = 0.2
aggregate_endpoints: bool = False aggregate_endpoints: bool = False
require_endpoints: Optional[Tuple[str, ...]] = None
@dataclasses.dataclass @dataclasses.dataclass
...@@ -146,6 +147,10 @@ class VideoClassificationTask(cfg.TaskConfig): ...@@ -146,6 +147,10 @@ class VideoClassificationTask(cfg.TaskConfig):
metrics: Metrics = Metrics() metrics: Metrics = Metrics()
init_checkpoint: Optional[str] = None init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone init_checkpoint_modules: str = 'all' # all or backbone
# Spatial Partitioning fields. See go/tf2-spatial-partition-api-examples
# for explanation of the technique.
train_input_partition_dims: Optional[Tuple[int, ...]] = None
eval_input_partition_dims: Optional[Tuple[int, ...]] = None
def add_trainer(experiment: cfg.ExperimentConfig, def add_trainer(experiment: cfg.ExperimentConfig,
......
...@@ -98,5 +98,6 @@ def build_video_classification_model( ...@@ -98,5 +98,6 @@ def build_video_classification_model(
input_specs=input_specs_dict, input_specs=input_specs_dict,
dropout_rate=model_config.dropout_rate, dropout_rate=model_config.dropout_rate,
aggregate_endpoints=model_config.aggregate_endpoints, aggregate_endpoints=model_config.aggregate_endpoints,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer,
require_endpoints=model_config.require_endpoints)
return model return model
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
"""Build video classification models.""" """Build video classification models."""
from typing import Any, Mapping, Optional, Union from typing import Any, Mapping, Optional, Union, List, Text
import tensorflow as tf import tensorflow as tf
layers = tf.keras.layers layers = tf.keras.layers
...@@ -33,6 +34,7 @@ class VideoClassificationModel(tf.keras.Model): ...@@ -33,6 +34,7 @@ class VideoClassificationModel(tf.keras.Model):
kernel_initializer: str = 'random_uniform', kernel_initializer: str = 'random_uniform',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
require_endpoints: Optional[List[Text]] = None,
**kwargs): **kwargs):
"""Video Classification initialization function. """Video Classification initialization function.
...@@ -48,6 +50,8 @@ class VideoClassificationModel(tf.keras.Model): ...@@ -48,6 +50,8 @@ class VideoClassificationModel(tf.keras.Model):
None. None.
bias_regularizer: tf.keras.regularizers.Regularizer object. Default to bias_regularizer: tf.keras.regularizers.Regularizer object. Default to
None. None.
require_endpoints: the required endpoints for prediction. If None or
empty, then only uses the final endpoint.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
if not input_specs: if not input_specs:
...@@ -64,6 +68,7 @@ class VideoClassificationModel(tf.keras.Model): ...@@ -64,6 +68,7 @@ class VideoClassificationModel(tf.keras.Model):
'kernel_initializer': kernel_initializer, 'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer, 'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer, 'bias_regularizer': bias_regularizer,
'require_endpoints': require_endpoints,
} }
self._input_specs = input_specs self._input_specs = input_specs
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
...@@ -82,8 +87,18 @@ class VideoClassificationModel(tf.keras.Model): ...@@ -82,8 +87,18 @@ class VideoClassificationModel(tf.keras.Model):
pooled_feats.append(x_pool) pooled_feats.append(x_pool)
x = tf.concat(pooled_feats, axis=1) x = tf.concat(pooled_feats, axis=1)
else: else:
x = endpoints[max(endpoints.keys())] if not require_endpoints:
x = tf.keras.layers.GlobalAveragePooling3D()(x) # Uses the last endpoint for prediction.
x = endpoints[max(endpoints.keys())]
x = tf.keras.layers.GlobalAveragePooling3D()(x)
else:
# Concats all the required endpoints for prediction.
outputs = []
for name in require_endpoints:
x = endpoints[name]
x = tf.keras.layers.GlobalAveragePooling3D()(x)
outputs.append(x)
x = tf.concat(outputs, axis=1)
x = tf.keras.layers.Dropout(dropout_rate)(x) x = tf.keras.layers.Dropout(dropout_rate)(x)
x = tf.keras.layers.Dense( x = tf.keras.layers.Dense(
......
...@@ -255,6 +255,11 @@ class VideoClassificationTask(base_task.Task): ...@@ -255,6 +255,11 @@ class VideoClassificationTask(base_task.Task):
A dictionary of logs. A dictionary of logs.
""" """
features, labels = inputs features, labels = inputs
input_partition_dims = self.task_config.train_input_partition_dims
if input_partition_dims:
strategy = tf.distribute.get_strategy()
features['image'] = strategy.experimental_split_to_logical_devices(
features['image'], input_partition_dims)
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
...@@ -314,6 +319,11 @@ class VideoClassificationTask(base_task.Task): ...@@ -314,6 +319,11 @@ class VideoClassificationTask(base_task.Task):
A dictionary of logs. A dictionary of logs.
""" """
features, labels = inputs features, labels = inputs
input_partition_dims = self.task_config.eval_input_partition_dims
if input_partition_dims:
strategy = tf.distribute.get_strategy()
features['image'] = strategy.experimental_split_to_logical_devices(
features['image'], input_partition_dims)
outputs = self.inference_step(features, model) outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment