Commit b0ccdb11 authored by Shixin Luo's avatar Shixin Luo
Browse files

resolve conflict with master

parents e61588cd 1611a8c5
...@@ -29,7 +29,7 @@ from absl import app ...@@ -29,7 +29,7 @@ from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.common import distribute_utils
from official.modeling import performance from official.modeling import performance
from official.nlp.transformer import compute_bleu from official.nlp.transformer import compute_bleu
from official.nlp.transformer import data_pipeline from official.nlp.transformer import data_pipeline
...@@ -40,7 +40,6 @@ from official.nlp.transformer import transformer ...@@ -40,7 +40,6 @@ from official.nlp.transformer import transformer
from official.nlp.transformer import translate from official.nlp.transformer import translate
from official.nlp.transformer.utils import tokenizer from official.nlp.transformer.utils import tokenizer
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
INF = int(1e9) INF = int(1e9)
...@@ -160,8 +159,9 @@ class TransformerTask(object): ...@@ -160,8 +159,9 @@ class TransformerTask(object):
params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
params["steps_between_evals"] = flags_obj.steps_between_evals params["steps_between_evals"] = flags_obj.steps_between_evals
params["enable_checkpointing"] = flags_obj.enable_checkpointing params["enable_checkpointing"] = flags_obj.enable_checkpointing
params["save_weights_only"] = flags_obj.save_weights_only
self.distribution_strategy = distribution_utils.get_distribution_strategy( self.distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=num_gpus, num_gpus=num_gpus,
all_reduce_alg=flags_obj.all_reduce_alg, all_reduce_alg=flags_obj.all_reduce_alg,
...@@ -197,7 +197,7 @@ class TransformerTask(object): ...@@ -197,7 +197,7 @@ class TransformerTask(object):
keras_utils.set_session_config(enable_xla=flags_obj.enable_xla) keras_utils.set_session_config(enable_xla=flags_obj.enable_xla)
_ensure_dir(flags_obj.model_dir) _ensure_dir(flags_obj.model_dir)
with distribution_utils.get_strategy_scope(self.distribution_strategy): with distribute_utils.get_strategy_scope(self.distribution_strategy):
model = transformer.create_model(params, is_train=True) model = transformer.create_model(params, is_train=True)
opt = self._create_optimizer() opt = self._create_optimizer()
...@@ -376,7 +376,7 @@ class TransformerTask(object): ...@@ -376,7 +376,7 @@ class TransformerTask(object):
# We only want to create the model under DS scope for TPU case. # We only want to create the model under DS scope for TPU case.
# When 'distribution_strategy' is None, a no-op DummyContextManager will # When 'distribution_strategy' is None, a no-op DummyContextManager will
# be used. # be used.
with distribution_utils.get_strategy_scope(distribution_strategy): with distribute_utils.get_strategy_scope(distribution_strategy):
if not self.predict_model: if not self.predict_model:
self.predict_model = transformer.create_model(self.params, False) self.predict_model = transformer.create_model(self.params, False)
self._load_weights_if_possible( self._load_weights_if_possible(
......
...@@ -14,11 +14,6 @@ ...@@ -14,11 +14,6 @@
# ============================================================================== # ==============================================================================
"""XLNet classification finetuning runner in tf2.0.""" """XLNet classification finetuning runner in tf2.0."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import functools import functools
# Import libraries # Import libraries
from absl import app from absl import app
...@@ -28,13 +23,13 @@ from absl import logging ...@@ -28,13 +23,13 @@ from absl import logging
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
# pylint: disable=unused-import # pylint: disable=unused-import
from official.common import distribute_utils
from official.nlp.xlnet import common_flags from official.nlp.xlnet import common_flags
from official.nlp.xlnet import data_utils from official.nlp.xlnet import data_utils
from official.nlp.xlnet import optimization from official.nlp.xlnet import optimization
from official.nlp.xlnet import training_utils from official.nlp.xlnet import training_utils
from official.nlp.xlnet import xlnet_config from official.nlp.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_modeling as modeling from official.nlp.xlnet import xlnet_modeling as modeling
from official.utils.misc import tpu_lib
flags.DEFINE_integer("n_class", default=2, help="Number of classes.") flags.DEFINE_integer("n_class", default=2, help="Number of classes.")
flags.DEFINE_string( flags.DEFINE_string(
...@@ -135,14 +130,9 @@ def get_metric_fn(): ...@@ -135,14 +130,9 @@ def get_metric_fn():
def main(unused_argv): def main(unused_argv):
del unused_argv del unused_argv
if FLAGS.strategy_type == "mirror": strategy = distribute_utils.get_distribution_strategy(
strategy = tf.distribute.MirroredStrategy() distribution_strategy=FLAGS.strategy_type,
elif FLAGS.strategy_type == "tpu": tpu_address=FLAGS.tpu)
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else:
raise ValueError("The distribution strategy type is not supported: %s" %
FLAGS.strategy_type)
if strategy: if strategy:
logging.info("***** Number of cores used : %d", logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync) strategy.num_replicas_in_sync)
......
...@@ -12,12 +12,7 @@ ...@@ -12,12 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""XLNet classification finetuning runner in tf2.0.""" """XLNet pretraining runner in tf2.0."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import functools import functools
import os import os
...@@ -28,13 +23,13 @@ from absl import flags ...@@ -28,13 +23,13 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
# pylint: disable=unused-import # pylint: disable=unused-import
from official.common import distribute_utils
from official.nlp.xlnet import common_flags from official.nlp.xlnet import common_flags
from official.nlp.xlnet import data_utils from official.nlp.xlnet import data_utils
from official.nlp.xlnet import optimization from official.nlp.xlnet import optimization
from official.nlp.xlnet import training_utils from official.nlp.xlnet import training_utils
from official.nlp.xlnet import xlnet_config from official.nlp.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_modeling as modeling from official.nlp.xlnet import xlnet_modeling as modeling
from official.utils.misc import tpu_lib
flags.DEFINE_integer( flags.DEFINE_integer(
"num_predict", "num_predict",
...@@ -77,17 +72,11 @@ def get_pretrainxlnet_model(model_config, run_config): ...@@ -77,17 +72,11 @@ def get_pretrainxlnet_model(model_config, run_config):
def main(unused_argv): def main(unused_argv):
del unused_argv del unused_argv
num_hosts = 1 num_hosts = 1
if FLAGS.strategy_type == "mirror": strategy = distribute_utils.get_distribution_strategy(
strategy = tf.distribute.MirroredStrategy() distribution_strategy=FLAGS.strategy_type,
elif FLAGS.strategy_type == "tpu": tpu_address=FLAGS.tpu)
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) if FLAGS.strategy_type == "tpu":
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) num_hosts = strategy.extended.num_hosts
topology = FLAGS.tpu_topology.split("x")
total_num_core = 2 * int(topology[0]) * int(topology[1])
num_hosts = total_num_core // FLAGS.num_core_per_host
else:
raise ValueError("The distribution strategy type is not supported: %s" %
FLAGS.strategy_type)
if strategy: if strategy:
logging.info("***** Number of cores used : %d", logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync) strategy.num_replicas_in_sync)
......
...@@ -14,11 +14,6 @@ ...@@ -14,11 +14,6 @@
# ============================================================================== # ==============================================================================
"""XLNet SQUAD finetuning runner in tf2.0.""" """XLNet SQUAD finetuning runner in tf2.0."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import functools import functools
import json import json
import os import os
...@@ -32,6 +27,7 @@ from absl import logging ...@@ -32,6 +27,7 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
# pylint: disable=unused-import # pylint: disable=unused-import
import sentencepiece as spm import sentencepiece as spm
from official.common import distribute_utils
from official.nlp.xlnet import common_flags from official.nlp.xlnet import common_flags
from official.nlp.xlnet import data_utils from official.nlp.xlnet import data_utils
from official.nlp.xlnet import optimization from official.nlp.xlnet import optimization
...@@ -39,7 +35,6 @@ from official.nlp.xlnet import squad_utils ...@@ -39,7 +35,6 @@ from official.nlp.xlnet import squad_utils
from official.nlp.xlnet import training_utils from official.nlp.xlnet import training_utils
from official.nlp.xlnet import xlnet_config from official.nlp.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_modeling as modeling from official.nlp.xlnet import xlnet_modeling as modeling
from official.utils.misc import tpu_lib
flags.DEFINE_string( flags.DEFINE_string(
"test_feature_path", default=None, help="Path to feature of test set.") "test_feature_path", default=None, help="Path to feature of test set.")
...@@ -217,14 +212,9 @@ def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top): ...@@ -217,14 +212,9 @@ def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top):
def main(unused_argv): def main(unused_argv):
del unused_argv del unused_argv
if FLAGS.strategy_type == "mirror": strategy = distribute_utils.get_distribution_strategy(
strategy = tf.distribute.MirroredStrategy() distribution_strategy=FLAGS.strategy_type,
elif FLAGS.strategy_type == "tpu": tpu_address=FLAGS.tpu)
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else:
raise ValueError("The distribution strategy type is not supported: %s" %
FLAGS.strategy_type)
if strategy: if strategy:
logging.info("***** Number of cores used : %d", logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync) strategy.num_replicas_in_sync)
......
...@@ -21,20 +21,17 @@ from __future__ import print_function ...@@ -21,20 +21,17 @@ from __future__ import print_function
import json import json
import os import os
# pylint: disable=g-bad-import-order
import numpy as np
from absl import flags from absl import flags
from absl import logging from absl import logging
import numpy as np
import tensorflow as tf import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.common import distribute_utils
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
from official.recommendation import data_pipeline from official.recommendation import data_pipeline
from official.recommendation import data_preprocessing from official.recommendation import data_preprocessing
from official.recommendation import movielens from official.recommendation import movielens
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -142,7 +139,7 @@ def get_v1_distribution_strategy(params): ...@@ -142,7 +139,7 @@ def get_v1_distribution_strategy(params):
tpu_cluster_resolver, steps_per_run=100) tpu_cluster_resolver, steps_per_run=100)
else: else:
distribution = distribution_utils.get_distribution_strategy( distribution = distribute_utils.get_distribution_strategy(
num_gpus=params["num_gpus"]) num_gpus=params["num_gpus"])
return distribution return distribution
......
...@@ -33,13 +33,13 @@ from absl import logging ...@@ -33,13 +33,13 @@ from absl import logging
import tensorflow.compat.v2 as tf import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.common import distribute_utils
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
from official.recommendation import movielens from official.recommendation import movielens
from official.recommendation import ncf_common from official.recommendation import ncf_common
from official.recommendation import ncf_input_pipeline from official.recommendation import ncf_input_pipeline
from official.recommendation import neumf_model from official.recommendation import neumf_model
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
...@@ -225,7 +225,7 @@ def run_ncf(_): ...@@ -225,7 +225,7 @@ def run_ncf(_):
loss_scale=flags_core.get_loss_scale(FLAGS, default_for_fp16="dynamic")) loss_scale=flags_core.get_loss_scale(FLAGS, default_for_fp16="dynamic"))
tf.keras.mixed_precision.experimental.set_policy(policy) tf.keras.mixed_precision.experimental.set_policy(policy)
strategy = distribution_utils.get_distribution_strategy( strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus, num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu) tpu_address=FLAGS.tpu)
...@@ -271,7 +271,7 @@ def run_ncf(_): ...@@ -271,7 +271,7 @@ def run_ncf(_):
params, producer, input_meta_data, strategy)) params, producer, input_meta_data, strategy))
steps_per_epoch = None if generate_input_online else num_train_steps steps_per_epoch = None if generate_input_online else num_train_steps
with distribution_utils.get_strategy_scope(strategy): with distribute_utils.get_strategy_scope(strategy):
keras_model = _get_keras_model(params) keras_model = _get_keras_model(params)
optimizer = tf.keras.optimizers.Adam( optimizer = tf.keras.optimizers.Adam(
learning_rate=params["learning_rate"], learning_rate=params["learning_rate"],
......
...@@ -17,7 +17,7 @@ gin-config ...@@ -17,7 +17,7 @@ gin-config
tf_slim>=1.1.0 tf_slim>=1.1.0
Cython Cython
matplotlib matplotlib
pyyaml pyyaml>=5.1
# CV related dependencies # CV related dependencies
opencv-python-headless opencv-python-headless
Pillow Pillow
......
...@@ -13,197 +13,5 @@ ...@@ -13,197 +13,5 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Helper functions for running models in a distributed setting.""" """Helper functions for running models in a distributed setting."""
# pylint: disable=wildcard-import
from __future__ import absolute_import from official.common.distribute_utils import *
from __future__ import division
from __future__ import print_function
import json
import os
import random
import string
from absl import logging
import tensorflow.compat.v2 as tf
from official.utils.misc import tpu_lib
def _collective_communication(all_reduce_alg):
"""Return a CollectiveCommunication based on all_reduce_alg.
Args:
all_reduce_alg: a string specifying which collective communication to pick,
or None.
Returns:
tf.distribute.experimental.CollectiveCommunication object
Raises:
ValueError: if `all_reduce_alg` not in [None, "ring", "nccl"]
"""
collective_communication_options = {
None: tf.distribute.experimental.CollectiveCommunication.AUTO,
"ring": tf.distribute.experimental.CollectiveCommunication.RING,
"nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
}
if all_reduce_alg not in collective_communication_options:
raise ValueError(
"When used with `multi_worker_mirrored`, valid values for "
"all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}".format(
all_reduce_alg))
return collective_communication_options[all_reduce_alg]
def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
"""Return a CrossDeviceOps based on all_reduce_alg and num_packs.
Args:
all_reduce_alg: a string specifying which cross device op to pick, or None.
num_packs: an integer specifying number of packs for the cross device op.
Returns:
tf.distribute.CrossDeviceOps object or None.
Raises:
ValueError: if `all_reduce_alg` not in [None, "nccl", "hierarchical_copy"].
"""
if all_reduce_alg is None:
return None
mirrored_all_reduce_options = {
"nccl": tf.distribute.NcclAllReduce,
"hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce
}
if all_reduce_alg not in mirrored_all_reduce_options:
raise ValueError(
"When used with `mirrored`, valid values for all_reduce_alg are "
"[`nccl`, `hierarchical_copy`]. Supplied value: {}".format(
all_reduce_alg))
cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
return cross_device_ops_class(num_packs=num_packs)
def get_distribution_strategy(distribution_strategy="mirrored",
num_gpus=0,
all_reduce_alg=None,
num_packs=1,
tpu_address=None):
"""Return a DistributionStrategy for running the model.
Args:
distribution_strategy: a string specifying which distribution strategy to
use. Accepted values are "off", "one_device", "mirrored",
"parameter_server", "multi_worker_mirrored", and "tpu" -- case
insensitive. "off" means not to use Distribution Strategy; "tpu" means to
use TPUStrategy using `tpu_address`.
num_gpus: Number of GPUs to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing
all-reduce. For `MirroredStrategy`, valid values are "nccl" and
"hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
"ring" and "nccl". If None, DistributionStrategy will choose based on
device topology.
num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce`
or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
tpu_address: Optional. String that represents TPU to connect to. Must not be
None if `distribution_strategy` is set to `tpu`.
Returns:
tf.distribute.DistibutionStrategy object.
Raises:
ValueError: if `distribution_strategy` is "off" or "one_device" and
`num_gpus` is larger than 1; or `num_gpus` is negative or if
`distribution_strategy` is `tpu` but `tpu_address` is not specified.
"""
if num_gpus < 0:
raise ValueError("`num_gpus` can not be negative.")
distribution_strategy = distribution_strategy.lower()
if distribution_strategy == "off":
if num_gpus > 1:
raise ValueError("When {} GPUs are specified, distribution_strategy "
"flag cannot be set to `off`.".format(num_gpus))
return None
if distribution_strategy == "tpu":
# When tpu_address is an empty string, we communicate with local TPUs.
cluster_resolver = tpu_lib.tpu_initialize(tpu_address)
return tf.distribute.experimental.TPUStrategy(cluster_resolver)
if distribution_strategy == "multi_worker_mirrored":
return tf.distribute.experimental.MultiWorkerMirroredStrategy(
communication=_collective_communication(all_reduce_alg))
if distribution_strategy == "one_device":
if num_gpus == 0:
return tf.distribute.OneDeviceStrategy("device:CPU:0")
if num_gpus > 1:
raise ValueError("`OneDeviceStrategy` can not be used for more than "
"one device.")
return tf.distribute.OneDeviceStrategy("device:GPU:0")
if distribution_strategy == "mirrored":
if num_gpus == 0:
devices = ["device:CPU:0"]
else:
devices = ["device:GPU:%d" % i for i in range(num_gpus)]
return tf.distribute.MirroredStrategy(
devices=devices,
cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
if distribution_strategy == "parameter_server":
return tf.distribute.experimental.ParameterServerStrategy()
raise ValueError("Unrecognized Distribution Strategy: %r" %
distribution_strategy)
def configure_cluster(worker_hosts=None, task_index=-1):
"""Set multi-worker cluster spec in TF_CONFIG environment variable.
Args:
worker_hosts: comma-separated list of worker ip:port pairs.
Returns:
Number of workers in the cluster.
"""
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
if tf_config:
num_workers = (
len(tf_config["cluster"].get("chief", [])) +
len(tf_config["cluster"].get("worker", [])))
elif worker_hosts:
workers = worker_hosts.split(",")
num_workers = len(workers)
if num_workers > 1 and task_index < 0:
raise ValueError("Must specify task_index when number of workers > 1")
task_index = 0 if num_workers == 1 else task_index
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": workers
},
"task": {
"type": "worker",
"index": task_index
}
})
else:
num_workers = 1
return num_workers
def get_strategy_scope(strategy):
if strategy:
strategy_scope = strategy.scope()
else:
strategy_scope = DummyContextManager()
return strategy_scope
class DummyContextManager(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
...@@ -30,9 +30,13 @@ from tensorflow.python.eager import monitoring ...@@ -30,9 +30,13 @@ from tensorflow.python.eager import monitoring
global_batch_size_gauge = monitoring.IntGauge( global_batch_size_gauge = monitoring.IntGauge(
'/tensorflow/training/global_batch_size', 'TF training global batch size') '/tensorflow/training/global_batch_size', 'TF training global batch size')
first_batch_start_time = monitoring.IntGauge( first_batch_time_gauge = monitoring.IntGauge(
'/tensorflow/training/first_batch_start', '/tensorflow/training/first_batch',
'TF training start time (unix epoch time in us.') 'TF training start/end time for first batch (unix epoch time in us.',
'type')
first_batch_start_time = first_batch_time_gauge.get_cell('start')
first_batch_end_time = first_batch_time_gauge.get_cell('end')
class BatchTimestamp(object): class BatchTimestamp(object):
...@@ -121,8 +125,8 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -121,8 +125,8 @@ class TimeHistory(tf.keras.callbacks.Callback):
def on_batch_begin(self, batch, logs=None): def on_batch_begin(self, batch, logs=None):
if not self.start_time: if not self.start_time:
self.start_time = time.time() self.start_time = time.time()
if not first_batch_start_time.get_cell().value(): if not first_batch_start_time.value():
first_batch_start_time.get_cell().set(int(self.start_time * 1000000)) first_batch_start_time.set(int(self.start_time * 1000000))
# Record the timestamp of the first global step # Record the timestamp of the first global step
if not self.timestamp_log: if not self.timestamp_log:
...@@ -131,6 +135,8 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -131,6 +135,8 @@ class TimeHistory(tf.keras.callbacks.Callback):
def on_batch_end(self, batch, logs=None): def on_batch_end(self, batch, logs=None):
"""Records elapse time of the batch and calculates examples per second.""" """Records elapse time of the batch and calculates examples per second."""
if not first_batch_end_time.value():
first_batch_end_time.set(int(time.time() * 1000000))
self.steps_in_epoch = batch + 1 self.steps_in_epoch = batch + 1
steps_since_last_log = self.global_steps - self.last_log_step steps_since_last_log = self.global_steps - self.last_log_step
if steps_since_last_log >= self.log_steps: if steps_since_last_log >= self.log_steps:
......
# Benchmarks runs on same instnace, change eval batch size to fit on 4x4 tpu
task:
validation_data:
global_batch_size: 32
trainer:
validation_interval: 1560
validation_steps: 156
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests classification_input.py."""
import io
# Import libraries
from absl.testing import parameterized
import numpy as np
from PIL import Image
import tensorflow as tf
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.vision.beta.dataloaders import classification_input
def _encode_image(image_array, fmt):
image = Image.fromarray(image_array)
with io.BytesIO() as output:
image.save(output, format=fmt)
return output.getvalue()
class DecoderTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
(100, 100, 0), (100, 100, 1), (100, 100, 2),
)
def test_decoder(self, image_height, image_width, num_instances):
decoder = classification_input.Decoder()
image = _encode_image(
np.uint8(np.random.rand(image_height, image_width, 3) * 255),
fmt='JPEG')
label = 2
serialized_example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded': (tf.train.Feature(
bytes_list=tf.train.BytesList(value=[image]))),
'image/class/label': (
tf.train.Feature(
int64_list=tf.train.Int64List(value=[label]))),
})).SerializeToString()
decoded_tensors = decoder.decode(tf.convert_to_tensor(serialized_example))
results = tf.nest.map_structure(lambda x: x.numpy(), decoded_tensors)
self.assertCountEqual(
['image/encoded', 'image/class/label'], results.keys())
self.assertEqual(label, results['image/class/label'])
class ParserTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
([224, 224, 3], 'float32', True),
([224, 224, 3], 'float16', True),
([224, 224, 3], 'float32', False),
([224, 224, 3], 'float16', False),
([512, 640, 3], 'float32', True),
([512, 640, 3], 'float16', True),
([512, 640, 3], 'float32', False),
([512, 640, 3], 'float16', False),
([640, 640, 3], 'float32', True),
([640, 640, 3], 'bfloat16', True),
([640, 640, 3], 'float32', False),
([640, 640, 3], 'bfloat16', False),
)
def test_parser(self, output_size, dtype, is_training):
params = cfg.DataConfig(
input_path='imagenet-2012-tfrecord/train*',
global_batch_size=2,
is_training=True,
examples_consume=4)
decoder = classification_input.Decoder()
parser = classification_input.Parser(
output_size=output_size[:2],
num_classes=1001,
aug_rand_hflip=False,
dtype=dtype)
reader = input_reader.InputReader(
params,
dataset_fn=tf.data.TFRecordDataset,
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read()
images, labels = next(iter(dataset))
self.assertAllEqual(images.numpy().shape,
[params.global_batch_size] + output_size)
self.assertAllEqual(labels.numpy().shape, [params.global_batch_size])
if dtype == 'float32':
self.assertAllEqual(images.dtype, tf.float32)
elif dtype == 'float16':
self.assertAllEqual(images.dtype, tf.float16)
elif dtype == 'bfloat16':
self.assertAllEqual(images.dtype, tf.bfloat16)
if __name__ == '__main__':
tf.test.main()
...@@ -23,11 +23,11 @@ from official.vision.beta.ops import preprocess_ops ...@@ -23,11 +23,11 @@ from official.vision.beta.ops import preprocess_ops
def process_source_id(source_id): def process_source_id(source_id):
"""Processes source_id to the right format.""" """Processes source_id to the right format."""
if source_id.dtype == tf.string: if source_id.dtype == tf.string:
source_id = tf.cast(tf.strings.to_number(source_id), tf.int32) source_id = tf.cast(tf.strings.to_number(source_id), tf.int64)
with tf.control_dependencies([source_id]): with tf.control_dependencies([source_id]):
source_id = tf.cond( source_id = tf.cond(
pred=tf.equal(tf.size(input=source_id), 0), pred=tf.equal(tf.size(input=source_id), 0),
true_fn=lambda: tf.cast(tf.constant(-1), tf.int32), true_fn=lambda: tf.cast(tf.constant(-1), tf.int64),
false_fn=lambda: tf.identity(source_id)) false_fn=lambda: tf.identity(source_id))
return source_id return source_id
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for coco_evaluator."""
import io
import os
# Import libraries
from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from PIL import Image
from pycocotools.coco import COCO
import six
import tensorflow as tf
from official.vision.beta.evaluation import coco_evaluator
from official.vision.beta.evaluation import coco_utils
_COCO_JSON_FILE = '/placer/prod/home/snaggletooth/test/data/coco/instances_val2017.json'
_SAVED_COCO_JSON_FILE = 'tmp.json'
def get_groundtruth_annotations(image_id, coco, include_mask=False):
anns = coco.loadAnns(coco.getAnnIds([image_id]))
if not anns:
return None
image = coco.loadImgs([image_id])[0]
groundtruths = {
'boxes': [],
'classes': [],
'is_crowds': [],
'areas': [],
}
if include_mask:
groundtruths['masks'] = []
for ann in anns:
# Creates detections from groundtruths.
# Converts [x, y, w, h] to [y1, x1, y2, x2] box format.
box = [ann['bbox'][1],
ann['bbox'][0],
(ann['bbox'][1] + ann['bbox'][3]),
(ann['bbox'][0] + ann['bbox'][2])]
# Creates groundtruths.
groundtruths['boxes'].append(box)
groundtruths['classes'].append(ann['category_id'])
groundtruths['is_crowds'].append(ann['iscrowd'])
groundtruths['areas'].append(ann['area'])
if include_mask:
mask_img = Image.fromarray(coco.annToMask(ann).astype(np.uint8))
with io.BytesIO() as stream:
mask_img.save(stream, format='PNG')
mask_bytes = stream.getvalue()
groundtruths['masks'].append(mask_bytes)
for key, val in groundtruths.items():
groundtruths[key] = np.stack(val, axis=0)
groundtruths['source_id'] = image['id']
groundtruths['height'] = image['height']
groundtruths['width'] = image['width']
groundtruths['num_detections'] = len(anns)
for k, v in six.iteritems(groundtruths):
groundtruths[k] = np.expand_dims(v, axis=0)
return groundtruths
def get_predictions(image_id, coco, include_mask=False):
anns = coco.loadAnns(coco.getAnnIds([image_id]))
if not anns:
return None
image = coco.loadImgs([image_id])[0]
predictions = {
'detection_boxes': [],
'detection_classes': [],
'detection_scores': [],
}
if include_mask:
predictions['detection_masks'] = []
for ann in anns:
# Creates detections from groundtruths.
# Converts [x, y, w, h] to [y1, x1, y2, x2] box format and
# does the denormalization.
box = [ann['bbox'][1],
ann['bbox'][0],
(ann['bbox'][1] + ann['bbox'][3]),
(ann['bbox'][0] + ann['bbox'][2])]
predictions['detection_boxes'].append(box)
predictions['detection_classes'].append(ann['category_id'])
predictions['detection_scores'].append(1)
if include_mask:
mask = coco.annToMask(ann)
predictions['detection_masks'].append(mask)
for key, val in predictions.items():
predictions[key] = np.expand_dims(np.stack(val, axis=0), axis=0)
predictions['source_id'] = np.array([image['id']])
predictions['num_detections'] = np.array([len(anns)])
predictions['image_info'] = np.array(
[[[image['height'], image['width']],
[image['height'], image['width']],
[1, 1],
[0, 0]]], dtype=np.float32)
return predictions
def get_fake_predictions(image_id, coco, include_mask=False):
anns = coco.loadAnns(coco.getAnnIds([image_id]))
if not anns:
return None
label_id_max = max([ann['category_id'] for ann in anns])
image = coco.loadImgs([image_id])[0]
num_detections = 100
xmin = np.random.randint(
low=0, high=int(image['width'] / 2), size=(1, num_detections))
xmax = np.random.randint(
low=int(image['width'] / 2), high=image['width'],
size=(1, num_detections))
ymin = np.random.randint(
low=0, high=int(image['height'] / 2), size=(1, num_detections))
ymax = np.random.randint(
low=int(image['height'] / 2), high=image['height'],
size=(1, num_detections))
predictions = {
'detection_boxes': np.stack([ymin, xmin, ymax, xmax], axis=-1),
'detection_classes': np.random.randint(
low=0, high=(label_id_max + 1), size=(1, num_detections)),
'detection_scores': np.random.random(size=(1, num_detections)),
}
if include_mask:
predictions['detection_masks'] = np.random.randint(
1, size=(1, num_detections, image['height'], image['width']))
predictions['source_id'] = np.array([image['id']])
predictions['num_detections'] = np.array([num_detections])
predictions['image_info'] = np.array(
[[[image['height'], image['width']],
[image['height'], image['width']],
[1, 1],
[0, 0]]], dtype=np.float32)
return predictions
class DummyGroundtruthGenerator(object):
def __init__(self, include_mask, image_id, coco):
self._include_mask = include_mask
self._image_id = image_id
self._coco = coco
def __call__(self):
yield get_groundtruth_annotations(
self._image_id, self._coco, self._include_mask)
class COCOEvaluatorTest(parameterized.TestCase, absltest.TestCase):
def setUp(self):
super(COCOEvaluatorTest, self).setUp()
temp = self.create_tempdir()
self._saved_coco_json_file = os.path.join(temp.full_path,
_SAVED_COCO_JSON_FILE)
def tearDown(self):
super(COCOEvaluatorTest, self).tearDown()
@parameterized.parameters(
(False, False), (False, True), (True, False), (True, True))
def testEval(self, include_mask, use_fake_predictions):
coco = COCO(annotation_file=_COCO_JSON_FILE)
index = np.random.randint(len(coco.dataset['images']))
image_id = coco.dataset['images'][index]['id']
# image_id = 26564
# image_id = 324158
if use_fake_predictions:
predictions = get_fake_predictions(
image_id, coco, include_mask=include_mask)
else:
predictions = get_predictions(image_id, coco, include_mask=include_mask)
if not predictions:
logging.info('Empty predictions for index=%d', index)
return
predictions = tf.nest.map_structure(
lambda x: tf.convert_to_tensor(x) if x is not None else None,
predictions)
evaluator_w_json = coco_evaluator.COCOEvaluator(
annotation_file=_COCO_JSON_FILE, include_mask=include_mask)
evaluator_w_json.update_state(groundtruths=None, predictions=predictions)
results_w_json = evaluator_w_json.result()
dummy_generator = DummyGroundtruthGenerator(
include_mask=include_mask, image_id=image_id, coco=coco)
coco_utils.generate_annotation_file(dummy_generator,
self._saved_coco_json_file)
evaluator_no_json = coco_evaluator.COCOEvaluator(
annotation_file=self._saved_coco_json_file, include_mask=include_mask)
evaluator_no_json.update_state(groundtruths=None, predictions=predictions)
results_no_json = evaluator_no_json.result()
for k, v in results_w_json.items():
self.assertEqual(v, results_no_json[k])
@parameterized.parameters(
(False, False), (False, True), (True, False), (True, True))
def testEvalOnTheFly(self, include_mask, use_fake_predictions):
coco = COCO(annotation_file=_COCO_JSON_FILE)
index = np.random.randint(len(coco.dataset['images']))
image_id = coco.dataset['images'][index]['id']
# image_id = 26564
# image_id = 324158
if use_fake_predictions:
predictions = get_fake_predictions(
image_id, coco, include_mask=include_mask)
else:
predictions = get_predictions(image_id, coco, include_mask=include_mask)
if not predictions:
logging.info('Empty predictions for index=%d', index)
return
predictions = tf.nest.map_structure(
lambda x: tf.convert_to_tensor(x) if x is not None else None,
predictions)
evaluator_w_json = coco_evaluator.COCOEvaluator(
annotation_file=_COCO_JSON_FILE, include_mask=include_mask)
evaluator_w_json.update_state(groundtruths=None, predictions=predictions)
results_w_json = evaluator_w_json.result()
groundtruths = get_groundtruth_annotations(image_id, coco, include_mask)
groundtruths = tf.nest.map_structure(
lambda x: tf.convert_to_tensor(x) if x is not None else None,
groundtruths)
evaluator_no_json = coco_evaluator.COCOEvaluator(
annotation_file=None, include_mask=include_mask)
evaluator_no_json.update_state(groundtruths, predictions)
results_no_json = evaluator_no_json.result()
for k, v in results_w_json.items():
self.assertEqual(v, results_no_json[k])
if __name__ == '__main__':
absltest.main()
...@@ -19,6 +19,7 @@ import math ...@@ -19,6 +19,7 @@ import math
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks from official.vision.beta.modeling.layers import nn_blocks
from official.vision.beta.modeling.layers import nn_layers from official.vision.beta.modeling.layers import nn_layers
...@@ -174,7 +175,7 @@ class EfficientNet(tf.keras.Model): ...@@ -174,7 +175,7 @@ class EfficientNet(tf.keras.Model):
x = self._block_group( x = self._block_group(
inputs=x, specs=specs, name='block_group_{}'.format(i)) inputs=x, specs=specs, name='block_group_{}'.format(i))
if specs.is_output: if specs.is_output:
endpoints[endpoint_level] = x endpoints[str(endpoint_level)] = x
endpoint_level += 1 endpoint_level += 1
# Build output specs for downstream tasks. # Build output specs for downstream tasks.
...@@ -194,7 +195,7 @@ class EfficientNet(tf.keras.Model): ...@@ -194,7 +195,7 @@ class EfficientNet(tf.keras.Model):
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x) x)
endpoints[endpoint_level] = tf_utils.get_activation(activation)(x) endpoints[str(endpoint_level)] = tf_utils.get_activation(activation)(x)
super(EfficientNet, self).__init__( super(EfficientNet, self).__init__(
inputs=inputs, outputs=endpoints, **kwargs) inputs=inputs, outputs=endpoints, **kwargs)
...@@ -275,3 +276,27 @@ class EfficientNet(tf.keras.Model): ...@@ -275,3 +276,27 @@ class EfficientNet(tf.keras.Model):
def output_specs(self): def output_specs(self):
"""A dict of {level: TensorShape} pairs for the model output.""" """A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs return self._output_specs
@factory.register_backbone_builder('efficientnet')
def build_efficientnet(
input_specs: tf.keras.layers.InputSpec,
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds ResNet 3d backbone from a config."""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'efficientnet', (f'Inconsistent backbone type '
f'{backbone_type}')
return EfficientNet(
model_id=backbone_cfg.model_id,
input_specs=input_specs,
stochastic_depth_drop_rate=backbone_cfg.stochastic_depth_drop_rate,
se_ratio=backbone_cfg.se_ratio,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
...@@ -35,13 +35,13 @@ class EfficientNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -35,13 +35,13 @@ class EfficientNetTest(parameterized.TestCase, tf.test.TestCase):
endpoints = network(inputs) endpoints = network(inputs)
self.assertAllEqual([1, input_size / 2**2, input_size / 2**2, 24], self.assertAllEqual([1, input_size / 2**2, input_size / 2**2, 24],
endpoints[2].shape.as_list()) endpoints['2'].shape.as_list())
self.assertAllEqual([1, input_size / 2**3, input_size / 2**3, 40], self.assertAllEqual([1, input_size / 2**3, input_size / 2**3, 40],
endpoints[3].shape.as_list()) endpoints['3'].shape.as_list())
self.assertAllEqual([1, input_size / 2**4, input_size / 2**4, 112], self.assertAllEqual([1, input_size / 2**4, input_size / 2**4, 112],
endpoints[4].shape.as_list()) endpoints['4'].shape.as_list())
self.assertAllEqual([1, input_size / 2**5, input_size / 2**5, 320], self.assertAllEqual([1, input_size / 2**5, input_size / 2**5, 320],
endpoints[5].shape.as_list()) endpoints['5'].shape.as_list())
@parameterized.parameters('b0', 'b3', 'b6') @parameterized.parameters('b0', 'b3', 'b6')
def test_network_scaling(self, model_id): def test_network_scaling(self, model_id):
......
...@@ -13,100 +13,76 @@ ...@@ -13,100 +13,76 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""factory method.""" """Backbone registers and factory method.
One can regitered a new backbone model by the following two steps:
1 Import the factory and register the build in the backbone file.
2 Import the backbone class and add a build in __init__.py.
```
# my_backbone.py
from modeling.backbones import factory
class MyBackbone():
...
@factory.register_backbone_builder('my_backbone')
def build_my_backbone():
return MyBackbone()
# backbones/__init__.py adds import
from modeling.backbones.my_backbone import MyBackbone
```
If one wants the MyBackbone class to be used only by those binary
then don't imported the backbone module in backbones/__init__.py, but import it
in place that uses it.
"""
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.vision.beta.modeling import backbones from official.core import registry
from official.vision.beta.modeling.backbones import spinenet
def build_backbone(input_specs: tf.keras.layers.InputSpec, _REGISTERED_BACKBONE_CLS = {}
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds backbone from a config. def register_backbone_builder(key: str):
"""Decorates a builder of backbone class.
The builder should be a Callable (a class or a function).
This decorator supports registration of backbone builder as follows:
```
class MyBackbone(tf.keras.Model):
pass
@register_backbone_builder('mybackbone')
def builder(input_specs, config, l2_reg):
return MyBackbone(...)
# Builds a MyBackbone object.
my_backbone = build_backbone_3d(input_specs, config, l2_reg)
```
Args: Args:
input_specs: tf.keras.layers.InputSpec. key: the key to look up the builder.
model_config: a OneOfConfig. Model config.
l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None.
Returns: Returns:
tf.keras.Model instance of the backbone. A callable for use as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
""" """
backbone_type = model_config.backbone.type return registry.register(_REGISTERED_BACKBONE_CLS, key)
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
def build_backbone(input_specs: tf.keras.layers.InputSpec,
if backbone_type == 'resnet': model_config,
backbone = backbones.ResNet( l2_regularizer: tf.keras.regularizers.Regularizer = None):
model_id=backbone_cfg.model_id, """Builds backbone from a config.
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
elif backbone_type == 'efficientnet':
backbone = backbones.EfficientNet(
model_id=backbone_cfg.model_id,
input_specs=input_specs,
stochastic_depth_drop_rate=backbone_cfg.stochastic_depth_drop_rate,
se_ratio=backbone_cfg.se_ratio,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
elif backbone_type == 'spinenet':
model_id = backbone_cfg.model_id
if model_id not in spinenet.SCALING_MAP:
raise ValueError(
'SpineNet-{} is not a valid architecture.'.format(model_id))
scaling_params = spinenet.SCALING_MAP[model_id]
backbone = backbones.SpineNet(
input_specs=input_specs,
min_level=model_config.min_level,
max_level=model_config.max_level,
endpoints_num_filters=scaling_params['endpoints_num_filters'],
resample_alpha=scaling_params['resample_alpha'],
block_repeats=scaling_params['block_repeats'],
filter_size_scale=scaling_params['filter_size_scale'],
kernel_regularizer=l2_regularizer,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon)
elif backbone_type == 'revnet':
backbone = backbones.RevNet(
model_id=backbone_cfg.model_id,
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
elif backbone_type == 'mobilenet':
backbone = backbones.MobileNet(
model_id=backbone_cfg.model_id,
width_multiplier=backbone_cfg.width_multiplier,
input_specs=input_specs,
stochastic_depth_drop_rate=backbone_cfg.stochastic_depth_drop_rate,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
else:
raise ValueError('Backbone {!r} not implement'.format(backbone_type))
return backbone
def build_backbone_3d(input_specs: tf.keras.layers.InputSpec,
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds 3d backbone from a config.
Args: Args:
input_specs: tf.keras.layers.InputSpec. input_specs: tf.keras.layers.InputSpec.
...@@ -116,32 +92,7 @@ def build_backbone_3d(input_specs: tf.keras.layers.InputSpec, ...@@ -116,32 +92,7 @@ def build_backbone_3d(input_specs: tf.keras.layers.InputSpec,
Returns: Returns:
tf.keras.Model instance of the backbone. tf.keras.Model instance of the backbone.
""" """
backbone_type = model_config.backbone.type backbone_builder = registry.lookup(_REGISTERED_BACKBONE_CLS,
backbone_cfg = model_config.backbone.get() model_config.backbone.type)
norm_activation_config = model_config.norm_activation
return backbone_builder(input_specs, model_config, l2_regularizer)
# Flatten configs before passing to the backbone.
temporal_strides = []
temporal_kernel_sizes = []
use_self_gating = []
for block_spec in backbone_cfg.block_specs:
temporal_strides.append(block_spec.temporal_strides)
temporal_kernel_sizes.append(block_spec.temporal_kernel_sizes)
use_self_gating.append(block_spec.use_self_gating)
if backbone_type == 'resnet_3d':
backbone = backbones.ResNet3D(
model_id=backbone_cfg.model_id,
temporal_strides=temporal_strides,
temporal_kernel_sizes=temporal_kernel_sizes,
use_self_gating=use_self_gating,
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
else:
raise ValueError('Backbone {!r} not implement'.format(backbone_type))
return backbone
...@@ -18,6 +18,7 @@ from typing import Text, Optional, Dict ...@@ -18,6 +18,7 @@ from typing import Text, Optional, Dict
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks from official.vision.beta.modeling.layers import nn_blocks
from official.vision.beta.modeling.layers import nn_layers from official.vision.beta.modeling.layers import nn_layers
...@@ -591,3 +592,25 @@ class MobileNet(tf.keras.Model): ...@@ -591,3 +592,25 @@ class MobileNet(tf.keras.Model):
def output_specs(self): def output_specs(self):
"""A dict of {level: TensorShape} pairs for the model output.""" """A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs return self._output_specs
@factory.register_backbone_builder('mobilenet')
def build_mobilenet(
input_specs: tf.keras.layers.InputSpec,
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds MobileNet 3d backbone from a config."""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'mobilenet', (f'Inconsistent backbone type '
f'{backbone_type}')
return MobileNet(
model_id=backbone_cfg.model_id,
width_multiplier=backbone_cfg.width_multiplier,
input_specs=input_specs,
stochastic_depth_drop_rate=backbone_cfg.stochastic_depth_drop_rate,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
...@@ -22,6 +22,7 @@ Residual networks (ResNets) were proposed in: ...@@ -22,6 +22,7 @@ Residual networks (ResNets) were proposed in:
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks from official.vision.beta.modeling.layers import nn_blocks
layers = tf.keras.layers layers = tf.keras.layers
...@@ -152,7 +153,7 @@ class ResNet(tf.keras.Model): ...@@ -152,7 +153,7 @@ class ResNet(tf.keras.Model):
block_fn=block_fn, block_fn=block_fn,
block_repeats=spec[2], block_repeats=spec[2],
name='block_group_l{}'.format(i + 2)) name='block_group_l{}'.format(i + 2))
endpoints[i + 2] = x endpoints[str(i + 2)] = x
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints} self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
...@@ -229,3 +230,25 @@ class ResNet(tf.keras.Model): ...@@ -229,3 +230,25 @@ class ResNet(tf.keras.Model):
def output_specs(self): def output_specs(self):
"""A dict of {level: TensorShape} pairs for the model output.""" """A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs return self._output_specs
@factory.register_backbone_builder('resnet')
def build_resnet(
input_specs: tf.keras.layers.InputSpec,
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds ResNet 3d backbone from a config."""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'resnet', (f'Inconsistent backbone type '
f'{backbone_type}')
return ResNet(
model_id=backbone_cfg.model_id,
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
...@@ -18,6 +18,7 @@ from typing import List, Tuple ...@@ -18,6 +18,7 @@ from typing import List, Tuple
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_blocks_3d from official.vision.beta.modeling.layers import nn_blocks_3d
layers = tf.keras.layers layers = tf.keras.layers
...@@ -259,3 +260,37 @@ class ResNet3D(tf.keras.Model): ...@@ -259,3 +260,37 @@ class ResNet3D(tf.keras.Model):
def output_specs(self): def output_specs(self):
"""A dict of {level: TensorShape} pairs for the model output.""" """A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs return self._output_specs
@factory.register_backbone_builder('resnet_3d')
def build_resnet3d(
input_specs: tf.keras.layers.InputSpec,
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds ResNet 3d backbone from a config."""
backbone_type = model_config.backbone.type
backbone_cfg = model_config.backbone.get()
norm_activation_config = model_config.norm_activation
assert backbone_type == 'resnet_3d', (f'Inconsistent backbone type '
f'{backbone_type}')
# Flatten configs before passing to the backbone.
temporal_strides = []
temporal_kernel_sizes = []
use_self_gating = []
for block_spec in backbone_cfg.block_specs:
temporal_strides.append(block_spec.temporal_strides)
temporal_kernel_sizes.append(block_spec.temporal_kernel_sizes)
use_self_gating.append(block_spec.use_self_gating)
return ResNet3D(
model_id=backbone_cfg.model_id,
temporal_strides=temporal_strides,
temporal_kernel_sizes=temporal_kernel_sizes,
use_self_gating=use_self_gating,
input_specs=input_specs,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
...@@ -54,16 +54,16 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -54,16 +54,16 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual( self.assertAllEqual(
[1, input_size / 2**2, input_size / 2**2, 64 * endpoint_filter_scale], [1, input_size / 2**2, input_size / 2**2, 64 * endpoint_filter_scale],
endpoints[2].shape.as_list()) endpoints['2'].shape.as_list())
self.assertAllEqual( self.assertAllEqual(
[1, input_size / 2**3, input_size / 2**3, 128 * endpoint_filter_scale], [1, input_size / 2**3, input_size / 2**3, 128 * endpoint_filter_scale],
endpoints[3].shape.as_list()) endpoints['3'].shape.as_list())
self.assertAllEqual( self.assertAllEqual(
[1, input_size / 2**4, input_size / 2**4, 256 * endpoint_filter_scale], [1, input_size / 2**4, input_size / 2**4, 256 * endpoint_filter_scale],
endpoints[4].shape.as_list()) endpoints['4'].shape.as_list())
self.assertAllEqual( self.assertAllEqual(
[1, input_size / 2**5, input_size / 2**5, 512 * endpoint_filter_scale], [1, input_size / 2**5, input_size / 2**5, 512 * endpoint_filter_scale],
endpoints[5].shape.as_list()) endpoints['5'].shape.as_list())
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment