Commit afd5579f authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge remote-tracking branch 'upstream/master' into context_tf2

parents dcd96e02 567bd18d
This diff is collapsed.
...@@ -14,33 +14,19 @@ ...@@ -14,33 +14,19 @@
# ============================================================================== # ==============================================================================
"""Test beam search helper methods.""" """Test beam search helper methods."""
import tensorflow.compat.v1 as tf import tensorflow as tf
from official.nlp.transformer import beam_search_v1 as beam_search from official.nlp.modeling.ops import beam_search
class BeamSearchHelperTests(tf.test.TestCase): class BeamSearchHelperTests(tf.test.TestCase):
def setUp(self):
super(BeamSearchHelperTests, self).setUp()
tf.compat.v1.disable_eager_execution()
def test_expand_to_beam_size(self): def test_expand_to_beam_size(self):
x = tf.ones([7, 4, 2, 5]) x = tf.ones([7, 4, 2, 5])
x = beam_search._expand_to_beam_size(x, 3) x = beam_search._expand_to_beam_size(x, 3)
with self.session() as sess: shape = tf.shape(x)
shape = sess.run(tf.shape(x))
self.assertAllEqual([7, 3, 4, 2, 5], shape) self.assertAllEqual([7, 3, 4, 2, 5], shape)
def test_shape_list(self):
y = tf.compat.v1.placeholder(dtype=tf.int32, shape=[])
x = tf.ones([7, y, 2, 5])
shape = beam_search._shape_list(x)
self.assertIsInstance(shape[0], int)
self.assertIsInstance(shape[1], tf.Tensor)
self.assertIsInstance(shape[2], int)
self.assertIsInstance(shape[3], int)
def test_get_shape_keep_last_dim(self): def test_get_shape_keep_last_dim(self):
y = tf.constant(4.0) y = tf.constant(4.0)
x = tf.ones([7, tf.cast(tf.sqrt(y), tf.int32), 2, 5]) x = tf.ones([7, tf.cast(tf.sqrt(y), tf.int32), 2, 5])
...@@ -51,16 +37,12 @@ class BeamSearchHelperTests(tf.test.TestCase): ...@@ -51,16 +37,12 @@ class BeamSearchHelperTests(tf.test.TestCase):
def test_flatten_beam_dim(self): def test_flatten_beam_dim(self):
x = tf.ones([7, 4, 2, 5]) x = tf.ones([7, 4, 2, 5])
x = beam_search._flatten_beam_dim(x) x = beam_search._flatten_beam_dim(x)
with self.session() as sess: self.assertAllEqual([28, 2, 5], tf.shape(x))
shape = sess.run(tf.shape(x))
self.assertAllEqual([28, 2, 5], shape)
def test_unflatten_beam_dim(self): def test_unflatten_beam_dim(self):
x = tf.ones([28, 2, 5]) x = tf.ones([28, 2, 5])
x = beam_search._unflatten_beam_dim(x, 7, 4) x = beam_search._unflatten_beam_dim(x, 7, 4)
with self.session() as sess: self.assertAllEqual([7, 4, 2, 5], tf.shape(x))
shape = sess.run(tf.shape(x))
self.assertAllEqual([7, 4, 2, 5], shape)
def test_gather_beams(self): def test_gather_beams(self):
x = tf.reshape(tf.range(24), [2, 3, 4]) x = tf.reshape(tf.range(24), [2, 3, 4])
...@@ -73,9 +55,6 @@ class BeamSearchHelperTests(tf.test.TestCase): ...@@ -73,9 +55,6 @@ class BeamSearchHelperTests(tf.test.TestCase):
# [20 21 22 23]]] # [20 21 22 23]]]
y = beam_search._gather_beams(x, [[1, 2], [0, 2]], 2, 2) y = beam_search._gather_beams(x, [[1, 2], [0, 2]], 2, 2)
with self.session() as sess:
y = sess.run(y)
self.assertAllEqual([[[4, 5, 6, 7], self.assertAllEqual([[[4, 5, 6, 7],
[8, 9, 10, 11]], [8, 9, 10, 11]],
[[12, 13, 14, 15], [[12, 13, 14, 15],
...@@ -87,9 +66,6 @@ class BeamSearchHelperTests(tf.test.TestCase): ...@@ -87,9 +66,6 @@ class BeamSearchHelperTests(tf.test.TestCase):
x_scores = [[0, 1, 1], [1, 0, 1]] x_scores = [[0, 1, 1], [1, 0, 1]]
y = beam_search._gather_topk_beams(x, x_scores, 2, 2) y = beam_search._gather_topk_beams(x, x_scores, 2, 2)
with self.session() as sess:
y = sess.run(y)
self.assertAllEqual([[[4, 5, 6, 7], self.assertAllEqual([[[4, 5, 6, 7],
[8, 9, 10, 11]], [8, 9, 10, 11]],
[[12, 13, 14, 15], [[12, 13, 14, 15],
......
...@@ -31,7 +31,7 @@ from official.nlp.modeling.layers import multi_channel_attention ...@@ -31,7 +31,7 @@ from official.nlp.modeling.layers import multi_channel_attention
from official.nlp.nhnet import configs from official.nlp.nhnet import configs
from official.nlp.nhnet import decoder from official.nlp.nhnet import decoder
from official.nlp.nhnet import utils from official.nlp.nhnet import utils
from official.nlp.transformer import beam_search from official.nlp.modeling.ops import beam_search
def embedding_linear(embedding_matrix, x): def embedding_linear(embedding_matrix, x):
......
...@@ -40,8 +40,9 @@ class MaskedLMConfig(cfg.TaskConfig): ...@@ -40,8 +40,9 @@ class MaskedLMConfig(cfg.TaskConfig):
class MaskedLMTask(base_task.Task): class MaskedLMTask(base_task.Task):
"""Mock task object for testing.""" """Mock task object for testing."""
def build_model(self): def build_model(self, params=None):
return bert.instantiate_pretrainer_from_cfg(self.task_config.model) params = params or self.task_config.model
return bert.instantiate_pretrainer_from_cfg(params)
def build_losses(self, def build_losses(self,
labels, labels,
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Beam search in TF v2."""
import tensorflow as tf
from official.nlp.transformer import beam_search_v1 as v1
_StateKeys = v1._StateKeys # pylint: disable=protected-access
class SequenceBeamSearchV2(v1.SequenceBeamSearch):
"""Implementation of beam search loop in v2."""
def search(self, initial_ids, initial_cache):
"""Beam search for sequences with highest scores."""
state, state_shapes = self._create_initial_state(initial_ids, initial_cache)
finished_state = tf.nest.map_structure(
tf.stop_gradient,
tf.while_loop(self._continue_search,
self._search_step,
loop_vars=[state],
shape_invariants=[state_shapes],
parallel_iterations=1))
finished_state = finished_state[0]
alive_seq = finished_state[_StateKeys.ALIVE_SEQ]
alive_log_probs = finished_state[_StateKeys.ALIVE_LOG_PROBS]
finished_seq = finished_state[_StateKeys.FINISHED_SEQ]
finished_scores = finished_state[_StateKeys.FINISHED_SCORES]
finished_flags = finished_state[_StateKeys.FINISHED_FLAGS]
# 2.0 changes tf.where behavior. Should make parameters broadcastable.
finished_cond = tf.reduce_any(finished_flags, 1, name="finished_cond")
seq_cond = _expand_to_same_rank(finished_cond, finished_seq)
score_cond = _expand_to_same_rank(finished_cond, finished_scores)
# Account for corner case where there are no finished sequences for a
# particular batch item. In that case, return alive sequences for that batch
# item.
finished_seq = tf.where(seq_cond, finished_seq, alive_seq)
finished_scores = tf.where(
score_cond, finished_scores, alive_log_probs)
return finished_seq, finished_scores
def sequence_beam_search(symbols_to_logits_fn,
initial_ids,
initial_cache,
vocab_size,
beam_size,
alpha,
max_decode_length,
eos_id,
padded_decode=False,
dtype="float32"):
"""Search for sequence of subtoken ids with the largest probability.
Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape:
ids -> A tensor with shape [batch_size * beam_size, index].
index -> A scalar.
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and new cache:
logits -> A tensor with shape [batch * beam_size, vocab_size].
new cache -> A nested dictionary with the same shape/structure as the
inputted cache.
initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
each batch item.
initial_cache: A dictionary, containing starting decoder variables
information.
vocab_size: An integer, the size of tokens.
beam_size: An integer, the number of beams.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
"""
batch_size = (
initial_ids.shape.as_list()[0] if padded_decode else
tf.shape(initial_ids)[0])
sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id,
padded_decode, dtype)
return sbs.search(initial_ids, initial_cache)
def _expand_to_same_rank(tensor, target):
"""Expands a given tensor to target's rank to be broadcastable.
Args:
tensor: input tensor to tile. Shape: [b, d1, ..., da]
target: target tensor. Shape: [b, d1, ..., da, ..., dn]
Returns:
Tiled tensor of shape [b, d1, ..., da, 1, ..., 1] with same rank of target.
Raises:
ValueError, if the shape rank of rank tensor/target is None.
"""
if tensor.shape.rank is None:
raise ValueError("Expect rank for tensor shape, but got None.")
if target.shape.rank is None:
raise ValueError("Expect rank for target shape, but got None.")
with tf.name_scope("expand_rank"):
diff_rank = target.shape.rank - tensor.shape.rank
for _ in range(diff_rank):
tensor = tf.expand_dims(tensor, -1)
return tensor
This diff is collapsed.
...@@ -43,6 +43,7 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer): ...@@ -43,6 +43,7 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
self.shared_weights = self.add_weight( self.shared_weights = self.add_weight(
"weights", "weights",
shape=[self.vocab_size, self.hidden_size], shape=[self.vocab_size, self.hidden_size],
dtype=tf.float32,
initializer=tf.random_normal_initializer( initializer=tf.random_normal_initializer(
mean=0., stddev=self.hidden_size**-0.5)) mean=0., stddev=self.hidden_size**-0.5))
super(EmbeddingSharedWeights, self).build(input_shape) super(EmbeddingSharedWeights, self).build(input_shape)
......
...@@ -23,8 +23,8 @@ from __future__ import print_function ...@@ -23,8 +23,8 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling.layers import position_embedding from official.nlp.modeling.layers import position_embedding
from official.nlp.modeling.ops import beam_search
from official.nlp.transformer import attention_layer from official.nlp.transformer import attention_layer
from official.nlp.transformer import beam_search
from official.nlp.transformer import embedding_layer from official.nlp.transformer import embedding_layer
from official.nlp.transformer import ffn_layer from official.nlp.transformer import ffn_layer
from official.nlp.transformer import metrics from official.nlp.transformer import metrics
......
...@@ -94,7 +94,7 @@ def parse_flags(flags_obj): ...@@ -94,7 +94,7 @@ def parse_flags(flags_obj):
"beta2": flags_obj.beta2, "beta2": flags_obj.beta2,
"epsilon": flags_obj.epsilon, "epsilon": flags_obj.epsilon,
"match_mlperf": flags_obj.ml_perf, "match_mlperf": flags_obj.ml_perf,
"epochs_between_evals": FLAGS.epochs_between_evals, "epochs_between_evals": flags_obj.epochs_between_evals,
"keras_use_ctl": flags_obj.keras_use_ctl, "keras_use_ctl": flags_obj.keras_use_ctl,
"hr_threshold": flags_obj.hr_threshold, "hr_threshold": flags_obj.hr_threshold,
"stream_files": flags_obj.tpu is not None, "stream_files": flags_obj.tpu is not None,
......
...@@ -25,10 +25,8 @@ import tensorflow.compat.v2 as tf ...@@ -25,10 +25,8 @@ import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
from official.recommendation import movielens
from official.recommendation import data_pipeline from official.recommendation import data_pipeline
from official.recommendation import movielens
NUM_SHARDS = 16
def create_dataset_from_tf_record_files(input_file_pattern, def create_dataset_from_tf_record_files(input_file_pattern,
...@@ -36,17 +34,15 @@ def create_dataset_from_tf_record_files(input_file_pattern, ...@@ -36,17 +34,15 @@ def create_dataset_from_tf_record_files(input_file_pattern,
batch_size, batch_size,
is_training=True): is_training=True):
"""Creates dataset from (tf)records files for training/evaluation.""" """Creates dataset from (tf)records files for training/evaluation."""
files = tf.data.Dataset.list_files(input_file_pattern, shuffle=is_training)
def make_dataset(files_dataset, shard_index):
"""Returns dataset for sharded tf record files."""
if pre_batch_size != batch_size: if pre_batch_size != batch_size:
raise ValueError("Pre-batch ({}) size is not equal to batch " raise ValueError("Pre-batch ({}) size is not equal to batch "
"size ({})".format(pre_batch_size, batch_size)) "size ({})".format(pre_batch_size, batch_size))
files_dataset = files_dataset.shard(NUM_SHARDS, shard_index)
dataset = files_dataset.interleave( files = tf.data.Dataset.list_files(input_file_pattern, shuffle=is_training)
dataset = files.interleave(
tf.data.TFRecordDataset, tf.data.TFRecordDataset,
cycle_length=16,
num_parallel_calls=tf.data.experimental.AUTOTUNE) num_parallel_calls=tf.data.experimental.AUTOTUNE)
decode_fn = functools.partial( decode_fn = functools.partial(
data_pipeline.DatasetManager.deserialize, data_pipeline.DatasetManager.deserialize,
...@@ -54,14 +50,7 @@ def create_dataset_from_tf_record_files(input_file_pattern, ...@@ -54,14 +50,7 @@ def create_dataset_from_tf_record_files(input_file_pattern,
is_training=is_training) is_training=is_training)
dataset = dataset.map( dataset = dataset.map(
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
dataset = tf.data.Dataset.range(NUM_SHARDS)
map_fn = functools.partial(make_dataset, files)
dataset = dataset.interleave(
map_fn,
cycle_length=NUM_SHARDS,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset return dataset
......
...@@ -28,7 +28,7 @@ import functools ...@@ -28,7 +28,7 @@ import functools
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import backend from official.vision.detection.modeling.architecture import keras_utils
from official.vision.detection.modeling.architecture import nn_ops from official.vision.detection.modeling.architecture import nn_ops
from official.vision.detection.ops import spatial_transform_ops from official.vision.detection.ops import spatial_transform_ops
...@@ -120,7 +120,7 @@ class Fpn(object): ...@@ -120,7 +120,7 @@ class Fpn(object):
'The minimum backbone level %d should be '%(min(input_levels)) + 'The minimum backbone level %d should be '%(min(input_levels)) +
'less or equal to FPN minimum level %d.:'%(self._min_level)) 'less or equal to FPN minimum level %d.:'%(self._min_level))
backbone_max_level = min(max(input_levels), self._max_level) backbone_max_level = min(max(input_levels), self._max_level)
with backend.get_graph().as_default(), tf.name_scope('fpn'): with keras_utils.maybe_enter_backend_graph(), tf.name_scope('fpn'):
# Adds lateral connections. # Adds lateral connections.
feats_lateral = {} feats_lateral = {}
for level in range(self._min_level, backbone_max_level + 1): for level in range(self._min_level, backbone_max_level + 1):
......
...@@ -22,7 +22,8 @@ import functools ...@@ -22,7 +22,8 @@ import functools
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import backend
from official.vision.detection.modeling.architecture import keras_utils
from official.vision.detection.modeling.architecture import nn_ops from official.vision.detection.modeling.architecture import nn_ops
from official.vision.detection.ops import spatial_transform_ops from official.vision.detection.ops import spatial_transform_ops
...@@ -127,7 +128,7 @@ class RpnHead(tf.keras.layers.Layer): ...@@ -127,7 +128,7 @@ class RpnHead(tf.keras.layers.Layer):
scores_outputs = {} scores_outputs = {}
box_outputs = {} box_outputs = {}
with backend.get_graph().as_default(), tf.name_scope('rpn_head'): with keras_utils.maybe_enter_backend_graph(), tf.name_scope('rpn_head'):
for level in range(self._min_level, self._max_level + 1): for level in range(self._min_level, self._max_level + 1):
scores_output, box_output = self._shared_rpn_heads( scores_output, box_output = self._shared_rpn_heads(
features[level], self._anchors_per_location, level, is_training) features[level], self._anchors_per_location, level, is_training)
...@@ -249,7 +250,8 @@ class FastrcnnHead(tf.keras.layers.Layer): ...@@ -249,7 +250,8 @@ class FastrcnnHead(tf.keras.layers.Layer):
predictions. predictions.
""" """
with backend.get_graph().as_default(), tf.name_scope('fast_rcnn_head'): with keras_utils.maybe_enter_backend_graph(), tf.name_scope(
'fast_rcnn_head'):
# reshape inputs beofre FC. # reshape inputs beofre FC.
_, num_rois, height, width, filters = roi_features.get_shape().as_list() _, num_rois, height, width, filters = roi_features.get_shape().as_list()
...@@ -368,7 +370,7 @@ class MaskrcnnHead(tf.keras.layers.Layer): ...@@ -368,7 +370,7 @@ class MaskrcnnHead(tf.keras.layers.Layer):
boxes is not 4. boxes is not 4.
""" """
with backend.get_graph().as_default(): with keras_utils.maybe_enter_backend_graph():
with tf.name_scope('mask_head'): with tf.name_scope('mask_head'):
_, num_rois, height, width, filters = roi_features.get_shape().as_list() _, num_rois, height, width, filters = roi_features.get_shape().as_list()
net = tf.reshape(roi_features, [-1, height, width, filters]) net = tf.reshape(roi_features, [-1, height, width, filters])
...@@ -552,7 +554,8 @@ class RetinanetHead(object): ...@@ -552,7 +554,8 @@ class RetinanetHead(object):
"""Returns outputs of RetinaNet head.""" """Returns outputs of RetinaNet head."""
class_outputs = {} class_outputs = {}
box_outputs = {} box_outputs = {}
with backend.get_graph().as_default(), tf.name_scope('retinanet_head'): with keras_utils.maybe_enter_backend_graph(), tf.name_scope(
'retinanet_head'):
for level in range(self._min_level, self._max_level + 1): for level in range(self._min_level, self._max_level + 1):
features = fpn_features[level] features = fpn_features[level]
...@@ -644,7 +647,7 @@ class ShapemaskPriorHead(object): ...@@ -644,7 +647,7 @@ class ShapemaskPriorHead(object):
detection_priors: A float Tensor of shape [batch_size * num_instances, detection_priors: A float Tensor of shape [batch_size * num_instances,
mask_size, mask_size, 1]. mask_size, mask_size, 1].
""" """
with backend.get_graph().as_default(), tf.name_scope('prior_mask'): with keras_utils.maybe_enter_backend_graph(), tf.name_scope('prior_mask'):
batch_size, num_instances, _ = boxes.get_shape().as_list() batch_size, num_instances, _ = boxes.get_shape().as_list()
outer_boxes = tf.cast(outer_boxes, tf.float32) outer_boxes = tf.cast(outer_boxes, tf.float32)
boxes = tf.cast(boxes, tf.float32) boxes = tf.cast(boxes, tf.float32)
...@@ -807,7 +810,7 @@ class ShapemaskCoarsemaskHead(object): ...@@ -807,7 +810,7 @@ class ShapemaskCoarsemaskHead(object):
mask_outputs: instance mask prediction as a float Tensor of shape mask_outputs: instance mask prediction as a float Tensor of shape
[batch_size, num_instances, mask_size, mask_size]. [batch_size, num_instances, mask_size, mask_size].
""" """
with backend.get_graph().as_default(), tf.name_scope('coarse_mask'): with keras_utils.maybe_enter_backend_graph(), tf.name_scope('coarse_mask'):
# Transform detection priors to have the same dimension as features. # Transform detection priors to have the same dimension as features.
detection_priors = tf.expand_dims(detection_priors, axis=-1) detection_priors = tf.expand_dims(detection_priors, axis=-1)
detection_priors = self._coarse_mask_fc(detection_priors) detection_priors = self._coarse_mask_fc(detection_priors)
...@@ -939,7 +942,7 @@ class ShapemaskFinemaskHead(object): ...@@ -939,7 +942,7 @@ class ShapemaskFinemaskHead(object):
""" """
# Extract the foreground mean features # Extract the foreground mean features
# with tf.variable_scope('fine_mask', reuse=tf.AUTO_REUSE): # with tf.variable_scope('fine_mask', reuse=tf.AUTO_REUSE):
with backend.get_graph().as_default(), tf.name_scope('fine_mask'): with keras_utils.maybe_enter_backend_graph(), tf.name_scope('fine_mask'):
mask_probs = tf.nn.sigmoid(mask_logits) mask_probs = tf.nn.sigmoid(mask_logits)
# Compute instance embedding for hard average. # Compute instance embedding for hard average.
binary_mask = tf.cast(tf.greater(mask_probs, 0.5), features.dtype) binary_mask = tf.cast(tf.greater(mask_probs, 0.5), features.dtype)
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Util functions to integrate with Keras internals."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras import backend
try:
from tensorflow.python.keras.engine import keras_tensor # pylint: disable=g-import-not-at-top,unused-import
except ImportError:
keras_tensor = None
class NoOpContextManager(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
def maybe_enter_backend_graph():
if (keras_tensor is not None) and keras_tensor.keras_tensors_enabled():
return NoOpContextManager()
else:
return backend.get_graph().as_default()
...@@ -25,7 +25,7 @@ from __future__ import print_function ...@@ -25,7 +25,7 @@ from __future__ import print_function
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import backend from official.vision.detection.modeling.architecture import keras_utils
from official.vision.detection.modeling.architecture import nn_ops from official.vision.detection.modeling.architecture import nn_ops
# TODO(b/140112644): Refactor the code with Keras style, i.e. build and call. # TODO(b/140112644): Refactor the code with Keras style, i.e. build and call.
...@@ -90,7 +90,7 @@ class Resnet(object): ...@@ -90,7 +90,7 @@ class Resnet(object):
The values are corresponding feature hierarchy in ResNet with shape The values are corresponding feature hierarchy in ResNet with shape
[batch_size, height_l, width_l, num_filters]. [batch_size, height_l, width_l, num_filters].
""" """
with backend.get_graph().as_default(): with keras_utils.maybe_enter_backend_graph():
with tf.name_scope('resnet%s' % self._resnet_depth): with tf.name_scope('resnet%s' % self._resnet_depth):
return self._resnet_fn(inputs, is_training) return self._resnet_fn(inputs, is_training)
......
...@@ -24,8 +24,8 @@ import math ...@@ -24,8 +24,8 @@ import math
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import backend
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.detection.modeling.architecture import keras_utils
from official.vision.detection.modeling.architecture import nn_blocks from official.vision.detection.modeling.architecture import nn_blocks
layers = tf.keras.layers layers = tf.keras.layers
...@@ -486,7 +486,7 @@ class SpineNetBuilder(object): ...@@ -486,7 +486,7 @@ class SpineNetBuilder(object):
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
def __call__(self, inputs, is_training=None): def __call__(self, inputs, is_training=None):
with backend.get_graph().as_default(): with keras_utils.maybe_enter_backend_graph():
model = SpineNet( model = SpineNet(
input_specs=self._input_specs, input_specs=self._input_specs,
min_level=self._min_level, min_level=self._min_level,
......
...@@ -20,13 +20,13 @@ from __future__ import print_function ...@@ -20,13 +20,13 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import backend
from official.vision.detection.dataloader import anchor from official.vision.detection.dataloader import anchor
from official.vision.detection.dataloader import mode_keys from official.vision.detection.dataloader import mode_keys
from official.vision.detection.evaluation import factory as eval_factory from official.vision.detection.evaluation import factory as eval_factory
from official.vision.detection.modeling import base_model from official.vision.detection.modeling import base_model
from official.vision.detection.modeling import losses from official.vision.detection.modeling import losses
from official.vision.detection.modeling.architecture import factory from official.vision.detection.modeling.architecture import factory
from official.vision.detection.modeling.architecture import keras_utils
from official.vision.detection.ops import postprocess_ops from official.vision.detection.ops import postprocess_ops
from official.vision.detection.ops import roi_ops from official.vision.detection.ops import roi_ops
from official.vision.detection.ops import spatial_transform_ops from official.vision.detection.ops import spatial_transform_ops
...@@ -297,7 +297,7 @@ class MaskrcnnModel(base_model.Model): ...@@ -297,7 +297,7 @@ class MaskrcnnModel(base_model.Model):
def build_model(self, params, mode): def build_model(self, params, mode):
if self._keras_model is None: if self._keras_model is None:
input_layers = self.build_input_layers(self._params, mode) input_layers = self.build_input_layers(self._params, mode)
with backend.get_graph().as_default(): with keras_utils.maybe_enter_backend_graph():
outputs = self.model_outputs(input_layers, mode) outputs = self.model_outputs(input_layers, mode)
model = tf.keras.models.Model( model = tf.keras.models.Model(
......
...@@ -20,12 +20,12 @@ from __future__ import print_function ...@@ -20,12 +20,12 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import backend
from official.vision.detection.dataloader import mode_keys from official.vision.detection.dataloader import mode_keys
from official.vision.detection.evaluation import factory as eval_factory from official.vision.detection.evaluation import factory as eval_factory
from official.vision.detection.modeling import base_model from official.vision.detection.modeling import base_model
from official.vision.detection.modeling import losses from official.vision.detection.modeling import losses
from official.vision.detection.modeling.architecture import factory from official.vision.detection.modeling.architecture import factory
from official.vision.detection.modeling.architecture import keras_utils
from official.vision.detection.ops import postprocess_ops from official.vision.detection.ops import postprocess_ops
...@@ -120,7 +120,7 @@ class RetinanetModel(base_model.Model): ...@@ -120,7 +120,7 @@ class RetinanetModel(base_model.Model):
def build_model(self, params, mode=None): def build_model(self, params, mode=None):
if self._keras_model is None: if self._keras_model is None:
with backend.get_graph().as_default(): with keras_utils.maybe_enter_backend_graph():
outputs = self.model_outputs(self._input_layer, mode) outputs = self.model_outputs(self._input_layer, mode)
model = tf.keras.models.Model( model = tf.keras.models.Model(
......
...@@ -20,13 +20,13 @@ from __future__ import print_function ...@@ -20,13 +20,13 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import backend
from official.vision.detection.dataloader import anchor from official.vision.detection.dataloader import anchor
from official.vision.detection.dataloader import mode_keys from official.vision.detection.dataloader import mode_keys
from official.vision.detection.evaluation import factory as eval_factory from official.vision.detection.evaluation import factory as eval_factory
from official.vision.detection.modeling import base_model from official.vision.detection.modeling import base_model
from official.vision.detection.modeling import losses from official.vision.detection.modeling import losses
from official.vision.detection.modeling.architecture import factory from official.vision.detection.modeling.architecture import factory
from official.vision.detection.modeling.architecture import keras_utils
from official.vision.detection.ops import postprocess_ops from official.vision.detection.ops import postprocess_ops
from official.vision.detection.utils import box_utils from official.vision.detection.utils import box_utils
...@@ -265,7 +265,7 @@ class ShapeMaskModel(base_model.Model): ...@@ -265,7 +265,7 @@ class ShapeMaskModel(base_model.Model):
def build_model(self, params, mode): def build_model(self, params, mode):
if self._keras_model is None: if self._keras_model is None:
input_layers = self.build_input_layers(self._params, mode) input_layers = self.build_input_layers(self._params, mode)
with backend.get_graph().as_default(): with keras_utils.maybe_enter_backend_graph():
outputs = self.model_outputs(input_layers, mode) outputs = self.model_outputs(input_layers, mode)
model = tf.keras.models.Model( model = tf.keras.models.Model(
......
...@@ -255,7 +255,7 @@ def define_keras_flags( ...@@ -255,7 +255,7 @@ def define_keras_flags(
name='tpu', default='', help='TPU address to connect to.') name='tpu', default='', help='TPU address to connect to.')
flags.DEFINE_integer( flags.DEFINE_integer(
name='steps_per_loop', name='steps_per_loop',
default=500, default=None,
help='Number of steps per training loop. Only training step happens ' help='Number of steps per training loop. Only training step happens '
'inside the loop. Callbacks will not be called inside. Will be capped at ' 'inside the loop. Callbacks will not be called inside. Will be capped at '
'steps per epoch.') 'steps per epoch.')
......
...@@ -125,7 +125,7 @@ def run(flags_obj): ...@@ -125,7 +125,7 @@ def run(flags_obj):
per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations( per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
flags_obj) flags_obj)
if not flags_obj.steps_per_loop: if flags_obj.steps_per_loop is None:
steps_per_loop = per_epoch_steps steps_per_loop = per_epoch_steps
elif flags_obj.steps_per_loop > per_epoch_steps: elif flags_obj.steps_per_loop > per_epoch_steps:
steps_per_loop = per_epoch_steps steps_per_loop = per_epoch_steps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment