"docs/vscode:/vscode.git/clone" did not exist on "a64c6749b53fbd423e73ff42bf82bdecf6e0a219"
Commit 1fd7aaaf authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Makes the label field name **configurable** from tf.SequenceExample.

PiperOrigin-RevId: 446008537
parent 43d232e5
...@@ -45,6 +45,7 @@ class DataConfig(cfg.DataConfig): ...@@ -45,6 +45,7 @@ class DataConfig(cfg.DataConfig):
feature_sources: if the feature from 'context' or 'features'. feature_sources: if the feature from 'context' or 'features'.
feature_dtypes: dtype of decoded feature. feature_dtypes: dtype of decoded feature.
feature_from_bytes: decode feature from bytes or as dtype list. feature_from_bytes: decode feature from bytes or as dtype list.
label_fields: name of field to read from tf.SequenceExample.
segment_size: Number of frames in each segment. segment_size: Number of frames in each segment.
segment_labels: Use segment level label. Default: False, video level label. segment_labels: Use segment level label. Default: False, video level label.
include_video_id: `True` means include video id (string) in the input to include_video_id: `True` means include video id (string) in the input to
...@@ -70,6 +71,7 @@ class DataConfig(cfg.DataConfig): ...@@ -70,6 +71,7 @@ class DataConfig(cfg.DataConfig):
feature_sources: Tuple[str, ...] = ('feature', 'feature') feature_sources: Tuple[str, ...] = ('feature', 'feature')
feature_dtypes: Tuple[str, ...] = ('uint8', 'uint8') feature_dtypes: Tuple[str, ...] = ('uint8', 'uint8')
feature_from_bytes: Tuple[bool, ...] = (True, True) feature_from_bytes: Tuple[bool, ...] = (True, True)
label_field: str = 'labels'
segment_size: int = 1 segment_size: int = 1
segment_labels: bool = False segment_labels: bool = False
include_video_id: bool = False include_video_id: bool = False
......
...@@ -248,7 +248,9 @@ def MakeExampleWithFloatFeatures( ...@@ -248,7 +248,9 @@ def MakeExampleWithFloatFeatures(
seq_example = tf.train.SequenceExample() seq_example = tf.train.SequenceExample()
seq_example.context.feature["id"].bytes_list.value[:] = [b"id001"] seq_example.context.feature["id"].bytes_list.value[:] = [b"id001"]
seq_example.context.feature["labels"].int64_list.value[:] = [1, 2, 3, 4] seq_example.context.feature["clip/label/index"].int64_list.value[:] = [
1, 2, 3, 4
]
seq_example.context.feature["segment_labels"].int64_list.value[:] = ( seq_example.context.feature["segment_labels"].int64_list.value[:] = (
[4] * num_segment) [4] * num_segment)
seq_example.context.feature["segment_start_times"].int64_list.value[:] = [ seq_example.context.feature["segment_start_times"].int64_list.value[:] = [
......
...@@ -251,6 +251,7 @@ class Decoder(decoder.Decoder): ...@@ -251,6 +251,7 @@ class Decoder(decoder.Decoder):
self._feature_dtypes = input_params.feature_dtypes self._feature_dtypes = input_params.feature_dtypes
self._feature_from_bytes = input_params.feature_from_bytes self._feature_from_bytes = input_params.feature_from_bytes
self._include_video_id = input_params.include_video_id self._include_video_id = input_params.include_video_id
self._label_field = input_params.label_field
assert len(self._feature_names) == len(self._feature_sources), ( assert len(self._feature_names) == len(self._feature_sources), (
"length of feature_names (={}) != length of feature_sizes (={})".format( "length of feature_names (={}) != length of feature_sizes (={})".format(
...@@ -270,7 +271,8 @@ class Decoder(decoder.Decoder): ...@@ -270,7 +271,8 @@ class Decoder(decoder.Decoder):
"segment_scores": tf.io.VarLenFeature(tf.float32) "segment_scores": tf.io.VarLenFeature(tf.float32)
}) })
else: else:
self._context_features.update({"labels": tf.io.VarLenFeature(tf.int64)}) self._context_features.update(
{self._label_field: tf.io.VarLenFeature(tf.int64)})
for i, name in enumerate(self._feature_names): for i, name in enumerate(self._feature_names):
if self._feature_from_bytes[i]: if self._feature_from_bytes[i]:
...@@ -308,6 +310,8 @@ class Decoder(decoder.Decoder): ...@@ -308,6 +310,8 @@ class Decoder(decoder.Decoder):
else: else:
if isinstance(decoded_tensor[name], tf.SparseTensor): if isinstance(decoded_tensor[name], tf.SparseTensor):
decoded_tensor[name] = tf.sparse.to_dense(decoded_tensor[name]) decoded_tensor[name] = tf.sparse.to_dense(decoded_tensor[name])
if not self._segment_labels:
decoded_tensor["labels"] = decoded_tensor[self._label_field]
return decoded_tensor return decoded_tensor
......
...@@ -16,6 +16,7 @@ import os ...@@ -16,6 +16,7 @@ import os
from absl import logging from absl import logging
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np
import tensorflow as tf import tensorflow as tf
from official.core import input_reader from official.core import input_reader
...@@ -161,17 +162,25 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase): ...@@ -161,17 +162,25 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase):
else: else:
self.assertCountEqual(['video_matrix', 'labels', 'num_frames'], self.assertCountEqual(['video_matrix', 'labels', 'num_frames'],
example.keys()) example.keys())
batch_size = params.global_batch_size
# Check tensor values.
expected_context = examples[0].context.feature[ expected_context = examples[0].context.feature[
'VIDEO_EMBEDDING/context_feature/floats'].float_list.value 'VIDEO_EMBEDDING/context_feature/floats'].float_list.value
expected_feature = examples[0].feature_lists.feature_list[ expected_feature = examples[0].feature_lists.feature_list[
'FEATURE/feature/floats'].feature[0].float_list.value 'FEATURE/feature/floats'].feature[0].float_list.value
expected_labels = examples[0].context.feature[
params.label_field].int64_list.value
self.assertAllEqual( self.assertAllEqual(
expected_feature, expected_feature,
example['video_matrix'][0, 0, params.feature_sizes[0]:]) example['video_matrix'][0, 0, params.feature_sizes[0]:])
self.assertAllEqual( self.assertAllEqual(
expected_context, expected_context,
example['video_matrix'][0, 0, :params.feature_sizes[0]]) example['video_matrix'][0, 0, :params.feature_sizes[0]])
self.assertAllEqual(
np.nonzero(example['labels'][0, :].numpy())[0], expected_labels)
# Check tensor shape.
batch_size = params.global_batch_size
self.assertEqual( self.assertEqual(
example['video_matrix'].shape.as_list(), example['video_matrix'].shape.as_list(),
[batch_size, params.max_frames, sum(params.feature_sizes)]) [batch_size, params.max_frames, sum(params.feature_sizes)])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment