Commit d4f9d872 authored by Chaochao Yan's avatar Chaochao Yan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 476987858
parent eb5d69ae
......@@ -178,7 +178,7 @@ def make_yt8m_example(num_segment: int = 5) -> tf.train.SequenceExample:
i * 5 for i in range(num_segment)
]
seq_example.context.feature["segment_scores"].float_list.value[:] = (
[0.] * num_segment)
[0.5] * num_segment)
tfexample_utils.put_bytes_list_to_feature(
seq_example, rgb.tobytes(), key="rgb", repeat_num=120)
tfexample_utils.put_bytes_list_to_feature(
......
......@@ -295,12 +295,10 @@ class Decoder(decoder.Decoder):
def decode(self,
serialized_example: tf.train.SequenceExample) -> Dict[str, Any]:
"""Parses a single tf.train.SequenceExample into video and label tensors."""
contexts, features = tf.io.parse_single_sequence_example(
serialized_example,
context_features=self._context_features,
sequence_features=self._sequence_features)
decoded_tensor = {**contexts, **features}
for i, name in enumerate(self._feature_names):
# Convert the VarLen feature to dense tensor.
......@@ -330,6 +328,7 @@ class Parser(parser.Parser):
min_quantized_value=-2,
):
self._num_classes = input_params.num_classes
self._label_field = input_params.label_field
self._segment_size = input_params.segment_size
self._segment_labels = input_params.segment_labels
self._include_video_id = input_params.include_video_id
......@@ -377,6 +376,8 @@ class Parser(parser.Parser):
Returns:
output: dictionary containing batch information
"""
if self._label_field and not self._segment_labels:
contexts["labels"] = contexts[self._label_field]
output_dict = _process_segment_and_label(video_matrix, num_frames, contexts,
self._segment_labels,
self._segment_size,
......
......@@ -80,11 +80,14 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase):
self.assertCountEqual(['video_matrix', 'labels', 'num_frames'],
example.keys())
batch_size = params.global_batch_size
self.assertEqual(
example['video_matrix'].shape.as_list(),
[batch_size, params.max_frames, sum(params.feature_sizes)])
self.assertEqual(example['video_matrix'].shape.as_list(),
[batch_size, params.max_frames,
sum(params.feature_sizes)])
self.assertEqual(example['labels'].shape.as_list(),
[batch_size, params.num_classes])
# Check non empty labels.
self.assertGreater(np.nonzero(example['labels'][0].numpy())[0].shape[0], 0)
self.assertEqual(example['num_frames'].shape.as_list(), [batch_size, 1])
if include_video_id:
self.assertEqual(example['video_ids'].shape.as_list(), [batch_size, 1])
......@@ -115,9 +118,11 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase):
batch_size = params.global_batch_size * self.num_segment
self.assertEqual(
example['video_matrix'].shape.as_list(),
[batch_size, params.segment_size, sum(params.feature_sizes)])
[batch_size, params.segment_size,
sum(params.feature_sizes)])
self.assertEqual(example['labels'].shape.as_list(),
[batch_size, params.num_classes])
self.assertGreater(np.nonzero(example['labels'][0].numpy())[0].shape[0], 0)
self.assertEqual(example['num_frames'].shape.as_list(), [batch_size, 1])
self.assertEqual(example['label_weights'].shape.as_list(),
[batch_size, params.num_classes])
......@@ -147,6 +152,7 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase):
params.feature_dtypes = ('float32', 'float32')
params.feature_sizes = (256, 2048)
params.feature_from_bytes = (False, False)
params.label_field = 'clip/label/index'
params.include_video_id = include_video_id
reader = self.create_input_reader(params)
......@@ -171,20 +177,19 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase):
'FEATURE/feature/floats'].feature[0].float_list.value
expected_labels = examples[0].context.feature[
params.label_field].int64_list.value
self.assertAllEqual(
expected_feature,
example['video_matrix'][0, 0, params.feature_sizes[0]:])
self.assertAllEqual(
expected_context,
example['video_matrix'][0, 0, :params.feature_sizes[0]])
self.assertAllEqual(expected_feature,
example['video_matrix'][0, 0, params.feature_sizes[0]:])
self.assertAllEqual(expected_context,
example['video_matrix'][0, 0, :params.feature_sizes[0]])
self.assertAllEqual(
np.nonzero(example['labels'][0, :].numpy())[0], expected_labels)
self.assertGreater(np.nonzero(example['labels'][0].numpy())[0].shape[0], 0)
# Check tensor shape.
batch_size = params.global_batch_size
self.assertEqual(
example['video_matrix'].shape.as_list(),
[batch_size, params.max_frames, sum(params.feature_sizes)])
self.assertEqual(example['video_matrix'].shape.as_list(),
[batch_size, params.max_frames,
sum(params.feature_sizes)])
self.assertEqual(example['labels'].shape.as_list(),
[batch_size, params.num_classes])
self.assertEqual(example['num_frames'].shape.as_list(), [batch_size, 1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment