Commit 2d353306 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Moves the YT8M folder.

PiperOrigin-RevId: 402314322
parent 6a47721e
...@@ -14,4 +14,4 @@ ...@@ -14,4 +14,4 @@
"""Configs package definition.""" """Configs package definition."""
from official.vision.beta.projects.yt8m.configs import yt8m from official.projects.yt8m.configs import yt8m
...@@ -26,10 +26,10 @@ ...@@ -26,10 +26,10 @@
from typing import Dict from typing import Dict
import tensorflow as tf import tensorflow as tf
from official.projects.yt8m.dataloaders import utils
from official.vision.beta.configs import video_classification as exp_cfg from official.vision.beta.configs import video_classification as exp_cfg
from official.vision.beta.dataloaders import decoder from official.vision.beta.dataloaders import decoder
from official.vision.beta.dataloaders import parser from official.vision.beta.dataloaders import parser
from official.vision.beta.projects.yt8m.dataloaders import utils
def resize_axis(tensor, axis, new_size, fill_value=0): def resize_axis(tensor, axis, new_size, fill_value=0):
......
...@@ -15,10 +15,8 @@ ...@@ -15,10 +15,8 @@
"""Provides functions to help with evaluating models.""" """Provides functions to help with evaluating models."""
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.vision.beta.projects.yt8m.eval_utils import \ from official.projects.yt8m.eval_utils import average_precision_calculator as ap_calculator
average_precision_calculator as ap_calculator from official.projects.yt8m.eval_utils import mean_average_precision_calculator as map_calculator
from official.vision.beta.projects.yt8m.eval_utils import \
mean_average_precision_calculator as map_calculator
def flatten(l): def flatten(l):
......
...@@ -37,8 +37,7 @@ aps = calculator.peek_map_at_n() ...@@ -37,8 +37,7 @@ aps = calculator.peek_map_at_n()
``` ```
""" """
from official.vision.beta.projects.yt8m.eval_utils import \ from official.projects.yt8m.eval_utils import average_precision_calculator
average_precision_calculator
class MeanAveragePrecisionCalculator(object): class MeanAveragePrecisionCalculator(object):
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
"""Contains model definitions.""" """Contains model definitions."""
from typing import Optional, Dict, Any from typing import Any, Dict, Optional
import tensorflow as tf import tensorflow as tf
from official.vision.beta.projects.yt8m.modeling import yt8m_model_utils as utils from official.projects.yt8m.modeling import yt8m_model_utils as utils
layers = tf.keras.layers layers = tf.keras.layers
......
...@@ -17,9 +17,9 @@ from typing import Optional ...@@ -17,9 +17,9 @@ from typing import Optional
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.projects.yt8m.configs import yt8m as yt8m_cfg from official.projects.yt8m.configs import yt8m as yt8m_cfg
from official.vision.beta.projects.yt8m.modeling import yt8m_agg_models from official.projects.yt8m.modeling import yt8m_agg_models
from official.vision.beta.projects.yt8m.modeling import yt8m_model_utils as utils from official.projects.yt8m.modeling import yt8m_model_utils as utils
layers = tf.keras.layers layers = tf.keras.layers
......
...@@ -18,8 +18,8 @@ from absl.testing import parameterized ...@@ -18,8 +18,8 @@ from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.vision.beta.projects.yt8m.configs import yt8m as yt8m_cfg from official.projects.yt8m.configs import yt8m as yt8m_cfg
from official.vision.beta.projects.yt8m.modeling import yt8m_model from official.projects.yt8m.modeling import yt8m_model
class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase): class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase):
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains a collection of util functions for model construction.""" """Contains a collection of util functions for model construction."""
from typing import Dict, Optional, Union, Any from typing import Any, Dict, Optional, Union
import tensorflow as tf import tensorflow as tf
......
...@@ -13,4 +13,4 @@ ...@@ -13,4 +13,4 @@
# limitations under the License. # limitations under the License.
"""Tasks package definition.""" """Tasks package definition."""
from official.vision.beta.projects.yt8m.tasks import yt8m_task from official.projects.yt8m.tasks import yt8m_task
...@@ -20,11 +20,11 @@ from official.core import base_task ...@@ -20,11 +20,11 @@ from official.core import base_task
from official.core import input_reader from official.core import input_reader
from official.core import task_factory from official.core import task_factory
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.projects.yt8m.configs import yt8m as yt8m_cfg from official.projects.yt8m.configs import yt8m as yt8m_cfg
from official.vision.beta.projects.yt8m.dataloaders import yt8m_input from official.projects.yt8m.dataloaders import yt8m_input
from official.vision.beta.projects.yt8m.eval_utils import eval_util from official.projects.yt8m.eval_utils import eval_util
from official.vision.beta.projects.yt8m.modeling import yt8m_model_utils as utils from official.projects.yt8m.modeling import yt8m_model_utils as utils
from official.vision.beta.projects.yt8m.modeling.yt8m_model import DbofModel from official.projects.yt8m.modeling.yt8m_model import DbofModel
@task_factory.register_task_cls(yt8m_cfg.YT8MTask) @task_factory.register_task_cls(yt8m_cfg.YT8MTask)
......
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
from absl import app from absl import app
from official.common import flags as tfm_flags from official.common import flags as tfm_flags
from official.vision.beta import train
# pylint: disable=unused-import # pylint: disable=unused-import
from official.vision.beta.projects.yt8m.configs import yt8m from official.projects.yt8m.configs import yt8m
from official.vision.beta.projects.yt8m.tasks import yt8m_task from official.projects.yt8m.tasks import yt8m_task
# pylint: enable=unused-import # pylint: enable=unused-import
from official.vision.beta import train
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
import json
import os
from absl import flags
from absl.testing import flagsaver
import numpy as np
import tensorflow as tf
from official.projects.yt8m import train as train_lib
from official.vision.beta.dataloaders import tfexample_utils
FLAGS = flags.FLAGS
def make_yt8m_example():
rgb = np.random.randint(low=256, size=1024, dtype=np.uint8)
audio = np.random.randint(low=256, size=128, dtype=np.uint8)
seq_example = tf.train.SequenceExample()
seq_example.context.feature['id'].bytes_list.value[:] = [b'id001']
seq_example.context.feature['labels'].int64_list.value[:] = [1, 2, 3, 4]
tfexample_utils.put_bytes_list_to_feature(
seq_example, rgb.tobytes(), key='rgb', repeat_num=120)
tfexample_utils.put_bytes_list_to_feature(
seq_example, audio.tobytes(), key='audio', repeat_num=120)
return seq_example
class TrainTest(tf.test.TestCase):
def setUp(self):
super(TrainTest, self).setUp()
self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
tf.io.gfile.makedirs(self._model_dir)
data_dir = os.path.join(self.get_temp_dir(), 'data')
tf.io.gfile.makedirs(data_dir)
self._data_path = os.path.join(data_dir, 'data.tfrecord')
examples = [make_yt8m_example() for _ in range(8)]
tfexample_utils.dump_to_tfrecord(self._data_path, tf_examples=examples)
def test_run(self):
saved_flag_values = flagsaver.save_flag_values()
train_lib.tfm_flags.define_flags()
FLAGS.mode = 'train'
FLAGS.model_dir = self._model_dir
FLAGS.experiment = 'yt8m_experiment'
FLAGS.tpu = ''
params_override = json.dumps({
'runtime': {
'distribution_strategy': 'mirrored',
'mixed_precision_dtype': 'float32',
},
'trainer': {
'train_steps': 1,
'validation_steps': 1,
},
'task': {
'model': {
'cluster_size': 16,
'hidden_size': 16,
'use_context_gate_cluster_layer': True,
'agg_model': {
'use_input_context_gate': True,
'use_output_context_gate': True,
},
},
'train_data': {
'input_path': self._data_path,
'global_batch_size': 4,
},
'validation_data': {
'input_path': self._data_path,
'global_batch_size': 4,
}
}
})
FLAGS.params_override = params_override
train_lib.train.main('unused_args')
FLAGS.mode = 'eval'
with train_lib.train.gin.unlock_config():
train_lib.train.main('unused_args')
flagsaver.restore_flag_values(saved_flag_values)
if __name__ == '__main__':
tf.config.set_soft_device_placement(True)
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment