Commit 1fd7aaaf authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Makes the label field name **configurable** from tf.SequenceExample.

PiperOrigin-RevId: 446008537
parent 43d232e5
......@@ -45,6 +45,7 @@ class DataConfig(cfg.DataConfig):
feature_sources: if the feature from 'context' or 'features'.
feature_dtypes: dtype of decoded feature.
feature_from_bytes: decode feature from bytes or as dtype list.
label_fields: name of field to read from tf.SequenceExample.
segment_size: Number of frames in each segment.
segment_labels: Use segment level label. Default: False, video level label.
include_video_id: `True` means include video id (string) in the input to
......@@ -70,6 +71,7 @@ class DataConfig(cfg.DataConfig):
feature_sources: Tuple[str, ...] = ('feature', 'feature')
feature_dtypes: Tuple[str, ...] = ('uint8', 'uint8')
feature_from_bytes: Tuple[bool, ...] = (True, True)
label_field: str = 'labels'
segment_size: int = 1
segment_labels: bool = False
include_video_id: bool = False
......
......@@ -248,7 +248,9 @@ def MakeExampleWithFloatFeatures(
seq_example = tf.train.SequenceExample()
seq_example.context.feature["id"].bytes_list.value[:] = [b"id001"]
seq_example.context.feature["labels"].int64_list.value[:] = [1, 2, 3, 4]
seq_example.context.feature["clip/label/index"].int64_list.value[:] = [
1, 2, 3, 4
]
seq_example.context.feature["segment_labels"].int64_list.value[:] = (
[4] * num_segment)
seq_example.context.feature["segment_start_times"].int64_list.value[:] = [
......
......@@ -251,6 +251,7 @@ class Decoder(decoder.Decoder):
self._feature_dtypes = input_params.feature_dtypes
self._feature_from_bytes = input_params.feature_from_bytes
self._include_video_id = input_params.include_video_id
self._label_field = input_params.label_field
assert len(self._feature_names) == len(self._feature_sources), (
"length of feature_names (={}) != length of feature_sizes (={})".format(
......@@ -270,7 +271,8 @@ class Decoder(decoder.Decoder):
"segment_scores": tf.io.VarLenFeature(tf.float32)
})
else:
self._context_features.update({"labels": tf.io.VarLenFeature(tf.int64)})
self._context_features.update(
{self._label_field: tf.io.VarLenFeature(tf.int64)})
for i, name in enumerate(self._feature_names):
if self._feature_from_bytes[i]:
......@@ -308,6 +310,8 @@ class Decoder(decoder.Decoder):
else:
if isinstance(decoded_tensor[name], tf.SparseTensor):
decoded_tensor[name] = tf.sparse.to_dense(decoded_tensor[name])
if not self._segment_labels:
decoded_tensor["labels"] = decoded_tensor[self._label_field]
return decoded_tensor
......
......@@ -16,6 +16,7 @@ import os
from absl import logging
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.core import input_reader
......@@ -161,17 +162,25 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase):
else:
self.assertCountEqual(['video_matrix', 'labels', 'num_frames'],
example.keys())
batch_size = params.global_batch_size
# Check tensor values.
expected_context = examples[0].context.feature[
'VIDEO_EMBEDDING/context_feature/floats'].float_list.value
expected_feature = examples[0].feature_lists.feature_list[
'FEATURE/feature/floats'].feature[0].float_list.value
expected_labels = examples[0].context.feature[
params.label_field].int64_list.value
self.assertAllEqual(
expected_feature,
example['video_matrix'][0, 0, params.feature_sizes[0]:])
self.assertAllEqual(
expected_context,
example['video_matrix'][0, 0, :params.feature_sizes[0]])
self.assertAllEqual(
np.nonzero(example['labels'][0, :].numpy())[0], expected_labels)
# Check tensor shape.
batch_size = params.global_batch_size
self.assertEqual(
example['video_matrix'].shape.as_list(),
[batch_size, params.max_frames, sum(params.feature_sizes)])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment