Commit d4f9d872 authored by Chaochao Yan's avatar Chaochao Yan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 476987858
parent eb5d69ae
...@@ -178,7 +178,7 @@ def make_yt8m_example(num_segment: int = 5) -> tf.train.SequenceExample: ...@@ -178,7 +178,7 @@ def make_yt8m_example(num_segment: int = 5) -> tf.train.SequenceExample:
i * 5 for i in range(num_segment) i * 5 for i in range(num_segment)
] ]
seq_example.context.feature["segment_scores"].float_list.value[:] = ( seq_example.context.feature["segment_scores"].float_list.value[:] = (
[0.] * num_segment) [0.5] * num_segment)
tfexample_utils.put_bytes_list_to_feature( tfexample_utils.put_bytes_list_to_feature(
seq_example, rgb.tobytes(), key="rgb", repeat_num=120) seq_example, rgb.tobytes(), key="rgb", repeat_num=120)
tfexample_utils.put_bytes_list_to_feature( tfexample_utils.put_bytes_list_to_feature(
......
...@@ -295,12 +295,10 @@ class Decoder(decoder.Decoder): ...@@ -295,12 +295,10 @@ class Decoder(decoder.Decoder):
def decode(self, def decode(self,
serialized_example: tf.train.SequenceExample) -> Dict[str, Any]: serialized_example: tf.train.SequenceExample) -> Dict[str, Any]:
"""Parses a single tf.train.SequenceExample into video and label tensors.""" """Parses a single tf.train.SequenceExample into video and label tensors."""
contexts, features = tf.io.parse_single_sequence_example( contexts, features = tf.io.parse_single_sequence_example(
serialized_example, serialized_example,
context_features=self._context_features, context_features=self._context_features,
sequence_features=self._sequence_features) sequence_features=self._sequence_features)
decoded_tensor = {**contexts, **features} decoded_tensor = {**contexts, **features}
for i, name in enumerate(self._feature_names): for i, name in enumerate(self._feature_names):
# Convert the VarLen feature to dense tensor. # Convert the VarLen feature to dense tensor.
...@@ -330,6 +328,7 @@ class Parser(parser.Parser): ...@@ -330,6 +328,7 @@ class Parser(parser.Parser):
min_quantized_value=-2, min_quantized_value=-2,
): ):
self._num_classes = input_params.num_classes self._num_classes = input_params.num_classes
self._label_field = input_params.label_field
self._segment_size = input_params.segment_size self._segment_size = input_params.segment_size
self._segment_labels = input_params.segment_labels self._segment_labels = input_params.segment_labels
self._include_video_id = input_params.include_video_id self._include_video_id = input_params.include_video_id
...@@ -377,6 +376,8 @@ class Parser(parser.Parser): ...@@ -377,6 +376,8 @@ class Parser(parser.Parser):
Returns: Returns:
output: dictionary containing batch information output: dictionary containing batch information
""" """
if self._label_field and not self._segment_labels:
contexts["labels"] = contexts[self._label_field]
output_dict = _process_segment_and_label(video_matrix, num_frames, contexts, output_dict = _process_segment_and_label(video_matrix, num_frames, contexts,
self._segment_labels, self._segment_labels,
self._segment_size, self._segment_size,
......
...@@ -80,11 +80,14 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase): ...@@ -80,11 +80,14 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase):
self.assertCountEqual(['video_matrix', 'labels', 'num_frames'], self.assertCountEqual(['video_matrix', 'labels', 'num_frames'],
example.keys()) example.keys())
batch_size = params.global_batch_size batch_size = params.global_batch_size
self.assertEqual( self.assertEqual(example['video_matrix'].shape.as_list(),
example['video_matrix'].shape.as_list(), [batch_size, params.max_frames,
[batch_size, params.max_frames, sum(params.feature_sizes)]) sum(params.feature_sizes)])
self.assertEqual(example['labels'].shape.as_list(), self.assertEqual(example['labels'].shape.as_list(),
[batch_size, params.num_classes]) [batch_size, params.num_classes])
# Check non empty labels.
self.assertGreater(np.nonzero(example['labels'][0].numpy())[0].shape[0], 0)
self.assertEqual(example['num_frames'].shape.as_list(), [batch_size, 1]) self.assertEqual(example['num_frames'].shape.as_list(), [batch_size, 1])
if include_video_id: if include_video_id:
self.assertEqual(example['video_ids'].shape.as_list(), [batch_size, 1]) self.assertEqual(example['video_ids'].shape.as_list(), [batch_size, 1])
...@@ -115,9 +118,11 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase): ...@@ -115,9 +118,11 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase):
batch_size = params.global_batch_size * self.num_segment batch_size = params.global_batch_size * self.num_segment
self.assertEqual( self.assertEqual(
example['video_matrix'].shape.as_list(), example['video_matrix'].shape.as_list(),
[batch_size, params.segment_size, sum(params.feature_sizes)]) [batch_size, params.segment_size,
sum(params.feature_sizes)])
self.assertEqual(example['labels'].shape.as_list(), self.assertEqual(example['labels'].shape.as_list(),
[batch_size, params.num_classes]) [batch_size, params.num_classes])
self.assertGreater(np.nonzero(example['labels'][0].numpy())[0].shape[0], 0)
self.assertEqual(example['num_frames'].shape.as_list(), [batch_size, 1]) self.assertEqual(example['num_frames'].shape.as_list(), [batch_size, 1])
self.assertEqual(example['label_weights'].shape.as_list(), self.assertEqual(example['label_weights'].shape.as_list(),
[batch_size, params.num_classes]) [batch_size, params.num_classes])
...@@ -147,6 +152,7 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase): ...@@ -147,6 +152,7 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase):
params.feature_dtypes = ('float32', 'float32') params.feature_dtypes = ('float32', 'float32')
params.feature_sizes = (256, 2048) params.feature_sizes = (256, 2048)
params.feature_from_bytes = (False, False) params.feature_from_bytes = (False, False)
params.label_field = 'clip/label/index'
params.include_video_id = include_video_id params.include_video_id = include_video_id
reader = self.create_input_reader(params) reader = self.create_input_reader(params)
...@@ -171,20 +177,19 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase): ...@@ -171,20 +177,19 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase):
'FEATURE/feature/floats'].feature[0].float_list.value 'FEATURE/feature/floats'].feature[0].float_list.value
expected_labels = examples[0].context.feature[ expected_labels = examples[0].context.feature[
params.label_field].int64_list.value params.label_field].int64_list.value
self.assertAllEqual( self.assertAllEqual(expected_feature,
expected_feature,
example['video_matrix'][0, 0, params.feature_sizes[0]:]) example['video_matrix'][0, 0, params.feature_sizes[0]:])
self.assertAllEqual( self.assertAllEqual(expected_context,
expected_context,
example['video_matrix'][0, 0, :params.feature_sizes[0]]) example['video_matrix'][0, 0, :params.feature_sizes[0]])
self.assertAllEqual( self.assertAllEqual(
np.nonzero(example['labels'][0, :].numpy())[0], expected_labels) np.nonzero(example['labels'][0, :].numpy())[0], expected_labels)
self.assertGreater(np.nonzero(example['labels'][0].numpy())[0].shape[0], 0)
# Check tensor shape. # Check tensor shape.
batch_size = params.global_batch_size batch_size = params.global_batch_size
self.assertEqual( self.assertEqual(example['video_matrix'].shape.as_list(),
example['video_matrix'].shape.as_list(), [batch_size, params.max_frames,
[batch_size, params.max_frames, sum(params.feature_sizes)]) sum(params.feature_sizes)])
self.assertEqual(example['labels'].shape.as_list(), self.assertEqual(example['labels'].shape.as_list(),
[batch_size, params.num_classes]) [batch_size, params.num_classes])
self.assertEqual(example['num_frames'].shape.as_list(), [batch_size, 1]) self.assertEqual(example['num_frames'].shape.as_list(), [batch_size, 1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment