Unverified Commit 55bf4b80 authored by Hongkun Yu's avatar Hongkun Yu Committed by GitHub
Browse files

Merge branch 'master' into absl

parents 15e0057f 2416dd9c
...@@ -59,13 +59,15 @@ class StringHelperTest(tf.test.TestCase): ...@@ -59,13 +59,15 @@ class StringHelperTest(tf.test.TestCase):
def test_split_string_to_tokens(self): def test_split_string_to_tokens(self):
text = "test? testing 123." text = "test? testing 123."
tokens = tokenizer._split_string_to_tokens(text) tokens = tokenizer._split_string_to_tokens(text,
tokenizer._ALPHANUMERIC_CHAR_SET)
self.assertEqual(["test", "? ", "testing", "123", "."], tokens) self.assertEqual(["test", "? ", "testing", "123", "."], tokens)
def test_join_tokens_to_string(self): def test_join_tokens_to_string(self):
tokens = ["test", "? ", "testing", "123", "."] tokens = ["test", "? ", "testing", "123", "."]
s = tokenizer._join_tokens_to_string(tokens) s = tokenizer._join_tokens_to_string(tokens,
tokenizer._ALPHANUMERIC_CHAR_SET)
self.assertEqual("test? testing 123.", s) self.assertEqual("test? testing 123.", s)
def test_escape_token(self): def test_escape_token(self):
...@@ -79,8 +81,7 @@ class StringHelperTest(tf.test.TestCase): ...@@ -79,8 +81,7 @@ class StringHelperTest(tf.test.TestCase):
escaped_token = u"Underline: \\u, Backslash: \\\\, Unicode: \\52;" escaped_token = u"Underline: \\u, Backslash: \\\\, Unicode: \\52;"
unescaped_token = tokenizer._unescape_token(escaped_token) unescaped_token = tokenizer._unescape_token(escaped_token)
self.assertEqual( self.assertEqual("Underline: _, Backslash: \\, Unicode: 4", unescaped_token)
"Underline: _, Backslash: \\, Unicode: 4", unescaped_token)
def test_list_to_index_dict(self): def test_list_to_index_dict(self):
lst = ["test", "strings"] lst = ["test", "strings"]
...@@ -93,8 +94,8 @@ class StringHelperTest(tf.test.TestCase): ...@@ -93,8 +94,8 @@ class StringHelperTest(tf.test.TestCase):
subtoken_dict = {"a": 0, "b": 1, "c": 2, "ab": 3} subtoken_dict = {"a": 0, "b": 1, "c": 2, "ab": 3}
max_subtoken_length = 2 max_subtoken_length = 2
subtokens = tokenizer._split_token_to_subtokens( subtokens = tokenizer._split_token_to_subtokens(token, subtoken_dict,
token, subtoken_dict, max_subtoken_length) max_subtoken_length)
self.assertEqual(["ab", "c"], subtokens) self.assertEqual(["ab", "c"], subtokens)
def test_generate_alphabet_dict(self): def test_generate_alphabet_dict(self):
...@@ -124,12 +125,28 @@ class StringHelperTest(tf.test.TestCase): ...@@ -124,12 +125,28 @@ class StringHelperTest(tf.test.TestCase):
self.assertIsInstance(subtoken_counts, collections.defaultdict) self.assertIsInstance(subtoken_counts, collections.defaultdict)
self.assertDictEqual( self.assertDictEqual(
{"a": 5, "b": 5, "c": 5, "_": 5, "ab": 5, "bc": 5, "c_": 5, {
"abc": 5, "bc_": 5, "abc_": 5}, subtoken_counts) "a": 5,
"b": 5,
"c": 5,
"_": 5,
"ab": 5,
"bc": 5,
"c_": 5,
"abc": 5,
"bc_": 5,
"abc_": 5
}, subtoken_counts)
def test_filter_and_bucket_subtokens(self): def test_filter_and_bucket_subtokens(self):
subtoken_counts = collections.defaultdict( subtoken_counts = collections.defaultdict(int, {
int, {"a": 2, "b": 4, "c": 1, "ab": 6, "ac": 3, "abbc": 5}) "a": 2,
"b": 4,
"c": 1,
"ab": 6,
"ac": 3,
"abbc": 5
})
min_count = 3 min_count = 3
subtoken_buckets = tokenizer._filter_and_bucket_subtokens( subtoken_buckets = tokenizer._filter_and_bucket_subtokens(
...@@ -142,8 +159,12 @@ class StringHelperTest(tf.test.TestCase): ...@@ -142,8 +159,12 @@ class StringHelperTest(tf.test.TestCase):
self.assertEqual(set(["abbc"]), subtoken_buckets[4]) self.assertEqual(set(["abbc"]), subtoken_buckets[4])
def test_gen_new_subtoken_list(self): def test_gen_new_subtoken_list(self):
subtoken_counts = collections.defaultdict( subtoken_counts = collections.defaultdict(int, {
int, {"translate": 10, "t": 40, "tr": 16, "tra": 12}) "translate": 10,
"t": 40,
"tr": 16,
"tra": 12
})
min_count = 5 min_count = 5
alphabet = set("translate") alphabet = set("translate")
reserved_tokens = ["reserved", "tokens"] reserved_tokens = ["reserved", "tokens"]
...@@ -167,8 +188,9 @@ class StringHelperTest(tf.test.TestCase): ...@@ -167,8 +188,9 @@ class StringHelperTest(tf.test.TestCase):
num_iterations = 1 num_iterations = 1
reserved_tokens = ["reserved", "tokens"] reserved_tokens = ["reserved", "tokens"]
vocab_list = tokenizer._generate_subtokens( vocab_list = tokenizer._generate_subtokens(token_counts, alphabet,
token_counts, alphabet, min_count, num_iterations, reserved_tokens) min_count, num_iterations,
reserved_tokens)
# Check that reserved tokens are at the front of the list # Check that reserved tokens are at the front of the list
self.assertEqual(vocab_list[:2], reserved_tokens) self.assertEqual(vocab_list[:2], reserved_tokens)
......
...@@ -87,7 +87,7 @@ def run_evaluation(strategy, ...@@ -87,7 +87,7 @@ def run_evaluation(strategy,
@tf.function @tf.function
def _run_evaluation(test_iterator): def _run_evaluation(test_iterator):
"""Runs validation steps.""" """Runs validation steps."""
logits, labels, masks = strategy.experimental_run_v2( logits, labels, masks = strategy.run(
_test_step_fn, args=(next(test_iterator),)) _test_step_fn, args=(next(test_iterator),))
return logits, labels, masks return logits, labels, masks
......
...@@ -130,7 +130,7 @@ def run_evaluation(strategy, test_input_fn, eval_examples, eval_features, ...@@ -130,7 +130,7 @@ def run_evaluation(strategy, test_input_fn, eval_examples, eval_features,
@tf.function @tf.function
def _run_evaluation(test_iterator): def _run_evaluation(test_iterator):
"""Runs validation steps.""" """Runs validation steps."""
res, unique_ids = strategy.experimental_run_v2( res, unique_ids = strategy.run(
_test_step_fn, args=(next(test_iterator),)) _test_step_fn, args=(next(test_iterator),))
return res, unique_ids return res, unique_ids
......
...@@ -222,16 +222,16 @@ def train( ...@@ -222,16 +222,16 @@ def train(
return mems return mems
if input_meta_data["mem_len"] > 0: if input_meta_data["mem_len"] > 0:
mem = strategy.experimental_run_v2(cache_fn) mem = strategy.run(cache_fn)
for _ in tf.range(steps): for _ in tf.range(steps):
mem = strategy.experimental_run_v2( mem = strategy.run(
_replicated_step, args=( _replicated_step, args=(
next(iterator), next(iterator),
mem, mem,
)) ))
else: else:
for _ in tf.range(steps): for _ in tf.range(steps):
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),)) strategy.run(_replicated_step, args=(next(iterator),))
if not run_eagerly: if not run_eagerly:
train_steps = tf.function(train_steps) train_steps = tf.function(train_steps)
......
...@@ -13,30 +13,76 @@ ...@@ -13,30 +13,76 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Sets up TensorFlow Official Models.""" """Sets up TensorFlow Official Models."""
import datetime
import os
import sys
from setuptools import find_packages from setuptools import find_packages
from setuptools import setup from setuptools import setup
version = '2.2.0'
project_name = 'tf-models-official'
long_description = """The TensorFlow official models are a collection of
models that use TensorFlow's high-level APIs.
They are intended to be well-maintained, tested, and kept up to date with the
latest TensorFlow API. They should also be reasonably optimized for fast
performance while still being easy to read."""
if '--project_name' in sys.argv:
project_name_idx = sys.argv.index('--project_name')
project_name = sys.argv[project_name_idx + 1]
sys.argv.remove('--project_name')
sys.argv.pop(project_name_idx)
def _get_requirements():
"""Parses requirements.txt file."""
install_requires_tmp = []
dependency_links_tmp = []
with open(
os.path.join(os.path.dirname(__file__), '../requirements.txt'), 'r') as f:
for line in f:
package_name = line.strip()
if package_name.startswith('-e '):
dependency_links_tmp.append(package_name[3:].strip())
else:
install_requires_tmp.append(package_name)
return install_requires_tmp, dependency_links_tmp
install_requires, dependency_links = _get_requirements()
if project_name == 'tf-models-nightly':
version += '.dev' + datetime.datetime.now().strftime('%Y%m%d')
install_requires.append('tf-nightly')
else:
install_requires.append('tensorflow>=2.1.0')
print('install_requires: ', install_requires)
print('dependency_links: ', dependency_links)
setup( setup(
name='tf-models-official', name=project_name,
version='0.0.3.dev1', version=version,
description='TensorFlow Official Models', description='TensorFlow Official Models',
long_description=long_description,
author='Google Inc.', author='Google Inc.',
author_email='no-reply@google.com', author_email='no-reply@google.com',
url='https://github.com/tensorflow/models', url='https://github.com/tensorflow/models',
license='Apache 2.0', license='Apache 2.0',
packages=find_packages(exclude=["research*", "tutorials*", "samples*"]), packages=find_packages(exclude=[
'research*',
'tutorials*',
'samples*',
'official.r1*',
'official.pip_package*',
'official.benchmark*',
]),
exclude_package_data={ exclude_package_data={
'': [ '': ['*_test.py',],
'*_test.py',
],
},
install_requires=[
'six',
],
extras_require={
'tensorflow': ['tensorflow>=2.0.0'],
'tensorflow_gpu': ['tensorflow-gpu>=2.0.0'],
'tensorflow-hub': ['tensorflow-hub>=0.6.0'],
}, },
install_requires=install_requires,
dependency_links=dependency_links,
python_requires='>=3.6', python_requires='>=3.6',
) )
...@@ -3,6 +3,12 @@ ...@@ -3,6 +3,12 @@
The R1 folder contains legacy model implmentation and models that will not The R1 folder contains legacy model implmentation and models that will not
update to TensorFlow 2.x. They do not have solid performance tracking. update to TensorFlow 2.x. They do not have solid performance tracking.
**Note: models will be removed from the master branch by 2020/06.**
After removal, you can still access to these legacy models in the previous
released tags, e.g. [v2.1.0](https://github.com/tensorflow/models/releases/tag/v2.1.0).
## Legacy model implmentation ## Legacy model implmentation
Transformer and MNIST implementation uses pure TF 1.x TF-Estimator. Transformer and MNIST implementation uses pure TF 1.x TF-Estimator.
......
...@@ -38,8 +38,6 @@ def create_model(data_format): ...@@ -38,8 +38,6 @@ def create_model(data_format):
Network structure is equivalent to: Network structure is equivalent to:
https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
and
https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py
But uses the tf.keras API. But uses the tf.keras API.
......
...@@ -69,7 +69,7 @@ def build_estimator(model_dir, model_type, model_column_fn, inter_op, intra_op): ...@@ -69,7 +69,7 @@ def build_estimator(model_dir, model_type, model_column_fn, inter_op, intra_op):
model_dir=model_dir, model_dir=model_dir,
feature_columns=deep_columns, feature_columns=deep_columns,
hidden_units=hidden_units, hidden_units=hidden_units,
optimizer=tf.train.AdamOptimizer(), optimizer=tf.compat.v1.train.AdamOptimizer(),
activation_fn=tf.nn.sigmoid, activation_fn=tf.nn.sigmoid,
dropout=0.3, dropout=0.3,
loss_reduction=tf.losses.Reduction.MEAN) loss_reduction=tf.losses.Reduction.MEAN)
......
...@@ -405,7 +405,7 @@ def run_ncf_custom_training(params, ...@@ -405,7 +405,7 @@ def run_ncf_custom_training(params,
optimizer.apply_gradients(grads) optimizer.apply_gradients(grads)
return loss return loss
per_replica_losses = strategy.experimental_run_v2( per_replica_losses = strategy.run(
step_fn, args=(next(train_iterator),)) step_fn, args=(next(train_iterator),))
mean_loss = strategy.reduce( mean_loss = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
...@@ -425,7 +425,7 @@ def run_ncf_custom_training(params, ...@@ -425,7 +425,7 @@ def run_ncf_custom_training(params,
return hr_sum, hr_count return hr_sum, hr_count
per_replica_hr_sum, per_replica_hr_count = ( per_replica_hr_sum, per_replica_hr_count = (
strategy.experimental_run_v2( strategy.run(
step_fn, args=(next(eval_iterator),))) step_fn, args=(next(eval_iterator),)))
hr_sum = strategy.reduce( hr_sum = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_hr_sum, axis=None) tf.distribute.ReduceOp.SUM, per_replica_hr_sum, axis=None)
......
six
google-api-python-client>=1.6.7 google-api-python-client>=1.6.7
google-cloud-bigquery>=0.31.0 google-cloud-bigquery>=0.31.0
kaggle>=1.3.9 kaggle>=1.3.9
......
...@@ -78,9 +78,10 @@ class Controller(object): ...@@ -78,9 +78,10 @@ class Controller(object):
eval_summary_dir: The directory to write eval summaries. If None, it will eval_summary_dir: The directory to write eval summaries. If None, it will
be set to `summary_dir`. be set to `summary_dir`.
eval_steps: Number of steps to run evaluation. eval_steps: Number of steps to run evaluation.
eval_interval: Step interval for evaluation. If None, will skip eval_interval: Step interval for evaluation. If None, will skip evaluation
evaluation. Note that evaluation only happens outside the training loop, in the middle of training. Note that evaluation only happens outside the
which the loop iteration is specify by `steps_per_loop` parameter. training loop, which the loop iteration is specify by `steps_per_loop`
parameter.
Raises: Raises:
ValueError: If both `train_fn` and `eval_fn` are None. ValueError: If both `train_fn` and `eval_fn` are None.
...@@ -111,13 +112,12 @@ class Controller(object): ...@@ -111,13 +112,12 @@ class Controller(object):
self.train_fn = train_fn self.train_fn = train_fn
self.eval_fn = eval_fn self.eval_fn = eval_fn
self.global_step = global_step self.global_step = global_step
self.checkpoint_manager = checkpoint_manager
if self.train_fn is not None:
self.train_steps = train_steps self.train_steps = train_steps
self.steps_per_loop = steps_per_loop self.steps_per_loop = steps_per_loop
self.summary_dir = summary_dir or checkpoint_manager.directory self.summary_dir = summary_dir or checkpoint_manager.directory
self.checkpoint_manager = checkpoint_manager
self.summary_interval = summary_interval self.summary_interval = summary_interval
summary_writer = tf.summary.create_file_writer( summary_writer = tf.summary.create_file_writer(
...@@ -129,17 +129,24 @@ class Controller(object): ...@@ -129,17 +129,24 @@ class Controller(object):
tf.summary.scalar, tf.summary.scalar,
global_step=self.global_step, global_step=self.global_step,
summary_interval=self.summary_interval) summary_interval=self.summary_interval)
if self.global_step:
tf.summary.experimental.set_step(self.global_step)
self.eval_summary_dir = eval_summary_dir or self.summary_dir if self.eval_fn is not None:
eval_summary_writer = tf.summary.create_file_writer(self.eval_summary_dir) eval_summary_dir = eval_summary_dir or self.summary_dir
eval_summary_writer = tf.summary.create_file_writer(
eval_summary_dir) if eval_summary_dir else None
self.eval_summary_manager = utils.SummaryManager( self.eval_summary_manager = utils.SummaryManager(
eval_summary_writer, tf.summary.scalar, global_step=self.global_step) eval_summary_writer, tf.summary.scalar, global_step=self.global_step)
self.eval_steps = eval_steps self.eval_steps = eval_steps
self.eval_interval = eval_interval self.eval_interval = eval_interval
# Create and initialize the interval triggers.
self.eval_trigger = utils.IntervalTrigger(self.eval_interval,
self.global_step.numpy())
if self.global_step:
tf.summary.experimental.set_step(self.global_step)
# Restore Model if needed. # Restore Model if needed.
if self.checkpoint_manager is not None: if self.checkpoint_manager is not None:
model_restored = self._restore_model() model_restored = self._restore_model()
...@@ -150,10 +157,6 @@ class Controller(object): ...@@ -150,10 +157,6 @@ class Controller(object):
checkpoint_number=self.global_step) checkpoint_number=self.global_step)
logging.info("Saved checkpoins in %s", ckpt_path) logging.info("Saved checkpoins in %s", ckpt_path)
# Create and initialize the interval triggers.
self.eval_trigger = utils.IntervalTrigger(self.eval_interval,
self.global_step.numpy())
def _restore_model(self, checkpoint_path=None): def _restore_model(self, checkpoint_path=None):
"""Restore or initialize the model. """Restore or initialize the model.
...@@ -186,11 +189,12 @@ class Controller(object): ...@@ -186,11 +189,12 @@ class Controller(object):
self._log_info(info) self._log_info(info)
self.eval_summary_manager.write_summaries(eval_outputs) self.eval_summary_manager.write_summaries(eval_outputs)
self.eval_summary_manager.flush()
def _maybe_save_checkpoints(self, current_step, force_trigger=False): def _maybe_save_checkpoints(self, current_step, force_trigger=False):
if self.checkpoint_manager.checkpoint_interval: if self.checkpoint_manager.checkpoint_interval:
ckpt_path = self.checkpoint_manager.save( ckpt_path = self.checkpoint_manager.save(
checkpoint_number=current_step, check_interval=force_trigger) checkpoint_number=current_step, check_interval=not force_trigger)
if ckpt_path is not None: if ckpt_path is not None:
logging.info("Saved checkpoins in %s", ckpt_path) logging.info("Saved checkpoins in %s", ckpt_path)
...@@ -265,6 +269,7 @@ class Controller(object): ...@@ -265,6 +269,7 @@ class Controller(object):
self._maybe_evaluate(current_step) self._maybe_evaluate(current_step)
self.summary_manager.write_summaries(train_outputs, always_write=True) self.summary_manager.write_summaries(train_outputs, always_write=True)
self.summary_manager.flush()
self._maybe_save_checkpoints(current_step, force_trigger=True) self._maybe_save_checkpoints(current_step, force_trigger=True)
if evaluate: if evaluate:
self._maybe_evaluate(current_step, force_trigger=True) self._maybe_evaluate(current_step, force_trigger=True)
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.staging.training.controller."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.staging.training import controller
from official.staging.training import standard_runnable
def all_strategy_combinations():
"""Gets combinations of distribution strategies."""
return combinations.combine(
strategy=[
strategy_combinations.one_device_strategy,
strategy_combinations.tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode="eager",
)
def create_model():
x = tf.keras.layers.Input(shape=(3,), name="input")
y = tf.keras.layers.Dense(4, name="dense")(x)
model = tf.keras.Model(x, y)
return model
def summaries_with_matching_keyword(keyword, summary_dir):
"""Yields summary protos matching given keyword from event file."""
event_paths = tf.io.gfile.glob(os.path.join(summary_dir, "events*"))
for event in tf.compat.v1.train.summary_iterator(event_paths[-1]):
if event.summary is not None:
for value in event.summary.value:
if keyword in value.tag:
tf.compat.v1.logging.error(event)
yield event.summary
def check_eventfile_for_keyword(keyword, summary_dir):
"""Checks event files for the keyword."""
return any(summaries_with_matching_keyword(keyword, summary_dir))
def dataset_fn(ctx):
del ctx
inputs = np.zeros((10, 3), dtype=np.float32)
targets = np.zeros((10, 4), dtype=np.float32)
dataset = tf.data.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
dataset = dataset.batch(10, drop_remainder=True)
return dataset
class TestRunnable(standard_runnable.StandardTrainable,
standard_runnable.StandardEvaluable):
"""Implements the training and evaluation APIs for the test model."""
def __init__(self):
standard_runnable.StandardTrainable.__init__(self)
standard_runnable.StandardEvaluable.__init__(self)
self.strategy = tf.distribute.get_strategy()
self.model = create_model()
self.optimizer = tf.keras.optimizers.RMSprop()
self.global_step = self.optimizer.iterations
self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
self.eval_loss = tf.keras.metrics.Mean("eval_loss", dtype=tf.float32)
def build_train_dataset(self):
return self.strategy.experimental_distribute_datasets_from_function(
dataset_fn)
def train_step(self, iterator):
def _replicated_step(inputs):
"""Replicated training step."""
inputs, targets = inputs
with tf.GradientTape() as tape:
outputs = self.model(inputs)
loss = tf.math.reduce_sum(outputs - targets)
grads = tape.gradient(loss, self.model.variables)
self.optimizer.apply_gradients(zip(grads, self.model.variables))
self.train_loss.update_state(loss)
self.strategy.run(_replicated_step, args=(next(iterator),))
def train_loop_end(self):
return {
"loss": self.train_loss.result(),
}
def build_eval_dataset(self):
return self.strategy.experimental_distribute_datasets_from_function(
dataset_fn)
def eval_begin(self):
self.eval_loss.reset_states()
def eval_step(self, iterator):
def _replicated_step(inputs):
"""Replicated evaluation step."""
inputs, targets = inputs
outputs = self.model(inputs)
loss = tf.math.reduce_sum(outputs - targets)
self.eval_loss.update_state(loss)
self.strategy.run(_replicated_step, args=(next(iterator),))
def eval_end(self):
return {
"eval_loss": self.eval_loss.result(),
}
class ControllerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(ControllerTest, self).setUp()
self.model_dir = self.get_temp_dir()
@combinations.generate(all_strategy_combinations())
def test_train_and_evaluate(self, strategy):
with strategy.scope():
test_runnable = TestRunnable()
checkpoint = tf.train.Checkpoint(
model=test_runnable.model, optimizer=test_runnable.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runnable.global_step,
checkpoint_interval=10)
test_controller = controller.Controller(
strategy=strategy,
train_fn=test_runnable.train,
eval_fn=test_runnable.evaluate,
global_step=test_runnable.global_step,
train_steps=10,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2,
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
eval_steps=2,
eval_interval=5)
test_controller.train(evaluate=True)
# Checkpoints are saved.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue(
check_eventfile_for_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue(
check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
@combinations.generate(all_strategy_combinations())
def test_train_only(self, strategy):
with strategy.scope():
test_runnable = TestRunnable()
checkpoint = tf.train.Checkpoint(
model=test_runnable.model, optimizer=test_runnable.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runnable.global_step,
checkpoint_interval=10)
test_controller = controller.Controller(
strategy=strategy,
train_fn=test_runnable.train,
global_step=test_runnable.global_step,
train_steps=10,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2,
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
)
test_controller.train(evaluate=False)
# Checkpoints are saved.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
# Only train summaries are written.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue(
check_eventfile_for_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
@combinations.generate(all_strategy_combinations())
def test_evaluate_only(self, strategy):
with strategy.scope():
test_runnable = TestRunnable()
checkpoint = tf.train.Checkpoint(model=test_runnable.model)
checkpoint.save(os.path.join(self.model_dir, "ckpt"))
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runnable.global_step)
test_controller = controller.Controller(
strategy=strategy,
eval_fn=test_runnable.evaluate,
global_step=test_runnable.global_step,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
eval_steps=2,
eval_interval=5)
test_controller.evaluate()
# Only eval summaries are written
self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue(
check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
if __name__ == "__main__":
tf.test.main()
...@@ -39,7 +39,7 @@ class AbstractTrainable(tf.Module): ...@@ -39,7 +39,7 @@ class AbstractTrainable(tf.Module):
python callbacks. This is necessary for getting good performance in TPU python callbacks. This is necessary for getting good performance in TPU
training, as the overhead for launching a multi worker tf.function may be training, as the overhead for launching a multi worker tf.function may be
large in Eager mode. It is usually encouraged to create a host training loop large in Eager mode. It is usually encouraged to create a host training loop
(e.g. using a `tf.range` wrapping `strategy.experimental_run_v2` inside a (e.g. using a `tf.range` wrapping `strategy.run` inside a
`tf.function`) in the TPU case. For the cases that don't require host `tf.function`) in the TPU case. For the cases that don't require host
training loop to acheive peak performance, users can just implement a simple training loop to acheive peak performance, users can just implement a simple
python loop to drive each step. python loop to drive each step.
......
...@@ -87,7 +87,7 @@ class StandardTrainable(runnable.AbstractTrainable): ...@@ -87,7 +87,7 @@ class StandardTrainable(runnable.AbstractTrainable):
What a "step" consists of is up to the implementer. If using distribution What a "step" consists of is up to the implementer. If using distribution
strategies, the call to this method should take place in the "cross-replica strategies, the call to this method should take place in the "cross-replica
context" for generality, to allow e.g. multiple iterator dequeues and calls context" for generality, to allow e.g. multiple iterator dequeues and calls
to `strategy.experimental_run_v2`. to `strategy.run`.
Args: Args:
iterator: A tf.nest-compatible structure of tf.data Iterator or iterator: A tf.nest-compatible structure of tf.data Iterator or
...@@ -163,7 +163,7 @@ class StandardEvaluable(runnable.AbstractEvaluable): ...@@ -163,7 +163,7 @@ class StandardEvaluable(runnable.AbstractEvaluable):
What a "step" consists of is up to the implementer. If using distribution What a "step" consists of is up to the implementer. If using distribution
strategies, the call to this method should take place in the "cross-replica strategies, the call to this method should take place in the "cross-replica
context" for generality, to allow e.g. multiple iterator dequeues and calls context" for generality, to allow e.g. multiple iterator dequeues and calls
to `strategy.experimental_run_v2`. to `strategy.run`.
Args: Args:
iterator: A tf.nest-compatible structure of tf.data Iterator or iterator: A tf.nest-compatible structure of tf.data Iterator or
......
...@@ -193,6 +193,11 @@ class SummaryManager(object): ...@@ -193,6 +193,11 @@ class SummaryManager(object):
"""Returns the underlying summary writer.""" """Returns the underlying summary writer."""
return self._summary_writer return self._summary_writer
def flush(self):
"""Flush the underlying summary writer."""
if self._enabled:
tf.summary.flush(self._summary_writer)
def write_summaries(self, items, always_write=True): def write_summaries(self, items, always_write=True):
"""Write a bulk of summaries. """Write a bulk of summaries.
......
...@@ -48,15 +48,26 @@ class PerfZeroBenchmark(tf.test.Benchmark): ...@@ -48,15 +48,26 @@ class PerfZeroBenchmark(tf.test.Benchmark):
flag_methods: Set of flag methods to run during setup. flag_methods: Set of flag methods to run during setup.
tpu: (optional) TPU name to use in a TPU benchmark. tpu: (optional) TPU name to use in a TPU benchmark.
""" """
if not output_dir: if os.getenv('BENCHMARK_OUTPUT_DIR'):
output_dir = '/tmp' self.output_dir = os.getenv('BENCHMARK_OUTPUT_DIR')
elif output_dir:
self.output_dir = output_dir self.output_dir = output_dir
else:
self.output_dir = '/tmp'
self.default_flags = default_flags or {} self.default_flags = default_flags or {}
self.flag_methods = flag_methods or {} self.flag_methods = flag_methods or {}
if tpu:
if os.getenv('BENCHMARK_TPU'):
resolved_tpu = os.getenv('BENCHMARK_TPU')
elif tpu:
resolved_tpu = tpu
else:
resolved_tpu = None
if resolved_tpu:
# TPU models are expected to accept a --tpu=name flag. PerfZero creates # TPU models are expected to accept a --tpu=name flag. PerfZero creates
# the TPU at runtime and passes the TPU's name to this flag. # the TPU at runtime and passes the TPU's name to this flag.
self.default_flags['tpu'] = tpu self.default_flags['tpu'] = resolved_tpu
def _get_model_dir(self, folder_name): def _get_model_dir(self, folder_name):
"""Returns directory to store info, e.g. saved model and event log.""" """Returns directory to store info, e.g. saved model and event log."""
......
...@@ -80,12 +80,11 @@ class DetectionDistributedExecutor(executor.DistributedExecutor): ...@@ -80,12 +80,11 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
all_losses = loss_fn(labels, outputs) all_losses = loss_fn(labels, outputs)
losses = {} losses = {}
for k, v in all_losses.items(): for k, v in all_losses.items():
v = tf.reduce_mean(v) / strategy.num_replicas_in_sync losses[k] = tf.reduce_mean(v)
losses[k] = v per_replica_loss = losses['total_loss'] / strategy.num_replicas_in_sync
loss = losses['total_loss']
_update_state(labels, outputs) _update_state(labels, outputs)
grads = tape.gradient(loss, trainable_variables) grads = tape.gradient(per_replica_loss, trainable_variables)
optimizer.apply_gradients(zip(grads, trainable_variables)) optimizer.apply_gradients(zip(grads, trainable_variables))
return losses return losses
...@@ -119,7 +118,7 @@ class DetectionDistributedExecutor(executor.DistributedExecutor): ...@@ -119,7 +118,7 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
return labels, prediction_outputs return labels, prediction_outputs
labels, outputs = strategy.experimental_run_v2( labels, outputs = strategy.run(
_test_step_fn, args=( _test_step_fn, args=(
next(iterator), next(iterator),
eval_steps, eval_steps,
......
...@@ -21,8 +21,6 @@ from __future__ import print_function ...@@ -21,8 +21,6 @@ from __future__ import print_function
import abc import abc
import functools import functools
import re import re
from absl import logging
import tensorflow.compat.v2 as tf import tensorflow.compat.v2 as tf
from official.vision.detection.modeling import checkpoint_utils from official.vision.detection.modeling import checkpoint_utils
from official.vision.detection.modeling import learning_rates from official.vision.detection.modeling import learning_rates
...@@ -60,11 +58,10 @@ class OptimizerFactory(object): ...@@ -60,11 +58,10 @@ class OptimizerFactory(object):
def _make_filter_trainable_variables_fn(frozen_variable_prefix): def _make_filter_trainable_variables_fn(frozen_variable_prefix):
"""Creates a function for filtering trainable varialbes. """Creates a function for filtering trainable varialbes."""
"""
def _filter_trainable_variables(variables): def _filter_trainable_variables(variables):
"""Filters trainable varialbes """Filters trainable varialbes.
Args: Args:
variables: a list of tf.Variable to be filtered. variables: a list of tf.Variable to be filtered.
...@@ -141,8 +138,7 @@ class Model(object): ...@@ -141,8 +138,7 @@ class Model(object):
return self._optimizer_fn(self._learning_rate) return self._optimizer_fn(self._learning_rate)
def make_filter_trainable_variables_fn(self): def make_filter_trainable_variables_fn(self):
"""Creates a function for filtering trainable varialbes. """Creates a function for filtering trainable varialbes."""
"""
return _make_filter_trainable_variables_fn(self._frozen_variable_prefix) return _make_filter_trainable_variables_fn(self._frozen_variable_prefix)
def weight_decay_loss(self, trainable_variables): def weight_decay_loss(self, trainable_variables):
...@@ -151,8 +147,6 @@ class Model(object): ...@@ -151,8 +147,6 @@ class Model(object):
if self._regularization_var_regex is None if self._regularization_var_regex is None
or re.match(self._regularization_var_regex, v.name) or re.match(self._regularization_var_regex, v.name)
] ]
logging.info('Regularization Variables: %s',
[v.name for v in reg_variables])
return self._l2_weight_decay * tf.add_n( return self._l2_weight_decay * tf.add_n(
[tf.nn.l2_loss(v) for v in reg_variables]) [tf.nn.l2_loss(v) for v in reg_variables])
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Factory to build detection model.""" """Factory to build detection model."""
from official.vision.detection.modeling import maskrcnn_model
from official.vision.detection.modeling import retinanet_model from official.vision.detection.modeling import retinanet_model
...@@ -22,6 +23,8 @@ def model_generator(params): ...@@ -22,6 +23,8 @@ def model_generator(params):
"""Model function generator.""" """Model function generator."""
if params.type == 'retinanet': if params.type == 'retinanet':
model_fn = retinanet_model.RetinanetModel(params) model_fn = retinanet_model.RetinanetModel(params)
elif params.type == 'mask_rcnn':
model_fn = maskrcnn_model.MaskrcnnModel(params)
else: else:
raise ValueError('Model %s is not supported.'% params.type) raise ValueError('Model %s is not supported.'% params.type)
......
...@@ -18,6 +18,7 @@ from __future__ import absolute_import ...@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import logging
import tensorflow.compat.v2 as tf import tensorflow.compat.v2 as tf
...@@ -89,6 +90,8 @@ class RpnScoreLoss(object): ...@@ -89,6 +90,8 @@ class RpnScoreLoss(object):
def __init__(self, params): def __init__(self, params):
self._rpn_batch_size_per_im = params.rpn_batch_size_per_im self._rpn_batch_size_per_im = params.rpn_batch_size_per_im
self._binary_crossentropy = tf.keras.losses.BinaryCrossentropy(
reduction=tf.keras.losses.Reduction.SUM, from_logits=True)
def __call__(self, score_outputs, labels): def __call__(self, score_outputs, labels):
"""Computes total RPN detection loss. """Computes total RPN detection loss.
...@@ -129,16 +132,15 @@ class RpnScoreLoss(object): ...@@ -129,16 +132,15 @@ class RpnScoreLoss(object):
with tf.name_scope('rpn_score_loss'): with tf.name_scope('rpn_score_loss'):
mask = tf.math.logical_or(tf.math.equal(score_targets, 1), mask = tf.math.logical_or(tf.math.equal(score_targets, 1),
tf.math.equal(score_targets, 0)) tf.math.equal(score_targets, 0))
score_targets = tf.math.maximum(score_targets, tf.zeros_like(score_targets))
# RPN score loss is sum over all except ignored samples. score_targets = tf.math.maximum(score_targets,
# Keep the compat.v1 loss because Keras does not have a tf.zeros_like(score_targets))
# sigmoid_cross_entropy substitution yet.
# TODO(b/143720144): replace this loss. score_targets = tf.expand_dims(score_targets, axis=-1)
score_loss = tf.compat.v1.losses.sigmoid_cross_entropy( score_outputs = tf.expand_dims(score_outputs, axis=-1)
score_targets, score_loss = self._binary_crossentropy(
score_outputs, score_targets, score_outputs, sample_weight=mask)
weights=mask,
reduction=tf.compat.v1.losses.Reduction.SUM)
score_loss /= normalizer score_loss /= normalizer
return score_loss return score_loss
...@@ -147,7 +149,10 @@ class RpnBoxLoss(object): ...@@ -147,7 +149,10 @@ class RpnBoxLoss(object):
"""Region Proposal Network box regression loss function.""" """Region Proposal Network box regression loss function."""
def __init__(self, params): def __init__(self, params):
self._delta = params.huber_loss_delta logging.info('RpnBoxLoss huber_loss_delta %s', params.huber_loss_delta)
# The delta is typically around the mean value of regression target.
# for instances, the regression targets of 512x512 input with 6 anchors on
# P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
self._huber_loss = tf.keras.losses.Huber( self._huber_loss = tf.keras.losses.Huber(
delta=params.huber_loss_delta, reduction=tf.keras.losses.Reduction.SUM) delta=params.huber_loss_delta, reduction=tf.keras.losses.Reduction.SUM)
...@@ -171,35 +176,32 @@ class RpnBoxLoss(object): ...@@ -171,35 +176,32 @@ class RpnBoxLoss(object):
box_losses = [] box_losses = []
for level in levels: for level in levels:
box_losses.append( box_losses.append(self._rpn_box_loss(box_outputs[level], labels[level]))
self._rpn_box_loss(
box_outputs[level], labels[level], delta=self._delta))
# Sum per level losses to total loss. # Sum per level losses to total loss.
return tf.add_n(box_losses) return tf.add_n(box_losses)
def _rpn_box_loss(self, box_outputs, box_targets, normalizer=1.0, delta=1./9): def _rpn_box_loss(self, box_outputs, box_targets, normalizer=1.0):
"""Computes box regression loss.""" """Computes box regression loss."""
# The delta is typically around the mean value of regression target.
# for instances, the regression targets of 512x512 input with 6 anchors on
# P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
with tf.name_scope('rpn_box_loss'): with tf.name_scope('rpn_box_loss'):
mask = tf.math.not_equal(box_targets, 0.0) mask = tf.cast(tf.not_equal(box_targets, 0.0), dtype=tf.float32)
# The loss is normalized by the sum of non-zero weights before additional box_targets = tf.expand_dims(box_targets, axis=-1)
# normalizer provided by the function caller. box_outputs = tf.expand_dims(box_outputs, axis=-1)
box_loss = tf.compat.v1.losses.huber_loss( box_loss = self._huber_loss(box_targets, box_outputs, sample_weight=mask)
box_targets, # The loss is normalized by the sum of non-zero weights and additional
box_outputs, # normalizer provided by the function caller. Using + 0.01 here to avoid
weights=mask, # division by zero.
delta=delta, box_loss /= normalizer * (tf.reduce_sum(mask) + 0.01)
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS)
box_loss /= normalizer
return box_loss return box_loss
class FastrcnnClassLoss(object): class FastrcnnClassLoss(object):
"""Fast R-CNN classification loss function.""" """Fast R-CNN classification loss function."""
def __init__(self):
self._categorical_crossentropy = tf.keras.losses.CategoricalCrossentropy(
reduction=tf.keras.losses.Reduction.SUM, from_logits=True)
def __call__(self, class_outputs, class_targets): def __call__(self, class_outputs, class_targets):
"""Computes the class loss (Fast-RCNN branch) of Mask-RCNN. """Computes the class loss (Fast-RCNN branch) of Mask-RCNN.
...@@ -218,24 +220,19 @@ class FastrcnnClassLoss(object): ...@@ -218,24 +220,19 @@ class FastrcnnClassLoss(object):
a scalar tensor representing total class loss. a scalar tensor representing total class loss.
""" """
with tf.name_scope('fast_rcnn_loss'): with tf.name_scope('fast_rcnn_loss'):
_, _, num_classes = class_outputs.get_shape().as_list() batch_size, num_boxes, num_classes = class_outputs.get_shape().as_list()
class_targets = tf.cast(class_targets, dtype=tf.int32) class_targets = tf.cast(class_targets, dtype=tf.int32)
class_targets_one_hot = tf.one_hot(class_targets, num_classes) class_targets_one_hot = tf.one_hot(class_targets, num_classes)
return self._fast_rcnn_class_loss(class_outputs, class_targets_one_hot) return self._fast_rcnn_class_loss(class_outputs, class_targets_one_hot,
normalizer=batch_size * num_boxes / 2.0)
def _fast_rcnn_class_loss(self, class_outputs, class_targets_one_hot, def _fast_rcnn_class_loss(self, class_outputs, class_targets_one_hot,
normalizer=1.0): normalizer):
"""Computes classification loss.""" """Computes classification loss."""
with tf.name_scope('fast_rcnn_class_loss'): with tf.name_scope('fast_rcnn_class_loss'):
# The loss is normalized by the sum of non-zero weights before additional class_loss = self._categorical_crossentropy(class_targets_one_hot,
# normalizer provided by the function caller. class_outputs)
# Keep the compat.v1 loss because Keras does not have a
# softmax_cross_entropy substitution yet.
# TODO(b/143720144): replace this loss.
class_loss = tf.compat.v1.losses.softmax_cross_entropy(
class_targets_one_hot,
class_outputs,
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS)
class_loss /= normalizer class_loss /= normalizer
return class_loss return class_loss
...@@ -244,7 +241,12 @@ class FastrcnnBoxLoss(object): ...@@ -244,7 +241,12 @@ class FastrcnnBoxLoss(object):
"""Fast R-CNN box regression loss function.""" """Fast R-CNN box regression loss function."""
def __init__(self, params): def __init__(self, params):
self._delta = params.huber_loss_delta logging.info('FastrcnnBoxLoss huber_loss_delta %s', params.huber_loss_delta)
# The delta is typically around the mean value of regression target.
# for instances, the regression targets of 512x512 input with 6 anchors on
# P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
self._huber_loss = tf.keras.losses.Huber(
delta=params.huber_loss_delta, reduction=tf.keras.losses.Reduction.SUM)
def __call__(self, box_outputs, class_targets, box_targets): def __call__(self, box_outputs, class_targets, box_targets):
"""Computes the box loss (Fast-RCNN branch) of Mask-RCNN. """Computes the box loss (Fast-RCNN branch) of Mask-RCNN.
...@@ -296,36 +298,32 @@ class FastrcnnBoxLoss(object): ...@@ -296,36 +298,32 @@ class FastrcnnBoxLoss(object):
dtype=box_outputs.dtype), tf.reshape(box_outputs, [-1, 4])) dtype=box_outputs.dtype), tf.reshape(box_outputs, [-1, 4]))
box_outputs = tf.reshape(box_outputs, [batch_size, -1, 4]) box_outputs = tf.reshape(box_outputs, [batch_size, -1, 4])
return self._fast_rcnn_box_loss(box_outputs, box_targets, class_targets, return self._fast_rcnn_box_loss(box_outputs, box_targets, class_targets)
delta=self._delta)
def _fast_rcnn_box_loss(self, box_outputs, box_targets, class_targets, def _fast_rcnn_box_loss(self, box_outputs, box_targets, class_targets,
normalizer=1.0, delta=1.): normalizer=1.0):
"""Computes box regression loss.""" """Computes box regression loss."""
# The delta is typically around the mean value of regression target.
# for instances, the regression targets of 512x512 input with 6 anchors on
# P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
with tf.name_scope('fast_rcnn_box_loss'): with tf.name_scope('fast_rcnn_box_loss'):
mask = tf.tile(tf.expand_dims(tf.greater(class_targets, 0), axis=2), mask = tf.tile(tf.expand_dims(tf.greater(class_targets, 0), axis=2),
[1, 1, 4]) [1, 1, 4])
# The loss is normalized by the sum of non-zero weights before additional mask = tf.cast(mask, dtype=tf.float32)
# normalizer provided by the function caller. box_targets = tf.expand_dims(box_targets, axis=-1)
# Keep the compat.v1 loss because Keras does not have a box_outputs = tf.expand_dims(box_outputs, axis=-1)
# Reduction.SUM_BY_NONZERO_WEIGHTS substitution yet. box_loss = self._huber_loss(box_targets, box_outputs, sample_weight=mask)
# TODO(b/143720144): replace this loss. # The loss is normalized by the number of ones in mask,
box_loss = tf.compat.v1.losses.huber_loss( # additianal normalizer provided by the user and using 0.01 here to avoid
box_targets, # division by 0.
box_outputs, box_loss /= normalizer * (tf.reduce_sum(mask) + 0.01)
weights=mask,
delta=delta,
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS)
box_loss /= normalizer
return box_loss return box_loss
class MaskrcnnLoss(object): class MaskrcnnLoss(object):
"""Mask R-CNN instance segmentation mask loss function.""" """Mask R-CNN instance segmentation mask loss function."""
def __init__(self):
self._binary_crossentropy = tf.keras.losses.BinaryCrossentropy(
reduction=tf.keras.losses.Reduction.SUM, from_logits=True)
def __call__(self, mask_outputs, mask_targets, select_class_targets): def __call__(self, mask_outputs, mask_targets, select_class_targets):
"""Computes the mask loss of Mask-RCNN. """Computes the mask loss of Mask-RCNN.
...@@ -358,11 +356,16 @@ class MaskrcnnLoss(object): ...@@ -358,11 +356,16 @@ class MaskrcnnLoss(object):
tf.reshape(tf.greater(select_class_targets, 0), tf.reshape(tf.greater(select_class_targets, 0),
[batch_size, num_masks, 1, 1]), [batch_size, num_masks, 1, 1]),
[1, 1, mask_height, mask_width]) [1, 1, mask_height, mask_width])
return tf.compat.v1.losses.sigmoid_cross_entropy( weights = tf.cast(weights, dtype=tf.float32)
mask_targets,
mask_outputs, mask_targets = tf.expand_dims(mask_targets, axis=-1)
weights=weights, mask_outputs = tf.expand_dims(mask_outputs, axis=-1)
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS) mask_loss = self._binary_crossentropy(mask_targets, mask_outputs,
sample_weight=weights)
# The loss is normalized by the number of 1's in weights and
# + 0.01 is used to avoid division by zero.
return mask_loss / (tf.reduce_sum(weights) + 0.01)
class RetinanetClassLoss(object): class RetinanetClassLoss(object):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment