Commit f58534eb authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Makes video id optional for input reader.

Adds unit tests for the dataloader.
Adds segment level training pipeline test.

PiperOrigin-RevId: 445277327
parent 2f9266ac
...@@ -35,13 +35,37 @@ YT8M_VAL_PATH = 'gs://youtube8m-ml/3/frame/validate/validate*.tfrecord' ...@@ -35,13 +35,37 @@ YT8M_VAL_PATH = 'gs://youtube8m-ml/3/frame/validate/validate*.tfrecord'
@dataclasses.dataclass @dataclasses.dataclass
class DataConfig(cfg.DataConfig): class DataConfig(cfg.DataConfig):
"""The base configuration for building datasets.""" """The base configuration for building datasets.
Attributes:
name: Dataset name.
split: dataset split, 'train' or 'valid'.
feature_sizes: shape(length) of each feature specified in the feature_names.
feature_names: names of the features in the tf.SequenceExample.
segment_size: Number of frames in each segment.
segment_labels: Use segment level label. Default: False, video level label.
include_video_id: `True` means include video id (string) in the input to
the model.
temporal_stride: Not used. Need to deprecated.
max_frames: Maxim Number of frames in a input example. It is used to crop
the input in the temporal dimension.
num_frames: Number of frames in a single input example.
num_classes: Number of classes to classify. Assuming it is a classification
task.
num_devices: Not used. To be deprecated.
input_path: The path to the input.
is_training: Whether this data is used for training or not.
num_examples: Number of examples in the dataset. It is used to compute the
steps for train or eval. set the value to `-1` to make the experiment run
until the end of dataset.
"""
name: Optional[str] = 'yt8m' name: Optional[str] = 'yt8m'
split: Optional[str] = None split: Optional[str] = None
feature_sizes: Tuple[int, ...] = (1024, 128) feature_sizes: Tuple[int, ...] = (1024, 128)
feature_names: Tuple[str, ...] = ('rgb', 'audio') feature_names: Tuple[str, ...] = ('rgb', 'audio')
segment_size: int = 1 segment_size: int = 1
segment_labels: bool = False segment_labels: bool = False
include_video_id: bool = False
temporal_stride: int = 1 temporal_stride: int = 1
max_frames: int = 300 max_frames: int = 300
num_frames: int = 300 # set smaller to allow random sample (Parser) num_frames: int = 300 # set smaller to allow random sample (Parser)
...@@ -49,7 +73,6 @@ class DataConfig(cfg.DataConfig): ...@@ -49,7 +73,6 @@ class DataConfig(cfg.DataConfig):
num_devices: int = 1 num_devices: int = 1
input_path: str = '' input_path: str = ''
is_training: bool = True is_training: bool = True
random_seed: int = 123
num_examples: int = -1 num_examples: int = -1
......
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
"""Contains a collection of util functions for training and evaluating.""" """Contains a collection of util functions for training and evaluating."""
from absl import logging from absl import logging
import numpy import numpy as np
import tensorflow as tf import tensorflow as tf
from official.vision.dataloaders import tfexample_utils
def Dequantize(feat_vector, max_quantized_value=2, min_quantized_value=-2): def Dequantize(feat_vector, max_quantized_value=2, min_quantized_value=-2):
...@@ -113,7 +114,7 @@ def AddEpochSummary(summary_writer, ...@@ -113,7 +114,7 @@ def AddEpochSummary(summary_writer,
avg_loss = epoch_info_dict["avg_loss"] avg_loss = epoch_info_dict["avg_loss"]
aps = epoch_info_dict["aps"] aps = epoch_info_dict["aps"]
gap = epoch_info_dict["gap"] gap = epoch_info_dict["gap"]
mean_ap = numpy.mean(aps) mean_ap = np.mean(aps)
summary_writer.add_summary( summary_writer.add_summary(
MakeSummary("Epoch/" + summary_scope + "_Avg_Hit@1", avg_hit_at_one), MakeSummary("Epoch/" + summary_scope + "_Avg_Hit@1", avg_hit_at_one),
...@@ -213,3 +214,26 @@ def CombineGradients(tower_grads): ...@@ -213,3 +214,26 @@ def CombineGradients(tower_grads):
)) ))
return final_grads return final_grads
def MakeYt8mExample(num_segment: int = 5) -> tf.train.SequenceExample:
"""Generate fake data for unit tests."""
rgb = np.random.randint(low=256, size=1024, dtype=np.uint8)
audio = np.random.randint(low=256, size=128, dtype=np.uint8)
seq_example = tf.train.SequenceExample()
seq_example.context.feature["id"].bytes_list.value[:] = [b"id001"]
seq_example.context.feature["labels"].int64_list.value[:] = [1, 2, 3, 4]
seq_example.context.feature["segment_labels"].int64_list.value[:] = (
[4] * num_segment)
seq_example.context.feature["segment_start_times"].int64_list.value[:] = [
i * 5 for i in range(num_segment)
]
seq_example.context.feature["segment_scores"].float_list.value[:] = (
[0.] * num_segment)
tfexample_utils.put_bytes_list_to_feature(
seq_example, rgb.tobytes(), key="rgb", repeat_num=120)
tfexample_utils.put_bytes_list_to_feature(
seq_example, audio.tobytes(), key="audio", repeat_num=120)
return seq_example
...@@ -81,13 +81,14 @@ def _process_segment_and_label(video_matrix, num_frames, contexts, ...@@ -81,13 +81,14 @@ def _process_segment_and_label(video_matrix, num_frames, contexts,
num_frames: Number of frames per subclip. num_frames: Number of frames per subclip.
contexts: context information extracted from decoder contexts: context information extracted from decoder
segment_labels: if we read segment labels instead. segment_labels: if we read segment labels instead.
segment_size: the segment_size used for reading segments. segment_size: the segment_size used for reading segments. Segment length.
num_classes: a positive integer for the number of classes. num_classes: a positive integer for the number of classes.
Returns: Returns:
output: dictionary containing batch information output: dictionary containing batch information
""" """
# Partition frame-level feature matrix to segment-level feature matrix. # Partition frame-level feature matrix to segment-level feature matrix.
batch_video_ids = None
if segment_labels: if segment_labels:
start_times = contexts["segment_start_times"].values start_times = contexts["segment_start_times"].values
# Here we assume all the segments that started at the same start time has # Here we assume all the segments that started at the same start time has
...@@ -101,8 +102,9 @@ def _process_segment_and_label(video_matrix, num_frames, contexts, ...@@ -101,8 +102,9 @@ def _process_segment_and_label(video_matrix, num_frames, contexts,
batch_video_matrix = tf.gather_nd(video_matrix, batch_video_matrix = tf.gather_nd(video_matrix,
tf.expand_dims(range_mtx, axis=-1)) tf.expand_dims(range_mtx, axis=-1))
num_segment = tf.shape(batch_video_matrix)[0] num_segment = tf.shape(batch_video_matrix)[0]
batch_video_ids = tf.reshape( if "id" in contexts:
tf.tile([contexts["id"]], [num_segment]), (num_segment,)) batch_video_ids = tf.reshape(
tf.tile([contexts["id"]], [num_segment]), (num_segment,))
batch_frames = tf.reshape( batch_frames = tf.reshape(
tf.tile([segment_size], [num_segment]), (num_segment,)) tf.tile([segment_size], [num_segment]), (num_segment,))
batch_frames = tf.cast(tf.expand_dims(batch_frames, 1), tf.float32) batch_frames = tf.cast(tf.expand_dims(batch_frames, 1), tf.float32)
...@@ -134,18 +136,20 @@ def _process_segment_and_label(video_matrix, num_frames, contexts, ...@@ -134,18 +136,20 @@ def _process_segment_and_label(video_matrix, num_frames, contexts,
sparse_labels, default_value=False, validate_indices=False) sparse_labels, default_value=False, validate_indices=False)
# convert to batch format. # convert to batch format.
batch_video_ids = tf.expand_dims(contexts["id"], 0) if "id" in contexts:
batch_video_ids = tf.expand_dims(contexts["id"], 0)
batch_video_matrix = tf.expand_dims(video_matrix, 0) batch_video_matrix = tf.expand_dims(video_matrix, 0)
batch_labels = tf.expand_dims(labels, 0) batch_labels = tf.expand_dims(labels, 0)
batch_frames = tf.expand_dims(num_frames, 0) batch_frames = tf.expand_dims(num_frames, 0)
batch_label_weights = None batch_label_weights = None
output_dict = { output_dict = {
"video_ids": batch_video_ids,
"video_matrix": batch_video_matrix, "video_matrix": batch_video_matrix,
"labels": batch_labels, "labels": batch_labels,
"num_frames": batch_frames, "num_frames": batch_frames,
} }
if batch_video_ids is not None:
output_dict["video_ids"] = batch_video_ids
if batch_label_weights is not None: if batch_label_weights is not None:
output_dict["label_weights"] = batch_label_weights output_dict["label_weights"] = batch_label_weights
...@@ -280,12 +284,10 @@ class Parser(parser.Parser): ...@@ -280,12 +284,10 @@ class Parser(parser.Parser):
self._num_classes = input_params.num_classes self._num_classes = input_params.num_classes
self._segment_size = input_params.segment_size self._segment_size = input_params.segment_size
self._segment_labels = input_params.segment_labels self._segment_labels = input_params.segment_labels
self._include_video_id = input_params.include_video_id
self._feature_names = input_params.feature_names self._feature_names = input_params.feature_names
self._feature_sizes = input_params.feature_sizes self._feature_sizes = input_params.feature_sizes
self.stride = input_params.temporal_stride
self._max_frames = input_params.max_frames self._max_frames = input_params.max_frames
self._num_frames = input_params.num_frames
self._seed = input_params.random_seed
self._max_quantized_value = max_quantized_value self._max_quantized_value = max_quantized_value
self._min_quantized_value = min_quantized_value self._min_quantized_value = min_quantized_value
...@@ -295,6 +297,8 @@ class Parser(parser.Parser): ...@@ -295,6 +297,8 @@ class Parser(parser.Parser):
self.video_matrix, self.num_frames = _concat_features( self.video_matrix, self.num_frames = _concat_features(
decoded_tensors["features"], self._feature_names, self._feature_sizes, decoded_tensors["features"], self._feature_names, self._feature_sizes,
self._max_frames, self._max_quantized_value, self._min_quantized_value) self._max_frames, self._max_quantized_value, self._min_quantized_value)
if not self._include_video_id:
del decoded_tensors["contexts"]["id"]
output_dict = _process_segment_and_label(self.video_matrix, self.num_frames, output_dict = _process_segment_and_label(self.video_matrix, self.num_frames,
decoded_tensors["contexts"], decoded_tensors["contexts"],
self._segment_labels, self._segment_labels,
...@@ -308,6 +312,8 @@ class Parser(parser.Parser): ...@@ -308,6 +312,8 @@ class Parser(parser.Parser):
self.video_matrix, self.num_frames = _concat_features( self.video_matrix, self.num_frames = _concat_features(
decoded_tensors["features"], self._feature_names, self._feature_sizes, decoded_tensors["features"], self._feature_names, self._feature_sizes,
self._max_frames, self._max_quantized_value, self._min_quantized_value) self._max_frames, self._max_quantized_value, self._min_quantized_value)
if not self._include_video_id:
del decoded_tensors["contexts"]["id"]
output_dict = _process_segment_and_label(self.video_matrix, self.num_frames, output_dict = _process_segment_and_label(self.video_matrix, self.num_frames,
decoded_tensors["contexts"], decoded_tensors["contexts"],
self._segment_labels, self._segment_labels,
...@@ -347,7 +353,7 @@ class PostBatchProcessor(): ...@@ -347,7 +353,7 @@ class PostBatchProcessor():
def post_fn(self, batched_tensors): def post_fn(self, batched_tensors):
"""Processes batched Tensors.""" """Processes batched Tensors."""
video_ids = batched_tensors["video_ids"] video_ids = batched_tensors.get("video_ids", None)
video_matrix = batched_tensors["video_matrix"] video_matrix = batched_tensors["video_matrix"]
labels = batched_tensors["labels"] labels = batched_tensors["labels"]
num_frames = batched_tensors["num_frames"] num_frames = batched_tensors["num_frames"]
...@@ -356,7 +362,8 @@ class PostBatchProcessor(): ...@@ -356,7 +362,8 @@ class PostBatchProcessor():
if self.segment_labels: if self.segment_labels:
# [batch x num_segment x segment_size x num_features] # [batch x num_segment x segment_size x num_features]
# -> [batch * num_segment x segment_size x num_features] # -> [batch * num_segment x segment_size x num_features]
video_ids = tf.reshape(video_ids, [-1]) if video_ids is not None:
video_ids = tf.reshape(video_ids, [-1])
video_matrix = tf.reshape(video_matrix, [-1, self.segment_size, 1152]) video_matrix = tf.reshape(video_matrix, [-1, self.segment_size, 1152])
labels = tf.reshape(labels, [-1, self.num_classes]) labels = tf.reshape(labels, [-1, self.num_classes])
num_frames = tf.reshape(num_frames, [-1, 1]) num_frames = tf.reshape(num_frames, [-1, 1])
...@@ -369,11 +376,12 @@ class PostBatchProcessor(): ...@@ -369,11 +376,12 @@ class PostBatchProcessor():
labels = tf.squeeze(labels) labels = tf.squeeze(labels)
batched_tensors = { batched_tensors = {
"video_ids": video_ids,
"video_matrix": video_matrix, "video_matrix": video_matrix,
"labels": labels, "labels": labels,
"num_frames": num_frames, "num_frames": num_frames,
} }
if video_ids is not None:
batched_tensors["video_ids"] = video_ids
if label_weights is not None: if label_weights is not None:
batched_tensors["label_weights"] = label_weights batched_tensors["label_weights"] = label_weights
...@@ -388,6 +396,7 @@ class TransformBatcher(): ...@@ -388,6 +396,7 @@ class TransformBatcher():
self._segment_labels = input_params.segment_labels self._segment_labels = input_params.segment_labels
self._global_batch_size = input_params.global_batch_size self._global_batch_size = input_params.global_batch_size
self._is_training = input_params.is_training self._is_training = input_params.is_training
self._include_video_id = input_params.include_video_id
def batch_fn(self, dataset, input_context): def batch_fn(self, dataset, input_context):
"""Add padding when segment_labels is true.""" """Add padding when segment_labels is true."""
...@@ -398,19 +407,20 @@ class TransformBatcher(): ...@@ -398,19 +407,20 @@ class TransformBatcher():
else: else:
# add padding # add padding
pad_shapes = { pad_shapes = {
"video_ids": [None],
"video_matrix": [None, None, None], "video_matrix": [None, None, None],
"labels": [None, None], "labels": [None, None],
"num_frames": [None, None], "num_frames": [None, None],
"label_weights": [None, None] "label_weights": [None, None]
} }
pad_values = { pad_values = {
"video_ids": None,
"video_matrix": 0.0, "video_matrix": 0.0,
"labels": -1.0, "labels": -1.0,
"num_frames": 0.0, "num_frames": 0.0,
"label_weights": 0.0 "label_weights": 0.0
} }
if self._include_video_id:
pad_shapes["video_ids"] = [None]
pad_values["video_ids"] = None
dataset = dataset.padded_batch( dataset = dataset.padded_batch(
per_replica_batch_size, per_replica_batch_size,
padded_shapes=pad_shapes, padded_shapes=pad_shapes,
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from absl import logging
from absl.testing import parameterized
import tensorflow as tf
from official.core import input_reader
from official.projects.yt8m.configs import yt8m as yt8m_configs
from official.projects.yt8m.dataloaders import utils
from official.projects.yt8m.dataloaders import yt8m_input
from official.vision.dataloaders import tfexample_utils
class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase):
def setUp(self):
super().setUp()
self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
tf.io.gfile.makedirs(self._model_dir)
data_dir = os.path.join(self.get_temp_dir(), 'data')
tf.io.gfile.makedirs(data_dir)
self.data_path = os.path.join(data_dir, 'data.tfrecord')
self.num_segment = 6
examples = [utils.MakeYt8mExample(self.num_segment) for _ in range(8)]
tfexample_utils.dump_to_tfrecord(self.data_path, tf_examples=examples)
def create_input_reader(self, params):
decoder = yt8m_input.Decoder(input_params=params)
decoder_fn = decoder.decode
parser = yt8m_input.Parser(input_params=params)
parser_fn = parser.parse_fn(params.is_training)
postprocess = yt8m_input.PostBatchProcessor(input_params=params)
postprocess_fn = postprocess.post_fn
transform_batch = yt8m_input.TransformBatcher(input_params=params)
batch_fn = transform_batch.batch_fn
return input_reader.InputReader(
params,
dataset_fn=tf.data.TFRecordDataset,
decoder_fn=decoder_fn,
parser_fn=parser_fn,
postprocess_fn=postprocess_fn,
transform_and_batch_fn=batch_fn)
@parameterized.parameters((True,), (False,))
def test_read_video_level_input(self, include_video_id):
params = yt8m_configs.yt8m(is_training=False)
params.global_batch_size = 4
params.segment_labels = False
params.input_path = self.data_path
params.include_video_id = include_video_id
reader = self.create_input_reader(params)
dataset = reader.read()
iterator = iter(dataset)
example = next(iterator)
for k, v in example.items():
logging.info('DEBUG read example %r %r %r', k, v.shape, type(v))
if include_video_id:
self.assertCountEqual(
['video_matrix', 'labels', 'num_frames', 'video_ids'], example.keys())
else:
self.assertCountEqual(['video_matrix', 'labels', 'num_frames'],
example.keys())
batch_size = params.global_batch_size
self.assertEqual(
example['video_matrix'].shape.as_list(),
[batch_size, params.max_frames, sum(params.feature_sizes)])
self.assertEqual(example['labels'].shape.as_list(),
[batch_size, params.num_classes])
self.assertEqual(example['num_frames'].shape.as_list(), [batch_size, 1])
if include_video_id:
self.assertEqual(example['video_ids'].shape.as_list(), [batch_size, 1])
@parameterized.parameters((True,), (False,))
def test_read_segement_level_input(self, include_video_id):
params = yt8m_configs.yt8m(is_training=False)
params.global_batch_size = 4
params.segment_labels = True
params.input_path = self.data_path
params.include_video_id = include_video_id
reader = self.create_input_reader(params)
dataset = reader.read()
iterator = iter(dataset)
example = next(iterator)
for k, v in example.items():
logging.info('DEBUG read example %r %r %r', k, v.shape, type(v))
if include_video_id:
self.assertCountEqual([
'video_matrix', 'labels', 'num_frames', 'label_weights', 'video_ids'
], example.keys())
else:
self.assertCountEqual(
['video_matrix', 'labels', 'num_frames', 'label_weights'],
example.keys())
batch_size = params.global_batch_size * self.num_segment
self.assertEqual(
example['video_matrix'].shape.as_list(),
[batch_size, params.segment_size, sum(params.feature_sizes)])
self.assertEqual(example['labels'].shape.as_list(),
[batch_size, params.num_classes])
self.assertEqual(example['num_frames'].shape.as_list(), [batch_size, 1])
self.assertEqual(example['label_weights'].shape.as_list(),
[batch_size, params.num_classes])
if include_video_id:
self.assertEqual(example['video_ids'].shape.as_list(), [batch_size])
if __name__ == '__main__':
tf.test.main()
...@@ -27,7 +27,6 @@ task: ...@@ -27,7 +27,6 @@ task:
num_devices: 1 num_devices: 1
input_path: 'gs://youtube8m-ml/2/frame/train/train*.tfrecord' input_path: 'gs://youtube8m-ml/2/frame/train/train*.tfrecord'
is_training: true is_training: true
random_seed: 123
validation_data: validation_data:
name: 'yt8m' name: 'yt8m'
split: 'train' split: 'train'
...@@ -46,7 +45,6 @@ task: ...@@ -46,7 +45,6 @@ task:
num_devices: 1 num_devices: 1
input_path: 'gs://youtube8m-ml/3/frame/validate/validate*.tfrecord' input_path: 'gs://youtube8m-ml/3/frame/validate/validate*.tfrecord'
is_training: false is_training: false
random_seed: 123
losses: losses:
name: 'binary_crossentropy' name: 'binary_crossentropy'
from_logits: false from_logits: false
......
...@@ -38,9 +38,10 @@ class DbofModel(tf.keras.Model): ...@@ -38,9 +38,10 @@ class DbofModel(tf.keras.Model):
def __init__( def __init__(
self, self,
params: yt8m_cfg.DbofModel, params: yt8m_cfg.DbofModel,
num_frames=30, num_frames: int = 30,
num_classes=3862, num_classes: int = 3862,
input_specs=layers.InputSpec(shape=[None, None, 1152]), input_specs: layers.InputSpec = layers.InputSpec(
shape=[None, None, 1152]),
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
activation: str = "relu", activation: str = "relu",
use_sync_bn: bool = False, use_sync_bn: bool = False,
......
...@@ -12,49 +12,37 @@ ...@@ -12,49 +12,37 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
import os import os
from absl import flags from absl import flags
from absl.testing import flagsaver from absl.testing import flagsaver
import numpy as np from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official.projects.yt8m import train as train_lib from official.projects.yt8m import train as train_lib
from official.projects.yt8m.dataloaders import utils
from official.vision.dataloaders import tfexample_utils from official.vision.dataloaders import tfexample_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def make_yt8m_example(): class TrainTest(parameterized.TestCase, tf.test.TestCase):
rgb = np.random.randint(low=256, size=1024, dtype=np.uint8)
audio = np.random.randint(low=256, size=128, dtype=np.uint8)
seq_example = tf.train.SequenceExample()
seq_example.context.feature['id'].bytes_list.value[:] = [b'id001']
seq_example.context.feature['labels'].int64_list.value[:] = [1, 2, 3, 4]
tfexample_utils.put_bytes_list_to_feature(
seq_example, rgb.tobytes(), key='rgb', repeat_num=120)
tfexample_utils.put_bytes_list_to_feature(
seq_example, audio.tobytes(), key='audio', repeat_num=120)
return seq_example
class TrainTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super(TrainTest, self).setUp() super().setUp()
self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir') self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
tf.io.gfile.makedirs(self._model_dir) tf.io.gfile.makedirs(self._model_dir)
data_dir = os.path.join(self.get_temp_dir(), 'data') data_dir = os.path.join(self.get_temp_dir(), 'data')
tf.io.gfile.makedirs(data_dir) tf.io.gfile.makedirs(data_dir)
self._data_path = os.path.join(data_dir, 'data.tfrecord') self._data_path = os.path.join(data_dir, 'data.tfrecord')
examples = [make_yt8m_example() for _ in range(8)] examples = [utils.MakeYt8mExample() for _ in range(8)]
tfexample_utils.dump_to_tfrecord(self._data_path, tf_examples=examples) tfexample_utils.dump_to_tfrecord(self._data_path, tf_examples=examples)
def test_run(self): @parameterized.named_parameters(
dict(testcase_name='segment', use_segment_level_labels=True),
dict(testcase_name='video', use_segment_level_labels=False))
def test_train_and_eval(self, use_segment_level_labels):
saved_flag_values = flagsaver.save_flag_values() saved_flag_values = flagsaver.save_flag_values()
train_lib.tfm_flags.define_flags() train_lib.tfm_flags.define_flags()
FLAGS.mode = 'train' FLAGS.mode = 'train'
...@@ -87,6 +75,7 @@ class TrainTest(tf.test.TestCase): ...@@ -87,6 +75,7 @@ class TrainTest(tf.test.TestCase):
}, },
'validation_data': { 'validation_data': {
'input_path': self._data_path, 'input_path': self._data_path,
'segment_labels': use_segment_level_labels,
'global_batch_size': 4, 'global_batch_size': 4,
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment