Commit 1407e274 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Small changes in video classification task.

PiperOrigin-RevId: 335548133
parent 75c6e3bc
...@@ -192,6 +192,7 @@ class Parser(parser.Parser): ...@@ -192,6 +192,7 @@ class Parser(parser.Parser):
self._num_classes = input_params.num_classes self._num_classes = input_params.num_classes
self._image_key = image_key self._image_key = image_key
self._label_key = label_key self._label_key = label_key
self._dtype = tf.dtypes.as_dtype(input_params.dtype)
def _parse_train_data( def _parse_train_data(
self, decoded_tensors: Dict[str, tf.Tensor] self, decoded_tensors: Dict[str, tf.Tensor]
...@@ -208,6 +209,7 @@ class Parser(parser.Parser): ...@@ -208,6 +209,7 @@ class Parser(parser.Parser):
num_test_clips=self._num_test_clips, num_test_clips=self._num_test_clips,
min_resize=self._min_resize, min_resize=self._min_resize,
crop_size=self._crop_size) crop_size=self._crop_size)
image = tf.cast(image, dtype=self._dtype)
label = _process_label(label, self._one_hot_label, self._num_classes) label = _process_label(label, self._one_hot_label, self._num_classes)
return {'image': image}, label return {'image': image}, label
...@@ -226,6 +228,7 @@ class Parser(parser.Parser): ...@@ -226,6 +228,7 @@ class Parser(parser.Parser):
num_test_clips=self._num_test_clips, num_test_clips=self._num_test_clips,
min_resize=self._min_resize, min_resize=self._min_resize,
crop_size=self._crop_size) crop_size=self._crop_size)
image = tf.cast(image, dtype=self._dtype)
label = _process_label(label, self._one_hot_label, self._num_classes) label = _process_label(label, self._one_hot_label, self._num_classes)
return {'image': image}, label return {'image': image}, label
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Video classification task definition.""" """Video classification task definition."""
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.core import base_task from official.core import base_task
from official.core import input_reader from official.core import input_reader
...@@ -30,7 +31,13 @@ class VideoClassificationTask(base_task.Task): ...@@ -30,7 +31,13 @@ class VideoClassificationTask(base_task.Task):
def build_model(self): def build_model(self):
"""Builds video classification model.""" """Builds video classification model."""
input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, None, 3]) common_input_shape = [
d1 if d1 == d2 else None
for d1, d2 in zip(self.task_config.train_data.feature_shape,
self.task_config.validation_data.feature_shape)
]
input_specs = tf.keras.layers.InputSpec(shape=[None] + common_input_shape)
logging.info('Build model input %r', common_input_shape)
l2_weight_decay = self.task_config.losses.l2_weight_decay l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss. # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment