Commit ca431476 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Moves the YT8M folder.

PiperOrigin-RevId: 402314322
parent c303344e
......@@ -14,4 +14,4 @@
"""Configs package definition."""
from official.vision.beta.projects.yt8m.configs import yt8m
from official.projects.yt8m.configs import yt8m
......@@ -26,10 +26,10 @@
from typing import Dict
import tensorflow as tf
from official.projects.yt8m.dataloaders import utils
from official.vision.beta.configs import video_classification as exp_cfg
from official.vision.beta.dataloaders import decoder
from official.vision.beta.dataloaders import parser
from official.vision.beta.projects.yt8m.dataloaders import utils
def resize_axis(tensor, axis, new_size, fill_value=0):
......
......@@ -15,10 +15,8 @@
"""Provides functions to help with evaluating models."""
import numpy as np
import tensorflow as tf
from official.vision.beta.projects.yt8m.eval_utils import \
average_precision_calculator as ap_calculator
from official.vision.beta.projects.yt8m.eval_utils import \
mean_average_precision_calculator as map_calculator
from official.projects.yt8m.eval_utils import average_precision_calculator as ap_calculator
from official.projects.yt8m.eval_utils import mean_average_precision_calculator as map_calculator
def flatten(l):
......
......@@ -37,8 +37,7 @@ aps = calculator.peek_map_at_n()
```
"""
from official.vision.beta.projects.yt8m.eval_utils import \
average_precision_calculator
from official.projects.yt8m.eval_utils import average_precision_calculator
class MeanAveragePrecisionCalculator(object):
......
......@@ -13,10 +13,10 @@
# limitations under the License.
"""Contains model definitions."""
from typing import Optional, Dict, Any
from typing import Any, Dict, Optional
import tensorflow as tf
from official.vision.beta.projects.yt8m.modeling import yt8m_model_utils as utils
from official.projects.yt8m.modeling import yt8m_model_utils as utils
layers = tf.keras.layers
......
......@@ -17,9 +17,9 @@ from typing import Optional
import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.projects.yt8m.configs import yt8m as yt8m_cfg
from official.vision.beta.projects.yt8m.modeling import yt8m_agg_models
from official.vision.beta.projects.yt8m.modeling import yt8m_model_utils as utils
from official.projects.yt8m.configs import yt8m as yt8m_cfg
from official.projects.yt8m.modeling import yt8m_agg_models
from official.projects.yt8m.modeling import yt8m_model_utils as utils
layers = tf.keras.layers
......
......@@ -18,8 +18,8 @@ from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.projects.yt8m.configs import yt8m as yt8m_cfg
from official.vision.beta.projects.yt8m.modeling import yt8m_model
from official.projects.yt8m.configs import yt8m as yt8m_cfg
from official.projects.yt8m.modeling import yt8m_model
class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase):
......
......@@ -13,7 +13,7 @@
# limitations under the License.
"""Contains a collection of util functions for model construction."""
from typing import Dict, Optional, Union, Any
from typing import Any, Dict, Optional, Union
import tensorflow as tf
......
......@@ -13,4 +13,4 @@
# limitations under the License.
"""Tasks package definition."""
from official.vision.beta.projects.yt8m.tasks import yt8m_task
from official.projects.yt8m.tasks import yt8m_task
......@@ -20,11 +20,11 @@ from official.core import base_task
from official.core import input_reader
from official.core import task_factory
from official.modeling import tf_utils
from official.vision.beta.projects.yt8m.configs import yt8m as yt8m_cfg
from official.vision.beta.projects.yt8m.dataloaders import yt8m_input
from official.vision.beta.projects.yt8m.eval_utils import eval_util
from official.vision.beta.projects.yt8m.modeling import yt8m_model_utils as utils
from official.vision.beta.projects.yt8m.modeling.yt8m_model import DbofModel
from official.projects.yt8m.configs import yt8m as yt8m_cfg
from official.projects.yt8m.dataloaders import yt8m_input
from official.projects.yt8m.eval_utils import eval_util
from official.projects.yt8m.modeling import yt8m_model_utils as utils
from official.projects.yt8m.modeling.yt8m_model import DbofModel
@task_factory.register_task_cls(yt8m_cfg.YT8MTask)
......
......@@ -17,11 +17,11 @@
from absl import app
from official.common import flags as tfm_flags
from official.vision.beta import train
# pylint: disable=unused-import
from official.vision.beta.projects.yt8m.configs import yt8m
from official.vision.beta.projects.yt8m.tasks import yt8m_task
from official.projects.yt8m.configs import yt8m
from official.projects.yt8m.tasks import yt8m_task
# pylint: enable=unused-import
from official.vision.beta import train
if __name__ == '__main__':
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
import json
import os
from absl import flags
from absl.testing import flagsaver
import numpy as np
import tensorflow as tf
from official.projects.yt8m import train as train_lib
from official.vision.beta.dataloaders import tfexample_utils
FLAGS = flags.FLAGS
def make_yt8m_example():
rgb = np.random.randint(low=256, size=1024, dtype=np.uint8)
audio = np.random.randint(low=256, size=128, dtype=np.uint8)
seq_example = tf.train.SequenceExample()
seq_example.context.feature['id'].bytes_list.value[:] = [b'id001']
seq_example.context.feature['labels'].int64_list.value[:] = [1, 2, 3, 4]
tfexample_utils.put_bytes_list_to_feature(
seq_example, rgb.tobytes(), key='rgb', repeat_num=120)
tfexample_utils.put_bytes_list_to_feature(
seq_example, audio.tobytes(), key='audio', repeat_num=120)
return seq_example
class TrainTest(tf.test.TestCase):
def setUp(self):
super(TrainTest, self).setUp()
self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
tf.io.gfile.makedirs(self._model_dir)
data_dir = os.path.join(self.get_temp_dir(), 'data')
tf.io.gfile.makedirs(data_dir)
self._data_path = os.path.join(data_dir, 'data.tfrecord')
examples = [make_yt8m_example() for _ in range(8)]
tfexample_utils.dump_to_tfrecord(self._data_path, tf_examples=examples)
def test_run(self):
saved_flag_values = flagsaver.save_flag_values()
train_lib.tfm_flags.define_flags()
FLAGS.mode = 'train'
FLAGS.model_dir = self._model_dir
FLAGS.experiment = 'yt8m_experiment'
FLAGS.tpu = ''
params_override = json.dumps({
'runtime': {
'distribution_strategy': 'mirrored',
'mixed_precision_dtype': 'float32',
},
'trainer': {
'train_steps': 1,
'validation_steps': 1,
},
'task': {
'model': {
'cluster_size': 16,
'hidden_size': 16,
'use_context_gate_cluster_layer': True,
'agg_model': {
'use_input_context_gate': True,
'use_output_context_gate': True,
},
},
'train_data': {
'input_path': self._data_path,
'global_batch_size': 4,
},
'validation_data': {
'input_path': self._data_path,
'global_batch_size': 4,
}
}
})
FLAGS.params_override = params_override
train_lib.train.main('unused_args')
FLAGS.mode = 'eval'
with train_lib.train.gin.unlock_config():
train_lib.train.main('unused_args')
flagsaver.restore_flag_values(saved_flag_values)
if __name__ == '__main__':
tf.config.set_soft_device_placement(True)
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment