Commit 5d38628c authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 457890217
parent cfcbb6cb
...@@ -88,6 +88,7 @@ class DataConfig(cfg.DataConfig): ...@@ -88,6 +88,7 @@ class DataConfig(cfg.DataConfig):
def yt8m(is_training): def yt8m(is_training):
"""YT8M dataset configs.""" """YT8M dataset configs."""
# pylint: disable=unexpected-keyword-arg
return DataConfig( return DataConfig(
num_frames=30, num_frames=30,
temporal_stride=1, temporal_stride=1,
...@@ -95,8 +96,10 @@ def yt8m(is_training): ...@@ -95,8 +96,10 @@ def yt8m(is_training):
segment_size=5, segment_size=5,
is_training=is_training, is_training=is_training,
split='train' if is_training else 'valid', split='train' if is_training else 'valid',
drop_remainder=is_training, # pytype: disable=wrong-keyword-args
num_examples=YT8M_TRAIN_EXAMPLES if is_training else YT8M_VAL_EXAMPLES, num_examples=YT8M_TRAIN_EXAMPLES if is_training else YT8M_VAL_EXAMPLES,
input_path=YT8M_TRAIN_PATH if is_training else YT8M_VAL_PATH) input_path=YT8M_TRAIN_PATH if is_training else YT8M_VAL_PATH)
# pylint: enable=unexpected-keyword-arg
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
back into a range between min_quantized_value and max_quantized_value. back into a range between min_quantized_value and max_quantized_value.
link for details: https://research.google.com/youtube8m/download.html link for details: https://research.google.com/youtube8m/download.html
""" """
from typing import Dict from typing import Dict
import tensorflow as tf import tensorflow as tf
...@@ -424,6 +423,7 @@ class PostBatchProcessor(): ...@@ -424,6 +423,7 @@ class PostBatchProcessor():
[-1, self.num_classes]) [-1, self.num_classes])
else: else:
# NOTE(b/237445211): Must provide axis argument to tf.squeeze.
video_matrix = tf.squeeze(video_matrix, axis=1) video_matrix = tf.squeeze(video_matrix, axis=1)
labels = tf.squeeze(labels, axis=1) labels = tf.squeeze(labels, axis=1)
...@@ -449,13 +449,15 @@ class TransformBatcher(): ...@@ -449,13 +449,15 @@ class TransformBatcher():
self._global_batch_size = input_params.global_batch_size self._global_batch_size = input_params.global_batch_size
self._is_training = input_params.is_training self._is_training = input_params.is_training
self._include_video_id = input_params.include_video_id self._include_video_id = input_params.include_video_id
self._drop_remainder = input_params.drop_remainder
def batch_fn(self, dataset, input_context): def batch_fn(self, dataset, input_context):
"""Add padding when segment_labels is true.""" """Add padding when segment_labels is true."""
per_replica_batch_size = input_context.get_per_replica_batch_size( per_replica_batch_size = input_context.get_per_replica_batch_size(
self._global_batch_size) if input_context else self._global_batch_size self._global_batch_size) if input_context else self._global_batch_size
if not self._segment_labels: if not self._segment_labels:
dataset = dataset.batch(per_replica_batch_size, drop_remainder=True) dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._drop_remainder)
else: else:
# add padding # add padding
pad_shapes = { pad_shapes = {
...@@ -476,6 +478,6 @@ class TransformBatcher(): ...@@ -476,6 +478,6 @@ class TransformBatcher():
dataset = dataset.padded_batch( dataset = dataset.padded_batch(
per_replica_batch_size, per_replica_batch_size,
padded_shapes=pad_shapes, padded_shapes=pad_shapes,
drop_remainder=True, drop_remainder=self._drop_remainder,
padding_values=pad_values) padding_values=pad_values)
return dataset return dataset
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment